use crate::PreprocessingError;
use linfa::ParamGuard;
use regex::Regex;
use std::cell::{Ref, RefCell};
use std::collections::HashSet;
#[derive(Clone, Debug)]
pub struct CountVectorizerValidParams {
convert_to_lowercase: bool,
split_regex_expr: String,
split_regex: RefCell<Option<Regex>>,
n_gram_range: (usize, usize),
normalize: bool,
document_frequency: (f32, f32),
stopwords: Option<HashSet<String>>,
}
impl CountVectorizerValidParams {
pub fn convert_to_lowercase(&self) -> bool {
self.convert_to_lowercase
}
pub fn split_regex(&self) -> Ref<'_, Regex> {
Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap())
}
pub fn n_gram_range(&self) -> (usize, usize) {
self.n_gram_range
}
pub fn normalize(&self) -> bool {
self.normalize
}
pub fn document_frequency(&self) -> (f32, f32) {
self.document_frequency
}
pub fn stopwords(&self) -> &Option<HashSet<String>> {
&self.stopwords
}
}
#[derive(Clone, Debug)]
pub struct CountVectorizerParams(CountVectorizerValidParams);
impl std::default::Default for CountVectorizerParams {
fn default() -> Self {
Self(CountVectorizerValidParams {
convert_to_lowercase: true,
split_regex_expr: r"\b\w\w+\b".to_string(),
split_regex: RefCell::new(None),
n_gram_range: (1, 1),
normalize: true,
document_frequency: (0., 1.),
stopwords: None,
})
}
}
impl CountVectorizerParams {
pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
self.0.convert_to_lowercase = convert_to_lowercase;
self
}
pub fn split_regex(mut self, regex_str: &str) -> Self {
self.0.split_regex_expr = regex_str.to_string();
self
}
pub fn n_gram_range(mut self, min_n: usize, max_n: usize) -> Self {
self.0.n_gram_range = (min_n, max_n);
self
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.0.normalize = normalize;
self
}
pub fn document_frequency(mut self, min_freq: f32, max_freq: f32) -> Self {
self.0.document_frequency = (min_freq, max_freq);
self
}
pub fn stopwords<T: ToString>(mut self, stopwords: &[T]) -> Self {
self.0.stopwords = Some(stopwords.iter().map(|t| t.to_string()).collect());
self
}
}
impl ParamGuard for CountVectorizerParams {
type Checked = CountVectorizerValidParams;
type Error = PreprocessingError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
let (n_gram_min, n_gram_max) = self.0.n_gram_range;
let (min_freq, max_freq) = self.0.document_frequency;
if n_gram_min == 0 || n_gram_max == 0 {
Err(PreprocessingError::InvalidNGramBoundaries(
n_gram_min, n_gram_max,
))
} else if n_gram_min > n_gram_max {
Err(PreprocessingError::FlippedNGramBoundaries(
n_gram_min, n_gram_max,
))
} else if min_freq < 0. || max_freq < 0. {
Err(PreprocessingError::InvalidDocumentFrequencies(
min_freq, max_freq,
))
} else if max_freq < min_freq {
Err(PreprocessingError::FlippedDocumentFrequencies(
min_freq, max_freq,
))
} else {
*self.0.split_regex.borrow_mut() = Some(Regex::new(&self.0.split_regex_expr)?);
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}