py3langid_rs 0.1.0

A high-performance, pure Rust port of py3langid
Documentation
use std::{
    fs::File,
    io::{self, BufReader, Cursor, Read},
    path::Path,
};

// Helper functions for reading primitives
fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
    let mut buf = [0u8; 4];
    r.read_exact(&mut buf)?;
    Ok(u32::from_le_bytes(buf))
}

fn read_string<R: Read>(r: &mut R) -> io::Result<String> {
    let len = read_u32(r)? as usize;
    let mut str_buf = vec![0u8; len];
    r.read_exact(&mut str_buf)?;
    String::from_utf8(str_buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}

// Zero-copy read for vectors of primitive types
fn read_f32_vec<R: Read>(r: &mut R, count: usize) -> io::Result<Vec<f32>> {
    let mut vec = vec![0.0f32; count];
    r.read_exact(bytemuck::cast_slice_mut(&mut vec))?;
    Ok(vec)
}

fn read_u16_vec<R: Read>(r: &mut R, count: usize) -> io::Result<Vec<u16>> {
    let mut vec = vec![0u16; count];
    r.read_exact(bytemuck::cast_slice_mut(&mut vec))?;
    Ok(vec)
}

#[derive(Debug)]
struct TkOutput {
    offsets: Vec<(u32, u32)>, // (offset, length) pairs
    values: Vec<u16>,
}

impl TkOutput {
    fn get(&self, i: u16) -> impl Iterator<Item = &u16> {
        let (offset, length) = self.offsets[i as usize];
        self.values.as_slice()[offset as usize..(offset + length) as usize].iter()
    }
}

#[derive(Debug)]
struct FlatArray<T> {
    cols: usize,
    data: Vec<T>,
}

impl<T> FlatArray<T> {
    fn new(rows: usize, cols: usize, data: Vec<T>) -> Self {
        assert_eq!(rows * cols, data.len());
        Self { cols, data }
    }

    fn get(&self, i: usize, j: usize) -> &T {
        &self.data[i * self.cols + j]
    }
}

#[derive(Debug)]
pub struct LanguageIdentifier {
    nb_classes: Vec<String>,
    nb_pc: Vec<f32>,
    nb_ptc: FlatArray<f32>,
    tk_nextmove: Vec<u16>,
    tk_output: TkOutput,
    num_feats: u32,
}

impl LanguageIdentifier {
    // NOTE: this fn is generated by claude
    pub fn from_reader<R: Read>(mut r: R) -> io::Result<Self> {
        // Magic bytes check
        let mut magic = [0u8; 4];
        r.read_exact(&mut magic)?;
        if &magic != b"LANG" {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "Invalid magic bytes",
            ));
        }

        // Read counts
        let num_classes = read_u32(&mut r)?;
        let num_feats = read_u32(&mut r)?;

        // Read classes
        let nb_classes = (0..num_classes)
            .map(|_| read_string(&mut r))
            .collect::<io::Result<Vec<String>>>()?;

        // Read nb_pc (class probabilities)
        let nb_pc = read_f32_vec(&mut r, num_classes as usize)?;

        // Read nb_ptc (feature matrix)
        let rows = read_u32(&mut r)? as usize;
        let cols = read_u32(&mut r)? as usize;
        let nb_ptc_data = read_f32_vec(&mut r, rows * cols)?;
        let nb_ptc = FlatArray::new(rows, cols, nb_ptc_data);

        // Read tk_nextmove
        let n_elem = read_u32(&mut r)? as usize;
        let tk_nextmove = read_u16_vec(&mut r, n_elem)?;

        // Read tk_output
        let num_offsets = read_u32(&mut r)? as usize;
        let offsets = (0..num_offsets)
            .map(|_| {
                let mut offset_buf = [0u8; 8];
                r.read_exact(&mut offset_buf)?;
                let offset = u32::from_le_bytes(offset_buf[0..4].try_into().unwrap());
                let length = u32::from_le_bytes(offset_buf[4..8].try_into().unwrap());
                Ok((offset, length))
            })
            .collect::<io::Result<Vec<(u32, u32)>>>()?;

        let n_elem = read_u32(&mut r)? as usize;
        let values = read_u16_vec(&mut r, n_elem)?;
        let tk_output = TkOutput { offsets, values };

        Ok(Self {
            nb_classes,
            nb_pc,
            nb_ptc,
            tk_nextmove,
            tk_output,
            num_feats,
        })
    }

    pub fn from_lzma_bytes<R: Read>(bytes: R) -> io::Result<Self> {
        let mut decompressed = vec![];
        let mut compressed_file = BufReader::new(bytes);
        lzma_rs::xz_decompress(&mut compressed_file, &mut decompressed)
            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
        Self::from_reader(Cursor::new(decompressed))
    }

    pub fn from_lzma_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
        Self::from_lzma_bytes(File::open(path)?)
    }

    pub fn new() -> Self {
        let model_bytes = include_bytes!("../resource/model.bin");
        Self::from_lzma_bytes(Cursor::new(model_bytes)).unwrap()
    }

    fn instance2fv<I: Iterator<Item = u8>>(&self, text: I) -> Vec<u16> {
        let mut state = 0;
        let mut count = vec![0; self.num_feats as usize];
        for letter in text {
            state = self.tk_nextmove[((state as usize) << 8) + letter as usize];
            for i in self.tk_output.get(state) {
                count[*i as usize] += 1;
            }
        }
        count
    }

    fn nb_classprobs(&self, fv: Vec<u16>) -> Vec<f32> {
        let mut pdc = vec![0.0; self.nb_classes.len()];
        // manual dot product
        for (i, v) in fv.into_iter().enumerate() {
            if v > 0 {
                for (j, w) in pdc.iter_mut().enumerate() {
                    *w += v as f32 * self.nb_ptc.get(i, j);
                }
            }
        }
        for (v, w) in pdc.iter_mut().zip(self.nb_pc.iter()) {
            *v += w;
        }
        pdc
    }

    pub fn classify(&self, text: &str) -> (String, f32) {
        let fv = self.instance2fv(text.as_bytes().iter().copied());
        let probs = self.nb_classprobs(fv);
        let cl = probs
            .iter()
            .enumerate()
            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
            .map(|(i, _)| i)
            .unwrap();
        (self.nb_classes[cl].clone(), probs[cl])
    }
}

#[cfg(test)]
mod tests {
    use super::LanguageIdentifier;

    fn f32_is_close(a: f32, b: f32, epsilon: f32) -> bool {
        (a - b).abs() < epsilon
    }

    #[test]
    fn test_classify() {
        let li = LanguageIdentifier::new();
        for (text, (exp_lang, exp_prob)) in [
            ("你是我万水千山的冒险要找的标记点", ("zh", -256.80695)),
            ("あなたの体育の先生は誰ですか?", ("ja", -376.09363)),
            ("This text is in English.", ("en", -56.77429)),
        ] {
            let (lang, prob) = li.classify(text);
            assert_eq!(lang, exp_lang);
            assert!(f32_is_close(prob, exp_prob, 1e-4));
        }
    }
}