1use ferrolearn_core::error::FerroError;
7use ndarray::{Array1, Array2};
8use num_traits::Float;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum TfidfNorm {
17 L1,
19 #[default]
21 L2,
22 None,
24}
25
26#[derive(Debug, Clone)]
55pub struct TfidfTransformer<F> {
56 pub norm: TfidfNorm,
58 pub use_idf: bool,
60 pub smooth_idf: bool,
62 pub sublinear_tf: bool,
64 _marker: std::marker::PhantomData<F>,
65}
66
67impl<F: Float + Send + Sync + 'static> TfidfTransformer<F> {
68 #[must_use]
70 pub fn new() -> Self {
71 Self {
72 norm: TfidfNorm::L2,
73 use_idf: true,
74 smooth_idf: true,
75 sublinear_tf: false,
76 _marker: std::marker::PhantomData,
77 }
78 }
79
80 #[must_use]
82 pub fn norm(mut self, norm: TfidfNorm) -> Self {
83 self.norm = norm;
84 self
85 }
86
87 #[must_use]
89 pub fn use_idf(mut self, use_idf: bool) -> Self {
90 self.use_idf = use_idf;
91 self
92 }
93
94 #[must_use]
96 pub fn smooth_idf(mut self, smooth: bool) -> Self {
97 self.smooth_idf = smooth;
98 self
99 }
100
101 #[must_use]
103 pub fn sublinear_tf(mut self, sublinear: bool) -> Self {
104 self.sublinear_tf = sublinear;
105 self
106 }
107
108 pub fn fit(&self, counts: &Array2<F>) -> Result<FittedTfidfTransformer<F>, FerroError> {
114 let n_docs = counts.nrows();
115 if n_docs == 0 {
116 return Err(FerroError::InsufficientSamples {
117 required: 1,
118 actual: 0,
119 context: "TfidfTransformer::fit".into(),
120 });
121 }
122
123 let n_features = counts.ncols();
124 let n_f = F::from(n_docs).unwrap();
125
126 let idf = if self.use_idf {
127 let mut idf_vec = Array1::zeros(n_features);
128 for j in 0..n_features {
129 let df = counts.column(j).iter().filter(|&&v| v > F::zero()).count();
131 let df_f = F::from(df).unwrap();
132
133 if self.smooth_idf {
134 idf_vec[j] = ((F::one() + n_f) / (F::one() + df_f)).ln() + F::one();
136 } else {
137 if df > 0 {
139 idf_vec[j] = (n_f / df_f).ln() + F::one();
140 } else {
141 idf_vec[j] = F::one();
142 }
143 }
144 }
145 Some(idf_vec)
146 } else {
147 None
148 };
149
150 Ok(FittedTfidfTransformer {
151 idf,
152 norm: self.norm,
153 sublinear_tf: self.sublinear_tf,
154 })
155 }
156}
157
158impl<F: Float + Send + Sync + 'static> Default for TfidfTransformer<F> {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[derive(Debug, Clone)]
172pub struct FittedTfidfTransformer<F> {
173 idf: Option<Array1<F>>,
175 norm: TfidfNorm,
177 sublinear_tf: bool,
179}
180
181impl<F: Float + Send + Sync + 'static> FittedTfidfTransformer<F> {
182 #[must_use]
184 pub fn idf(&self) -> Option<&Array1<F>> {
185 self.idf.as_ref()
186 }
187
188 pub fn transform(&self, counts: &Array2<F>) -> Result<Array2<F>, FerroError> {
196 if counts.nrows() == 0 {
197 return Err(FerroError::InsufficientSamples {
198 required: 1,
199 actual: 0,
200 context: "FittedTfidfTransformer::transform".into(),
201 });
202 }
203
204 if let Some(ref idf) = self.idf {
205 if counts.ncols() != idf.len() {
206 return Err(FerroError::ShapeMismatch {
207 expected: vec![counts.nrows(), idf.len()],
208 actual: vec![counts.nrows(), counts.ncols()],
209 context: "FittedTfidfTransformer::transform".into(),
210 });
211 }
212 }
213
214 let mut result = counts.to_owned();
215
216 if self.sublinear_tf {
218 result.mapv_inplace(|v| if v > F::zero() { F::one() + v.ln() } else { v });
219 }
220
221 if let Some(ref idf) = self.idf {
223 for mut row in result.rows_mut() {
224 for (j, v) in row.iter_mut().enumerate() {
225 *v = *v * idf[j];
226 }
227 }
228 }
229
230 match self.norm {
232 TfidfNorm::L1 => {
233 for mut row in result.rows_mut() {
234 let norm: F = row.iter().map(|v| v.abs()).fold(F::zero(), |a, b| a + b);
235 if norm > F::zero() {
236 for v in &mut row {
237 *v = *v / norm;
238 }
239 }
240 }
241 }
242 TfidfNorm::L2 => {
243 for mut row in result.rows_mut() {
244 let norm_sq: F = row.iter().map(|v| *v * *v).fold(F::zero(), |a, b| a + b);
245 let norm = norm_sq.sqrt();
246 if norm > F::zero() {
247 for v in &mut row {
248 *v = *v / norm;
249 }
250 }
251 }
252 }
253 TfidfNorm::None => {}
254 }
255
256 Ok(result)
257 }
258}
259
260#[cfg(test)]
265mod tests {
266 use super::*;
267 use approx::assert_abs_diff_eq;
268 use ndarray::array;
269
270 #[test]
271 fn test_tfidf_basic() {
272 let counts = array![[1.0_f64, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 0.0, 0.0],];
274 let transformer = TfidfTransformer::<f64>::new();
275 let fitted = transformer.fit(&counts).unwrap();
276 let result = fitted.transform(&counts).unwrap();
277 assert_eq!(result.shape(), &[3, 3]);
278
279 for i in 0..3 {
281 let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
282 assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
283 }
284 }
285
286 #[test]
287 fn test_tfidf_no_idf() {
288 let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
289 let transformer = TfidfTransformer::<f64>::new().use_idf(false);
290 let fitted = transformer.fit(&counts).unwrap();
291 let result = fitted.transform(&counts).unwrap();
292 for i in 0..2 {
294 let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
295 assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
296 }
297 }
298
299 #[test]
300 fn test_tfidf_l1_norm() {
301 let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
302 let transformer = TfidfTransformer::<f64>::new()
303 .use_idf(false)
304 .norm(TfidfNorm::L1);
305 let fitted = transformer.fit(&counts).unwrap();
306 let result = fitted.transform(&counts).unwrap();
307 for i in 0..2 {
308 let row_l1: f64 = result.row(i).iter().map(|v| v.abs()).sum();
309 assert_abs_diff_eq!(row_l1, 1.0, epsilon = 1e-10);
310 }
311 }
312
313 #[test]
314 fn test_tfidf_no_norm() {
315 let counts = array![[1.0_f64, 0.0], [1.0, 1.0]];
316 let transformer = TfidfTransformer::<f64>::new()
317 .use_idf(false)
318 .norm(TfidfNorm::None);
319 let fitted = transformer.fit(&counts).unwrap();
320 let result = fitted.transform(&counts).unwrap();
321 for (a, b) in counts.iter().zip(result.iter()) {
323 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
324 }
325 }
326
327 #[test]
328 fn test_tfidf_sublinear_tf() {
329 let counts = array![[4.0_f64, 1.0]];
330 let transformer = TfidfTransformer::<f64>::new()
331 .use_idf(false)
332 .sublinear_tf(true)
333 .norm(TfidfNorm::None);
334 let fitted = transformer.fit(&counts).unwrap();
335 let result = fitted.transform(&counts).unwrap();
336 assert_abs_diff_eq!(result[[0, 0]], 1.0 + 4.0_f64.ln(), epsilon = 1e-10);
338 assert_abs_diff_eq!(result[[0, 1]], 1.0, epsilon = 1e-10);
339 }
340
341 #[test]
342 fn test_tfidf_smooth_idf() {
343 let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
345 let transformer = TfidfTransformer::<f64>::new().norm(TfidfNorm::None);
346 let fitted = transformer.fit(&counts).unwrap();
347 let idf = fitted.idf().unwrap();
348
349 assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
351 assert_abs_diff_eq!(idf[1], 2.0_f64.ln() + 1.0, epsilon = 1e-10);
353 }
354
355 #[test]
356 fn test_tfidf_no_smooth_idf() {
357 let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
358 let transformer = TfidfTransformer::<f64>::new()
359 .smooth_idf(false)
360 .norm(TfidfNorm::None);
361 let fitted = transformer.fit(&counts).unwrap();
362 let idf = fitted.idf().unwrap();
363
364 assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
366 assert_abs_diff_eq!(idf[1], 3.0_f64.ln() + 1.0, epsilon = 1e-10);
368 }
369
370 #[test]
371 fn test_tfidf_empty() {
372 let counts = Array2::<f64>::zeros((0, 3));
373 let transformer = TfidfTransformer::<f64>::new();
374 assert!(transformer.fit(&counts).is_err());
375 }
376
377 #[test]
378 fn test_tfidf_shape_mismatch() {
379 let train = array![[1.0_f64, 0.0], [0.0, 1.0]];
380 let fitted = TfidfTransformer::<f64>::new().fit(&train).unwrap();
381 let bad = array![[1.0_f64, 0.0, 0.0]];
382 assert!(fitted.transform(&bad).is_err());
383 }
384
385 #[test]
386 fn test_tfidf_f32() {
387 let counts = array![[1.0_f32, 0.0], [0.0, 1.0]];
388 let transformer = TfidfTransformer::<f32>::new();
389 let fitted = transformer.fit(&counts).unwrap();
390 let result = fitted.transform(&counts).unwrap();
391 assert_eq!(result.shape(), &[2, 2]);
392 }
393}