wordfreq/
lib.rs

1// Copyright 2022 Robyn Speer
2// Copyright 2023 Shunsuke Kanda
3//
4// The code is a port from https://github.com/rspeer/wordfreq/tree/v3.0.2
5// together with the comments, following the MIT-license.
6//! # wordfreq
7//!
8//! This crate is a yet another Rust port of [Python's wordfreq](https://github.com/rspeer/wordfreq),
9//! allowing you to look up the frequencies of words in many languages.
10//!
11//! Note that **this crate provides only the algorithms (including hardcoded standardization) and does not contain the models**.
12//! Use [wordfreq-model](https://docs.rs/wordfreq-model/) to easily load distributed models.
13//! We recommend to see the [documentation](https://docs.rs/wordfreq-model/) for quick start.
14//!
15//! ## How to create instances without wordfreq-model
16//!
17//! If you do not desire automatic model downloads using [wordfreq-model](https://docs.rs/wordfreq-model/),
18//! you can create instances directly from the actual model files placed [here](https://github.com/kampersanda/wordfreq-rs/releases/tag/models-v1).
19//! The model files describe words and their frequencies in the text format:
20//!
21//! ```text
22//! <word1> <freq1>
23//! <word2> <freq2>
24//! <word3> <freq3>
25//! ...
26//! ```
27//!
28//! You can create instances as follows:
29//!
30//! ```
31//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
32//! use approx::assert_relative_eq;
33//! use wordfreq::WordFreq;
34//!
35//! let word_weight_text = "las 10\nvegas 30\n";
36//! let word_weights = wordfreq::word_weights_from_text(word_weight_text.as_bytes())?;
37//!
38//! let wf = WordFreq::new(word_weights);
39//! assert_relative_eq!(wf.word_frequency("las"), 0.25);
40//! assert_relative_eq!(wf.word_frequency("vegas"), 0.75);
41//! assert_relative_eq!(wf.word_frequency("Las"), 0.00);
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! ## Standardization
47//!
48//! If you want to standardize words before looking up their frequencies,
49//! set up an instance of [`Standardizer`].
50//!
51//! ```
52//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
53//! use approx::assert_relative_eq;
54//! use wordfreq::WordFreq;
55//! use wordfreq::Standardizer;
56//!
57//! let word_weight_text = "las 10\nvegas 30\n";
58//! let word_weights = wordfreq::word_weights_from_text(word_weight_text.as_bytes())?;
59//!
60//! let standardizer = Standardizer::new("en")?;
61//! let wf = WordFreq::new(word_weights).standardizer(standardizer);
62//! assert_relative_eq!(wf.word_frequency("Las"), 0.25); // Standardized
63//! # Ok(())
64//! # }
65//! ```
66//!
67//! ## Precision errors
68//!
69//! Even if the algorithms are the same, the results may differ slightly from the original implementation
70//! due to floating point precision.
71//!
72//! ## Unprovided features
73//!
74//! This crate is a straightforward port of Python's wordfreq,
75//! although some features are not provided:
76//!
77//! - [Tokenization](https://github.com/rspeer/wordfreq/tree/v3.0.2#tokenization)
78//! - [Additional functions](https://github.com/rspeer/wordfreq/tree/v3.0.2#other-functions)
79#![deny(missing_docs)]
80
81mod chinese;
82pub mod language;
83mod numbers;
84pub mod preprocessers;
85mod transliterate;
86
87use std::io::BufRead;
88
89use anyhow::{anyhow, Result};
90use hashbrown::HashMap;
91
92pub use preprocessers::Standardizer;
93
94/// Common type of floating numbers.
95pub type Float = f32;
96
97// To avoid annoying clippy errors.
98const FLOAT_10: Float = 10.;
99
100/// Implementation of wordfreq.
101#[derive(Clone)]
102pub struct WordFreq {
103    map: HashMap<String, Float>,
104    minimum: Float,
105    num_handler: numbers::NumberHandler,
106    standardizer: Option<Standardizer>,
107}
108
109impl WordFreq {
110    /// Creates an instance from frequencies.
111    ///
112    /// # Arguments
113    ///
114    /// - `word_weights`: Pairs of words and their frequencies (or probabilities) from a corpus.
115    ///
116    /// # Notes
117    ///
118    /// If the input contains duplicate words, the last occurrence is used.
119    pub fn new<I, W>(word_weights: I) -> Self
120    where
121        I: IntoIterator<Item = (W, Float)>,
122        W: AsRef<str>,
123    {
124        let mut map: HashMap<_, _> = word_weights
125            .into_iter()
126            .map(|(word, weight)| (word.as_ref().to_string(), weight))
127            .collect();
128        let sum_weight = map.values().fold(0., |acc, w| acc + w);
129        map.values_mut().for_each(|w| *w /= sum_weight);
130        Self {
131            map,
132            minimum: 0.,
133            num_handler: numbers::NumberHandler::new(),
134            standardizer: None,
135        }
136    }
137
138    /// Sets the lower bound of returned frequencies (default is 0.0).
139    ///
140    /// An error is returned if the input is negative.
141    pub fn minimum(mut self, minimum: Float) -> Result<Self> {
142        if minimum < 0. {
143            return Err(anyhow!("minimum must be non-negative"));
144        }
145        self.minimum = minimum;
146        Ok(self)
147    }
148
149    /// Sets the standardizer for preprocessing words.
150    ///
151    /// If set, the standardizer is always applied to words before looking up their frequencies.
152    #[allow(clippy::missing_const_for_fn)]
153    pub fn standardizer(mut self, standardizer: Standardizer) -> Self {
154        self.standardizer = Some(standardizer);
155        self
156    }
157
158    /// Returns the word's frequency, normalized between 0.0 and 1.0.
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// use approx::assert_relative_eq;
164    /// use wordfreq::WordFreq;
165    ///
166    /// let word_weights = [("las", 10.), ("vegas", 30.)];
167    /// let wf = WordFreq::new(word_weights);
168    ///
169    /// assert_relative_eq!(wf.word_frequency("las"), 0.25);
170    /// assert_relative_eq!(wf.word_frequency("vegas"), 0.75);
171    /// assert_relative_eq!(wf.word_frequency("Las"), 0.00);
172    /// ```
173    pub fn word_frequency<W>(&self, word: W) -> Float
174    where
175        W: AsRef<str>,
176    {
177        self.word_frequency_in(word).unwrap_or(0.).max(self.minimum)
178    }
179
180    /// Returns the Zipf frequency of a word as a human-friendly logarithmic scale.
181    ///
182    /// # Examples
183    ///
184    /// ```
185    /// use approx::assert_relative_eq;
186    /// use wordfreq::WordFreq;
187    ///
188    /// let word_weights = [("las", 10.), ("vegas", 30.)];
189    /// let wf = WordFreq::new(word_weights);
190    ///
191    /// assert_relative_eq!(wf.zipf_frequency("las"), 8.4);
192    /// assert_relative_eq!(wf.zipf_frequency("vegas"), 8.88);
193    /// assert_relative_eq!(wf.zipf_frequency("Las"), 0.00);
194    /// ```
195    pub fn zipf_frequency<W>(&self, word: W) -> Float
196    where
197        W: AsRef<str>,
198    {
199        let freq_min = Self::zipf_to_freq(self.minimum);
200        let freq = self.word_frequency_in(word).unwrap_or(0.).max(freq_min);
201        let zipf = Self::freq_to_zipf(freq);
202        Self::round(zipf, 2)
203    }
204
205    fn word_frequency_in<W>(&self, word: W) -> Option<Float>
206    where
207        W: AsRef<str>,
208    {
209        let word = self.standardizer.as_ref().map_or_else(
210            || word.as_ref().to_string(),
211            |standardizer| standardizer.apply(word.as_ref()),
212        );
213
214        let smashed = self.num_handler.smash_numbers(&word);
215        let mut freq = self.map.get(&smashed).cloned()?;
216
217        if smashed != word {
218            // If there is a digit sequence in the token, the digits are
219            // internally replaced by 0s to aggregate their probabilities
220            // together. We then assign a specific frequency to the digit
221            // sequence using the `digit_freq` distribution.
222            freq *= self.num_handler.digit_freq(&word);
223        }
224
225        // All our frequency data is only precise to within 1% anyway, so round
226        // it to 3 significant digits
227        // let leading_zeroes = (-freq.log10()).floor() as i32;
228        // Some(Self::round(freq, leading_zeroes + 3))
229
230        // NOTE(kampersanda): Rounding would not always be necessary.
231        Some(freq)
232    }
233
234    fn zipf_to_freq(zipf: Float) -> Float {
235        FLOAT_10.powf(zipf - 9.)
236    }
237
238    fn freq_to_zipf(freq: Float) -> Float {
239        freq.log10() + 9.
240    }
241
242    fn round(x: Float, places: i32) -> Float {
243        let multiplier = FLOAT_10.powi(places);
244        (x * multiplier).round() / multiplier
245    }
246
247    /// Exports the model data.
248    ///
249    /// Note that the format is distinct from the one used in the oritinal Python package.
250    pub fn serialize(&self) -> Result<Vec<u8>> {
251        let mut bytes = vec![];
252        for (k, v) in &self.map {
253            bincode::serialize_into(&mut bytes, k.as_bytes())?;
254            bincode::serialize_into(&mut bytes, v)?;
255        }
256        Ok(bytes)
257    }
258
259    /// Deserializes the model, which is exported by [`WordFreq::serialize()`].
260    pub fn deserialize(mut bytes: &[u8]) -> Result<Self> {
261        let mut map = HashMap::new();
262        while !bytes.is_empty() {
263            let k: String = bincode::deserialize_from(&mut bytes)?;
264            let v: Float = bincode::deserialize_from(&mut bytes)?;
265            map.insert(k, v);
266        }
267        Ok(Self {
268            map,
269            minimum: 0.,
270            num_handler: numbers::NumberHandler::new(),
271            standardizer: None,
272        })
273    }
274
275    /// Returns the reference to the internal word-frequency map.
276    pub const fn word_frequency_map(&self) -> &HashMap<String, Float> {
277        &self.map
278    }
279}
280
281/// Parses pairs of a word and its weight from a text file,
282/// where each line has a word and its weight sparated by the ASCII whitespace.
283///
284/// ```text
285/// <word> <weight>
286/// ```
287///
288/// # Examples
289///
290/// ```
291/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
292/// let word_weight_text = "las 10\nvegas 30\n";
293/// let word_weights = wordfreq::word_weights_from_text(word_weight_text.as_bytes())?;
294///
295/// assert_eq!(
296///     word_weights,
297///     vec![("las".to_string(), 10.), ("vegas".to_string(), 30.)]
298/// );
299/// # Ok(())
300/// # }
301/// ```
302pub fn word_weights_from_text<R: BufRead>(rdr: R) -> Result<Vec<(String, Float)>> {
303    let mut word_weights = vec![];
304    for (i, line) in rdr.lines().enumerate() {
305        let line = line?;
306        let cols: Vec<_> = line.split_ascii_whitespace().collect();
307        if cols.len() != 2 {
308            return Err(anyhow!(
309                "Line {i}: a line should be <word><SPACE><weight>, but got {line}."
310            ));
311        }
312        word_weights.push((cols[0].to_string(), cols[1].parse()?));
313    }
314    Ok(word_weights)
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    use approx::assert_relative_eq;
322
323    #[test]
324    fn test_empty() {
325        let word_weights = Vec::<(&str, Float)>::new();
326        let wf = WordFreq::new(word_weights);
327        assert_relative_eq!(wf.word_frequency("las"), 0.00);
328        assert_relative_eq!(wf.word_frequency("vegas"), 0.00);
329    }
330
331    #[test]
332    fn test_io() {
333        let word_weights = [("las", 10.), ("vegas", 30.)];
334        let wf = WordFreq::new(word_weights);
335
336        let model = wf.serialize().unwrap();
337        let other = WordFreq::deserialize(&model[..]).unwrap();
338
339        assert_eq!(wf.map, other.map);
340    }
341}