1use std::collections::{HashMap, HashSet};
2use std::io::{self, Cursor, Read, Seek, SeekFrom};
3
4#[derive(Debug)]
5pub struct Model {
6 tk_output: HashMap<u16, Vec<i32>>, nb_numfeats: usize,
8 tk_nextmove: Vec<u16>,
9 norm_probs: bool,
10 data: ModelData,
11 used_data: Option<ModelData>,
12}
13
14#[derive(Debug)]
15pub struct ModelData {
16 pub nb_classes: Vec<String>, pub nb_ptc: Vec<Vec<f32>>, pub nb_pc: Vec<f32>, }
20
21pub enum Error {
22 UnknownLanguageCode(String),
23 NoLanguage,
24}
25
26impl Model {
27 fn data(&self) -> &ModelData {
28 if let Some(data) = &self.used_data {
29 data
30 } else {
31 &self.data
32 }
33 }
34}
35
36impl Model {
37 pub fn set_langs(&mut self, langs: Option<HashSet<String>>) -> Result<(), Error> {
38 if let Some(langs) = langs {
39 if langs.len() < 2 {
40 return Err(Error::NoLanguage);
41 }
42 let unknown = langs
43 .iter()
44 .find(|lang| !self.data.nb_classes.contains(&lang));
45 if let Some(lang) = unknown {
46 return Err(Error::UnknownLanguageCode(lang.to_owned()));
47 }
48 let subset_mask = self
49 .data
50 .nb_classes
51 .iter()
52 .map(|s| langs.contains(s))
53 .collect::<Vec<_>>();
54 let nb_classes = self
55 .data
56 .nb_classes
57 .iter()
58 .filter(|s| langs.contains(*s))
59 .cloned()
60 .collect::<Vec<_>>();
61 let nb_pc = self
62 .data
63 .nb_pc
64 .iter()
65 .zip(&subset_mask)
66 .filter(|(_, m)| **m)
67 .map(|v| *v.0)
68 .collect::<Vec<_>>();
69
70 let nb_ptc = self
71 .data
72 .nb_ptc
73 .iter()
74 .map(|v| {
75 v.iter()
76 .zip(&subset_mask)
77 .filter(|(_, m)| **m)
78 .map(|v| *v.0)
79 .collect::<Vec<_>>()
80 })
81 .collect::<Vec<_>>();
82 self.used_data = Some(ModelData {
83 nb_classes,
84 nb_ptc,
85 nb_pc,
86 });
87 } else {
88 self.used_data = None;
89 }
90 Ok(())
91 }
92
93 fn compute_softmax(&self, pd: &[f32]) -> Vec<f32> {
94 pd.iter()
95 .map(|vi| 1.0 / pd.iter().map(|vj| (vj - vi).exp()).sum::<f32>())
96 .collect()
97 }
98
99 fn apply_norm_probs(&self, pd: Vec<f32>) -> Vec<f32> {
100 if self.norm_probs {
101 self.compute_softmax(&pd)
102 } else {
103 pd
104 }
105 }
106
107 pub fn rank(&self, text: &str) -> Vec<(&str, f32)> {
109 let fv: Vec<u16> = self.instance2fv(text);
110 let probs: Vec<f32> = self.apply_norm_probs(self.nb_classprobs(fv));
111 let mut class_probs: Vec<(&str, f32)> = self
112 .data()
113 .nb_classes
114 .iter()
115 .map(|class| class.as_str())
116 .zip(probs.into_iter())
117 .collect();
118
119 class_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120 class_probs
121 }
122
123 pub fn classify(&self, text: &str) -> Option<(&str, f32)> {
126 let fv = self.instance2fv(text);
127 let probs = self.apply_norm_probs(self.nb_classprobs(fv));
128 let (max_index, max_prob) = probs
129 .iter()
130 .enumerate()
131 .fold(None, |max, (idx, &prob)| match max {
132 None => Some((idx, prob)),
133 Some((max_idx, max_val)) => {
134 if prob > max_val {
135 Some((idx, prob))
136 } else {
137 Some((max_idx, max_val))
138 }
139 }
140 })
141 .unzip();
142 let (max_index, max_prob) = (max_index?, max_prob?);
143
144 self.data()
145 .nb_classes
146 .get(max_index)
147 .map(|class| (class.as_str(), max_prob))
148 }
149
150 fn nb_classprobs(&self, fv: Vec<u16>) -> Vec<f32> {
151 let n = self.data().nb_pc.len();
156 let mut pdc = vec![0f32; n];
157
158 for (i, fv_val) in fv.into_iter().enumerate() {
159 let fv_val = fv_val as f32;
160 for j in 0..n {
161 pdc[j] += fv_val * self.data().nb_ptc[i][j];
162 }
163 }
164
165 for j in 0..n {
167 pdc[j] += self.data().nb_pc[j];
168 }
169
170 pdc
171 }
172
173 fn instance2fv(&self, text: &str) -> Vec<u16> {
174 let indexes = text
175 .as_bytes()
176 .iter()
177 .fold((0usize, Vec::new()), |(state, mut acc), &letter| {
178 let new_state = self.tk_nextmove[(state << 8) + letter as usize];
179 let new_state_u = new_state as usize;
180 let output = self.tk_output.get(&new_state).cloned().unwrap_or_default();
181 acc.extend(output);
182 (new_state_u, acc)
183 })
184 .1;
185 let mut arr = vec![0; self.nb_numfeats];
186
187 let counts = counter(&indexes);
188
189 for (index, value) in counts {
190 arr[index as usize] = value;
191 }
192 arr
193 }
194}
195
196fn counter(indexes: &Vec<i32>) -> HashMap<i32, u16> {
197 let mut counts = HashMap::new();
198
199 for inner_vec in indexes {
200 *counts.entry(*inner_vec).or_insert(0) += 1;
201 }
202
203 counts
204}
205fn read_u32(reader: &mut impl Read) -> io::Result<u32> {
206 let mut buf = [0u8; 4];
207 reader.read_exact(&mut buf)?;
208 Ok(u32::from_le_bytes(buf))
209}
210
211fn read_f32_vec(reader: &mut impl Read, len: usize) -> io::Result<Vec<f32>> {
212 let mut buf = vec![0u8; len * 4];
213 reader.read_exact(&mut buf)?;
214 let floats = buf
215 .chunks_exact(4)
216 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
217 .collect();
218 Ok(floats)
219}
220
221fn read_u16_vec(reader: &mut impl Read, len: usize) -> io::Result<Vec<u16>> {
222 let mut buf = vec![0u8; len * 2];
223 reader.read_exact(&mut buf)?;
224 let floats = buf
225 .chunks_exact(2)
226 .map(|b| u16::from_le_bytes(b.try_into().unwrap()))
227 .collect();
228 Ok(floats)
229}
230
231fn read_string(reader: &mut impl Read) -> io::Result<String> {
232 let len = read_u32(reader)? as usize;
233 let mut buf = vec![0u8; len];
234 reader.read_exact(&mut buf)?;
235 Ok(String::from_utf8(buf).expect("Invalid UTF-8"))
236}
237
238fn read_i32_vec(reader: &mut impl Read, len: usize) -> io::Result<Vec<i32>> {
239 let mut buf = vec![0u8; len * 4];
240 reader.read_exact(&mut buf)?;
241 let vals = buf
242 .chunks_exact(4)
243 .map(|b| i32::from_le_bytes(b.try_into().unwrap()))
244 .collect();
245 Ok(vals)
246}
247
248impl Model {
249 pub fn load(norm_probs: bool) -> io::Result<Self> {
250 let mut reader = Cursor::new(include_bytes!("model.bin"));
251
252 let rows = read_u32(&mut reader)? as usize;
253 let cols = read_u32(&mut reader)? as usize;
254 let nb_ptc_flat = read_f32_vec(&mut reader, rows * cols)?;
255 let nb_ptc: Vec<Vec<f32>> = nb_ptc_flat
256 .chunks_exact(cols)
257 .map(|row| row.to_vec())
258 .collect();
259
260 let nb_pc_len = read_u32(&mut reader)? as usize;
261 let nb_pc = read_f32_vec(&mut reader, nb_pc_len)?;
262
263 let tk_nextmove_len = read_u32(&mut reader)? as usize;
264 let tk_nextmove = read_u16_vec(&mut reader, tk_nextmove_len)?;
265
266 let nb_class_count = read_u32(&mut reader)? as usize;
267 let mut nb_classes = Vec::with_capacity(nb_class_count);
268 for _ in 0..nb_class_count {
269 nb_classes.push(read_string(&mut reader)?);
270 }
271
272 let tk_output_count = read_u32(&mut reader)? as usize;
273 let mut tk_output = HashMap::with_capacity(tk_output_count);
274 for _ in 0..tk_output_count {
275 let key = read_u32(&mut reader)?;
276 let key = match u16::try_from(key) {
277 Ok(v) => v,
278 Err(_) => unreachable!("Key does not fit in u16"),
279 };
280 let val_len = read_u32(&mut reader)? as usize;
281 let val = read_i32_vec(&mut reader, val_len)?;
282 tk_output.insert(key, val);
283 }
284
285 let nb_numfeats = nb_ptc.iter().map(|v| v.len()).sum::<usize>() / nb_pc.len();
286 assert_eq!(bytes_remaining(&mut reader)?, 0);
287 Ok(Self {
288 norm_probs,
289 used_data: None,
290 nb_numfeats,
291 data: ModelData {
292 nb_classes,
293 nb_ptc,
294 nb_pc,
295 },
296 tk_nextmove,
297 tk_output,
298 })
299 }
300}
301
302fn bytes_remaining<R: Read + Seek>(reader: &mut R) -> std::io::Result<u64> {
303 let current = reader.seek(SeekFrom::Current(0))?;
304 let end = reader.seek(SeekFrom::End(0))?;
305 reader.seek(SeekFrom::Start(current))?; Ok(end - current)
307}