tfidf/
idf.rs

1// Copyright 2016 rust-tfidf Developers
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::borrow::Borrow;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12use prelude::{ExpandableDocument, Idf, NaiveDocument, ProcessedDocument, SmoothingFactor};
13
14/// Unary weighting scheme for IDF. If the corpus contains a document with the
15/// term, returns 1, otherwise returns 0.
16#[derive(Copy, Clone)]
17pub struct UnaryIdf;
18
19impl<T> Idf<T> for UnaryIdf
20where
21  T: NaiveDocument,
22{
23  #[inline]
24  fn idf<'a, I, K>(term: K, docs: I) -> f64
25  where
26    I: Iterator<Item = &'a T>,
27    K: Borrow<T::Term>,
28    T: 'a,
29  {
30    docs.fold(0f64, |_, d| {
31      if d.term_exists(term.borrow()) {
32        1f64
33      } else {
34        0f64
35      }
36    })
37  }
38}
39
40/// Inverse frequency weighting scheme for IDF with a smoothing factor. Used
41/// internally as a marker trait.
42pub trait InverseFrequencySmoothedIdfStrategy: SmoothingFactor {}
43
44impl<S, T> Idf<T> for S
45where
46  S: InverseFrequencySmoothedIdfStrategy,
47  T: NaiveDocument,
48{
49  #[inline]
50  fn idf<'a, I, K>(term: K, docs: I) -> f64
51  where
52    I: Iterator<Item = &'a T>,
53    K: Borrow<T::Term>,
54    T: 'a,
55  {
56    let (num_docs, ttl_docs) = docs.fold((0f64, 0f64), |(n, t), d| {
57      (
58        if d.term_exists(term.borrow()) {
59          n + 1f64
60        } else {
61          n
62        },
63        t + 1f64,
64      )
65    });
66    (S::factor() + (ttl_docs as f64 / num_docs as f64)).ln()
67  }
68}
69
70/// Inverse frequency weighting scheme for IDF. Computes `log (N / nt)` where `N`
71/// is the number of documents, and `nt` is the number of times a term appears in
72/// the corpus of documents.
73#[derive(Copy, Clone)]
74pub struct InverseFrequencyIdf;
75
76impl SmoothingFactor for InverseFrequencyIdf {
77  fn factor() -> f64 {
78    0f64
79  }
80}
81
82impl InverseFrequencySmoothedIdfStrategy for InverseFrequencyIdf {}
83
84/// Inverse frequency weighting scheme for IDF. Computes `log (1 + (N / nt))`.
85#[derive(Copy, Clone)]
86pub struct InverseFrequencySmoothIdf;
87
88impl SmoothingFactor for InverseFrequencySmoothIdf {
89  fn factor() -> f64 {
90    1f64
91  }
92}
93
94impl InverseFrequencySmoothedIdfStrategy for InverseFrequencySmoothIdf {}
95
96/// Inverse frequency weighting scheme for IDF. Compute `log (1 + (max nt / nt))`
97/// where `nt` is the number of times a term appears in the corpus, and `max nt`
98/// returns the most number of times any term appears in the corpus.
99#[derive(Copy, Clone)]
100pub struct InverseFrequencyMaxIdf;
101
102impl<'l, T, E> Idf<T> for InverseFrequencyMaxIdf
103where
104  T: ProcessedDocument<Term = E> + ExpandableDocument<'l>,
105  E: Hash + Eq + 'l,
106{
107  #[inline]
108  fn idf<'a, I, K>(term: K, docs: I) -> f64
109  where
110    I: Iterator<Item = &'a T>,
111    K: Borrow<T::Term>,
112    T: 'a,
113  {
114    let mut counts: HashMap<&T::Term, usize> = HashMap::new();
115    let num_docs = docs.fold(0, |n, d| {
116      for t in d.terms() {
117        counts.insert(t, 0);
118      }
119
120      if d.term_exists(term.borrow()) {
121        n + 1
122      } else {
123        n
124      }
125    });
126    let max = *counts.values().max().unwrap_or(&1);
127
128    (1f64 + (max as f64 / num_docs as f64)).ln()
129  }
130}
131
132#[test]
133fn idf_wiki_example_tests() {
134  let mut docs = Vec::new();
135
136  docs.push(vec![("this", 1), ("is", 1), ("a", 2), ("sample", 1)]);
137  docs.push(vec![("this", 1), ("is", 1), ("another", 2), ("example", 3)]);
138
139  assert_eq!(UnaryIdf::idf("this", docs.iter()), 1f64);
140  assert_eq!(InverseFrequencyIdf::idf("this", docs.iter()), 0f64);
141}
142
143#[test]
144fn idf_wiki_example_tests_hashmap() {
145  let mut docs: Vec<std::collections::HashMap<&'static str, usize>> = Vec::new();
146
147  docs.push(
148    vec![("this", 1), ("is", 1), ("a", 2), ("sample", 1)]
149      .into_iter()
150      .collect(),
151  );
152  docs.push(
153    vec![("this", 1), ("is", 1), ("another", 2), ("example", 3)]
154      .into_iter()
155      .collect(),
156  );
157
158  assert_eq!(UnaryIdf::idf("this", docs.iter()), 1f64);
159  assert_eq!(InverseFrequencyIdf::idf("this", docs.iter()), 0f64);
160}
161
162#[test]
163fn idf_wiki_example_tests_btreemap() {
164  let mut docs: Vec<std::collections::BTreeMap<&'static str, usize>> = Vec::new();
165
166  docs.push(
167    vec![("this", 1), ("is", 1), ("a", 2), ("sample", 1)]
168      .into_iter()
169      .collect(),
170  );
171  docs.push(
172    vec![("this", 1), ("is", 1), ("another", 2), ("example", 3)]
173      .into_iter()
174      .collect(),
175  );
176
177  assert_eq!(UnaryIdf::idf("this", docs.iter()), 1f64);
178  assert_eq!(InverseFrequencyIdf::idf("this", docs.iter()), 0f64);
179}