use std::{
fs::File,
io::{self, BufReader, Cursor, Read},
path::Path,
};
fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
let mut buf = [0u8; 4];
r.read_exact(&mut buf)?;
Ok(u32::from_le_bytes(buf))
}
fn read_string<R: Read>(r: &mut R) -> io::Result<String> {
let len = read_u32(r)? as usize;
let mut str_buf = vec![0u8; len];
r.read_exact(&mut str_buf)?;
String::from_utf8(str_buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
fn read_f32_vec<R: Read>(r: &mut R, count: usize) -> io::Result<Vec<f32>> {
let mut vec = vec![0.0f32; count];
r.read_exact(bytemuck::cast_slice_mut(&mut vec))?;
Ok(vec)
}
fn read_u16_vec<R: Read>(r: &mut R, count: usize) -> io::Result<Vec<u16>> {
let mut vec = vec![0u16; count];
r.read_exact(bytemuck::cast_slice_mut(&mut vec))?;
Ok(vec)
}
#[derive(Debug)]
struct TkOutput {
offsets: Vec<(u32, u32)>, values: Vec<u16>,
}
impl TkOutput {
fn get(&self, i: u16) -> impl Iterator<Item = &u16> {
let (offset, length) = self.offsets[i as usize];
self.values.as_slice()[offset as usize..(offset + length) as usize].iter()
}
}
#[derive(Debug)]
struct FlatArray<T> {
cols: usize,
data: Vec<T>,
}
impl<T> FlatArray<T> {
fn new(rows: usize, cols: usize, data: Vec<T>) -> Self {
assert_eq!(rows * cols, data.len());
Self { cols, data }
}
fn get(&self, i: usize, j: usize) -> &T {
&self.data[i * self.cols + j]
}
}
#[derive(Debug)]
pub struct LanguageIdentifier {
nb_classes: Vec<String>,
nb_pc: Vec<f32>,
nb_ptc: FlatArray<f32>,
tk_nextmove: Vec<u16>,
tk_output: TkOutput,
num_feats: u32,
}
impl LanguageIdentifier {
pub fn from_reader<R: Read>(mut r: R) -> io::Result<Self> {
let mut magic = [0u8; 4];
r.read_exact(&mut magic)?;
if &magic != b"LANG" {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid magic bytes",
));
}
let num_classes = read_u32(&mut r)?;
let num_feats = read_u32(&mut r)?;
let nb_classes = (0..num_classes)
.map(|_| read_string(&mut r))
.collect::<io::Result<Vec<String>>>()?;
let nb_pc = read_f32_vec(&mut r, num_classes as usize)?;
let rows = read_u32(&mut r)? as usize;
let cols = read_u32(&mut r)? as usize;
let nb_ptc_data = read_f32_vec(&mut r, rows * cols)?;
let nb_ptc = FlatArray::new(rows, cols, nb_ptc_data);
let n_elem = read_u32(&mut r)? as usize;
let tk_nextmove = read_u16_vec(&mut r, n_elem)?;
let num_offsets = read_u32(&mut r)? as usize;
let offsets = (0..num_offsets)
.map(|_| {
let mut offset_buf = [0u8; 8];
r.read_exact(&mut offset_buf)?;
let offset = u32::from_le_bytes(offset_buf[0..4].try_into().unwrap());
let length = u32::from_le_bytes(offset_buf[4..8].try_into().unwrap());
Ok((offset, length))
})
.collect::<io::Result<Vec<(u32, u32)>>>()?;
let n_elem = read_u32(&mut r)? as usize;
let values = read_u16_vec(&mut r, n_elem)?;
let tk_output = TkOutput { offsets, values };
Ok(Self {
nb_classes,
nb_pc,
nb_ptc,
tk_nextmove,
tk_output,
num_feats,
})
}
pub fn from_lzma_bytes<R: Read>(bytes: R) -> io::Result<Self> {
let mut decompressed = vec![];
let mut compressed_file = BufReader::new(bytes);
lzma_rs::xz_decompress(&mut compressed_file, &mut decompressed)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Self::from_reader(Cursor::new(decompressed))
}
pub fn from_lzma_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
Self::from_lzma_bytes(File::open(path)?)
}
pub fn new() -> Self {
let model_bytes = include_bytes!("../resource/model.bin");
Self::from_lzma_bytes(Cursor::new(model_bytes)).unwrap()
}
fn instance2fv<I: Iterator<Item = u8>>(&self, text: I) -> Vec<u16> {
let mut state = 0;
let mut count = vec![0; self.num_feats as usize];
for letter in text {
state = self.tk_nextmove[((state as usize) << 8) + letter as usize];
for i in self.tk_output.get(state) {
count[*i as usize] += 1;
}
}
count
}
fn nb_classprobs(&self, fv: Vec<u16>) -> Vec<f32> {
let mut pdc = vec![0.0; self.nb_classes.len()];
for (i, v) in fv.into_iter().enumerate() {
if v > 0 {
for (j, w) in pdc.iter_mut().enumerate() {
*w += v as f32 * self.nb_ptc.get(i, j);
}
}
}
for (v, w) in pdc.iter_mut().zip(self.nb_pc.iter()) {
*v += w;
}
pdc
}
pub fn classify(&self, text: &str) -> (String, f32) {
let fv = self.instance2fv(text.as_bytes().iter().copied());
let probs = self.nb_classprobs(fv);
let cl = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
.map(|(i, _)| i)
.unwrap();
(self.nb_classes[cl].clone(), probs[cl])
}
}
#[cfg(test)]
mod tests {
use super::LanguageIdentifier;
fn f32_is_close(a: f32, b: f32, epsilon: f32) -> bool {
(a - b).abs() < epsilon
}
#[test]
fn test_classify() {
let li = LanguageIdentifier::new();
for (text, (exp_lang, exp_prob)) in [
("你是我万水千山的冒险要找的标记点", ("zh", -256.80695)),
("あなたの体育の先生は誰ですか?", ("ja", -376.09363)),
("This text is in English.", ("en", -56.77429)),
] {
let (lang, prob) = li.classify(text);
assert_eq!(lang, exp_lang);
assert!(f32_is_close(prob, exp_prob, 1e-4));
}
}
}