#![deny(missing_docs)]
mod chinese;
pub mod language;
mod numbers;
pub mod preprocessers;
mod transliterate;
use std::io::BufRead;
use anyhow::{anyhow, Result};
use hashbrown::HashMap;
pub use preprocessers::Standardizer;
pub type Float = f32;
const FLOAT_10: Float = 10.;
#[derive(Clone)]
pub struct WordFreq {
map: HashMap<String, Float>,
minimum: Float,
num_handler: numbers::NumberHandler,
standardizer: Option<Standardizer>,
}
impl WordFreq {
pub fn new<I, W>(word_weights: I) -> Self
where
I: IntoIterator<Item = (W, Float)>,
W: AsRef<str>,
{
let mut map: HashMap<_, _> = word_weights
.into_iter()
.map(|(word, weight)| (word.as_ref().to_string(), weight))
.collect();
let sum_weight = map.values().fold(0., |acc, w| acc + w);
map.values_mut().for_each(|w| *w /= sum_weight);
Self {
map,
minimum: 0.,
num_handler: numbers::NumberHandler::new(),
standardizer: None,
}
}
pub fn minimum(mut self, minimum: Float) -> Result<Self> {
if minimum < 0. {
return Err(anyhow!("minimum must be non-negative"));
}
self.minimum = minimum;
Ok(self)
}
#[allow(clippy::missing_const_for_fn)]
pub fn standardizer(mut self, standardizer: Standardizer) -> Self {
self.standardizer = Some(standardizer);
self
}
pub fn word_frequency<W>(&self, word: W) -> Float
where
W: AsRef<str>,
{
self.word_frequency_in(word).unwrap_or(0.).max(self.minimum)
}
pub fn zipf_frequency<W>(&self, word: W) -> Float
where
W: AsRef<str>,
{
let freq_min = Self::zipf_to_freq(self.minimum);
let freq = self.word_frequency_in(word).unwrap_or(0.).max(freq_min);
let zipf = Self::freq_to_zipf(freq);
Self::round(zipf, 2)
}
fn word_frequency_in<W>(&self, word: W) -> Option<Float>
where
W: AsRef<str>,
{
let word = self.standardizer.as_ref().map_or_else(
|| word.as_ref().to_string(),
|standardizer| standardizer.apply(word.as_ref()),
);
let smashed = self.num_handler.smash_numbers(&word);
let mut freq = self.map.get(&smashed).cloned()?;
if smashed != word {
freq *= self.num_handler.digit_freq(&word);
}
Some(freq)
}
fn zipf_to_freq(zipf: Float) -> Float {
FLOAT_10.powf(zipf - 9.)
}
fn freq_to_zipf(freq: Float) -> Float {
freq.log10() + 9.
}
fn round(x: Float, places: i32) -> Float {
let multiplier = FLOAT_10.powi(places);
(x * multiplier).round() / multiplier
}
pub fn serialize(&self) -> Result<Vec<u8>> {
let mut bytes = vec![];
for (k, v) in &self.map {
bincode::serialize_into(&mut bytes, k.as_bytes())?;
bincode::serialize_into(&mut bytes, v)?;
}
Ok(bytes)
}
pub fn deserialize(mut bytes: &[u8]) -> Result<Self> {
let mut map = HashMap::new();
while !bytes.is_empty() {
let k: String = bincode::deserialize_from(&mut bytes)?;
let v: Float = bincode::deserialize_from(&mut bytes)?;
map.insert(k, v);
}
Ok(Self {
map,
minimum: 0.,
num_handler: numbers::NumberHandler::new(),
standardizer: None,
})
}
}
pub fn word_weights_from_text<R: BufRead>(rdr: R) -> Result<Vec<(String, Float)>> {
let mut word_weights = vec![];
for (i, line) in rdr.lines().enumerate() {
let line = line?;
let cols: Vec<_> = line.split_ascii_whitespace().collect();
if cols.len() != 2 {
return Err(anyhow!(
"Line {i}: a line should be <word><SPACE><weight>, but got {line}."
));
}
word_weights.push((cols[0].to_string(), cols[1].parse()?));
}
Ok(word_weights)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_empty() {
let word_weights = Vec::<(&str, Float)>::new();
let wf = WordFreq::new(word_weights);
assert_relative_eq!(wf.word_frequency("las"), 0.00);
assert_relative_eq!(wf.word_frequency("vegas"), 0.00);
}
#[test]
fn test_io() {
let word_weights = [("las", 10.), ("vegas", 30.)];
let wf = WordFreq::new(word_weights);
let model = wf.serialize().unwrap();
let other = WordFreq::deserialize(&model[..]).unwrap();
assert_eq!(wf.map, other.map);
}
}