langid_rs/
lib.rs

1use std::collections::{HashMap, HashSet};
2use std::io::{self, Cursor, Read, Seek, SeekFrom};
3
4#[derive(Debug)]
5pub struct Model {
6    tk_output: HashMap<u16, Vec<i32>>, // Transition table
7    nb_numfeats: usize,
8    tk_nextmove: Vec<u16>,
9    norm_probs: bool,
10    data: ModelData,
11    used_data: Option<ModelData>,
12}
13
14#[derive(Debug)]
15pub struct ModelData {
16    pub nb_classes: Vec<String>, // List of class names
17    pub nb_ptc: Vec<Vec<f32>>,   // 2D array
18    pub nb_pc: Vec<f32>,         // 1D array
19}
20
21pub enum Error {
22    UnknownLanguageCode(String),
23    NoLanguage,
24}
25
26impl Model {
27    fn data(&self) -> &ModelData {
28        if let Some(data) = &self.used_data {
29            data
30        } else {
31            &self.data
32        }
33    }
34}
35
36impl Model {
37    pub fn set_langs(&mut self, langs: Option<HashSet<String>>) -> Result<(), Error> {
38        if let Some(langs) = langs {
39            if langs.len() < 2 {
40                return Err(Error::NoLanguage);
41            }
42            let unknown = langs
43                .iter()
44                .find(|lang| !self.data.nb_classes.contains(&lang));
45            if let Some(lang) = unknown {
46                return Err(Error::UnknownLanguageCode(lang.to_owned()));
47            }
48            let subset_mask = self
49                .data
50                .nb_classes
51                .iter()
52                .map(|s| langs.contains(s))
53                .collect::<Vec<_>>();
54            let nb_classes = self
55                .data
56                .nb_classes
57                .iter()
58                .filter(|s| langs.contains(*s))
59                .cloned()
60                .collect::<Vec<_>>();
61            let nb_pc = self
62                .data
63                .nb_pc
64                .iter()
65                .zip(&subset_mask)
66                .filter(|(_, m)| **m)
67                .map(|v| *v.0)
68                .collect::<Vec<_>>();
69
70            let nb_ptc = self
71                .data
72                .nb_ptc
73                .iter()
74                .map(|v| {
75                    v.iter()
76                        .zip(&subset_mask)
77                        .filter(|(_, m)| **m)
78                        .map(|v| *v.0)
79                        .collect::<Vec<_>>()
80                })
81                .collect::<Vec<_>>();
82            self.used_data = Some(ModelData {
83                nb_classes,
84                nb_ptc,
85                nb_pc,
86            });
87        } else {
88            self.used_data = None;
89        }
90        Ok(())
91    }
92
93    fn compute_softmax(&self, pd: &[f32]) -> Vec<f32> {
94        pd.iter()
95            .map(|vi| 1.0 / pd.iter().map(|vj| (vj - vi).exp()).sum::<f32>())
96            .collect()
97    }
98
99    fn apply_norm_probs(&self, pd: Vec<f32>) -> Vec<f32> {
100        if self.norm_probs {
101            self.compute_softmax(&pd)
102        } else {
103            pd
104        }
105    }
106
107    /// Return a list of languages in order of likelihood.
108    pub fn rank(&self, text: &str) -> Vec<(&str, f32)> {
109        let fv: Vec<u16> = self.instance2fv(text);
110        let probs: Vec<f32> = self.apply_norm_probs(self.nb_classprobs(fv));
111        let mut class_probs: Vec<(&str, f32)> = self
112            .data()
113            .nb_classes
114            .iter()
115            .map(|class| class.as_str())
116            .zip(probs.into_iter())
117            .collect();
118
119        class_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120        class_probs
121    }
122
123    /// Classify an instance.
124    /// Will return None if no language is detected.
125    pub fn classify(&self, text: &str) -> Option<(&str, f32)> {
126        let fv = self.instance2fv(text);
127        let probs = self.apply_norm_probs(self.nb_classprobs(fv));
128        let (max_index, max_prob) = probs
129            .iter()
130            .enumerate()
131            .fold(None, |max, (idx, &prob)| match max {
132                None => Some((idx, prob)),
133                Some((max_idx, max_val)) => {
134                    if prob > max_val {
135                        Some((idx, prob))
136                    } else {
137                        Some((max_idx, max_val))
138                    }
139                }
140            })
141            .unzip();
142        let (max_index, max_prob) = (max_index?, max_prob?);
143
144        self.data()
145            .nb_classes
146            .get(max_index)
147            .map(|class| (class.as_str(), max_prob))
148    }
149
150    fn nb_classprobs(&self, fv: Vec<u16>) -> Vec<f32> {
151        // let pdc = np.dot(fv, self.data.nb_ptc);
152        // pdc + self.data.nb_pc
153
154        // dot
155        let n = self.data().nb_pc.len();
156        let mut pdc = vec![0f32; n];
157
158        for (i, fv_val) in fv.into_iter().enumerate() {
159            let fv_val = fv_val as f32;
160            for j in 0..n {
161                pdc[j] += fv_val * self.data().nb_ptc[i][j];
162            }
163        }
164
165        // add
166        for j in 0..n {
167            pdc[j] += self.data().nb_pc[j];
168        }
169
170        pdc
171    }
172
173    fn instance2fv(&self, text: &str) -> Vec<u16> {
174        let indexes = text
175            .as_bytes()
176            .iter()
177            .fold((0usize, Vec::new()), |(state, mut acc), &letter| {
178                let new_state = self.tk_nextmove[(state << 8) + letter as usize];
179                let new_state_u = new_state as usize;
180                let output = self.tk_output.get(&new_state).cloned().unwrap_or_default();
181                acc.extend(output);
182                (new_state_u, acc)
183            })
184            .1;
185        let mut arr = vec![0; self.nb_numfeats];
186
187        let counts = counter(&indexes);
188
189        for (index, value) in counts {
190            arr[index as usize] = value;
191        }
192        arr
193    }
194}
195
196fn counter(indexes: &Vec<i32>) -> HashMap<i32, u16> {
197    let mut counts = HashMap::new();
198
199    for inner_vec in indexes {
200        *counts.entry(*inner_vec).or_insert(0) += 1;
201    }
202
203    counts
204}
205fn read_u32(reader: &mut impl Read) -> io::Result<u32> {
206    let mut buf = [0u8; 4];
207    reader.read_exact(&mut buf)?;
208    Ok(u32::from_le_bytes(buf))
209}
210
211fn read_f32_vec(reader: &mut impl Read, len: usize) -> io::Result<Vec<f32>> {
212    let mut buf = vec![0u8; len * 4];
213    reader.read_exact(&mut buf)?;
214    let floats = buf
215        .chunks_exact(4)
216        .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
217        .collect();
218    Ok(floats)
219}
220
221fn read_u16_vec(reader: &mut impl Read, len: usize) -> io::Result<Vec<u16>> {
222    let mut buf = vec![0u8; len * 2];
223    reader.read_exact(&mut buf)?;
224    let floats = buf
225        .chunks_exact(2)
226        .map(|b| u16::from_le_bytes(b.try_into().unwrap()))
227        .collect();
228    Ok(floats)
229}
230
231fn read_string(reader: &mut impl Read) -> io::Result<String> {
232    let len = read_u32(reader)? as usize;
233    let mut buf = vec![0u8; len];
234    reader.read_exact(&mut buf)?;
235    Ok(String::from_utf8(buf).expect("Invalid UTF-8"))
236}
237
238fn read_i32_vec(reader: &mut impl Read, len: usize) -> io::Result<Vec<i32>> {
239    let mut buf = vec![0u8; len * 4];
240    reader.read_exact(&mut buf)?;
241    let vals = buf
242        .chunks_exact(4)
243        .map(|b| i32::from_le_bytes(b.try_into().unwrap()))
244        .collect();
245    Ok(vals)
246}
247
248impl Model {
249    pub fn load(norm_probs: bool) -> io::Result<Self> {
250        let mut reader = Cursor::new(include_bytes!("model.bin"));
251
252        let rows = read_u32(&mut reader)? as usize;
253        let cols = read_u32(&mut reader)? as usize;
254        let nb_ptc_flat = read_f32_vec(&mut reader, rows * cols)?;
255        let nb_ptc: Vec<Vec<f32>> = nb_ptc_flat
256            .chunks_exact(cols)
257            .map(|row| row.to_vec())
258            .collect();
259
260        let nb_pc_len = read_u32(&mut reader)? as usize;
261        let nb_pc = read_f32_vec(&mut reader, nb_pc_len)?;
262
263        let tk_nextmove_len = read_u32(&mut reader)? as usize;
264        let tk_nextmove = read_u16_vec(&mut reader, tk_nextmove_len)?;
265
266        let nb_class_count = read_u32(&mut reader)? as usize;
267        let mut nb_classes = Vec::with_capacity(nb_class_count);
268        for _ in 0..nb_class_count {
269            nb_classes.push(read_string(&mut reader)?);
270        }
271
272        let tk_output_count = read_u32(&mut reader)? as usize;
273        let mut tk_output = HashMap::with_capacity(tk_output_count);
274        for _ in 0..tk_output_count {
275            let key = read_u32(&mut reader)?;
276            let key = match u16::try_from(key) {
277                Ok(v) => v,
278                Err(_) => unreachable!("Key does not fit in u16"),
279            };
280            let val_len = read_u32(&mut reader)? as usize;
281            let val = read_i32_vec(&mut reader, val_len)?;
282            tk_output.insert(key, val);
283        }
284
285        let nb_numfeats = nb_ptc.iter().map(|v| v.len()).sum::<usize>() / nb_pc.len();
286        assert_eq!(bytes_remaining(&mut reader)?, 0);
287        Ok(Self {
288            norm_probs,
289            used_data: None,
290            nb_numfeats,
291            data: ModelData {
292                nb_classes,
293                nb_ptc,
294                nb_pc,
295            },
296            tk_nextmove,
297            tk_output,
298        })
299    }
300}
301
302fn bytes_remaining<R: Read + Seek>(reader: &mut R) -> std::io::Result<u64> {
303    let current = reader.seek(SeekFrom::Current(0))?;
304    let end = reader.seek(SeekFrom::End(0))?;
305    reader.seek(SeekFrom::Start(current))?; // Restore position
306    Ok(end - current)
307}