Skip to main content

ferrolearn_preprocess/
tfidf.rs

1//! TF-IDF transformer: weight a term-count matrix by inverse document frequency.
2//!
3//! Applies TF-IDF weighting to a term-count matrix produced by
4//! [`CountVectorizer`](crate::count_vectorizer::CountVectorizer).
5
6use ferrolearn_core::error::FerroError;
7use ndarray::{Array1, Array2};
8use num_traits::Float;
9
10// ---------------------------------------------------------------------------
11// TfidfNorm
12// ---------------------------------------------------------------------------
13
14/// Row-normalization mode for the TF-IDF transformer.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum TfidfNorm {
17    /// Normalize rows to unit L1 norm.
18    L1,
19    /// Normalize rows to unit L2 norm (default).
20    #[default]
21    L2,
22    /// No row normalization.
23    None,
24}
25
26// ---------------------------------------------------------------------------
27// TfidfTransformer (unfitted)
28// ---------------------------------------------------------------------------
29
30/// An unfitted TF-IDF transformer.
31///
32/// Fits IDF weights from a term-count matrix and transforms new count
33/// matrices into TF-IDF weighted matrices.
34///
35/// # Examples
36///
37/// ```
38/// use ferrolearn_preprocess::tfidf::{TfidfTransformer, TfidfNorm};
39/// use ndarray::array;
40///
41/// let counts = array![
42///     [3.0_f64, 0.0, 1.0],
43///     [2.0, 0.0, 0.0],
44///     [3.0, 0.0, 0.0],
45///     [4.0, 0.0, 0.0],
46///     [3.0, 2.0, 0.0],
47///     [3.0, 0.0, 2.0],
48/// ];
49/// let tfidf = TfidfTransformer::<f64>::new();
50/// let fitted = tfidf.fit(&counts).unwrap();
51/// let result = fitted.transform(&counts).unwrap();
52/// assert_eq!(result.shape(), counts.shape());
53/// ```
54#[derive(Debug, Clone)]
55pub struct TfidfTransformer<F> {
56    /// Row normalization mode.
57    pub norm: TfidfNorm,
58    /// Whether to use IDF weighting.
59    pub use_idf: bool,
60    /// Whether to smooth IDF: `ln((1+n)/(1+df)) + 1`.
61    pub smooth_idf: bool,
62    /// Whether to apply sublinear TF scaling: `1 + ln(tf)`.
63    pub sublinear_tf: bool,
64    _marker: std::marker::PhantomData<F>,
65}
66
67impl<F: Float + Send + Sync + 'static> TfidfTransformer<F> {
68    /// Create a new `TfidfTransformer` with default settings.
69    #[must_use]
70    pub fn new() -> Self {
71        Self {
72            norm: TfidfNorm::L2,
73            use_idf: true,
74            smooth_idf: true,
75            sublinear_tf: false,
76            _marker: std::marker::PhantomData,
77        }
78    }
79
80    /// Set the row normalization mode.
81    #[must_use]
82    pub fn norm(mut self, norm: TfidfNorm) -> Self {
83        self.norm = norm;
84        self
85    }
86
87    /// Set whether to use IDF weighting.
88    #[must_use]
89    pub fn use_idf(mut self, use_idf: bool) -> Self {
90        self.use_idf = use_idf;
91        self
92    }
93
94    /// Set whether to smooth IDF.
95    #[must_use]
96    pub fn smooth_idf(mut self, smooth: bool) -> Self {
97        self.smooth_idf = smooth;
98        self
99    }
100
101    /// Set whether to apply sublinear TF scaling.
102    #[must_use]
103    pub fn sublinear_tf(mut self, sublinear: bool) -> Self {
104        self.sublinear_tf = sublinear;
105        self
106    }
107
108    /// Fit the transformer by computing IDF from a term-count matrix.
109    ///
110    /// # Errors
111    ///
112    /// Returns [`FerroError::InsufficientSamples`] if the matrix has zero rows.
113    pub fn fit(&self, counts: &Array2<F>) -> Result<FittedTfidfTransformer<F>, FerroError> {
114        let n_docs = counts.nrows();
115        if n_docs == 0 {
116            return Err(FerroError::InsufficientSamples {
117                required: 1,
118                actual: 0,
119                context: "TfidfTransformer::fit".into(),
120            });
121        }
122
123        let n_features = counts.ncols();
124        let n_f = F::from(n_docs).unwrap();
125
126        let idf = if self.use_idf {
127            let mut idf_vec = Array1::zeros(n_features);
128            for j in 0..n_features {
129                // df = number of documents where feature j is non-zero
130                let df = counts
131                    .column(j)
132                    .iter()
133                    .filter(|&&v| v > F::zero())
134                    .count();
135                let df_f = F::from(df).unwrap();
136
137                if self.smooth_idf {
138                    // idf = ln((1 + n) / (1 + df)) + 1
139                    idf_vec[j] =
140                        ((F::one() + n_f) / (F::one() + df_f)).ln() + F::one();
141                } else {
142                    // idf = ln(n / df) + 1
143                    if df > 0 {
144                        idf_vec[j] = (n_f / df_f).ln() + F::one();
145                    } else {
146                        idf_vec[j] = F::one();
147                    }
148                }
149            }
150            Some(idf_vec)
151        } else {
152            None
153        };
154
155        Ok(FittedTfidfTransformer {
156            idf,
157            norm: self.norm,
158            sublinear_tf: self.sublinear_tf,
159        })
160    }
161}
162
163impl<F: Float + Send + Sync + 'static> Default for TfidfTransformer<F> {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169// ---------------------------------------------------------------------------
170// FittedTfidfTransformer
171// ---------------------------------------------------------------------------
172
173/// A fitted TF-IDF transformer holding learned IDF weights.
174///
175/// Created by calling [`TfidfTransformer::fit`].
176#[derive(Debug, Clone)]
177pub struct FittedTfidfTransformer<F> {
178    /// Per-feature IDF weights, if `use_idf` was `true`.
179    idf: Option<Array1<F>>,
180    /// Row normalization mode.
181    norm: TfidfNorm,
182    /// Whether to apply sublinear TF.
183    sublinear_tf: bool,
184}
185
186impl<F: Float + Send + Sync + 'static> FittedTfidfTransformer<F> {
187    /// Return the IDF weights, if computed.
188    #[must_use]
189    pub fn idf(&self) -> Option<&Array1<F>> {
190        self.idf.as_ref()
191    }
192
193    /// Transform a term-count matrix into a TF-IDF matrix.
194    ///
195    /// # Errors
196    ///
197    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
198    /// match the fitted vocabulary size.
199    /// Returns [`FerroError::InsufficientSamples`] if the matrix has zero rows.
200    pub fn transform(&self, counts: &Array2<F>) -> Result<Array2<F>, FerroError> {
201        if counts.nrows() == 0 {
202            return Err(FerroError::InsufficientSamples {
203                required: 1,
204                actual: 0,
205                context: "FittedTfidfTransformer::transform".into(),
206            });
207        }
208
209        if let Some(ref idf) = self.idf {
210            if counts.ncols() != idf.len() {
211                return Err(FerroError::ShapeMismatch {
212                    expected: vec![counts.nrows(), idf.len()],
213                    actual: vec![counts.nrows(), counts.ncols()],
214                    context: "FittedTfidfTransformer::transform".into(),
215                });
216            }
217        }
218
219        let mut result = counts.to_owned();
220
221        // Sublinear TF: replace tf with 1 + ln(tf) for tf > 0.
222        if self.sublinear_tf {
223            result.mapv_inplace(|v| {
224                if v > F::zero() {
225                    F::one() + v.ln()
226                } else {
227                    v
228                }
229            });
230        }
231
232        // Multiply by IDF.
233        if let Some(ref idf) = self.idf {
234            for mut row in result.rows_mut() {
235                for (j, v) in row.iter_mut().enumerate() {
236                    *v = *v * idf[j];
237                }
238            }
239        }
240
241        // Row normalization.
242        match self.norm {
243            TfidfNorm::L1 => {
244                for mut row in result.rows_mut() {
245                    let norm: F = row.iter().map(|v| v.abs()).fold(F::zero(), |a, b| a + b);
246                    if norm > F::zero() {
247                        for v in row.iter_mut() {
248                            *v = *v / norm;
249                        }
250                    }
251                }
252            }
253            TfidfNorm::L2 => {
254                for mut row in result.rows_mut() {
255                    let norm_sq: F = row.iter().map(|v| *v * *v).fold(F::zero(), |a, b| a + b);
256                    let norm = norm_sq.sqrt();
257                    if norm > F::zero() {
258                        for v in row.iter_mut() {
259                            *v = *v / norm;
260                        }
261                    }
262                }
263            }
264            TfidfNorm::None => {}
265        }
266
267        Ok(result)
268    }
269}
270
271// ---------------------------------------------------------------------------
272// Tests
273// ---------------------------------------------------------------------------
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use approx::assert_abs_diff_eq;
279    use ndarray::array;
280
281    #[test]
282    fn test_tfidf_basic() {
283        // 3 docs, 3 features
284        let counts = array![
285            [1.0_f64, 1.0, 0.0],
286            [1.0, 0.0, 1.0],
287            [1.0, 0.0, 0.0],
288        ];
289        let transformer = TfidfTransformer::<f64>::new();
290        let fitted = transformer.fit(&counts).unwrap();
291        let result = fitted.transform(&counts).unwrap();
292        assert_eq!(result.shape(), &[3, 3]);
293
294        // Each row should have L2 norm ≈ 1
295        for i in 0..3 {
296            let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
297            assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
298        }
299    }
300
301    #[test]
302    fn test_tfidf_no_idf() {
303        let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
304        let transformer = TfidfTransformer::<f64>::new().use_idf(false);
305        let fitted = transformer.fit(&counts).unwrap();
306        let result = fitted.transform(&counts).unwrap();
307        // Should just normalize rows (L2)
308        for i in 0..2 {
309            let row_norm: f64 = result.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
310            assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
311        }
312    }
313
314    #[test]
315    fn test_tfidf_l1_norm() {
316        let counts = array![[3.0_f64, 1.0], [0.0, 2.0]];
317        let transformer = TfidfTransformer::<f64>::new()
318            .use_idf(false)
319            .norm(TfidfNorm::L1);
320        let fitted = transformer.fit(&counts).unwrap();
321        let result = fitted.transform(&counts).unwrap();
322        for i in 0..2 {
323            let row_l1: f64 = result.row(i).iter().map(|v| v.abs()).sum();
324            assert_abs_diff_eq!(row_l1, 1.0, epsilon = 1e-10);
325        }
326    }
327
328    #[test]
329    fn test_tfidf_no_norm() {
330        let counts = array![[1.0_f64, 0.0], [1.0, 1.0]];
331        let transformer = TfidfTransformer::<f64>::new()
332            .use_idf(false)
333            .norm(TfidfNorm::None);
334        let fitted = transformer.fit(&counts).unwrap();
335        let result = fitted.transform(&counts).unwrap();
336        // Should be unchanged
337        for (a, b) in counts.iter().zip(result.iter()) {
338            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
339        }
340    }
341
342    #[test]
343    fn test_tfidf_sublinear_tf() {
344        let counts = array![[4.0_f64, 1.0]];
345        let transformer = TfidfTransformer::<f64>::new()
346            .use_idf(false)
347            .sublinear_tf(true)
348            .norm(TfidfNorm::None);
349        let fitted = transformer.fit(&counts).unwrap();
350        let result = fitted.transform(&counts).unwrap();
351        // tf=4 -> 1+ln(4), tf=1 -> 1+ln(1) = 1
352        assert_abs_diff_eq!(result[[0, 0]], 1.0 + 4.0_f64.ln(), epsilon = 1e-10);
353        assert_abs_diff_eq!(result[[0, 1]], 1.0, epsilon = 1e-10);
354    }
355
356    #[test]
357    fn test_tfidf_smooth_idf() {
358        // 3 docs, feature 0 in all docs, feature 1 in 1 doc
359        let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
360        let transformer = TfidfTransformer::<f64>::new().norm(TfidfNorm::None);
361        let fitted = transformer.fit(&counts).unwrap();
362        let idf = fitted.idf().unwrap();
363
364        // idf[0]: ln((1+3)/(1+3)) + 1 = ln(1) + 1 = 1.0
365        assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
366        // idf[1]: ln((1+3)/(1+1)) + 1 = ln(2) + 1
367        assert_abs_diff_eq!(idf[1], 2.0_f64.ln() + 1.0, epsilon = 1e-10);
368    }
369
370    #[test]
371    fn test_tfidf_no_smooth_idf() {
372        let counts = array![[1.0_f64, 1.0], [1.0, 0.0], [1.0, 0.0]];
373        let transformer = TfidfTransformer::<f64>::new()
374            .smooth_idf(false)
375            .norm(TfidfNorm::None);
376        let fitted = transformer.fit(&counts).unwrap();
377        let idf = fitted.idf().unwrap();
378
379        // idf[0]: ln(3/3) + 1 = 1.0
380        assert_abs_diff_eq!(idf[0], 1.0, epsilon = 1e-10);
381        // idf[1]: ln(3/1) + 1
382        assert_abs_diff_eq!(idf[1], 3.0_f64.ln() + 1.0, epsilon = 1e-10);
383    }
384
385    #[test]
386    fn test_tfidf_empty() {
387        let counts = Array2::<f64>::zeros((0, 3));
388        let transformer = TfidfTransformer::<f64>::new();
389        assert!(transformer.fit(&counts).is_err());
390    }
391
392    #[test]
393    fn test_tfidf_shape_mismatch() {
394        let train = array![[1.0_f64, 0.0], [0.0, 1.0]];
395        let fitted = TfidfTransformer::<f64>::new().fit(&train).unwrap();
396        let bad = array![[1.0_f64, 0.0, 0.0]];
397        assert!(fitted.transform(&bad).is_err());
398    }
399
400    #[test]
401    fn test_tfidf_f32() {
402        let counts = array![[1.0_f32, 0.0], [0.0, 1.0]];
403        let transformer = TfidfTransformer::<f32>::new();
404        let fitted = transformer.fit(&counts).unwrap();
405        let result = fitted.transform(&counts).unwrap();
406        assert_eq!(result.shape(), &[2, 2]);
407    }
408}