use alloc::vec;
use alloc::vec::Vec;
use crate::dict::{builtin_dict, Dict, BUILTIN_WORDS};
use crate::error::KhamError;
use crate::freq::FreqMap;
use crate::normalizer;
use crate::pre_tokenizer::pre_tokenize;
use crate::tcc::tcc_boundaries;
use crate::token::{Token, TokenKind};
pub struct Tokenizer {
dict: Dict,
freq: FreqMap,
keep_whitespace: bool,
}
impl Tokenizer {
pub fn new() -> Self {
Self {
dict: builtin_dict(),
freq: FreqMap::builtin(),
keep_whitespace: false,
}
}
pub fn normalize(&self, text: &str) -> alloc::string::String {
normalizer::normalize(text)
}
pub fn builder() -> TokenizerBuilder {
TokenizerBuilder::default()
}
pub fn segment<'t>(&self, text: &'t str) -> Vec<Token<'t>> {
if text.is_empty() {
return Vec::new();
}
let pre_tokens = pre_tokenize(text);
let mut result: Vec<Token<'t>> = Vec::with_capacity(pre_tokens.len() * 2);
for token in pre_tokens {
match token.kind {
TokenKind::Thai => {
segment_thai(&self.dict, &self.freq, text, token.span, &mut result);
}
TokenKind::Whitespace if !self.keep_whitespace => {
}
_ => {
result.push(token);
}
}
}
result
}
pub fn segment_stream<'t>(&self, text: &'t str) -> TokenStream<'t> {
TokenStream {
inner: self.segment(text).into_iter(),
}
}
}
pub struct TokenStream<'t> {
inner: alloc::vec::IntoIter<Token<'t>>,
}
impl<'t> TokenStream<'t> {
pub fn next_word(&mut self) -> Option<Token<'t>> {
self.inner
.by_ref()
.find(|t| t.kind != TokenKind::Whitespace)
}
pub fn next_known(&mut self) -> Option<Token<'t>> {
self.inner
.by_ref()
.find(|t| t.kind != TokenKind::Whitespace && t.kind != TokenKind::Unknown)
}
pub fn next_above_confidence(&mut self, min: f32) -> Option<Token<'t>> {
self.inner.by_ref().find(|t| t.confidence >= min)
}
}
impl<'t> Iterator for TokenStream<'t> {
type Item = Token<'t>;
#[inline]
fn next(&mut self) -> Option<Token<'t>> {
self.inner.next()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct DpScore {
neg_unknowns: i32,
neg_tokens: i32,
dict_words: i32,
freq_score: u64,
}
impl DpScore {
const ZERO: Self = Self {
neg_unknowns: 0,
dict_words: 0,
freq_score: 0,
neg_tokens: 0,
};
fn dict_edge(self, freq: u32) -> Self {
Self {
dict_words: self.dict_words + 1,
freq_score: self.freq_score + freq as u64,
neg_tokens: self.neg_tokens - 1,
..self
}
}
fn unknown_edge(self) -> Self {
Self {
neg_unknowns: self.neg_unknowns - 1,
neg_tokens: self.neg_tokens - 1,
..self
}
}
}
struct DpTable {
from: Vec<usize>,
is_dict: Vec<bool>,
edge_freq: Vec<u32>,
competing: Vec<u8>,
}
fn forward_dp(dict: &Dict, freqs: &FreqMap, slice: &str, bounds: &[usize]) -> DpTable {
let nb = bounds.len();
let mut best: Vec<Option<DpScore>> = vec![None; nb];
let mut from = vec![0usize; nb];
let mut is_dict = vec![false; nb];
let mut edge_freq = vec![0u32; nb];
let mut competing = vec![0u8; nb];
best[0] = Some(DpScore::ZERO);
for i in 0..nb - 1 {
let score = match best[i] {
Some(s) => s,
None => continue,
};
let pos = bounds[i];
let remaining = &slice[pos..];
for prefix in dict.prefixes(remaining) {
let end_pos = pos + prefix.len();
if let Ok(j) = bounds.binary_search(&end_pos) {
competing[j] = competing[j].saturating_add(1);
let freq = freqs.get(prefix);
let candidate = Some(score.dict_edge(freq));
if candidate > best[j] {
best[j] = candidate;
from[j] = i;
is_dict[j] = true;
edge_freq[j] = freq;
}
}
}
let j = i + 1;
competing[j] = competing[j].saturating_add(1);
let candidate = Some(score.unknown_edge());
if candidate > best[j] {
best[j] = candidate;
from[j] = i;
is_dict[j] = false;
edge_freq[j] = 0;
}
}
DpTable {
from,
is_dict,
edge_freq,
competing,
}
}
fn backtrack_path(from: &[usize]) -> Vec<usize> {
let nb = from.len();
let mut path = Vec::with_capacity(nb);
let mut cur = nb - 1;
loop {
path.push(cur);
if cur == 0 {
break;
}
cur = from[cur];
}
path.reverse();
path
}
fn compute_confidence(is_dict: bool, freq: u32, competing: u8) -> f32 {
if !is_dict {
return 0.0;
}
let base = if freq > 0 { 1.0_f32 } else { 0.7_f32 };
let amb = match competing {
0 | 1 => 1.0,
2 => 0.9,
3 => 0.8,
_ => 0.7,
};
base * amb
}
fn segment_thai<'t>(
dict: &Dict,
freqs: &FreqMap,
text: &'t str,
span: core::ops::Range<usize>,
out: &mut Vec<Token<'t>>,
) {
let slice = &text[span.start..span.end];
let bounds = tcc_boundaries(slice);
if bounds.len() <= 1 {
return;
}
let dp = forward_dp(dict, freqs, slice, &bounds);
let path = backtrack_path(&dp.from);
let mut char_cursor = text[..span.start].chars().count();
for w in path.windows(2) {
let start_byte = span.start + bounds[w[0]];
let end_byte = span.start + bounds[w[1]];
let token_text = &text[start_byte..end_byte];
let char_start = char_cursor;
char_cursor += token_text.chars().count();
let kind = if dp.is_dict[w[1]] {
TokenKind::Thai
} else {
TokenKind::Unknown
};
let confidence =
compute_confidence(dp.is_dict[w[1]], dp.edge_freq[w[1]], dp.competing[w[1]]);
out.push(Token::new(
token_text,
start_byte..end_byte,
char_start..char_cursor,
kind,
confidence,
));
}
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct TokenizerBuilder {
dict_words: Option<alloc::string::String>,
dict_merge: Option<alloc::string::String>,
keep_whitespace: bool,
}
impl TokenizerBuilder {
pub fn dict_words(mut self, words: &str) -> Self {
self.dict_words = Some(alloc::string::String::from(words));
self
}
pub fn dict_merge(mut self, words: &str) -> Self {
self.dict_merge = Some(alloc::string::String::from(words));
self
}
pub fn keep_whitespace(mut self, keep: bool) -> Self {
self.keep_whitespace = keep;
self
}
pub fn build(self) -> Tokenizer {
let dict = if let Some(extra) = &self.dict_words {
let mut combined = alloc::string::String::from(BUILTIN_WORDS);
combined.push('\n');
combined.push_str(extra);
Dict::from_word_list(&combined)
} else if let Some(overlay) = &self.dict_merge {
builtin_dict().with_overlay(overlay)
} else {
builtin_dict()
};
Tokenizer {
dict,
freq: FreqMap::builtin(),
keep_whitespace: self.keep_whitespace,
}
}
#[cfg(feature = "std")]
pub fn dict_file(self, path: &str) -> Result<Self, KhamError> {
extern crate std;
let content = std::fs::read_to_string(path)
.map_err(|e| KhamError::DictLoadError(alloc::format!("{path}: {e}")))?;
Ok(self.dict_words(&content))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tok() -> Tokenizer {
Tokenizer::new()
}
#[test]
fn empty_input() {
assert!(tok().segment("").is_empty());
}
#[test]
fn pure_latin_passthrough() {
let tokens = tok().segment("hello");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[0].kind, TokenKind::Latin);
}
#[test]
fn pure_number_passthrough() {
let tokens = tok().segment("12345");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "12345");
assert_eq!(tokens[0].kind, TokenKind::Number);
}
#[test]
fn whitespace_dropped_by_default() {
let tokens = tok().segment("กิน ข้าว");
for t in &tokens {
assert_ne!(t.kind, TokenKind::Whitespace);
}
}
#[test]
fn whitespace_kept_when_requested() {
let tokens = Tokenizer::builder()
.keep_whitespace(true)
.build()
.segment("กิน ข้าว");
assert!(tokens.iter().any(|t| t.kind == TokenKind::Whitespace));
}
#[test]
fn gin_khao_gap_pla() {
let tokens = tok().segment("กินข้าวกับปลา");
let words: Vec<&str> = tokens.iter().map(|t| t.text).collect();
assert!(words.len() >= 2, "expected multiple words, got {words:?}");
assert_eq!(words.join(""), "กินข้าวกับปลา");
}
#[test]
fn mixed_thai_number_thai() {
let tokens = tok().segment("ธนาคาร100แห่ง");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "ธนาคาร100แห่ง");
let num = tokens.iter().find(|t| t.kind == TokenKind::Number);
assert!(num.is_some());
assert_eq!(num.unwrap().text, "100");
}
#[test]
fn mixed_thai_latin() {
let tokens = tok().segment("สวัสดี hello");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "สวัสดีhello");
assert!(tokens
.iter()
.any(|t| t.kind == TokenKind::Latin && t.text == "hello"));
}
#[test]
fn spans_cover_input_excluding_whitespace() {
let text = "กินข้าว123hello";
let tokens = tok().segment(text);
for t in &tokens {
assert_eq!(&text[t.span.clone()], t.text);
assert!(text.is_char_boundary(t.span.start));
assert!(text.is_char_boundary(t.span.end));
}
}
#[test]
fn adjacent_spans_are_contiguous() {
let text = "กินข้าวกับปลา";
let tokens = Tokenizer::builder()
.keep_whitespace(true)
.build()
.segment(text);
for w in tokens.windows(2) {
assert_eq!(
w[0].span.end, w[1].span.start,
"gap between {:?} and {:?}",
w[0], w[1]
);
}
}
#[test]
fn no_empty_tokens() {
let tokens = tok().segment("กินข้าวกับปลา 100 hello!");
for t in &tokens {
assert!(!t.text.is_empty());
}
}
#[test]
fn custom_dict_word_is_matched() {
let tok = Tokenizer::builder().dict_words("กขคงจฉ\n").build();
let tokens = tok.segment("กขคงจฉ");
let thai: Vec<&str> = tokens
.iter()
.filter(|t| t.kind == TokenKind::Thai)
.map(|t| t.text)
.collect();
assert!(thai.contains(&"กขคงจฉ"), "got: {thai:?}");
}
#[test]
fn normalize_deduplicates_tone_before_segment() {
let t = tok();
let raw = "กิน\u{0E02}\u{0E49}\u{0E49}าว"; let normalized = t.normalize(raw);
let tokens = t.segment(&normalized);
assert!(!tokens.is_empty());
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, normalized);
}
#[test]
fn normalize_clean_input_is_identity() {
let t = tok();
let clean = "กินข้าวกับปลา";
assert_eq!(t.normalize(clean), clean);
}
#[test]
fn segment_without_normalize_on_clean_input() {
let tokens = tok().segment("กินข้าวกับปลา");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "กินข้าวกับปลา");
}
#[test]
fn dp_score_fewer_unknowns_is_primary() {
let no_unknown = DpScore::ZERO;
let one_unknown = DpScore::ZERO.unknown_edge();
assert!(no_unknown > one_unknown);
}
#[test]
fn dp_score_fewer_tokens_beats_more_dict_words() {
let compound = DpScore::ZERO.dict_edge(0); let split = DpScore::ZERO.dict_edge(0).dict_edge(0); assert!(compound > split);
}
#[test]
fn dp_score_higher_freq_breaks_token_tie() {
let low_freq = DpScore::ZERO.dict_edge(10);
let high_freq = DpScore::ZERO.dict_edge(100);
assert!(high_freq > low_freq);
}
#[test]
fn dp_score_fewer_tokens_beats_higher_freq() {
let high_freq_more_tokens = DpScore {
neg_unknowns: 0,
neg_tokens: -2,
dict_words: 1,
freq_score: 200,
};
let low_freq_fewer_tokens = DpScore {
neg_unknowns: 0,
neg_tokens: -1,
dict_words: 1,
freq_score: 100,
};
assert!(low_freq_fewer_tokens > high_freq_more_tokens);
}
#[test]
fn dp_score_more_dict_words_breaks_token_tie() {
let fewer_dict = DpScore {
neg_unknowns: 0,
neg_tokens: -2,
dict_words: 1,
freq_score: 0,
};
let more_dict = DpScore {
neg_unknowns: 0,
neg_tokens: -2,
dict_words: 2,
freq_score: 0,
};
assert!(more_dict > fewer_dict);
}
#[test]
fn dict_edge_accumulates_freq_score() {
let after_one = DpScore::ZERO.dict_edge(50);
let after_two = after_one.dict_edge(30);
assert_eq!(after_one.freq_score, 50);
assert_eq!(after_two.freq_score, 80);
}
#[test]
fn dict_edge_increments_dict_words_and_neg_tokens() {
let s = DpScore::ZERO.dict_edge(0);
assert_eq!(s.dict_words, 1);
assert_eq!(s.neg_tokens, -1);
assert_eq!(s.neg_unknowns, 0);
}
#[test]
fn unknown_edge_increments_neg_unknowns_only() {
let s = DpScore::ZERO.unknown_edge();
assert_eq!(s.neg_unknowns, -1);
assert_eq!(s.neg_tokens, -1);
assert_eq!(s.dict_words, 0);
assert_eq!(s.freq_score, 0);
}
#[test]
fn unknown_edge_does_not_contribute_freq() {
let s = DpScore::ZERO.unknown_edge().unknown_edge();
assert_eq!(s.freq_score, 0);
}
#[test]
fn char_span_len_equals_char_count() {
let tokens = tok().segment("กินข้าวกับปลา");
for t in &tokens {
assert_eq!(
t.char_span.end - t.char_span.start,
t.text.chars().count(),
"char_span length mismatch for {:?}",
t.text
);
}
}
#[test]
fn char_spans_are_contiguous() {
let tokens = Tokenizer::builder()
.keep_whitespace(true)
.build()
.segment("กินข้าว 100 hello");
for w in tokens.windows(2) {
assert_eq!(
w[0].char_span.end, w[1].char_span.start,
"char_span gap between {:?} and {:?}",
w[0].text, w[1].text
);
}
}
#[test]
fn char_span_for_mixed_script() {
let tokens = tok().segment("ธนาคาร100แห่ง");
assert_eq!(tokens[0].char_span, 0..6);
assert_eq!(tokens[1].char_span, 6..9);
assert_eq!(tokens[2].char_span, 9..13);
}
#[test]
fn char_span_accounts_for_multibyte_chars() {
let tokens = tok().segment("กิน");
assert_eq!(tokens[0].span, 0..9);
assert_eq!(tokens[0].char_span, 0..3);
}
#[test]
fn char_span_emoji_is_single_char() {
let tokens = tok().segment("😀");
assert_eq!(tokens[0].char_len(), 1);
assert_eq!(tokens[0].byte_len(), 4);
}
#[test]
fn single_thai_char() {
let tokens = tok().segment("ก");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "ก");
}
#[test]
fn sawasdee_khao_lok() {
let tokens = tok().segment("สวัสดีชาวโลก");
let rebuilt: alloc::string::String = tokens.iter().map(|t| t.text).collect();
assert_eq!(rebuilt, "สวัสดีชาวโลก");
}
#[test]
fn confidence_unknown_token_is_zero() {
let tokens = tok().segment("กขคงจฉ"); let unknown = tokens.iter().find(|t| t.kind == TokenKind::Unknown);
if let Some(u) = unknown {
assert_eq!(u.confidence, 0.0, "Unknown token must have confidence 0.0");
}
}
#[test]
fn confidence_dict_word_is_positive() {
let tokens = tok().segment("กินข้าวกับปลา");
for t in &tokens {
if t.kind == TokenKind::Thai {
assert!(
t.confidence > 0.0,
"dict Thai token {:?} must have confidence > 0",
t.text
);
}
}
}
#[test]
fn confidence_non_thai_tokens_are_1() {
let tokens = tok().segment("hello 123 😀");
for t in &tokens {
assert_eq!(
t.confidence, 1.0,
"non-Thai token {:?} must have confidence 1.0",
t.text
);
}
}
#[test]
fn confidence_range_valid() {
let texts = &["กินข้าวกับปลา", "สวัสดีครับ", "hello กรุงเทพ 2024 😀", "กขคง"];
for text in texts {
for t in tok().segment(text) {
assert!(
(0.0..=1.0).contains(&t.confidence),
"token {:?} confidence {} out of range",
t.text,
t.confidence
);
}
}
}
#[test]
fn segment_stream_yields_same_as_segment() {
let t = tok();
let text = "กินข้าวกับปลา";
let direct: alloc::vec::Vec<_> = t.segment(text);
let streamed: alloc::vec::Vec<_> = t.segment_stream(text).collect();
assert_eq!(direct.len(), streamed.len());
for (a, b) in direct.iter().zip(streamed.iter()) {
assert_eq!(a.text, b.text);
assert_eq!(a.kind, b.kind);
assert_eq!(a.span, b.span);
}
}
#[test]
fn next_word_skips_whitespace() {
let t = Tokenizer::builder().keep_whitespace(true).build();
let mut stream = t.segment_stream("กิน ข้าว ปลา");
while let Some(tok) = stream.next_word() {
assert_ne!(
tok.kind,
TokenKind::Whitespace,
"next_word() must not return a whitespace token"
);
}
}
#[test]
fn next_known_skips_unknown() {
let t = tok();
let mut stream = t.segment_stream("กขค");
while let Some(tok) = stream.next_known() {
assert_ne!(
tok.kind,
TokenKind::Unknown,
"next_known() must not return an Unknown token"
);
assert_ne!(
tok.kind,
TokenKind::Whitespace,
"next_known() must not return a Whitespace token"
);
}
}
#[test]
fn next_above_confidence_filters_low() {
let t = tok();
let text = "กินข้าวกับปลา";
let threshold = 0.8_f32;
let mut stream = t.segment_stream(text);
while let Some(tok) = stream.next_above_confidence(threshold) {
assert!(
tok.confidence >= threshold,
"next_above_confidence({threshold}) returned token {:?} with confidence {}",
tok.text,
tok.confidence
);
}
}
}