use crate::text::Tokenizer;
use crate::AprenderError;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct WhitespaceTokenizer;
impl WhitespaceTokenizer {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Tokenizer for WhitespaceTokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>, AprenderError> {
let tokens: Vec<String> = text.split_whitespace().map(ToString::to_string).collect();
Ok(tokens)
}
}
#[derive(Debug, Clone, Default)]
pub struct WordTokenizer;
impl WordTokenizer {
#[must_use]
pub fn new() -> Self {
Self
}
fn is_separator(c: char) -> bool {
c.is_ascii_punctuation() && c != '\''
}
}
impl Tokenizer for WordTokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>, AprenderError> {
let mut tokens = Vec::new();
let mut current = String::new();
for ch in text.chars() {
if ch.is_whitespace() {
if !current.is_empty() {
tokens.push(current.clone());
current.clear();
}
} else if Self::is_separator(ch) {
if !current.is_empty() {
tokens.push(current.clone());
current.clear();
}
tokens.push(ch.to_string());
} else {
current.push(ch);
}
}
if !current.is_empty() {
tokens.push(current);
}
Ok(tokens)
}
}
#[derive(Debug, Clone, Default)]
pub struct CharTokenizer;
impl CharTokenizer {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Tokenizer for CharTokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>, AprenderError> {
let tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
Ok(tokens)
}
}
#[derive(Debug, Clone, Default)]
pub struct SentenceTokenizer {
abbreviations: Vec<&'static str>,
}
impl SentenceTokenizer {
#[must_use]
pub fn new() -> Self {
Self {
abbreviations: vec![
"mr", "mrs", "ms", "dr", "prof", "sr", "jr", "vs", "etc", "inc", "ltd", "corp",
"st", "ave", "blvd", "rd", "dept", "gov", "gen", "col", "lt", "sgt", "rev", "hon",
"pres", "jan", "feb", "mar", "apr", "jun", "jul", "aug", "sep", "oct", "nov",
"dec", "i.e", "e.g", "cf", "al", "vol", "no", "fig", "pp", "ph.d", "m.d", "b.a",
"m.a", "d.d.s",
],
}
}
#[must_use]
pub fn split(&self, text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let mut sentences = Vec::new();
let mut current = String::new();
let chars: Vec<char> = text.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
let c = chars[i];
current.push(c);
if c == '.' || c == '?' || c == '!' {
let is_end = if i + 1 < len {
let next = chars[i + 1];
if next.is_whitespace() {
let mut j = i + 2;
while j < len && chars[j].is_whitespace() {
j += 1;
}
j >= len || chars[j].is_uppercase()
} else {
false
}
} else {
true };
let is_abbrev = if c == '.' {
self.is_abbreviation(¤t)
} else {
false
};
if is_end && !is_abbrev {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
current.clear();
}
}
i += 1;
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
sentences
}
fn is_abbreviation(&self, text: &str) -> bool {
let text = text.trim_end_matches('.');
let last_word = text.split_whitespace().last().unwrap_or("");
let lower = last_word.to_lowercase();
self.abbreviations.contains(&lower.as_str())
}
}
#[derive(Debug, Clone)]
pub struct SpecialTokens {
pub unk: String,
pub bos: Option<String>,
pub eos: Option<String>,
pub pad: Option<String>,
}
impl Default for SpecialTokens {
fn default() -> Self {
Self {
unk: "<unk>".to_string(),
bos: Some("<s>".to_string()),
eos: Some("</s>".to_string()),
pad: Some("<pad>".to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct BpeTokenizer {
vocab: HashMap<String, u32>,
inverse_vocab: HashMap<u32, String>,
merges: Vec<(String, String)>,
special_tokens: SpecialTokens,
end_of_word: String,
}
mod bpe_impl;
mod bpe_tokenizer_impl;
pub use bpe_tokenizer_impl::*;
mod unigram_training;