#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]
use flate2::write::{DeflateEncoder, GzEncoder, ZlibEncoder};
use flate2::Compression;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::cmp::{max, min};
use std::collections::HashMap;
use std::io::Write;
use std::string::String;
use zstd::bulk::compress;
#[derive(Serialize, Deserialize, Debug)]
pub struct TrainingData<'a> {
pub label: &'a str,
pub content: &'a str,
pub compressed_length: Option<usize>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct NCD<'a> {
pub label: &'a str,
pub ncd: f64,
}
#[derive(Serialize, Deserialize, Debug)]
pub enum CompressionAlgorithm {
Zstd,
Gzip,
Zlib,
Deflate,
}
pub fn compressed_length(training: &str, level: i32, algorithm: &CompressionAlgorithm) -> usize {
let compressed = match *algorithm {
CompressionAlgorithm::Zstd => compress(training.as_bytes(), level).unwrap(),
CompressionAlgorithm::Gzip => {
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level as u32));
encoder.write_all(training.as_bytes()).unwrap();
encoder.finish().unwrap()
}
CompressionAlgorithm::Zlib => {
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(level as u32));
encoder.write_all(training.as_bytes()).unwrap();
encoder.finish().unwrap()
}
CompressionAlgorithm::Deflate => {
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(level as u32));
encoder.write_all(training.as_bytes()).unwrap();
encoder.finish().unwrap()
}
};
compressed.len()
}
pub fn ncd<'a>(
training_data: &'a Vec<TrainingData<'a>>,
query: &'a str,
level: i32,
algorithm: &'a CompressionAlgorithm,
) -> Vec<NCD<'a>> {
let len_training = training_data
.par_iter()
.map(|td| {
td.compressed_length
.unwrap_or_else(|| compressed_length(td.content, level, algorithm))
})
.collect::<Vec<usize>>();
let len_query = compressed_length(query, level, algorithm);
let len_combo = training_data
.par_iter()
.map(|td| compressed_length(&format!("{} {}", td.content, query), level, algorithm))
.collect::<Vec<usize>>();
let mins = len_training
.par_iter()
.map(|train_length| *min(train_length, &len_query))
.collect::<Vec<usize>>();
let maxes = len_training
.par_iter()
.map(|train_length| *max(train_length, &len_query))
.collect::<Vec<usize>>();
len_combo
.par_iter()
.zip(mins.par_iter())
.map(|(c, m)| c - m)
.collect::<Vec<usize>>()
.par_iter()
.zip(maxes.par_iter())
.map(|(n, d)| *n as f64 / *d as f64)
.collect::<Vec<f64>>()
.par_iter()
.zip(training_data.par_iter())
.map(|(ncd, td)| NCD {
label: td.label,
ncd: *ncd,
})
.collect()
}
pub fn classify(
training: &[String],
training_labels: &[String],
queries: &[String],
level: i32,
algorithm: CompressionAlgorithm,
k: usize,
) -> Vec<String> {
let training_data = training
.par_iter()
.zip(training_labels.par_iter())
.map(|(content, label)| TrainingData {
label,
content,
compressed_length: Some(compressed_length(content, level, &algorithm)),
})
.collect::<Vec<TrainingData>>();
queries
.par_iter()
.map(|query| {
let mut ncds = ncd(&training_data, query, level, &algorithm);
ncds.sort_by(|a, b| a.ncd.total_cmp(&b.ncd));
ncds[0..k]
.iter()
.map(|x| x.label)
.collect::<Vec<&str>>()
.iter()
.fold(HashMap::<String, usize>::new(), |mut m, x| {
*m.entry(x.to_string()).or_default() += 1;
m
})
.into_par_iter()
.max_by_key(|(_, v)| *v)
.map(|(x, _)| x)
.unwrap()
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use csv::Reader;
use std::fs::File;
#[test]
fn test_classification() {
let training = [
"some normal sentence".to_string(),
"godzilla ate mars in June".into(),
];
let training_labels = ["a".to_string(), "b".into()];
let queries = [
"another normal sentence".to_string(),
"godzilla eats marshes in August".into(),
];
assert_eq!(
classify(
&training,
&training_labels,
&queries,
3i32,
CompressionAlgorithm::Gzip,
1usize
),
vec!["a".to_string(), "b".into()]
);
}
#[test]
fn csv_classifications() {
let imdb = File::open("./data/imdb.csv").unwrap();
let mut reader = Reader::from_reader(imdb);
let mut content = Vec::with_capacity(50000);
let mut label = Vec::with_capacity(50000);
for record in reader.records() {
content.push(record.as_ref().unwrap()[0].to_string());
label.push(record.unwrap()[1].to_string());
}
let predictions = classify(
&content[0..5000],
&label[0..5000],
&content[5000..6000],
3i32,
CompressionAlgorithm::Zstd,
1usize,
);
let correct = predictions
.iter()
.zip(label[5000..6000].to_vec().iter())
.filter(|(a, b)| a == b)
.count();
assert_eq!(correct, 685usize)
}
}