use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::errors::{ModelLoadingError, ModelSavingError};
use crate::tokenizer::{tokenize, Token};
use crate::utils::{join_with_spaces, uppercase_first_letter};
pub(crate) type CaseMap = HashMap<String, String>;
#[derive(Serialize, Deserialize, Debug)]
pub struct Model {
pub(crate) unigrams: CaseMap,
pub(crate) bigrams: CaseMap,
pub(crate) trigrams: CaseMap,
}
#[derive(Copy, Clone, Debug)]
enum Mode {
Sentence,
Phrase,
}
impl Model {
pub fn deserialize(bytes: &[u8]) -> Result<Self, ModelLoadingError> {
Ok(serde_json::from_slice(bytes)?)
}
pub fn serialize(&self) -> Result<Vec<u8>, ModelSavingError> {
Ok(serde_json::to_vec(self)?)
}
pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self, ModelLoadingError> {
let mut vec = Vec::new();
File::open(path)?.read_to_end(&mut vec)?;
Model::deserialize(&vec)
}
pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<(), ModelSavingError> {
let serialized = Model::serialize(&self)?;
File::create(path)?.write_all(&serialized)?;
Ok(())
}
pub fn truecase(&self, sentence: &str) -> String {
self.truecase_tokens(tokenize(sentence), Mode::Sentence)
}
pub fn truecase_phrase(&self, phrase: &str) -> String {
self.truecase_tokens(tokenize(phrase), Mode::Phrase)
}
fn truecase_tokens<'a, I>(&self, tokens: I, mode: Mode) -> String
where
I: Iterator<Item = Token<'a>>,
{
let mut truecase_tokens = Vec::with_capacity(3);
let mut normalized_words_with_indexes = Vec::with_capacity(3);
let mut special_case_first_word = match mode {
Mode::Sentence => true,
Mode::Phrase => false,
};
for (index, token) in tokens.enumerate() {
let truecased;
if token.is_meaningful() {
let from_unigrams = self
.unigrams
.get(token.normalized.as_ref())
.map(|string| string.as_str());
if special_case_first_word {
special_case_first_word = false;
truecased = match from_unigrams {
Some(s) => s.to_owned(),
None => uppercase_first_letter(&token.normalized),
};
} else {
truecased = match from_unigrams {
Some(s) => s.to_owned(),
None => token.normalized.as_ref().to_owned(),
};
normalized_words_with_indexes.push((index, token.normalized));
}
} else {
truecased = token.original.to_owned();
}
truecase_tokens.push(truecased);
}
let update_true_cases_from_ngrams = |size, source: &CaseMap, result: &mut Vec<String>| {
for slice in normalized_words_with_indexes.windows(size) {
let indexes = slice.iter().map(|x| x.0);
let ngram = join_with_spaces(slice.iter().map(get_second_item));
if let Some(truecased_ngram) = source.get(&ngram) {
for (word, index) in truecased_ngram.split(' ').zip(indexes) {
result[index] = word.to_owned();
}
}
}
};
update_true_cases_from_ngrams(2, &self.bigrams, &mut truecase_tokens);
update_true_cases_from_ngrams(3, &self.trigrams, &mut truecase_tokens);
truecase_tokens.join("")
}
}
fn get_second_item<A, B>(tuple: &(A, B)) -> &B {
&tuple.1
}