#![feature(binary_heap_into_iter_sorted, type_alias_impl_trait)]
mod caching;
use caching::CachedTrainItem;
use itertools::Itertools;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{
cmp::{max, min, Reverse},
collections::BinaryHeap,
io::Write,
vec::Vec,
};
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct TrainItem {
class: String,
text: String,
}
impl TrainItem {
pub fn new(class: String, text: String) -> Self {
Self { class, text }
}
}
#[derive(Default)]
pub struct Dataset {
items: Vec<CachedTrainItem>,
}
impl FromIterator<TrainItem> for Dataset {
fn from_iter<T: IntoIterator<Item = TrainItem>>(iter: T) -> Self {
Self {
items: iter.into_iter().map(|x| x.into()).collect(),
}
}
}
struct Wraper<'a>(&'a str, f64);
impl<'a> PartialEq for Wraper<'a> {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
}
}
impl<'a> Eq for Wraper<'a> {}
impl<'a> PartialOrd for Wraper<'a> {
#[allow(clippy::non_canonical_partial_ord_impl)]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.1.partial_cmp(&other.1)
}
}
impl<'a> Ord for Wraper<'a> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ClassificationError {
#[error("dataset is empty")]
DatasetIsEmpty,
}
impl Dataset {
pub fn insert(&mut self, value: TrainItem) {
self.items.push(value.into());
}
pub fn classify(
&self,
text: String,
k_nearest_neighbours: usize,
) -> Result<String, ClassificationError> {
let cached_text = text.into();
Ok(self
.items
.par_iter()
.map(|item| {
Reverse(Wraper(
item.class(),
caching::normalized_compression_distance(item.into(), &cached_text),
))
})
.collect::<BinaryHeap<_>>()
.into_iter_sorted()
.take(k_nearest_neighbours)
.map(|Reverse(wraper)| (wraper.0, wraper.1))
.into_group_map()
.into_iter()
.map(|(class, texts)| (class, texts.len()))
.max_by_key(|x| x.1)
.ok_or(ClassificationError::DatasetIsEmpty {})?
.0
.to_string())
}
pub fn check_classification(
&self,
item: TrainItem,
k_nearest_neighbours: usize,
) -> Result<bool, ClassificationError> {
Ok(item.class == self.classify(item.text, k_nearest_neighbours)?)
}
}
pub fn normalized_compression_distance(x: &str, y: &str) -> f64 {
let c_x = compressed_size(x);
let c_y = compressed_size(y);
let c_xy = compressed_size((x.to_owned() + y).as_str());
(c_xy - min(c_x, c_y)) as f64 / max(c_x, c_y) as f64
}
fn compressed_size(text: &str) -> usize {
use flate2::{write::DeflateEncoder, Compression};
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(text.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
compressed.len()
}