find_simdoc/
tfidf.rs

1//! Weighters of TF-IDF.
2use std::hash::Hash;
3
4use hashbrown::{HashMap, HashSet};
5
6use crate::errors::{FindSimdocError, Result};
7use crate::feature::{FeatureConfig, FeatureExtractor};
8
9/// Weighter of inverse document frequency.
10#[derive(Default)]
11pub struct Idf<T> {
12    counter: HashMap<T, usize>,
13    dedup: HashSet<T>,
14    num_docs: usize,
15    smooth: bool,
16}
17
18impl<T> Idf<T>
19where
20    T: Hash + Eq + Copy + Default,
21{
22    /// Creates an instance.
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    /// Enables smoothing.
28    pub const fn smooth(mut self, yes: bool) -> Self {
29        self.smooth = yes;
30        self
31    }
32
33    /// Trains the frequency of terms for a document.
34    pub fn add(&mut self, terms: &[T]) {
35        self.dedup.clear();
36        for &term in terms {
37            if self.dedup.insert(term) {
38                self.counter
39                    .entry(term)
40                    .and_modify(|c| *c += 1)
41                    .or_insert(1);
42            }
43        }
44        self.num_docs += 1;
45    }
46
47    /// Gets the number of input documents.
48    pub const fn num_docs(&self) -> usize {
49        self.num_docs
50    }
51
52    /// Computes the IDF of an input term.
53    pub fn idf(&self, term: T) -> f64 {
54        let c = usize::from(self.smooth);
55        let n = (self.num_docs + c) as f64;
56        let m = (*self.counter.get(&term).unwrap() + c) as f64;
57        (n / m).log10() + 1.
58    }
59}
60
61impl Idf<u64> {
62    /// Trains the term frequency of input documents.
63    ///
64    /// # Arguments
65    ///
66    /// * `documents` - List of documents.
67    /// * `config` - Configuration of feature extraction. Use the same configuration as that in search.
68    pub fn build<I, D>(mut self, documents: I, config: &FeatureConfig) -> Result<Self>
69    where
70        I: IntoIterator<Item = D>,
71        D: AsRef<str>,
72    {
73        let extractor = FeatureExtractor::new(config);
74        let mut feature = vec![];
75        for doc in documents {
76            let doc = doc.as_ref();
77            if doc.is_empty() {
78                return Err(FindSimdocError::input("Input document must not be empty."));
79            }
80            extractor.extract(doc, &mut feature);
81            self.add(&feature);
82        }
83        Ok(self)
84    }
85}
86
87/// Weighter of term frequency.
88#[derive(Default)]
89pub struct Tf {
90    sublinear: bool,
91}
92
93impl Tf {
94    /// Creates an instance.
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Enables sublinear normalization.
100    pub const fn sublinear(mut self, yes: bool) -> Self {
101        self.sublinear = yes;
102        self
103    }
104
105    /// Computes the TF of input terms.
106    pub fn tf<T>(&self, terms: &mut [(T, f64)])
107    where
108        T: Hash + Eq + Copy + Default,
109    {
110        let counter = self.count(terms);
111        let total = terms.len() as f64;
112        for (term, weight) in terms {
113            let cnt = *counter.get(term).unwrap() as f64;
114            *weight = if self.sublinear {
115                cnt.log10() + 1.
116            } else {
117                cnt / total
118            };
119        }
120    }
121
122    fn count<T>(&self, terms: &mut [(T, f64)]) -> HashMap<T, usize>
123    where
124        T: Hash + Eq + Copy + Default,
125    {
126        let mut counter = HashMap::new();
127        for &(term, _) in terms.iter() {
128            counter.entry(term).and_modify(|c| *c += 1).or_insert(1);
129        }
130        counter
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::vec;
137
138    use super::*;
139
140    #[test]
141    fn test_idf() {
142        let mut idf = Idf::new();
143        idf.add(&['A', 'A', 'C']);
144        idf.add(&['A', 'C']);
145        idf.add(&['B', 'A']);
146
147        assert_eq!(idf.num_docs(), 3);
148
149        idf = idf.smooth(false);
150        assert_eq!(idf.idf('A'), (3f64 / 3f64).log10() + 1.);
151        assert_eq!(idf.idf('B'), (3f64 / 1f64).log10() + 1.);
152        assert_eq!(idf.idf('C'), (3f64 / 2f64).log10() + 1.);
153
154        idf = idf.smooth(true);
155        assert_eq!(idf.idf('A'), (4f64 / 4f64).log10() + 1.);
156        assert_eq!(idf.idf('B'), (4f64 / 2f64).log10() + 1.);
157        assert_eq!(idf.idf('C'), (4f64 / 3f64).log10() + 1.);
158    }
159
160    #[test]
161    fn test_tf() {
162        let mut tf = Tf::new();
163        let mut terms = vec![('A', 0.), ('B', 0.), ('A', 0.)];
164
165        tf = tf.sublinear(false);
166        tf.tf(&mut terms);
167        assert_eq!(
168            terms.clone(),
169            vec![('A', 2. / 3.), ('B', 1. / 3.), ('A', 2. / 3.)]
170        );
171
172        tf = tf.sublinear(true);
173        tf.tf(&mut terms);
174        assert_eq!(
175            terms.clone(),
176            vec![
177                ('A', 2f64.log10() + 1.),
178                ('B', 1f64.log10() + 1.),
179                ('A', 2f64.log10() + 1.)
180            ]
181        );
182    }
183}