1use std::hash::Hash;
3
4use hashbrown::{HashMap, HashSet};
5
6use crate::errors::{FindSimdocError, Result};
7use crate::feature::{FeatureConfig, FeatureExtractor};
8
9#[derive(Default)]
11pub struct Idf<T> {
12 counter: HashMap<T, usize>,
13 dedup: HashSet<T>,
14 num_docs: usize,
15 smooth: bool,
16}
17
18impl<T> Idf<T>
19where
20 T: Hash + Eq + Copy + Default,
21{
22 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub const fn smooth(mut self, yes: bool) -> Self {
29 self.smooth = yes;
30 self
31 }
32
33 pub fn add(&mut self, terms: &[T]) {
35 self.dedup.clear();
36 for &term in terms {
37 if self.dedup.insert(term) {
38 self.counter
39 .entry(term)
40 .and_modify(|c| *c += 1)
41 .or_insert(1);
42 }
43 }
44 self.num_docs += 1;
45 }
46
47 pub const fn num_docs(&self) -> usize {
49 self.num_docs
50 }
51
52 pub fn idf(&self, term: T) -> f64 {
54 let c = usize::from(self.smooth);
55 let n = (self.num_docs + c) as f64;
56 let m = (*self.counter.get(&term).unwrap() + c) as f64;
57 (n / m).log10() + 1.
58 }
59}
60
61impl Idf<u64> {
62 pub fn build<I, D>(mut self, documents: I, config: &FeatureConfig) -> Result<Self>
69 where
70 I: IntoIterator<Item = D>,
71 D: AsRef<str>,
72 {
73 let extractor = FeatureExtractor::new(config);
74 let mut feature = vec![];
75 for doc in documents {
76 let doc = doc.as_ref();
77 if doc.is_empty() {
78 return Err(FindSimdocError::input("Input document must not be empty."));
79 }
80 extractor.extract(doc, &mut feature);
81 self.add(&feature);
82 }
83 Ok(self)
84 }
85}
86
87#[derive(Default)]
89pub struct Tf {
90 sublinear: bool,
91}
92
93impl Tf {
94 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub const fn sublinear(mut self, yes: bool) -> Self {
101 self.sublinear = yes;
102 self
103 }
104
105 pub fn tf<T>(&self, terms: &mut [(T, f64)])
107 where
108 T: Hash + Eq + Copy + Default,
109 {
110 let counter = self.count(terms);
111 let total = terms.len() as f64;
112 for (term, weight) in terms {
113 let cnt = *counter.get(term).unwrap() as f64;
114 *weight = if self.sublinear {
115 cnt.log10() + 1.
116 } else {
117 cnt / total
118 };
119 }
120 }
121
122 fn count<T>(&self, terms: &mut [(T, f64)]) -> HashMap<T, usize>
123 where
124 T: Hash + Eq + Copy + Default,
125 {
126 let mut counter = HashMap::new();
127 for &(term, _) in terms.iter() {
128 counter.entry(term).and_modify(|c| *c += 1).or_insert(1);
129 }
130 counter
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::vec;
137
138 use super::*;
139
140 #[test]
141 fn test_idf() {
142 let mut idf = Idf::new();
143 idf.add(&['A', 'A', 'C']);
144 idf.add(&['A', 'C']);
145 idf.add(&['B', 'A']);
146
147 assert_eq!(idf.num_docs(), 3);
148
149 idf = idf.smooth(false);
150 assert_eq!(idf.idf('A'), (3f64 / 3f64).log10() + 1.);
151 assert_eq!(idf.idf('B'), (3f64 / 1f64).log10() + 1.);
152 assert_eq!(idf.idf('C'), (3f64 / 2f64).log10() + 1.);
153
154 idf = idf.smooth(true);
155 assert_eq!(idf.idf('A'), (4f64 / 4f64).log10() + 1.);
156 assert_eq!(idf.idf('B'), (4f64 / 2f64).log10() + 1.);
157 assert_eq!(idf.idf('C'), (4f64 / 3f64).log10() + 1.);
158 }
159
160 #[test]
161 fn test_tf() {
162 let mut tf = Tf::new();
163 let mut terms = vec![('A', 0.), ('B', 0.), ('A', 0.)];
164
165 tf = tf.sublinear(false);
166 tf.tf(&mut terms);
167 assert_eq!(
168 terms.clone(),
169 vec![('A', 2. / 3.), ('B', 1. / 3.), ('A', 2. / 3.)]
170 );
171
172 tf = tf.sublinear(true);
173 tf.tf(&mut terms);
174 assert_eq!(
175 terms.clone(),
176 vec![
177 ('A', 2f64.log10() + 1.),
178 ('B', 1f64.log10() + 1.),
179 ('A', 2f64.log10() + 1.)
180 ]
181 );
182 }
183}