extern crate regex;
extern crate unicode_segmentation;
#[cfg(feature = "python")]
use dict_derive::{FromPyObject, IntoPyObject};
use serde::{Deserialize, Serialize};
use unicode_segmentation::UnicodeSegmentation;
use crate::errors::EstimatorErr;
use crate::tokenize::Tokenizer;
use std::fmt;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone)]
pub struct UnicodeSentenceTokenizer {
pub params: UnicodeSentenceTokenizerParams,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))]
pub struct UnicodeSentenceTokenizerParams {}
impl UnicodeSentenceTokenizerParams {
pub fn build(&mut self) -> Result<UnicodeSentenceTokenizer, EstimatorErr> {
Ok(UnicodeSentenceTokenizer {
params: self.clone(),
})
}
}
impl Default for UnicodeSentenceTokenizerParams {
fn default() -> UnicodeSentenceTokenizerParams {
UnicodeSentenceTokenizerParams {}
}
}
impl Default for UnicodeSentenceTokenizer {
fn default() -> UnicodeSentenceTokenizer {
UnicodeSentenceTokenizerParams::default().build().unwrap()
}
}
impl Tokenizer for UnicodeSentenceTokenizer {
fn tokenize<'a>(&self, text: &'a str) -> Box<dyn Iterator<Item = &'a str> + 'a> {
Box::new(text.split_sentence_bounds())
}
}
#[derive(Clone)]
pub struct PunctuationTokenizer {
pub params: PunctuationTokenizerParams,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))]
pub struct PunctuationTokenizerParams {
punctuation: Vec<String>,
}
impl PunctuationTokenizerParams {
pub fn punctuation(&mut self, punctuation: Vec<String>) -> PunctuationTokenizerParams {
self.punctuation = punctuation;
self.clone()
}
pub fn build(&mut self) -> Result<PunctuationTokenizer, EstimatorErr> {
Ok(PunctuationTokenizer {
params: self.clone(),
})
}
}
#[macro_export]
macro_rules! vecString {
($( $char:expr ),*) => {{
vec![
$( $char.to_string(), )*
]
}}
}
impl Default for PunctuationTokenizerParams {
fn default() -> PunctuationTokenizerParams {
PunctuationTokenizerParams {
punctuation: vecString![".", "!", "?"],
}
}
}
impl Default for PunctuationTokenizer {
fn default() -> PunctuationTokenizer {
PunctuationTokenizerParams::default().build().unwrap()
}
}
impl fmt::Debug for PunctuationTokenizer {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"PunctuationTokenizer {{ punctuation: {:#?} }}",
self.params.punctuation
)
}
}
impl Tokenizer for PunctuationTokenizer {
fn tokenize<'a>(&'a self, text: &'a str) -> Box<dyn Iterator<Item = &'a str> + 'a> {
Box::new(punctuation_sentence_iterator(
text,
self.params.punctuation.clone(),
))
}
}
fn punctuation_sentence_iterator<'a>(
text: &'a str,
punctuation: Vec<String>,
) -> PunctuationTokenizerIterator<'a> {
let punctuation_chars: Vec<char> = punctuation
.iter()
.map(|x| x.chars().next().unwrap())
.collect();
PunctuationTokenizerIterator {
text,
punctuation: punctuation_chars,
seen_punct: false,
i: 0,
span_end: 0,
bytes_len: text.as_bytes().len(),
}
}
struct PunctuationTokenizerIterator<'a> {
text: &'a str,
punctuation: Vec<char>,
seen_punct: bool,
i: usize,
span_end: usize,
bytes_len: usize,
}
impl<'a> PunctuationTokenizerIterator<'a> {
fn bytes_slice(&self, start: Option<usize>, end: Option<usize>) -> &'a str {
let bytes = self.text.as_bytes();
let bytes_span: &[u8];
if let Some(start_idx) = start {
if let Some(end_idx) = end {
bytes_span = &bytes[start_idx..end_idx];
} else {
bytes_span = &bytes[start_idx..];
}
} else if let Some(end_idx) = end {
bytes_span = &bytes[..end_idx];
} else {
return self.text;
}
std::str::from_utf8(bytes_span).unwrap()
}
}
impl<'a> Iterator for PunctuationTokenizerIterator<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
if self.span_end >= self.bytes_len {
return None;
}
let remaining_text = self.bytes_slice(Some(self.span_end), None);
let idx_offset = self.span_end;
for (i, character) in remaining_text.char_indices() {
self.i = i + idx_offset;
let is_punct = self.punctuation.contains(&character);
if self.seen_punct {
let is_whitespace = character.is_whitespace();
if !is_whitespace {
let span_start = self.span_end;
self.span_end = idx_offset + i;
self.seen_punct = false;
let span = self.bytes_slice(Some(span_start), Some(self.span_end));
if !span.is_empty() {
return Some(span);
}
}
} else if is_punct {
self.seen_punct = true;
}
}
if self.span_end < self.bytes_len {
let span_start = self.span_end;
self.span_end = self.bytes_len;
let span = self.bytes_slice(Some(span_start), None);
if !span.is_empty() {
return Some(span);
}
}
None
}
}