Skip to main content

ferrolearn_preprocess/
tfidf.rs

1//! TF-IDF transformer: weight a term-count matrix by inverse document frequency.
2//!
3//! Applies TF-IDF weighting to a term-count matrix produced by
4//! [`CountVectorizer`](crate::count_vectorizer::CountVectorizer).
5
6use ferrolearn_core::error::FerroError;
7use ndarray::{Array1, Array2};
8use num_traits::Float;
9
10// ---------------------------------------------------------------------------
11// TfidfNorm
12// ---------------------------------------------------------------------------
13
14/// Row-normalization mode for the TF-IDF transformer.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum TfidfNorm {
17    /// Normalize rows to unit L1 norm.
18    L1,
19    /// Normalize rows to unit L2 norm (default).
20    #[default]
21    L2,
22    /// No row normalization.
23    None,
24}
25
26// ---------------------------------------------------------------------------
27// TfidfTransformer (unfitted)
28// ---------------------------------------------------------------------------
29
30/// An unfitted TF-IDF transformer.
31///
32/// Fits IDF weights from a term-count matrix and transforms new count
33/// matrices into TF-IDF weighted matrices.
34///
35/// # Examples
36///
37/// ```
38/// use ferrolearn_preprocess::tfidf::{TfidfTransformer, TfidfNorm};
39/// use ndarray::array;
40///
41/// let counts = array![
42///     [3.0_f64, 0.0, 1.0],
43///     [2.0, 0.0, 0.0],
44///     [3.0, 0.0, 0.0],
45///     [4.0, 0.0, 0.0],
46///     [3.0, 2.0, 0.0],
47///     [3.0, 0.0, 2.0],
48/// ];
49/// let tfidf = TfidfTransformer::<f64>::new();
50/// let fitted = tfidf.fit(&counts).unwrap();
51/// let result = fitted.transform(&counts).unwrap();
52/// assert_eq!(result.shape(), counts.shape());
53/// ```
54#[derive(Debug, Clone)]
55pub struct TfidfTransformer<F> {
56    /// Row normalization mode.
57    pub norm: TfidfNorm,
58    /// Whether to use IDF weighting.
59    pub use_idf: bool,
60    /// Whether to smooth IDF: `ln((1+n)/(1+df)) + 1`.
61    pub smooth_idf: bool,
62    /// Whether to apply sublinear TF scaling: `1 + ln(tf)`.
63    pub sublinear_tf: bool,
64    _marker: std::marker::PhantomData<F>,
65}
66
67impl<F: Float + Send + Sync + 'static> TfidfTransformer<F> {
68    /// Create a new `TfidfTransformer` with default settings.
69    #[must_use]
70    pub fn new() -> Self {
71        Self {
72            norm: TfidfNorm::L2,
73            use_idf: true,
74            smooth_idf: true,
75            sublinear_tf: false,
76            _marker: std::marker::PhantomData,
77        }
78    }
79
80    /// Set the row normalization mode.
81    #[must_use]
82    pub fn norm(mut self, norm: TfidfNorm) -> Self {
83        self.norm = norm;
84        self
85    }
86
87    /// Set whether to use IDF weighting.
88    #[must_use]
89    pub fn use_idf(mut self, use_idf: bool) -> Self {
90        self.use_idf = use_idf;
91        self
92    }
93
94    /// Set whether to smooth IDF.
95    #[must_use]
96    pub fn smooth_idf(mut self, smooth: bool) -> Self {
97        self.smooth_idf = smooth;
98        self
99    }
100
101    /// Set whether to apply sublinear TF scaling.
102    #[must_use]
103    pub fn sublinear_tf(mut self, sublinear: bool) -> Self {
104        self.sublinear_tf = sublinear;
105        self
106    }
107
108    /// Fit the transformer by computing IDF from a term-count matrix.
109    ///
110    /// # Errors
111    ///
112    /// Returns [`FerroError::InsufficientSamples`] if the matrix has zero rows.
113    pub fn fit(&self, counts: &Array2<F>) -> Result<FittedTfidfTransformer<F>, FerroError> {
114        let n_docs = counts.nrows();
115        if n_docs == 0 {
116            return Err(FerroError::InsufficientSamples {
117                required: 1,
118                actual: 0,
119                context: "TfidfTransformer::fit".into(),
120            });
121        }
122
123        let n_features = counts.ncols();
124        let n_f = F::from(n_docs).unwrap();
125
126        let idf = if self.use_idf {
127            let mut idf_vec = Array1::zeros(n_features);
128            for j in 0..n_features {
129                // df = number of documents where feature j is non-zero
130                let df = counts.column(j).iter().filter(|&&v| v > F::zero()).count();
131                let df_f = F::from(df).unwrap();
132
133                if self.smooth_idf {
134                    // idf = ln((1 + n) / (1 + df)) + 1
135                    idf_vec[j] = ((F::one() + n_f) / (F::one() + df_f)).ln() + F::one();
136                } else {
137                    // idf = ln(n / df) + 1
138                    if df > 0 {
139                        idf_vec[j] = (n_f / df_f).ln() + F::one();
140                    } else {
141                        idf_vec[j] = F::one();
142                    }
143                }
144            }
145            Some(idf_vec)
146        } else {
147            None
148        };
149
150        Ok(FittedTfidfTransformer {
151            idf,
152            norm: self.norm,
153            sublinear_tf: self.sublinear_tf,
154        })
155    }
156}
157
158impl<F: Float + Send + Sync + 'static> Default for TfidfTransformer<F> {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164// ---------------------------------------------------------------------------
165// FittedTfidfTransformer
166// ---------------------------------------------------------------------------
167
168/// A fitted TF-IDF transformer holding learned IDF weights.
169///
170/// Created by calling [`TfidfTransformer::fit`].
171#[derive(Debug, Clone)]
172pub struct FittedTfidfTransformer<F> {
173    /// Per-feature IDF weights, if `use_idf` was `true`.
174    idf: Option<Array1<F>>,
175    /// Row normalization mode.
176    norm: TfidfNorm,
177    /// Whether to apply sublinear TF.
178    sublinear_tf: bool,
179}
180
181impl<F: Float + Send + Sync + 'static> FittedTfidfTransformer<F> {
182    /// Return the IDF weights, if computed.
183    #[must_use]
184    pub fn idf(&self) -> Option<&Array1<F>> {
185        self.idf.as_ref()
186    }
187
188    /// Transform a term-count matrix into a TF-IDF matrix.
189    ///
190    /// # Errors
191    ///
192    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
193    /// match the fitted vocabulary size.
194    /// Returns [`FerroError::InsufficientSamples`] if the matrix has zero rows.
195    pub fn transform(&self, counts: &Array2<F>) -> Result<Array2<F>, FerroError> {
196        if counts.nrows() == 0 {
197            return Err(FerroError::InsufficientSamples {
198                required: 1,
199                actual: 0,
200                context: "FittedTfidfTransformer::transform".into(),
201            });
202        }
203
204        if let Some(ref idf) = self.idf {
205            if counts.ncols() != idf.len() {
206                return Err(FerroError::ShapeMismatch {
207                    expected: vec![counts.nrows(), idf.len()],
208                    actual: vec![counts.nrows(), counts.ncols()],
209                    context: "FittedTfidfTransformer::transform".into(),
210                });
211            }
212        }
213
214        let mut result = counts.to_owned();
215
216        // Sublinear TF: replace tf with 1 + ln(tf) for tf > 0.
217        if self.sublinear_tf {
218            result.mapv_inplace(|v| if v > F::zero() { F::one() + v.ln() } else { v });
219        }
220
221        // Multiply by IDF.
222        if let Some(ref idf) = self.idf {
223            for mut row in result.rows_mut() {
224                for (j, v) in row.iter_mut().enumerate() {
225                    *v = *v * idf[j];
226                }
227            }
228        }
229
230        // Row normalization.
231        match self.norm {
232            TfidfNorm::L1 => {
233                for mut row in result.rows_mut() {
234                    let norm: F = row.iter().map(|v| v.abs()).fold(F::zero(), |a, b| a + b);
235                    if norm > F::zero() {
236                        for v in &mut row {
237                            *v = *v / norm;
238                        }
239                    }
240                }
241            }
242            TfidfNorm::L2 => {
243                for mut row in result.rows_mut() {
244                    let norm_sq: F = row.iter().map(|v| *v * *v).fold(F::zero(), |a, b| a + b);
245                    let norm = norm_sq.sqrt();
246                    if norm > F::zero() {
247                        for v in &mut row {
248                            *v = *v / norm;
249                        }
250                    }
251                }
252            }
253            TfidfNorm::None => {}
254        }
255
256        Ok(result)
257    }
258}
259
260// ---------------------------------------------------------------------------
261// Tests
262// ---------------------------------------------------------------------------
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use approx::assert_abs_diff_eq;
268    use ndarray::array;
269
270    #[test]
271    fn test_tfidf_basic() {
272        // 3 docs, 3 features
273        let counts = array![[1.0_f64, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 0.0, 0.0],];
274        let transformer = TfidfTransformer::<f64>::new();
275        let fitted = transformer.fit(&counts).unwrap();
276        let result = fitted.transform(&counts).unwrap();
277        assert_eq!(result.shape(), &[3, 3]);
278
279        // Each row should have L2 norm ≈ 1
280        for i in 0..3 {
281            let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
282            assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
283        }
284    }
285
286    #[test]
287    fn test_tfidf_no_idf() {
288        let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
289        let transformer = TfidfTransformer::<f64>::new().use_idf(false);
290        let fitted = transformer.fit(&counts).unwrap();
291        let result = fitted.transform(&counts).unwrap();
292        // Should just normalize rows (L2)
293        for i in 0..2 {
294            let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
295            assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
296        }
297    }
298
299    #[test]
300    fn test_tfidf_l1_norm() {
301        let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
302        let transformer = TfidfTransformer::<f64>::new()
303            .use_idf(false)
304            .norm(TfidfNorm::L1);
305        let fitted = transformer.fit(&counts).unwrap();
306        let result = fitted.transform(&counts).unwrap();
307        for i in 0..2 {
308            let row_l1: f64 = result.row(i).iter().map(|v| v.abs()).sum();
309            assert_abs_diff_eq!(row_l1, 1.0, epsilon = 1e-10);
310        }
311    }
312
313    #[test]
314    fn test_tfidf_no_norm() {
315        let counts = array![[1.0_f64, 0.0], [1.0, 1.0]];
316        let transformer = TfidfTransformer::<f64>::new()
317            .use_idf(false)
318            .norm(TfidfNorm::None);
319        let fitted = transformer.fit(&counts).unwrap();
320        let result = fitted.transform(&counts).unwrap();
321        // Should be unchanged
322        for (a, b) in counts.iter().zip(result.iter()) {
323            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
324        }
325    }
326
327    #[test]
328    fn test_tfidf_sublinear_tf() {
329        let counts = array![[4.0_f64, 1.0]];
330        let transformer = TfidfTransformer::<f64>::new()
331            .use_idf(false)
332            .sublinear_tf(true)
333            .norm(TfidfNorm::None);
334        let fitted = transformer.fit(&counts).unwrap();
335        let result = fitted.transform(&counts).unwrap();
336        // tf=4 -> 1+ln(4), tf=1 -> 1+ln(1) = 1
337        assert_abs_diff_eq!(result[[0, 0]], 1.0 + 4.0_f64.ln(), epsilon = 1e-10);
338        assert_abs_diff_eq!(result[[0, 1]], 1.0, epsilon = 1e-10);
339    }
340
341    #[test]
342    fn test_tfidf_smooth_idf() {
343        // 3 docs, feature 0 in all docs, feature 1 in 1 doc
344        let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
345        let transformer = TfidfTransformer::<f64>::new().norm(TfidfNorm::None);
346        let fitted = transformer.fit(&counts).unwrap();
347        let idf = fitted.idf().unwrap();
348
349        // idf[0]: ln((1+3)/(1+3)) + 1 = ln(1) + 1 = 1.0
350        assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
351        // idf[1]: ln((1+3)/(1+1)) + 1 = ln(2) + 1
352        assert_abs_diff_eq!(idf[1], 2.0_f64.ln() + 1.0, epsilon = 1e-10);
353    }
354
355    #[test]
356    fn test_tfidf_no_smooth_idf() {
357        let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
358        let transformer = TfidfTransformer::<f64>::new()
359            .smooth_idf(false)
360            .norm(TfidfNorm::None);
361        let fitted = transformer.fit(&counts).unwrap();
362        let idf = fitted.idf().unwrap();
363
364        // idf[0]: ln(3/3) + 1 = 1.0
365        assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
366        // idf[1]: ln(3/1) + 1
367        assert_abs_diff_eq!(idf[1], 3.0_f64.ln() + 1.0, epsilon = 1e-10);
368    }
369
370    #[test]
371    fn test_tfidf_empty() {
372        let counts = Array2::<f64>::zeros((0, 3));
373        let transformer = TfidfTransformer::<f64>::new();
374        assert!(transformer.fit(&counts).is_err());
375    }
376
377    #[test]
378    fn test_tfidf_shape_mismatch() {
379        let train = array![[1.0_f64, 0.0], [0.0, 1.0]];
380        let fitted = TfidfTransformer::<f64>::new().fit(&train).unwrap();
381        let bad = array![[1.0_f64, 0.0, 0.0]];
382        assert!(fitted.transform(&bad).is_err());
383    }
384
385    #[test]
386    fn test_tfidf_f32() {
387        let counts = array![[1.0_f32, 0.0], [0.0, 1.0]];
388        let transformer = TfidfTransformer::<f32>::new();
389        let fitted = transformer.fit(&counts).unwrap();
390        let result = fitted.transform(&counts).unwrap();
391        assert_eq!(result.shape(), &[2, 2]);
392    }
393}