1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
use std::collections::HashMap;
use itertools::Itertools;
use rustc_serialize::json;
use ngrams::ngrams;
use errors::DeserializeError;
pub struct Model {
pub ngram_ranks: HashMap<String, usize>,
}
impl Model {
pub fn build_from_text(text: &str) -> Self {
let mut ngram_counts = HashMap::new();
let words = text.split(|ch: char| !ch.is_alphabetic()).filter(|s| !s.is_empty());
for word in words {
for n in 1..6 {
for ngram in ngrams(word, n) {
if let Some(count) = ngram_counts.get_mut(ngram) {
*count += 1;
continue;
}
ngram_counts.insert(ngram.to_owned(), 1);
}
}
}
let ngrams = ngram_counts
.into_iter()
.sorted_by(|a, b| Ord::cmp(&b.1, &a.1))
.into_iter()
.take(300)
.map(|(ngram, _count)| ngram);
Model { ngram_ranks: ngrams.enumerate().map(|(a, b)| (b, a)).collect() }
}
pub fn deserialize(bytes: Vec<u8>) -> Result<Self, DeserializeError> {
let string = try!(String::from_utf8(bytes));
let ngram_ranks = try!(json::decode(string.as_str()));
let model = Model { ngram_ranks: ngram_ranks };
Ok(model)
}
pub fn serialize(&self) -> Vec<u8> {
json::encode(&self.ngram_ranks).unwrap().into_bytes()
}
pub fn compare(&self, other: &Model) -> usize {
let max_difference = other.ngram_ranks.len();
let mut difference = 0;
for (ngram, rank) in &self.ngram_ranks {
difference += match other.ngram_ranks.get(ngram) {
Some(other_rank) => get_difference(*rank, *other_rank),
None => max_difference,
}
}
difference
}
}
fn get_difference(a: usize, b: usize) -> usize {
if a > b { a - b } else { b - a }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialization_and_deserialization() {
let model = Model::build_from_text("Testing text for serialization");
let serialized = model.serialize();
let deserialized = Model::deserialize(serialized).unwrap();
assert_eq!(model.ngram_ranks, deserialized.ngram_ranks);
}
}