use crate::ngram::Ngrams;
use glob::{glob, Paths};
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
fs::File,
io::{BufReader, Error, ErrorKind, Read, Write},
};
const DEFAULT_THRESHOLD: f32 = 0.03;
fn default_threshold() -> f32 {
DEFAULT_THRESHOLD
}
pub type IoResult<T> = std::result::Result<T, Error>;
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound = "T: Serialize, for<'a> T: Deserialize<'a>")]
struct Category<T>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
name: T,
ngrams: Ngrams,
}
impl<T> From<(T, Vec<&str>)> for Category<T>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
fn from(value: (T, Vec<&str>)) -> Category<T> {
Self {
name: value.0,
ngrams: value.1.into(),
}
}
}
impl<T> Category<T>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
pub fn distance(&self, ngrams: &Ngrams) -> u64 {
self.ngrams.distance(ngrams)
}
pub fn to_vec(&self) -> Vec<&str> {
self.ngrams.to_vec()
}
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound = "T: Serialize, for<'a> T: Deserialize<'a>")]
pub struct Categories<T>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
version: String,
categories: Vec<Category<T>>,
#[serde(
skip_deserializing,
skip_serializing,
default = "default_threshold"
)]
threshold: f32,
}
impl<T> From<Vec<(T, Vec<&str>)>> for Categories<T>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
fn from(categories: Vec<(T, Vec<&str>)>) -> Self {
let mut new = Self::new();
new.categories = categories
.iter()
.map(|m| m.to_owned().into())
.collect::<Vec<Category<T>>>();
new
}
}
#[allow(clippy::new_without_default)]
impl<T> Categories<T>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
pub fn new() -> Categories<T> {
Categories {
categories: Vec::new(),
version: env!("CARGO_PKG_VERSION").to_string(),
threshold: DEFAULT_THRESHOLD,
}
}
pub fn to_vec(&self) -> Vec<(T, Vec<&str>)> {
self.categories
.iter()
.map(|category| (category.name.clone(), category.to_vec()))
.collect()
}
pub fn set_threshold(&mut self, threshold: f32) -> Result<(), &str> {
if threshold <= 0.0 && 1.0 <= threshold {
return Err("The value has to between 0 and 1");
}
self.threshold = threshold;
Ok(())
}
pub fn get_category(&self, sample: &str) -> Option<T> {
if let Some(categories) = self.get_categories(sample) {
if categories.len() == 1 {
return Some(categories[0].0.to_owned());
}
}
None
}
pub fn get_categories(&self, sample: &str) -> Option<Vec<(T, u64)>> {
let ngrams = Ngrams::new(sample, 5);
let mut categories = self
.categories
.iter()
.map(|category| (category.distance(&ngrams), category))
.collect::<Vec<(u64, &Category<T>)>>();
categories.sort_by(|a, b| a.0.cmp(&b.0));
let best_candidate = categories.first()?;
let threshold: u64 =
((1.0 + self.threshold) * best_candidate.0 as f32) as u64;
Some(
categories
.iter()
.filter(|p| threshold > p.0)
.map(|p| (p.1.name.clone(), p.0))
.collect(),
)
}
pub fn persist(&self, output: &str) -> IoResult<()> {
let j = serde_json::to_string(&self)?;
File::create(output)?.write_all(j.as_bytes())?;
Ok(())
}
pub fn add_category(&mut self, name: T, sample: &str) {
self.categories.push(Category {
name,
ngrams: Ngrams::new(&<&str>::clone(&sample), 5),
});
}
pub fn categories(&self) -> Vec<T> {
self.categories.iter().map(|r| r.name.clone()).collect()
}
}
pub fn load<T>(path: &str) -> IoResult<Categories<T>>
where
for<'a> T: PartialEq<T> + Serialize + Deserialize<'a> + Clone,
{
let file = File::open(path)?;
let reader = BufReader::new(file);
let u = serde_json::from_reader(reader)?;
Ok(u)
}
pub fn learn_from_directory(path: &str) -> IoResult<Categories<String>> {
let files = get_files_from_directory(path)?;
let mut content = Categories::new();
for p in files {
let mut buf: Vec<u8> = Vec::new();
let p = p.map_err(|_e| {
Error::new(ErrorKind::InvalidData, "failed reading glob path")
})?;
let _bytes = File::open(p.as_path())?.read_to_end(&mut buf)?;
if let Some(Some(name)) = p.as_path().file_stem().map(|n| n.to_str()) {
let str = String::from_utf8_lossy(&buf).to_string();
content.add_category(name.to_string(), &str);
}
}
Ok(content)
}
fn get_files_from_directory(path: &str) -> IoResult<Paths> {
glob(format!("{}/*.sample", path).as_str())
.map_err(|_p| Error::new(ErrorKind::InvalidData, "invalid data"))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_files_listing_in_path() {
let r: Vec<String> = get_files_from_directory(&"tests")
.expect("Some went wrong")
.map(|p| p.expect("read name"))
.map(|p| p.to_str().clone().expect("to string").to_string())
.collect();
assert_eq!(vec!["tests/english.sample", "tests/spanish.sample",], r);
}
#[test]
fn test_learn_from_directory() {
learn_from_directory("tests").expect("failed to read file");
}
}