use crate::PreprocessingError;
use linfa::ParamGuard;
use regex::Regex;
use std::cell::{Ref, RefCell};
use std::collections::HashSet;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use super::{Tokenizer, Tokenizerfp};
#[derive(Clone, Debug)]
#[cfg(not(feature = "serde"))]
struct SerdeRegex(Regex);
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(crate = "serde_crate")]
#[cfg(feature = "serde")]
struct SerdeRegex(serde_regex::Serde<Regex>);
#[cfg(not(feature = "serde"))]
impl SerdeRegex {
fn new(re: &str) -> Result<Self, regex::Error> {
Ok(Self(Regex::new(re)?))
}
fn as_re(&self) -> &Regex {
&self.0
}
}
#[cfg(feature = "serde")]
impl SerdeRegex {
fn new(re: &str) -> Result<Self, regex::Error> {
Ok(Self(serde_regex::Serde(Regex::new(re)?)))
}
fn as_re(&self) -> &Regex {
use std::ops::Deref;
&self.0.deref()
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug)]
pub struct CountVectorizerValidParams {
convert_to_lowercase: bool,
split_regex_expr: String,
split_regex: RefCell<Option<SerdeRegex>>,
n_gram_range: (usize, usize),
normalize: bool,
document_frequency: (f32, f32),
stopwords: Option<HashSet<String>>,
max_features: Option<usize>,
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) tokenizer_function: Option<Tokenizerfp>,
pub(crate) tokenizer_deserialization_guard: bool,
}
impl CountVectorizerValidParams {
pub fn tokenizer_function(&self) -> Option<Tokenizerfp> {
self.tokenizer_function
}
pub fn max_features(&self) -> Option<usize> {
self.max_features
}
pub fn convert_to_lowercase(&self) -> bool {
self.convert_to_lowercase
}
pub fn split_regex(&self) -> Ref<'_, Regex> {
Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
}
pub fn n_gram_range(&self) -> (usize, usize) {
self.n_gram_range
}
pub fn normalize(&self) -> bool {
self.normalize
}
pub fn document_frequency(&self) -> (f32, f32) {
self.document_frequency
}
pub fn stopwords(&self) -> &Option<HashSet<String>> {
&self.stopwords
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug)]
pub struct CountVectorizerParams(CountVectorizerValidParams);
impl std::default::Default for CountVectorizerParams {
fn default() -> Self {
Self(CountVectorizerValidParams {
convert_to_lowercase: true,
split_regex_expr: r"\b\w\w+\b".to_string(),
split_regex: RefCell::new(None),
n_gram_range: (1, 1),
normalize: true,
document_frequency: (0., 1.),
stopwords: None,
max_features: None,
tokenizer_function: None,
tokenizer_deserialization_guard: false,
})
}
}
impl CountVectorizerParams {
pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
match tokenizer {
Tokenizer::Function(fp) => {
self.0.tokenizer_function = Some(fp);
self.0.tokenizer_deserialization_guard = true;
}
Tokenizer::Regex(regex_str) => {
self.0.split_regex_expr = regex_str.to_string();
self.0.tokenizer_deserialization_guard = false;
}
}
self
}
pub fn max_features(mut self, max_features: Option<usize>) -> Self {
self.0.max_features = max_features;
self
}
pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
self.0.convert_to_lowercase = convert_to_lowercase;
self
}
pub fn n_gram_range(mut self, min_n: usize, max_n: usize) -> Self {
self.0.n_gram_range = (min_n, max_n);
self
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.0.normalize = normalize;
self
}
pub fn document_frequency(mut self, min_freq: f32, max_freq: f32) -> Self {
self.0.document_frequency = (min_freq, max_freq);
self
}
pub fn stopwords<T: ToString>(mut self, stopwords: &[T]) -> Self {
self.0.stopwords = Some(stopwords.iter().map(|t| t.to_string()).collect());
self
}
}
impl ParamGuard for CountVectorizerParams {
type Checked = CountVectorizerValidParams;
type Error = PreprocessingError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
let (n_gram_min, n_gram_max) = self.0.n_gram_range;
let (min_freq, max_freq) = self.0.document_frequency;
if n_gram_min == 0 || n_gram_max == 0 {
Err(PreprocessingError::InvalidNGramBoundaries(
n_gram_min, n_gram_max,
))
} else if n_gram_min > n_gram_max {
Err(PreprocessingError::FlippedNGramBoundaries(
n_gram_min, n_gram_max,
))
} else if min_freq < 0. || max_freq < 0. {
Err(PreprocessingError::InvalidDocumentFrequencies(
min_freq, max_freq,
))
} else if max_freq < min_freq {
Err(PreprocessingError::FlippedDocumentFrequencies(
min_freq, max_freq,
))
} else {
*self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}