use alloc::vec;
use alloc::vec::Vec;
use crate::dict::{builtin_dict, Dict, BUILTIN_WORDS};
use crate::error::KhamError;
use crate::freq::FreqMap;
use crate::normalizer;
use crate::pre_tokenizer::pre_tokenize;
use crate::tcc::tcc_boundaries;
use crate::token::{Token, TokenKind};
pub struct Tokenizer {
dict: Dict,
freq: FreqMap,
keep_whitespace: bool,
}
impl Tokenizer {
pub fn new() -> Self {
Self {
dict: builtin_dict(),
freq: FreqMap::builtin(),
keep_whitespace: false,
}
}
pub fn normalize(&self, text: &str) -> alloc::string::String {
normalizer::normalize(text)
}
pub fn builder() -> TokenizerBuilder {
TokenizerBuilder::default()
}
pub fn segment<'t>(&self, text: &'t str) -> Vec<Token<'t>> {
if text.is_empty() {
return Vec::new();
}
let pre_tokens = pre_tokenize(text);
let mut result: Vec<Token<'t>> = Vec::with_capacity(pre_tokens.len() * 2);
for token in pre_tokens {
match token.kind {
TokenKind::Thai => {
segment_thai(&self.dict, &self.freq, text, token.span, &mut result);
}
TokenKind::Whitespace if !self.keep_whitespace => {
}
_ => {
result.push(token);
}
}
}
result
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct DpScore {
neg_unknowns: i32,
neg_tokens: i32,
dict_words: i32,
freq_score: u64,
}
impl DpScore {
const ZERO: Self = Self {
neg_unknowns: 0,
dict_words: 0,
freq_score: 0,
neg_tokens: 0,
};
fn dict_edge(self, freq: u32) -> Self {
Self {
dict_words: self.dict_words + 1,
freq_score: self.freq_score + freq as u64,
neg_tokens: self.neg_tokens - 1,
..self
}
}
fn unknown_edge(self) -> Self {
Self {
neg_unknowns: self.neg_unknowns - 1,
neg_tokens: self.neg_tokens - 1,
..self
}
}
}
struct DpTable {
from: Vec<usize>,
is_dict: Vec<bool>,
}
fn forward_dp(dict: &Dict, freqs: &FreqMap, slice: &str, bounds: &[usize]) -> DpTable {
let nb = bounds.len();
let mut best: Vec<Option<DpScore>> = vec![None; nb];
let mut from = vec![0usize; nb];
let mut is_dict = vec![false; nb];
best[0] = Some(DpScore::ZERO);
for i in 0..nb - 1 {
let score = match best[i] {
Some(s) => s,
None => continue,
};
let pos = bounds[i];
let remaining = &slice[pos..];
for prefix in dict.prefixes(remaining) {
let end_pos = pos + prefix.len();
if let Ok(j) = bounds.binary_search(&end_pos) {
let freq = freqs.get(prefix);
let candidate = Some(score.dict_edge(freq));
if candidate > best[j] {
best[j] = candidate;
from[j] = i;
is_dict[j] = true;
}
}
}
let j = i + 1;
let candidate = Some(score.unknown_edge());
if candidate > best[j] {
best[j] = candidate;
from[j] = i;
is_dict[j] = false;
}
}
DpTable { from, is_dict }
}
fn backtrack_path(from: &[usize]) -> Vec<usize> {
let nb = from.len();
let mut path = Vec::with_capacity(nb);
let mut cur = nb - 1;
loop {
path.push(cur);
if cur == 0 {
break;
}
cur = from[cur];
}
path.reverse();
path
}
fn segment_thai<'t>(
dict: &Dict,
freqs: &FreqMap,
text: &'t str,
span: core::ops::Range<usize>,
out: &mut Vec<Token<'t>>,
) {
let slice = &text[span.start..span.end];
let bounds = tcc_boundaries(slice);
if bounds.len() <= 1 {
return;
}
let dp = forward_dp(dict, freqs, slice, &bounds);
let path = backtrack_path(&dp.from);
let mut char_cursor = text[..span.start].chars().count();
for w in path.windows(2) {
let start_byte = span.start + bounds[w[0]];
let end_byte = span.start + bounds[w[1]];
let token_text = &text[start_byte..end_byte];
let char_start = char_cursor;
char_cursor += token_text.chars().count();
let kind = if dp.is_dict[w[1]] {
TokenKind::Thai
} else {
TokenKind::Unknown
};
out.push(Token::new(
token_text,
start_byte..end_byte,
char_start..char_cursor,
kind,
));
}
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct TokenizerBuilder {
dict_words: Option<alloc::string::String>,
keep_whitespace: bool,
}
impl TokenizerBuilder {
pub fn dict_words(mut self, words: &str) -> Self {
self.dict_words = Some(alloc::string::String::from(words));
self
}
pub fn keep_whitespace(mut self, keep: bool) -> Self {
self.keep_whitespace = keep;
self
}
pub fn build(self) -> Tokenizer {
let dict = if let Some(extra) = &self.dict_words {
let mut combined = alloc::string::String::from(BUILTIN_WORDS);
combined.push('\n');
combined.push_str(extra);
Dict::from_word_list(&combined)
} else {
builtin_dict()
};
Tokenizer {
dict,
freq: FreqMap::builtin(),
keep_whitespace: self.keep_whitespace,
}
}
#[cfg(feature = "std")]
pub fn dict_file(self, path: &str) -> Result<Self, KhamError> {
extern crate std;
let content = std::fs::read_to_string(path)
.map_err(|e| KhamError::DictLoadError(alloc::format!("{path}: {e}")))?;
Ok(self.dict_words(&content))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tok() -> Tokenizer {
Tokenizer::new()
}
#[test]
fn empty_input() {
assert!(tok().segment("").is_empty());
}
#[test]
fn pure_latin_passthrough() {
let tokens = tok().segment("hello");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[0].kind, TokenKind::Latin);
}
#[test]
fn pure_number_passthrough() {
let tokens = tok().segment("12345");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "12345");
assert_eq!(tokens[0].kind, TokenKind::Number);
}
#[test]
fn whitespace_dropped_by_default() {
let tokens = tok().segment("กิน ข้าว");
for t in &tokens {
assert_ne!(t.kind, TokenKind::Whitespace);
}
}
#[test]
fn whitespace_kept_when_requested() {
let tokens = Tokenizer::builder()
.keep_whitespace(true)
.build()
.segment("กิน ข้าว");
assert!(tokens.iter().any(|t| t.kind == TokenKind::Whitespace));
}
#[test]
fn gin_khao_gap_pla() {
let tokens = tok().segment("กินข้าวกับปลา");
let words: Vec<&str> = tokens.iter().map(|t| t.text).collect();
assert!(words.len() >= 2, "expected multiple words, got {words:?}");
assert_eq!(words.join(""), "กินข้าวกับปลา");
}
#[test]
fn mixed_thai_number_thai() {
let tokens = tok().segment("ธนาคาร100แห่ง");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "ธนาคาร100แห่ง");
let num = tokens.iter().find(|t| t.kind == TokenKind::Number);
assert!(num.is_some());
assert_eq!(num.unwrap().text, "100");
}
#[test]
fn mixed_thai_latin() {
let tokens = tok().segment("สวัสดี hello");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "สวัสดีhello");
assert!(tokens
.iter()
.any(|t| t.kind == TokenKind::Latin && t.text == "hello"));
}
#[test]
fn spans_cover_input_excluding_whitespace() {
let text = "กินข้าว123hello";
let tokens = tok().segment(text);
for t in &tokens {
assert_eq!(&text[t.span.clone()], t.text);
assert!(text.is_char_boundary(t.span.start));
assert!(text.is_char_boundary(t.span.end));
}
}
#[test]
fn adjacent_spans_are_contiguous() {
let text = "กินข้าวกับปลา";
let tokens = Tokenizer::builder()
.keep_whitespace(true)
.build()
.segment(text);
for w in tokens.windows(2) {
assert_eq!(
w[0].span.end, w[1].span.start,
"gap between {:?} and {:?}",
w[0], w[1]
);
}
}
#[test]
fn no_empty_tokens() {
let tokens = tok().segment("กินข้าวกับปลา 100 hello!");
for t in &tokens {
assert!(!t.text.is_empty());
}
}
#[test]
fn custom_dict_word_is_matched() {
let tok = Tokenizer::builder().dict_words("กขคงจฉ\n").build();
let tokens = tok.segment("กขคงจฉ");
let thai: Vec<&str> = tokens
.iter()
.filter(|t| t.kind == TokenKind::Thai)
.map(|t| t.text)
.collect();
assert!(thai.contains(&"กขคงจฉ"), "got: {thai:?}");
}
#[test]
fn normalize_deduplicates_tone_before_segment() {
let t = tok();
let raw = "กิน\u{0E02}\u{0E49}\u{0E49}าว"; let normalized = t.normalize(raw);
let tokens = t.segment(&normalized);
assert!(!tokens.is_empty());
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, normalized);
}
#[test]
fn normalize_clean_input_is_identity() {
let t = tok();
let clean = "กินข้าวกับปลา";
assert_eq!(t.normalize(clean), clean);
}
#[test]
fn segment_without_normalize_on_clean_input() {
let tokens = tok().segment("กินข้าวกับปลา");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "กินข้าวกับปลา");
}
#[test]
fn dp_score_fewer_unknowns_is_primary() {
let no_unknown = DpScore::ZERO;
let one_unknown = DpScore::ZERO.unknown_edge();
assert!(no_unknown > one_unknown);
}
#[test]
fn dp_score_fewer_tokens_beats_more_dict_words() {
let compound = DpScore::ZERO.dict_edge(0); let split = DpScore::ZERO.dict_edge(0).dict_edge(0); assert!(compound > split);
}
#[test]
fn dp_score_higher_freq_breaks_token_tie() {
let low_freq = DpScore::ZERO.dict_edge(10);
let high_freq = DpScore::ZERO.dict_edge(100);
assert!(high_freq > low_freq);
}
#[test]
fn dp_score_fewer_tokens_beats_higher_freq() {
let high_freq_more_tokens = DpScore {
neg_unknowns: 0,
neg_tokens: -2,
dict_words: 1,
freq_score: 200,
};
let low_freq_fewer_tokens = DpScore {
neg_unknowns: 0,
neg_tokens: -1,
dict_words: 1,
freq_score: 100,
};
assert!(low_freq_fewer_tokens > high_freq_more_tokens);
}
#[test]
fn dp_score_more_dict_words_breaks_token_tie() {
let fewer_dict = DpScore {
neg_unknowns: 0,
neg_tokens: -2,
dict_words: 1,
freq_score: 0,
};
let more_dict = DpScore {
neg_unknowns: 0,
neg_tokens: -2,
dict_words: 2,
freq_score: 0,
};
assert!(more_dict > fewer_dict);
}
#[test]
fn dict_edge_accumulates_freq_score() {
let after_one = DpScore::ZERO.dict_edge(50);
let after_two = after_one.dict_edge(30);
assert_eq!(after_one.freq_score, 50);
assert_eq!(after_two.freq_score, 80);
}
#[test]
fn dict_edge_increments_dict_words_and_neg_tokens() {
let s = DpScore::ZERO.dict_edge(0);
assert_eq!(s.dict_words, 1);
assert_eq!(s.neg_tokens, -1);
assert_eq!(s.neg_unknowns, 0);
}
#[test]
fn unknown_edge_increments_neg_unknowns_only() {
let s = DpScore::ZERO.unknown_edge();
assert_eq!(s.neg_unknowns, -1);
assert_eq!(s.neg_tokens, -1);
assert_eq!(s.dict_words, 0);
assert_eq!(s.freq_score, 0);
}
#[test]
fn unknown_edge_does_not_contribute_freq() {
let s = DpScore::ZERO.unknown_edge().unknown_edge();
assert_eq!(s.freq_score, 0);
}
#[test]
fn char_span_len_equals_char_count() {
let tokens = tok().segment("กินข้าวกับปลา");
for t in &tokens {
assert_eq!(
t.char_span.end - t.char_span.start,
t.text.chars().count(),
"char_span length mismatch for {:?}",
t.text
);
}
}
#[test]
fn char_spans_are_contiguous() {
let tokens = Tokenizer::builder()
.keep_whitespace(true)
.build()
.segment("กินข้าว 100 hello");
for w in tokens.windows(2) {
assert_eq!(
w[0].char_span.end, w[1].char_span.start,
"char_span gap between {:?} and {:?}",
w[0].text, w[1].text
);
}
}
#[test]
fn char_span_for_mixed_script() {
let tokens = tok().segment("ธนาคาร100แห่ง");
assert_eq!(tokens[0].char_span, 0..6);
assert_eq!(tokens[1].char_span, 6..9);
assert_eq!(tokens[2].char_span, 9..13);
}
#[test]
fn char_span_accounts_for_multibyte_chars() {
let tokens = tok().segment("กิน");
assert_eq!(tokens[0].span, 0..9);
assert_eq!(tokens[0].char_span, 0..3);
}
#[test]
fn char_span_emoji_is_single_char() {
let tokens = tok().segment("😀");
assert_eq!(tokens[0].char_len(), 1);
assert_eq!(tokens[0].byte_len(), 4);
}
#[test]
fn single_thai_char() {
let tokens = tok().segment("ก");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "ก");
}
#[test]
fn sawasdee_khao_lok() {
let tokens = tok().segment("สวัสดีชาวโลก");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "สวัสดีชาวโลก");
}
}