compression_text_classification 0.2.0

text classification using compression alorithm
Documentation
#![feature(binary_heap_into_iter_sorted, type_alias_impl_trait)]

mod caching;

use caching::CachedTrainItem;
use itertools::Itertools;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{
    cmp::{max, min, Reverse},
    collections::BinaryHeap,
    io::Write,
    vec::Vec,
};
use thiserror::Error;

#[derive(Debug, Clone)]
pub struct TrainItem {
    class: String,
    text: String,
}
impl TrainItem {
    pub fn new(class: String, text: String) -> Self {
        Self { class, text }
    }
}

#[derive(Default)]
pub struct Dataset {
    items: Vec<CachedTrainItem>,
}
impl FromIterator<TrainItem> for Dataset {
    fn from_iter<T: IntoIterator<Item = TrainItem>>(iter: T) -> Self {
        Self {
            items: iter.into_iter().map(|x| x.into()).collect(),
        }
    }
}
struct Wraper<'a>(&'a str, f64);
impl<'a> PartialEq for Wraper<'a> {
    fn eq(&self, other: &Self) -> bool {
        self.1 == other.1
    }
}
impl<'a> Eq for Wraper<'a> {}
impl<'a> PartialOrd for Wraper<'a> {
    #[allow(clippy::non_canonical_partial_ord_impl)]
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        self.1.partial_cmp(&other.1)
    }
}
impl<'a> Ord for Wraper<'a> {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        self.partial_cmp(other).unwrap()
    }
}

#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ClassificationError {
    #[error("dataset is empty")]
    DatasetIsEmpty,
}

impl Dataset {
    pub fn insert(&mut self, value: TrainItem) {
        self.items.push(value.into());
    }
    pub fn classify(
        &self,
        text: String,
        k_nearest_neighbours: usize,
    ) -> Result<String, ClassificationError> {
        let cached_text = text.into();
        Ok(self
            .items
            .par_iter()
            .map(|item| {
                Reverse(Wraper(
                    item.class(),
                    caching::normalized_compression_distance(item.into(), &cached_text),
                ))
            })
            .collect::<BinaryHeap<_>>()
            .into_iter_sorted()
            .take(k_nearest_neighbours)
            .map(|Reverse(wraper)| (wraper.0, wraper.1))
            .into_group_map()
            .into_iter()
            .map(|(class, texts)| (class, texts.len()))
            .max_by_key(|x| x.1)
            .ok_or(ClassificationError::DatasetIsEmpty {})?
            .0
            .to_string())
    }
    pub fn check_classification(
        &self,
        item: TrainItem,
        k_nearest_neighbours: usize,
    ) -> Result<bool, ClassificationError> {
        Ok(item.class == self.classify(item.text, k_nearest_neighbours)?)
    }
}
pub fn normalized_compression_distance(x: &str, y: &str) -> f64 {
    let c_x = compressed_size(x);
    let c_y = compressed_size(y);
    let c_xy = compressed_size((x.to_owned() + y).as_str());
    (c_xy - min(c_x, c_y)) as f64 / max(c_x, c_y) as f64
}
fn compressed_size(text: &str) -> usize {
    use flate2::{write::DeflateEncoder, Compression};
    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
    encoder.write_all(text.as_bytes()).unwrap();
    let compressed = encoder.finish().unwrap();
    compressed.len()
}