use std::collections::{HashMap, HashSet};
use std::f64;
use regex::Regex;
use rustc_serialize::json;
static DEFAULT_SMOOTHING: f64 = 1.0f64;
#[derive(Debug, Clone, RustcDecodable, RustcEncodable)]
pub struct Classifier {
vocab: HashSet<String>,
num_examples: u32,
smoothing: f64,
classifications: HashMap<String, Classification>
}
#[derive(Debug, Clone, RustcDecodable, RustcEncodable)]
struct Classification {
label: String,
num_examples: u32,
num_words: u32,
probability: f64,
default_word_probability: f64,
words: HashMap<String, (u32, f64)>,
}
impl Classifier {
pub fn new() -> Classifier {
Classifier {
vocab: HashSet::new(),
num_examples: 0u32,
smoothing: DEFAULT_SMOOTHING,
classifications: HashMap::new(),
}
}
pub fn add_document_tokenized(&mut self, document: &Vec<String>, label: &String) {
if document.len() == 0 { return; }
if !self.classifications.contains_key(label) {
let c = Classification::new(label);
self.classifications.insert(label.clone(), c);
}
let mut classification = self.classifications.get_mut(label).unwrap();
for word in document.iter() {
classification.add_word(word);
self.vocab.insert(word.to_string());
}
self.num_examples += 1;
classification.num_examples += 1;
}
pub fn add_document(&mut self, document: &String, label: &String) {
self.add_document_tokenized(&split_document(document), label);
}
pub fn add_documents(&mut self, examples: &Vec<(String, String)>) {
for &(ref document, ref label) in examples.iter() {
self.add_document(document, label);
}
}
pub fn add_documents_tokenized(&mut self, examples: &Vec<(Vec<String>, String)>) {
for &(ref document, ref label) in examples.iter() {
self.add_document_tokenized(document, label);
}
}
pub fn get_labels(&self) -> Vec<String> {
let labels: Vec<String> =
self.classifications.values().map(|c| c.label.clone()).collect();
labels
}
pub fn set_smoothing(&mut self, smoothing: f64) {
if smoothing <= 0.0 {
panic!("smoothing value must be a positive number");
}
self.smoothing = smoothing;
}
pub fn train(&mut self) {
for (_, classification) in self.classifications.iter_mut() {
classification.train(&self.vocab, self.num_examples, self.smoothing);
}
}
pub fn classify_tokenized(&self, document: &Vec<String>) -> String {
let mut max_score = f64::NEG_INFINITY;
let mut max_classification = None;
for classification in self.classifications.values() {
let score = classification.score_document(document, &self.vocab);
if score > max_score {
max_classification = Some(classification);
max_score = score;
}
}
max_classification.expect("no classification found").label.clone()
}
pub fn classify(&self, document: &String) -> String {
self.classify_tokenized(&split_document(document))
}
pub fn get_document_probabilities_tokenized(&self, document: &Vec<String>) -> Vec<(String, f64)> {
let all_probs:Vec<(String, f64)> = self.classifications.values().map(|classification| {
let score = classification.score_document(document, &self.vocab);
(classification.label.clone(), score)
}).collect();
let total_prob = all_probs.iter()
.map(|&(_, s)| s)
.fold(0.0, |acc, s| acc + s);
all_probs.into_iter().map(|(c, s)| (c, 1.0 - s/total_prob) ).collect()
}
pub fn get_document_probabilities(&self, document: &String) -> Vec<(String, f64)> {
self.get_document_probabilities_tokenized(&split_document(document))
}
pub fn to_json(&self) -> String {
json::encode(self).ok().expect("encoding JSON failed")
}
pub fn from_json(encoded: &str) -> Classifier {
let classifier: Classifier = json::decode(encoded).ok().expect("decoding JSON failed");
classifier
}
}
impl Classification {
fn new(label: &String) -> Classification {
Classification {
label: label.clone(),
num_examples: 0u32,
num_words: 0u32,
probability: 0.0f64,
default_word_probability: 0.0f64,
words: HashMap::new(),
}
}
fn add_word(&mut self, word: &String) {
self.num_words += 1;
if self.words.contains_key(word) {
self.words.get_mut(word).unwrap().0 += 1;
} else {
self.words.insert(word.clone(), (1, 0.0f64));
}
}
fn train(&mut self, vocab: &HashSet<String>, total_examples: u32, smoothing: f64) {
self.probability = self.num_examples as f64 / total_examples as f64;
self.default_word_probability = smoothing /
(self.num_words as f64 + smoothing * vocab.len() as f64);
for word in vocab.iter() {
if self.words.contains_key(word) {
let word_entry = self.words.get_mut(word).unwrap();
let word_count = word_entry.0;
let p_word_given_label =
(word_count as f64 + smoothing) /
(self.num_words as f64 + smoothing * vocab.len() as f64);
word_entry.1 = p_word_given_label;
}
}
}
fn score_document(&self, document: &Vec<String>, vocab: &HashSet<String>) -> f64 {
let mut total = 0.0f64;
for word in document.iter() {
if vocab.contains(word) {
let word_probability = match self.words.get(word) {
Some( &(_, p) ) => p,
None => self.default_word_probability,
};
total += word_probability.ln();
}
}
self.probability.ln() + total
}
}
fn split_document(document: &String) -> Vec<String> {
let re = Regex::new(r"(\s)").unwrap();
re.split(document).map(|s| s.to_string()).collect()
}