use std::collections::HashMap;
use std::fs::File;
use std::io;
use serde::{Deserialize, Serialize};
use serde_json::{from_reader, to_writer, to_writer_pretty};
use unicode_segmentation::UnicodeSegmentation;
const DEFAULT_FILE_PATH: &str = "model.json";
const INITIAL_RATING: f32 = 0.5;
const SPAM_PROB_THRESHOLD: f32 = 0.8;
#[derive(Debug, Default, Serialize, Deserialize)]
struct Counter {
ham: u32,
spam: u32,
}
#[derive(Default, Debug, Deserialize, Serialize)]
#[serde(from = "ClassifierSerialized")]
pub struct Classifier {
token_table: HashMap<String, Counter>,
#[serde(skip)]
spam_total_count: u32,
#[serde(skip)]
ham_total_count: u32,
}
#[derive(Deserialize, Serialize)]
struct ClassifierSerialized {
token_table: HashMap<String, Counter>,
}
impl std::convert::From<ClassifierSerialized> for Classifier {
fn from(c: ClassifierSerialized) -> Self {
let spam_total_count = c.token_table.values().map(|x| x.spam).sum();
let ham_total_count = c.token_table.values().map(|x| x.ham).sum();
Self {
token_table: c.token_table,
spam_total_count,
ham_total_count,
}
}
}
impl Classifier {
pub fn new() -> Self {
Default::default()
}
pub fn new_from_pre_trained(file: &mut File) -> Result<Self, io::Error> {
let pre_trained_model = from_reader(file)?;
Ok(pre_trained_model)
}
pub fn save(&self, file: &mut File, pretty: bool) -> Result<(), io::Error> {
if pretty {
to_writer_pretty(file, &self)?;
} else {
to_writer(file, &self)?;
}
Ok(())
}
fn load_word_list(msg: &str) -> Vec<String> {
let word_list = msg.unicode_words().collect::<Vec<&str>>();
word_list.iter().map(|word| word.to_string()).collect()
}
pub fn train_spam(&mut self, msg: &str) {
for word in Self::load_word_list(msg) {
let counter = self.token_table.entry(word).or_default();
counter.spam += 1;
self.spam_total_count += 1;
}
}
pub fn train_ham(&mut self, msg: &str) {
for word in Self::load_word_list(msg) {
let counter = self.token_table.entry(word).or_default();
counter.ham += 1;
self.ham_total_count += 1;
}
}
fn spam_total_count(&self) -> u32 {
self.spam_total_count
}
fn ham_total_count(&self) -> u32 {
self.ham_total_count
}
fn rate_words(&self, msg: &str) -> Vec<f32> {
Self::load_word_list(msg)
.into_iter()
.map(|word| {
if let Some(counter) = self.token_table.get(&word) {
if counter.spam > 0 && counter.ham == 0 {
return 0.99;
} else if counter.spam == 0 && counter.ham > 0 {
return 0.01;
} else if self.spam_total_count() > 0 && self.ham_total_count() > 0 {
let ham_prob = (counter.ham as f32) / (self.ham_total_count() as f32);
let spam_prob = (counter.spam as f32) / (self.spam_total_count() as f32);
return (spam_prob / (ham_prob + spam_prob)).max(0.01);
}
}
INITIAL_RATING
})
.collect()
}
pub fn score(&self, msg: &str) -> f32 {
let ratings = self.rate_words(msg);
let ratings = match ratings.len() {
0 => return 0.0,
x if x > 20 => {
let length = ratings.len();
let mut ratings = ratings;
ratings.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
[&ratings[..10], &ratings[length - 10..]].concat()
}
_ => ratings,
};
let product: f32 = ratings.iter().product();
let alt_product: f32 = ratings.iter().map(|x| 1.0 - x).product();
product / (product + alt_product)
}
pub fn identify(&self, msg: &str) -> bool {
self.score(msg) > SPAM_PROB_THRESHOLD
}
}
pub fn score(msg: &str) -> Result<f32, io::Error> {
let mut file = File::open(DEFAULT_FILE_PATH)?;
Classifier::new_from_pre_trained(&mut file).map(|classifier| classifier.score(msg))
}
pub fn identify(msg: &str) -> Result<bool, io::Error> {
let score = score(msg)?;
let is_spam = score > SPAM_PROB_THRESHOLD;
Ok(is_spam)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let mut classifier = Classifier::new();
let spam = "Don't forget our special promotion: -30% on men shoes, only today!";
classifier.train_spam(spam);
let ham = "Hi Bob, don't forget our meeting today at 4pm.";
classifier.train_ham(ham);
let spam = "Lose up to 19% weight. Special promotion on our new weightloss.";
let is_spam = classifier.identify(spam);
assert!(is_spam);
let ham = "Hi Bob, can you send me your machine learning homework?";
let is_spam = classifier.identify(ham);
assert!(!is_spam);
}
#[test]
fn test_new_unicode() {
let mut classifier = Classifier::new();
let spam = "Bon plan pour Nöel: profitez de -50% sur le 2ème article.";
classifier.train_spam(spam);
let ham = "Vous êtes tous cordialement invités à notre repas de Noël.";
classifier.train_ham(ham);
let spam = "Préparez les fêtes de Nöel: 1 article offert!";
let is_spam = classifier.identify(spam);
assert!(is_spam);
let ham = "Pourras-tu être des nôtres pour le repas de Noël?";
let is_spam = classifier.identify(ham);
assert!(!is_spam);
}
#[test]
fn test_new_from_pre_trained() -> Result<(), io::Error> {
let spam = "Lose up to 19% weight. Special promotion on our new weightloss.";
let is_spam = identify(spam)?;
assert!(is_spam);
let ham = "Hi Bob, can you send me your machine learning homework?";
let is_spam = identify(ham)?;
assert!(!is_spam);
Ok(())
}
}