use rten_tensor::Tensor;
use rten_tensor::prelude::*;
#[derive(Debug)]
pub struct Tokenizer {
chars: Vec<char>,
blank_id: usize,
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new()
}
}
impl Tokenizer {
#[must_use]
pub fn new() -> Self {
let raw_charset = "0123456789abcdefghijklmnopqrstuvwxyz";
let chars: Vec<char> = raw_charset.chars().collect();
Self {
chars,
blank_id: 0, }
}
#[must_use]
pub const fn charset_len(&self) -> usize {
self.chars.len()
}
#[must_use]
pub fn decode_rten(&self, logits: &Tensor<f32>) -> String {
let shape = logits.shape();
let (seq_len, num_classes) = match shape.len() {
3 => (shape[1], shape[2]), 2 => (shape[0], shape[1]), _ => return String::new(),
};
let data_vec;
let data = if let Some(slice) = logits.data() {
slice
} else {
data_vec = logits.iter().copied().collect::<Vec<_>>();
&data_vec
};
let class_stride = num_classes;
let tokens: Vec<usize> = (0..seq_len)
.map(|t| {
let base = t * class_stride;
(0..num_classes)
.max_by(|&a, &b| {
let val_a = data.get(base + a).copied().unwrap_or(f32::NEG_INFINITY);
let val_b = data.get(base + b).copied().unwrap_or(f32::NEG_INFINITY);
val_a
.partial_cmp(&val_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(0)
})
.collect();
let mut result = String::new();
let mut prev = self.blank_id;
for &token in &tokens {
if token != self.blank_id && token != prev {
if let Some(&c) = self.chars.get(token.saturating_sub(1)) {
result.push(c);
}
}
prev = token;
}
result
}
#[must_use]
pub fn decode(&self, logits: &Tensor<f32>) -> String {
self.decode_rten(logits)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_tokenizer_charset_length() {
let tokenizer = Tokenizer::new();
assert_eq!(
tokenizer.charset_len(),
36,
"Charset should have 36 alphanumeric characters (0-9, a-z)"
);
}
#[test]
fn test_tokenizer_default() {
let tokenizer = Tokenizer::default();
assert_eq!(tokenizer.charset_len(), 36);
}
#[test]
fn test_decode_repeated_chars() {
let tokenizer = Tokenizer::new();
let shape = [3, 37];
let mut data = vec![0.0f32; 3 * 37];
data[11] = 1.0;
data[37 + 11] = 1.0;
data[2 * 37 + 12] = 1.0;
let probs = Tensor::from_data(&shape, data);
let result = tokenizer.decode_rten(&probs);
assert_eq!(result, "ab");
}
#[test]
fn test_decode_with_blanks() {
let tokenizer = Tokenizer::new();
let shape = [3, 37];
let mut data = vec![0.0f32; 3 * 37];
data[11] = 1.0;
data[37] = 1.0;
data[2 * 37 + 11] = 1.0;
let probs = Tensor::from_data(&shape, data);
let result = tokenizer.decode_rten(&probs);
assert_eq!(result, "aa");
}
#[test]
fn test_decode_empty() {
let tokenizer = Tokenizer::new();
let shape = [0, 37];
let data: Vec<f32> = vec![];
let probs = Tensor::from_data(&shape, data);
let result = tokenizer.decode_rten(&probs);
assert_eq!(result, "");
}
#[test]
fn test_decode_all_blanks() {
let tokenizer = Tokenizer::new();
let shape = [5, 37];
let mut data = vec![0.0f32; 5 * 37];
for i in 0..5 {
data[i * 37] = 1.0;
}
let probs = Tensor::from_data(&shape, data);
let result = tokenizer.decode_rten(&probs);
assert_eq!(result, "");
}
#[test]
fn test_decode_complex_pattern() {
let tokenizer = Tokenizer::new();
let shape = [8, 37];
let mut data = vec![0.0f32; 8 * 37];
data[11] = 1.0;
data[37 + 11] = 1.0;
data[2 * 37] = 1.0;
data[3 * 37 + 12] = 1.0;
data[4 * 37 + 12] = 1.0;
data[5 * 37 + 12] = 1.0;
data[6 * 37] = 1.0;
data[7 * 37 + 13] = 1.0;
let probs = Tensor::from_data(&shape, data);
let result = tokenizer.decode_rten(&probs);
assert_eq!(result, "abc");
}
}