1use ferrolearn_core::error::FerroError;
7use ndarray::{Array1, Array2};
8use num_traits::Float;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum TfidfNorm {
17 L1,
19 #[default]
21 L2,
22 None,
24}
25
26#[derive(Debug, Clone)]
55pub struct TfidfTransformer<F> {
56 pub norm: TfidfNorm,
58 pub use_idf: bool,
60 pub smooth_idf: bool,
62 pub sublinear_tf: bool,
64 _marker: std::marker::PhantomData<F>,
65}
66
67impl<F: Float + Send + Sync + 'static> TfidfTransformer<F> {
68 #[must_use]
70 pub fn new() -> Self {
71 Self {
72 norm: TfidfNorm::L2,
73 use_idf: true,
74 smooth_idf: true,
75 sublinear_tf: false,
76 _marker: std::marker::PhantomData,
77 }
78 }
79
80 #[must_use]
82 pub fn norm(mut self, norm: TfidfNorm) -> Self {
83 self.norm = norm;
84 self
85 }
86
87 #[must_use]
89 pub fn use_idf(mut self, use_idf: bool) -> Self {
90 self.use_idf = use_idf;
91 self
92 }
93
94 #[must_use]
96 pub fn smooth_idf(mut self, smooth: bool) -> Self {
97 self.smooth_idf = smooth;
98 self
99 }
100
101 #[must_use]
103 pub fn sublinear_tf(mut self, sublinear: bool) -> Self {
104 self.sublinear_tf = sublinear;
105 self
106 }
107
108 pub fn fit(&self, counts: &Array2<F>) -> Result<FittedTfidfTransformer<F>, FerroError> {
114 let n_docs = counts.nrows();
115 if n_docs == 0 {
116 return Err(FerroError::InsufficientSamples {
117 required: 1,
118 actual: 0,
119 context: "TfidfTransformer::fit".into(),
120 });
121 }
122
123 let n_features = counts.ncols();
124 let n_f = F::from(n_docs).unwrap();
125
126 let idf = if self.use_idf {
127 let mut idf_vec = Array1::zeros(n_features);
128 for j in 0..n_features {
129 let df = counts
131 .column(j)
132 .iter()
133 .filter(|&&v| v > F::zero())
134 .count();
135 let df_f = F::from(df).unwrap();
136
137 if self.smooth_idf {
138 idf_vec[j] =
140 ((F::one() + n_f) / (F::one() + df_f)).ln() + F::one();
141 } else {
142 if df > 0 {
144 idf_vec[j] = (n_f / df_f).ln() + F::one();
145 } else {
146 idf_vec[j] = F::one();
147 }
148 }
149 }
150 Some(idf_vec)
151 } else {
152 None
153 };
154
155 Ok(FittedTfidfTransformer {
156 idf,
157 norm: self.norm,
158 sublinear_tf: self.sublinear_tf,
159 })
160 }
161}
162
163impl<F: Float + Send + Sync + 'static> Default for TfidfTransformer<F> {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[derive(Debug, Clone)]
177pub struct FittedTfidfTransformer<F> {
178 idf: Option<Array1<F>>,
180 norm: TfidfNorm,
182 sublinear_tf: bool,
184}
185
186impl<F: Float + Send + Sync + 'static> FittedTfidfTransformer<F> {
187 #[must_use]
189 pub fn idf(&self) -> Option<&Array1<F>> {
190 self.idf.as_ref()
191 }
192
193 pub fn transform(&self, counts: &Array2<F>) -> Result<Array2<F>, FerroError> {
201 if counts.nrows() == 0 {
202 return Err(FerroError::InsufficientSamples {
203 required: 1,
204 actual: 0,
205 context: "FittedTfidfTransformer::transform".into(),
206 });
207 }
208
209 if let Some(ref idf) = self.idf {
210 if counts.ncols() != idf.len() {
211 return Err(FerroError::ShapeMismatch {
212 expected: vec![counts.nrows(), idf.len()],
213 actual: vec![counts.nrows(), counts.ncols()],
214 context: "FittedTfidfTransformer::transform".into(),
215 });
216 }
217 }
218
219 let mut result = counts.to_owned();
220
221 if self.sublinear_tf {
223 result.mapv_inplace(|v| {
224 if v > F::zero() {
225 F::one() + v.ln()
226 } else {
227 v
228 }
229 });
230 }
231
232 if let Some(ref idf) = self.idf {
234 for mut row in result.rows_mut() {
235 for (j, v) in row.iter_mut().enumerate() {
236 *v = *v * idf[j];
237 }
238 }
239 }
240
241 match self.norm {
243 TfidfNorm::L1 => {
244 for mut row in result.rows_mut() {
245 let norm: F = row.iter().map(|v| v.abs()).fold(F::zero(), |a, b| a + b);
246 if norm > F::zero() {
247 for v in row.iter_mut() {
248 *v = *v / norm;
249 }
250 }
251 }
252 }
253 TfidfNorm::L2 => {
254 for mut row in result.rows_mut() {
255 let norm_sq: F = row.iter().map(|v| *v * *v).fold(F::zero(), |a, b| a + b);
256 let norm = norm_sq.sqrt();
257 if norm > F::zero() {
258 for v in row.iter_mut() {
259 *v = *v / norm;
260 }
261 }
262 }
263 }
264 TfidfNorm::None => {}
265 }
266
267 Ok(result)
268 }
269}
270
271#[cfg(test)]
276mod tests {
277 use super::*;
278 use approx::assert_abs_diff_eq;
279 use ndarray::array;
280
281 #[test]
282 fn test_tfidf_basic() {
283 let counts = array![
285 [1.0_f64, 1.0, 0.0],
286 [1.0, 0.0, 1.0],
287 [1.0, 0.0, 0.0],
288 ];
289 let transformer = TfidfTransformer::<f64>::new();
290 let fitted = transformer.fit(&counts).unwrap();
291 let result = fitted.transform(&counts).unwrap();
292 assert_eq!(result.shape(), &[3, 3]);
293
294 for i in 0..3 {
296 let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
297 assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
298 }
299 }
300
301 #[test]
302 fn test_tfidf_no_idf() {
303 let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
304 let transformer = TfidfTransformer::<f64>::new().use_idf(false);
305 let fitted = transformer.fit(&counts).unwrap();
306 let result = fitted.transform(&counts).unwrap();
307 for i in 0..2 {
309 let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
310 assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
311 }
312 }
313
314 #[test]
315 fn test_tfidf_l1_norm() {
316 let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
317 let transformer = TfidfTransformer::<f64>::new()
318 .use_idf(false)
319 .norm(TfidfNorm::L1);
320 let fitted = transformer.fit(&counts).unwrap();
321 let result = fitted.transform(&counts).unwrap();
322 for i in 0..2 {
323 let row_l1: f64 = result.row(i).iter().map(|v| v.abs()).sum();
324 assert_abs_diff_eq!(row_l1, 1.0, epsilon = 1e-10);
325 }
326 }
327
328 #[test]
329 fn test_tfidf_no_norm() {
330 let counts = array![[1.0_f64, 0.0], [1.0, 1.0]];
331 let transformer = TfidfTransformer::<f64>::new()
332 .use_idf(false)
333 .norm(TfidfNorm::None);
334 let fitted = transformer.fit(&counts).unwrap();
335 let result = fitted.transform(&counts).unwrap();
336 for (a, b) in counts.iter().zip(result.iter()) {
338 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
339 }
340 }
341
342 #[test]
343 fn test_tfidf_sublinear_tf() {
344 let counts = array![[4.0_f64, 1.0]];
345 let transformer = TfidfTransformer::<f64>::new()
346 .use_idf(false)
347 .sublinear_tf(true)
348 .norm(TfidfNorm::None);
349 let fitted = transformer.fit(&counts).unwrap();
350 let result = fitted.transform(&counts).unwrap();
351 assert_abs_diff_eq!(result[[0, 0]], 1.0 + 4.0_f64.ln(), epsilon = 1e-10);
353 assert_abs_diff_eq!(result[[0, 1]], 1.0, epsilon = 1e-10);
354 }
355
356 #[test]
357 fn test_tfidf_smooth_idf() {
358 let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
360 let transformer = TfidfTransformer::<f64>::new().norm(TfidfNorm::None);
361 let fitted = transformer.fit(&counts).unwrap();
362 let idf = fitted.idf().unwrap();
363
364 assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
366 assert_abs_diff_eq!(idf[1], 2.0_f64.ln() + 1.0, epsilon = 1e-10);
368 }
369
370 #[test]
371 fn test_tfidf_no_smooth_idf() {
372 let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
373 let transformer = TfidfTransformer::<f64>::new()
374 .smooth_idf(false)
375 .norm(TfidfNorm::None);
376 let fitted = transformer.fit(&counts).unwrap();
377 let idf = fitted.idf().unwrap();
378
379 assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
381 assert_abs_diff_eq!(idf[1], 3.0_f64.ln() + 1.0, epsilon = 1e-10);
383 }
384
385 #[test]
386 fn test_tfidf_empty() {
387 let counts = Array2::<f64>::zeros((0, 3));
388 let transformer = TfidfTransformer::<f64>::new();
389 assert!(transformer.fit(&counts).is_err());
390 }
391
392 #[test]
393 fn test_tfidf_shape_mismatch() {
394 let train = array![[1.0_f64, 0.0], [0.0, 1.0]];
395 let fitted = TfidfTransformer::<f64>::new().fit(&train).unwrap();
396 let bad = array![[1.0_f64, 0.0, 0.0]];
397 assert!(fitted.transform(&bad).is_err());
398 }
399
400 #[test]
401 fn test_tfidf_f32() {
402 let counts = array![[1.0_f32, 0.0], [0.0, 1.0]];
403 let transformer = TfidfTransformer::<f32>::new();
404 let fitted = transformer.fit(&counts).unwrap();
405 let result = fitted.transform(&counts).unwrap();
406 assert_eq!(result.shape(), &[2, 2]);
407 }
408}