compression_text_classification/
lib.rs

1#![feature(binary_heap_into_iter_sorted, type_alias_impl_trait)]
2
3mod caching;
4
5use caching::CachedTrainItem;
6use itertools::Itertools;
7use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
8use std::{
9    cmp::{max, min, Reverse},
10    collections::BinaryHeap,
11    io::Write,
12    vec::Vec,
13};
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub struct TrainItem {
18    class: String,
19    text: String,
20}
21impl TrainItem {
22    pub fn new(class: String, text: String) -> Self {
23        Self { class, text }
24    }
25}
26
27#[derive(Default)]
28pub struct Dataset {
29    items: Vec<CachedTrainItem>,
30}
31impl FromIterator<TrainItem> for Dataset {
32    fn from_iter<T: IntoIterator<Item = TrainItem>>(iter: T) -> Self {
33        Self {
34            items: iter.into_iter().map(|x| x.into()).collect(),
35        }
36    }
37}
38struct Wraper<'a>(&'a str, f64);
39impl<'a> PartialEq for Wraper<'a> {
40    fn eq(&self, other: &Self) -> bool {
41        self.1 == other.1
42    }
43}
44impl<'a> Eq for Wraper<'a> {}
45impl<'a> PartialOrd for Wraper<'a> {
46    #[allow(clippy::non_canonical_partial_ord_impl)]
47    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
48        self.1.partial_cmp(&other.1)
49    }
50}
51impl<'a> Ord for Wraper<'a> {
52    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
53        self.partial_cmp(other).unwrap()
54    }
55}
56
57#[derive(Error, Debug)]
58#[non_exhaustive]
59pub enum ClassificationError {
60    #[error("dataset is empty")]
61    DatasetIsEmpty,
62}
63
64impl Dataset {
65    pub fn insert(&mut self, value: TrainItem) {
66        self.items.push(value.into());
67    }
68    pub fn classify(
69        &self,
70        text: String,
71        k_nearest_neighbours: usize,
72    ) -> Result<String, ClassificationError> {
73        let cached_text = text.into();
74        Ok(self
75            .items
76            .par_iter()
77            .map(|item| {
78                Reverse(Wraper(
79                    item.class(),
80                    caching::normalized_compression_distance(item.into(), &cached_text),
81                ))
82            })
83            .collect::<BinaryHeap<_>>()
84            .into_iter_sorted()
85            .take(k_nearest_neighbours)
86            .map(|Reverse(wraper)| (wraper.0, wraper.1))
87            .into_group_map()
88            .into_iter()
89            .map(|(class, texts)| (class, texts.len()))
90            .max_by_key(|x| x.1)
91            .ok_or(ClassificationError::DatasetIsEmpty {})?
92            .0
93            .to_string())
94    }
95    pub fn check_classification(
96        &self,
97        item: TrainItem,
98        k_nearest_neighbours: usize,
99    ) -> Result<bool, ClassificationError> {
100        Ok(item.class == self.classify(item.text, k_nearest_neighbours)?)
101    }
102}
103pub fn normalized_compression_distance(x: &str, y: &str) -> f64 {
104    let c_x = compressed_size(x);
105    let c_y = compressed_size(y);
106    let c_xy = compressed_size((x.to_owned() + y).as_str());
107    (c_xy - min(c_x, c_y)) as f64 / max(c_x, c_y) as f64
108}
109fn compressed_size(text: &str) -> usize {
110    use flate2::{write::DeflateEncoder, Compression};
111    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
112    encoder.write_all(text.as_bytes()).unwrap();
113    let compressed = encoder.finish().unwrap();
114    compressed.len()
115}