use crate::countgrams::{CountVectorizer, CountVectorizerParams};
use crate::error::Result;
use encoding::types::EncodingRef;
use encoding::DecoderTrap;
use ndarray::{Array1, ArrayBase, Data, Ix1};
use sprs::CsMat;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum TfIdfMethod {
Smooth,
NonSmooth,
Textbook,
}
impl TfIdfMethod {
pub fn compute_idf(&self, n: usize, df: usize) -> f64 {
match self {
TfIdfMethod::Smooth => ((1. + n as f64) / (1. + df as f64)).ln() + 1.,
TfIdfMethod::NonSmooth => (n as f64 / df as f64).ln() + 1.,
TfIdfMethod::Textbook => (n as f64 / (1. + df as f64)).ln(),
}
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug)]
pub struct TfIdfVectorizer {
count_vectorizer: CountVectorizerParams,
method: TfIdfMethod,
}
impl std::default::Default for TfIdfVectorizer {
fn default() -> Self {
Self {
count_vectorizer: CountVectorizerParams::default(),
method: TfIdfMethod::Smooth,
}
}
}
impl TfIdfVectorizer {
pub fn convert_to_lowercase(self, convert_to_lowercase: bool) -> Self {
Self {
count_vectorizer: self
.count_vectorizer
.convert_to_lowercase(convert_to_lowercase),
method: self.method,
}
}
pub fn split_regex(self, regex_str: &str) -> Self {
Self {
count_vectorizer: self.count_vectorizer.split_regex(regex_str),
method: self.method,
}
}
pub fn n_gram_range(self, min_n: usize, max_n: usize) -> Self {
Self {
count_vectorizer: self.count_vectorizer.n_gram_range(min_n, max_n),
method: self.method,
}
}
pub fn normalize(self, normalize: bool) -> Self {
Self {
count_vectorizer: self.count_vectorizer.normalize(normalize),
method: self.method,
}
}
pub fn document_frequency(self, min_freq: f32, max_freq: f32) -> Self {
Self {
count_vectorizer: self.count_vectorizer.document_frequency(min_freq, max_freq),
method: self.method,
}
}
pub fn stopwords<T: ToString>(self, stopwords: &[T]) -> Self {
Self {
count_vectorizer: self.count_vectorizer.stopwords(stopwords),
method: self.method,
}
}
pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
&self,
x: &ArrayBase<D, Ix1>,
) -> Result<FittedTfIdfVectorizer> {
let fitted_vectorizer = self.count_vectorizer.fit(x)?;
Ok(FittedTfIdfVectorizer {
fitted_vectorizer,
method: self.method.clone(),
})
}
pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<FittedTfIdfVectorizer> {
let fitted_vectorizer = self.count_vectorizer.fit_vocabulary(words)?;
Ok(FittedTfIdfVectorizer {
fitted_vectorizer,
method: self.method.clone(),
})
}
pub fn fit_files<P: AsRef<std::path::Path>>(
&self,
input: &[P],
encoding: EncodingRef,
trap: DecoderTrap,
) -> Result<FittedTfIdfVectorizer> {
let fitted_vectorizer = self.count_vectorizer.fit_files(input, encoding, trap)?;
Ok(FittedTfIdfVectorizer {
fitted_vectorizer,
method: self.method.clone(),
})
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug)]
pub struct FittedTfIdfVectorizer {
fitted_vectorizer: CountVectorizer,
method: TfIdfMethod,
}
impl FittedTfIdfVectorizer {
pub fn nentries(&self) -> usize {
self.fitted_vectorizer.vocabulary.len()
}
pub fn vocabulary(&self) -> &Vec<String> {
self.fitted_vectorizer.vocabulary()
}
pub fn method(&self) -> &TfIdfMethod {
&self.method
}
pub fn transform<T: ToString, D: Data<Elem = T>>(&self, x: &ArrayBase<D, Ix1>) -> CsMat<f64> {
let (term_freqs, doc_freqs) = self.fitted_vectorizer.get_term_and_document_frequencies(x);
self.apply_tf_idf(term_freqs, doc_freqs)
}
pub fn transform_files<P: AsRef<std::path::Path>>(
&self,
input: &[P],
encoding: EncodingRef,
trap: DecoderTrap,
) -> CsMat<f64> {
let (term_freqs, doc_freqs) = self
.fitted_vectorizer
.get_term_and_document_frequencies_files(input, encoding, trap);
self.apply_tf_idf(term_freqs, doc_freqs)
}
fn apply_tf_idf(&self, term_freqs: CsMat<usize>, doc_freqs: Array1<usize>) -> CsMat<f64> {
let mut term_freqs: CsMat<f64> = term_freqs.map(|x| *x as f64);
let inv_doc_freqs =
doc_freqs.mapv(|doc_freq| self.method.compute_idf(term_freqs.rows(), doc_freq));
for mut row_vec in term_freqs.outer_iterator_mut() {
for (col_i, val) in row_vec.iter_mut() {
*val *= inv_doc_freqs[col_i];
}
}
term_freqs
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::column_for_word;
use approx::assert_abs_diff_eq;
use ndarray::array;
use std::fs::File;
use std::io::Write;
macro_rules! assert_tf_idfs_for_word {
($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
$ (
assert_abs_diff_eq!(column_for_word!($voc, $transf, $word), $counts, epsilon=1e-3);
)*
}
}
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<TfIdfMethod>();
}
#[test]
fn test_tf_idf() {
let texts = array![
"one and two and three",
"three and four and five",
"seven and eight",
"maybe ten and eleven",
"avoid singletons: one two four five seven eight ten eleven and an and"
];
let vectorizer = TfIdfVectorizer::default().fit(&texts).unwrap();
let vocabulary = vectorizer.vocabulary();
let transformed = vectorizer.transform(&texts).to_dense();
assert_eq!(transformed.dim(), (texts.len(), vocabulary.len()));
assert_tf_idfs_for_word!(
vocabulary,
transformed,
("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
);
}
#[test]
fn test_tf_idf_files() {
let text_files = create_test_files();
let vectorizer = TfIdfVectorizer::default()
.fit_files(
&text_files,
encoding::all::UTF_8,
encoding::DecoderTrap::Strict,
)
.unwrap();
let vocabulary = vectorizer.vocabulary();
let transformed = vectorizer
.transform_files(
&text_files,
encoding::all::UTF_8,
encoding::DecoderTrap::Strict,
)
.to_dense();
assert_eq!(transformed.dim(), (text_files.len(), vocabulary.len()));
assert_tf_idfs_for_word!(
vocabulary,
transformed,
("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
);
delete_test_files(&text_files)
}
fn create_test_files() -> Vec<&'static str> {
let file_names = vec![
"./tf_idf_vectorization_test_file_1",
"./tf_idf_vectorization_test_file_2",
"./tf_idf_vectorization_test_file_3",
"./tf_idf_vectorization_test_file_4",
"./tf_idf_vectorization_test_file_5",
];
let contents = vec![
"one and two and three",
"three and four and five",
"seven and eight",
"maybe ten and eleven",
"avoid singletons: one two four five seven eight ten eleven and an and",
];
for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
let mut file = File::create(f_name).unwrap();
file.write_all(f_content.as_bytes()).unwrap();
}
file_names
}
fn delete_test_files(file_names: &[&'static str]) {
for f_name in file_names.iter() {
std::fs::remove_file(f_name).unwrap();
}
}
}