compression_text_classification/
lib.rs1#![feature(binary_heap_into_iter_sorted, type_alias_impl_trait)]
2
3mod caching;
4
5use caching::CachedTrainItem;
6use itertools::Itertools;
7use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
8use std::{
9 cmp::{max, min, Reverse},
10 collections::BinaryHeap,
11 io::Write,
12 vec::Vec,
13};
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub struct TrainItem {
18 class: String,
19 text: String,
20}
21impl TrainItem {
22 pub fn new(class: String, text: String) -> Self {
23 Self { class, text }
24 }
25}
26
27#[derive(Default)]
28pub struct Dataset {
29 items: Vec<CachedTrainItem>,
30}
31impl FromIterator<TrainItem> for Dataset {
32 fn from_iter<T: IntoIterator<Item = TrainItem>>(iter: T) -> Self {
33 Self {
34 items: iter.into_iter().map(|x| x.into()).collect(),
35 }
36 }
37}
38struct Wraper<'a>(&'a str, f64);
39impl<'a> PartialEq for Wraper<'a> {
40 fn eq(&self, other: &Self) -> bool {
41 self.1 == other.1
42 }
43}
44impl<'a> Eq for Wraper<'a> {}
45impl<'a> PartialOrd for Wraper<'a> {
46 #[allow(clippy::non_canonical_partial_ord_impl)]
47 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
48 self.1.partial_cmp(&other.1)
49 }
50}
51impl<'a> Ord for Wraper<'a> {
52 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
53 self.partial_cmp(other).unwrap()
54 }
55}
56
57#[derive(Error, Debug)]
58#[non_exhaustive]
59pub enum ClassificationError {
60 #[error("dataset is empty")]
61 DatasetIsEmpty,
62}
63
64impl Dataset {
65 pub fn insert(&mut self, value: TrainItem) {
66 self.items.push(value.into());
67 }
68 pub fn classify(
69 &self,
70 text: String,
71 k_nearest_neighbours: usize,
72 ) -> Result<String, ClassificationError> {
73 let cached_text = text.into();
74 Ok(self
75 .items
76 .par_iter()
77 .map(|item| {
78 Reverse(Wraper(
79 item.class(),
80 caching::normalized_compression_distance(item.into(), &cached_text),
81 ))
82 })
83 .collect::<BinaryHeap<_>>()
84 .into_iter_sorted()
85 .take(k_nearest_neighbours)
86 .map(|Reverse(wraper)| (wraper.0, wraper.1))
87 .into_group_map()
88 .into_iter()
89 .map(|(class, texts)| (class, texts.len()))
90 .max_by_key(|x| x.1)
91 .ok_or(ClassificationError::DatasetIsEmpty {})?
92 .0
93 .to_string())
94 }
95 pub fn check_classification(
96 &self,
97 item: TrainItem,
98 k_nearest_neighbours: usize,
99 ) -> Result<bool, ClassificationError> {
100 Ok(item.class == self.classify(item.text, k_nearest_neighbours)?)
101 }
102}
103pub fn normalized_compression_distance(x: &str, y: &str) -> f64 {
104 let c_x = compressed_size(x);
105 let c_y = compressed_size(y);
106 let c_xy = compressed_size((x.to_owned() + y).as_str());
107 (c_xy - min(c_x, c_y)) as f64 / max(c_x, c_y) as f64
108}
109fn compressed_size(text: &str) -> usize {
110 use flate2::{write::DeflateEncoder, Compression};
111 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
112 encoder.write_all(text.as_bytes()).unwrap();
113 let compressed = encoder.finish().unwrap();
114 compressed.len()
115}