use crate::{error::TokenizerResult, tokenizer::OxiTokenizer};
#[derive(Debug, Clone)]
pub struct TextNormalizer {
pub lowercase: bool,
pub strip_accents: bool,
pub strip_whitespace: bool,
pub collapse_whitespace: bool,
}
impl Default for TextNormalizer {
fn default() -> Self {
Self::new()
}
}
impl TextNormalizer {
pub fn new() -> Self {
Self {
lowercase: false,
strip_accents: false,
strip_whitespace: false,
collapse_whitespace: false,
}
}
pub fn lowercase_only() -> Self {
Self {
lowercase: true,
..Self::new()
}
}
pub fn whitespace_only() -> Self {
Self {
strip_whitespace: true,
collapse_whitespace: true,
..Self::new()
}
}
pub fn normalize(&self, text: &str) -> String {
let mut result: String = if self.lowercase {
text.chars()
.map(|c| {
if c.is_ascii_alphabetic() {
c.to_ascii_lowercase()
} else {
c
}
})
.collect()
} else {
text.to_owned()
};
if self.strip_accents {
result = result.chars().filter(|&c| !is_combining(c)).collect();
}
if self.collapse_whitespace {
let mut collapsed = String::with_capacity(result.len());
let mut prev_was_space = false;
for c in result.chars() {
if c.is_whitespace() {
if !prev_was_space {
collapsed.push(' ');
}
prev_was_space = true;
} else {
collapsed.push(c);
prev_was_space = false;
}
}
result = collapsed;
}
if self.strip_whitespace {
result = result.trim().to_owned();
}
result
}
}
fn is_combining(c: char) -> bool {
let cp = c as u32;
matches!(
cp,
0x0300..=0x036F
| 0x1DC0..=0x1DFF
| 0x1AB0..=0x1AFF
| 0xFE20..=0xFE2F
)
}
#[derive(Debug, Clone)]
pub struct ChatTemplate {
template: String,
}
impl ChatTemplate {
pub fn new(template: &str) -> Self {
Self {
template: template.to_owned(),
}
}
pub fn chatml() -> Self {
Self::new(
"{% for message in messages %}<|im_start|>{{ role }}\n{{ content }}<|im_end|>\n{% endfor %}",
)
}
pub fn llama3() -> Self {
Self::new(concat!(
"<|begin_of_text|>",
"{% for message in messages %}",
"<|start_header_id|>{{ role }}<|end_header_id|>\n\n",
"{{ content }}<|eot_id|>",
"{% endfor %}",
"<|start_header_id|>assistant<|end_header_id|>\n\n",
))
}
pub fn format(&self, messages: &[(&str, &str)]) -> String {
render_template(&self.template, messages)
}
pub fn extract_user_message(prompt: &str) -> Option<String> {
let mut search = prompt;
let marker = "<|im_start|>user\n";
let end_marker = "<|im_end|>";
let mut last_user: Option<String> = None;
while let Some(start_pos) = search.find(marker) {
let after_marker = &search[start_pos + marker.len()..];
if let Some(end_pos) = after_marker.find(end_marker) {
last_user = Some(after_marker[..end_pos].to_owned());
}
search = &search[start_pos + marker.len()..];
}
last_user
}
}
pub(crate) fn render_template(template: &str, messages: &[(&str, &str)]) -> String {
let tokens = tokenize_template(template);
let mut output = String::new();
let for_start = tokens.iter().position(|t| {
matches!(t, TemplateToken::Tag(s) if s.trim().starts_with("for ") && s.contains("messages"))
});
let for_end = tokens
.iter()
.position(|t| matches!(t, TemplateToken::Tag(s) if s.trim() == "endfor"));
match (for_start, for_end) {
(Some(fs), Some(fe)) if fs < fe => {
for tok in &tokens[..fs] {
if let TemplateToken::Literal(lit) = tok {
output.push_str(lit);
}
}
let body_tokens = &tokens[fs + 1..fe];
for (role, content) in messages {
output.push_str(&render_body(body_tokens, role, content));
}
for tok in &tokens[fe + 1..] {
if let TemplateToken::Literal(lit) = tok {
output.push_str(lit);
}
}
}
_ => {
if let Some((role, content)) = messages.first() {
output.push_str(&render_body(&tokens, role, content));
}
}
}
output
}
fn render_body(tokens: &[TemplateToken], role: &str, content: &str) -> String {
let mut output = String::new();
let mut i = 0;
while i < tokens.len() {
match &tokens[i] {
TemplateToken::Literal(lit) => {
output.push_str(lit);
i += 1;
}
TemplateToken::Variable(var) => {
let val = resolve_variable(var.trim(), role, content);
output.push_str(&val);
i += 1;
}
TemplateToken::Tag(tag) => {
let tag_trimmed = tag.trim();
if tag_trimmed.starts_with("if ") {
let condition_met = evaluate_condition(tag_trimmed, role, content);
let (if_body, else_body, skip) = collect_if_bodies(&tokens[i + 1..]);
if condition_met {
output.push_str(&render_body(&if_body, role, content));
} else {
output.push_str(&render_body(&else_body, role, content));
}
i += 1 + skip;
} else {
i += 1;
}
}
}
}
output
}
fn collect_if_bodies(tokens: &[TemplateToken]) -> (Vec<TemplateToken>, Vec<TemplateToken>, usize) {
let mut if_body = Vec::new();
let mut else_body = Vec::new();
let mut in_else = false;
let mut depth = 1usize; let mut consumed = 0;
for (idx, tok) in tokens.iter().enumerate() {
consumed = idx + 1;
match tok {
TemplateToken::Tag(tag) => {
let t = tag.trim();
if t.starts_with("if ") {
depth += 1;
if in_else {
else_body.push(tok.clone());
} else {
if_body.push(tok.clone());
}
} else if t == "endif" {
depth -= 1;
if depth == 0 {
break;
}
if in_else {
else_body.push(tok.clone());
} else {
if_body.push(tok.clone());
}
} else if t == "else" && depth == 1 {
in_else = true;
} else if in_else {
else_body.push(tok.clone());
} else {
if_body.push(tok.clone());
}
}
other => {
if in_else {
else_body.push(other.clone());
} else {
if_body.push(other.clone());
}
}
}
}
(if_body, else_body, consumed)
}
fn evaluate_condition(tag: &str, role: &str, content: &str) -> bool {
let rest = tag.trim_start_matches("if ").trim();
if let Some((lhs, rhs)) = rest.split_once("==") {
let lhs = lhs.trim();
let rhs = rhs.trim().trim_matches('"').trim_matches('\'');
let lhs_val = resolve_variable(lhs, role, content);
return lhs_val == rhs;
}
false
}
fn resolve_variable(var: &str, role: &str, content: &str) -> String {
match var {
"role" | "message.role" => role.to_owned(),
"content" | "message.content" => content.to_owned(),
_ => String::new(),
}
}
#[derive(Debug, Clone)]
enum TemplateToken {
Literal(String),
Variable(String),
Tag(String),
}
fn tokenize_template(template: &str) -> Vec<TemplateToken> {
let mut tokens = Vec::new();
let mut rest = template;
while !rest.is_empty() {
let var_pos = rest.find("{{");
let tag_pos = rest.find("{%");
let next = match (var_pos, tag_pos) {
(None, None) => {
tokens.push(TemplateToken::Literal(rest.to_owned()));
break;
}
(Some(vp), None) => Some(('v', vp)),
(None, Some(tp)) => Some(('t', tp)),
(Some(vp), Some(tp)) => {
if vp <= tp {
Some(('v', vp))
} else {
Some(('t', tp))
}
}
};
match next {
None => break, Some(('v', vp)) => {
if vp > 0 {
tokens.push(TemplateToken::Literal(rest[..vp].to_owned()));
}
let after_open = &rest[vp + 2..];
if let Some(close) = after_open.find("}}") {
tokens.push(TemplateToken::Variable(after_open[..close].to_owned()));
rest = &after_open[close + 2..];
} else {
tokens.push(TemplateToken::Literal(rest.to_owned()));
break;
}
}
Some(('t', tp)) => {
if tp > 0 {
tokens.push(TemplateToken::Literal(rest[..tp].to_owned()));
}
let after_open = &rest[tp + 2..];
if let Some(close) = after_open.find("%}") {
tokens.push(TemplateToken::Tag(after_open[..close].to_owned()));
rest = &after_open[close + 2..];
} else {
tokens.push(TemplateToken::Literal(rest.to_owned()));
break;
}
}
Some(_) => break, }
}
tokens
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationSide {
Left,
Right,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PaddingStrategy {
Fixed(usize),
Longest,
}
pub struct BatchEncoder<'a> {
tokenizer: &'a OxiTokenizer,
pub max_length: Option<usize>,
pub truncation: Option<TruncationSide>,
pub padding: Option<PaddingStrategy>,
pub pad_token_id: u32,
}
impl<'a> BatchEncoder<'a> {
pub fn new(tokenizer: &'a OxiTokenizer) -> Self {
Self {
tokenizer,
max_length: None,
truncation: None,
padding: None,
pad_token_id: 3,
}
}
pub fn with_max_length(mut self, n: usize) -> Self {
self.max_length = Some(n);
self
}
pub fn with_truncation(mut self, side: TruncationSide) -> Self {
self.truncation = Some(side);
self
}
pub fn with_padding(mut self, strategy: PaddingStrategy) -> Self {
self.padding = Some(strategy);
self
}
pub fn encode_batch(&self, texts: &[&str]) -> TokenizerResult<BatchEncoding> {
if texts.is_empty() {
return Ok(BatchEncoding {
input_ids: Vec::new(),
attention_mask: Vec::new(),
lengths: Vec::new(),
});
}
let mut encoded: Vec<Vec<u32>> = texts
.iter()
.map(|t| self.tokenizer.encode(t))
.collect::<TokenizerResult<_>>()?;
if let Some(max) = self.max_length {
for seq in &mut encoded {
if seq.len() > max {
match self.truncation.unwrap_or(TruncationSide::Right) {
TruncationSide::Right => {
seq.truncate(max);
}
TruncationSide::Left => {
let excess = seq.len() - max;
seq.drain(..excess);
}
}
}
}
}
let lengths: Vec<usize> = encoded.iter().map(Vec::len).collect();
let target_len = match self.padding {
None => None,
Some(PaddingStrategy::Longest) => lengths.iter().copied().max(),
Some(PaddingStrategy::Fixed(n)) => Some(n),
};
let mut input_ids: Vec<Vec<u32>> = Vec::with_capacity(encoded.len());
let mut attention_mask: Vec<Vec<u32>> = Vec::with_capacity(encoded.len());
for (seq, &len) in encoded.iter().zip(lengths.iter()) {
match target_len {
None => {
let mask = vec![1u32; len];
input_ids.push(seq.clone());
attention_mask.push(mask);
}
Some(pad_to) => {
let mut ids = seq.clone();
let pad_count = pad_to.saturating_sub(len);
ids.extend(std::iter::repeat_n(self.pad_token_id, pad_count));
let mut mask = vec![1u32; len];
mask.extend(std::iter::repeat_n(0u32, pad_count));
input_ids.push(ids);
attention_mask.push(mask);
}
}
}
Ok(BatchEncoding {
input_ids,
attention_mask,
lengths,
})
}
}
pub struct BatchEncoding {
pub input_ids: Vec<Vec<u32>>,
pub attention_mask: Vec<Vec<u32>>,
pub lengths: Vec<usize>,
}
impl BatchEncoding {
pub fn batch_size(&self) -> usize {
self.input_ids.len()
}
pub fn max_seq_len(&self) -> usize {
self.input_ids.iter().map(Vec::len).max().unwrap_or(0)
}
pub fn is_padded(&self) -> bool {
let max = self.max_seq_len();
self.lengths.iter().any(|&l| l < max)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OxiTokenizer;
#[test]
fn test_text_normalizer_lowercase() {
let n = TextNormalizer::lowercase_only();
assert_eq!(n.normalize("Hello World"), "hello world");
assert_eq!(n.normalize("ABC123"), "abc123");
assert_eq!(n.normalize("already lower"), "already lower");
}
#[test]
fn test_text_normalizer_collapse_whitespace() {
let n = TextNormalizer::whitespace_only();
assert_eq!(n.normalize(" hello world "), "hello world");
assert_eq!(n.normalize("a b c"), "a b c");
assert_eq!(n.normalize("no extra"), "no extra");
}
#[test]
fn test_text_normalizer_strip_accents() {
let n = TextNormalizer {
strip_accents: true,
..TextNormalizer::new()
};
let input = "caf\u{0065}\u{0301}";
let result = n.normalize(input);
assert!(
!result.contains('\u{0301}'),
"combining accent should be removed"
);
}
#[test]
fn test_text_normalizer_combined() {
let n = TextNormalizer {
lowercase: true,
strip_whitespace: true,
collapse_whitespace: true,
..TextNormalizer::new()
};
assert_eq!(n.normalize(" HELLO WORLD "), "hello world");
}
#[test]
fn test_chat_template_chatml_format() {
let tmpl = ChatTemplate::chatml();
let messages = [("user", "Hello!")];
let out = tmpl.format(&messages);
assert!(
out.contains("<|im_start|>user"),
"should contain user start token"
);
assert!(out.contains("Hello!"), "should contain message content");
assert!(out.contains("<|im_end|>"), "should contain end token");
}
#[test]
fn test_chat_template_multi_turn() {
let tmpl = ChatTemplate::chatml();
let messages = [
("system", "You are helpful."),
("user", "What is 2+2?"),
("assistant", "4"),
];
let out = tmpl.format(&messages);
assert!(out.contains("<|im_start|>system"), "system role present");
assert!(out.contains("<|im_start|>user"), "user role present");
assert!(
out.contains("<|im_start|>assistant"),
"assistant role present"
);
assert!(out.contains("You are helpful."), "system content present");
assert!(out.contains("What is 2+2?"), "user content present");
assert!(out.contains('4'), "assistant content present");
}
#[test]
fn test_chat_template_extract_user_message() {
let tmpl = ChatTemplate::chatml();
let messages = [("user", "Find me a recipe.")];
let formatted = tmpl.format(&messages);
let extracted = ChatTemplate::extract_user_message(&formatted);
assert_eq!(extracted, Some("Find me a recipe.".to_owned()));
}
#[test]
fn test_chat_template_extract_user_message_multi_turn() {
let tmpl = ChatTemplate::chatml();
let messages = [
("user", "First question"),
("assistant", "First answer"),
("user", "Second question"),
];
let formatted = tmpl.format(&messages);
let extracted = ChatTemplate::extract_user_message(&formatted);
assert_eq!(extracted, Some("Second question".to_owned()));
}
fn make_tokenizer() -> OxiTokenizer {
OxiTokenizer::char_level_stub(256)
}
#[test]
fn test_batch_encoder_basic() {
let tok = make_tokenizer();
let enc = BatchEncoder::new(&tok);
let result = enc
.encode_batch(&["hello", "world"])
.expect("batch encode should succeed");
assert_eq!(result.batch_size(), 2);
assert!(!result.input_ids[0].is_empty());
assert!(!result.input_ids[1].is_empty());
}
#[test]
fn test_batch_encoder_truncation_right() {
let tok = make_tokenizer();
let enc = BatchEncoder::new(&tok)
.with_max_length(3)
.with_truncation(TruncationSide::Right);
let result = enc
.encode_batch(&["hello world"])
.expect("encode should succeed");
assert_eq!(result.lengths[0], 3, "should be truncated to 3 tokens");
assert_eq!(result.input_ids[0].len(), 3);
}
#[test]
fn test_batch_encoder_truncation_left() {
let tok = make_tokenizer();
let full = tok.encode("hello").expect("encode");
let full_len = full.len();
let enc = BatchEncoder::new(&tok)
.with_max_length(2)
.with_truncation(TruncationSide::Left);
let result = enc.encode_batch(&["hello"]).expect("encode should succeed");
if full_len >= 2 {
assert_eq!(result.input_ids[0], full[full_len - 2..]);
}
assert!(result.lengths[0] <= 2);
}
#[test]
fn test_batch_encoder_padding_fixed() {
let tok = make_tokenizer();
let enc = BatchEncoder::new(&tok).with_padding(PaddingStrategy::Fixed(10));
let result = enc
.encode_batch(&["hi", "hello"])
.expect("encode should succeed");
for ids in &result.input_ids {
assert_eq!(ids.len(), 10);
}
}
#[test]
fn test_batch_encoding_attention_mask() {
let tok = make_tokenizer();
let enc = BatchEncoder::new(&tok).with_padding(PaddingStrategy::Longest);
let result = enc
.encode_batch(&["hi", "hello world"])
.expect("encode should succeed");
let max_len = result.max_seq_len();
for (i, mask) in result.attention_mask.iter().enumerate() {
assert_eq!(mask.len(), max_len, "mask length matches padded seq len");
let real_len = result.lengths[i];
for &m in &mask[..real_len] {
assert_eq!(m, 1u32, "real token position should have mask=1");
}
for &m in &mask[real_len..] {
assert_eq!(m, 0u32, "padding position should have mask=0");
}
}
}
#[test]
fn test_batch_encoding_empty() {
let tok = make_tokenizer();
let enc = BatchEncoder::new(&tok);
let result = enc.encode_batch(&[]).expect("empty batch should succeed");
assert_eq!(result.batch_size(), 0);
assert_eq!(result.max_seq_len(), 0);
assert!(!result.is_padded());
}
}