use super::Tokenizer;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PretokenizedTokenizer {
vocab_size: usize,
bos: Option<u32>,
eos: Option<u32>,
pad: Option<u32>,
unk: Option<u32>,
}
impl PretokenizedTokenizer {
pub fn new(
vocab_size: usize,
bos: Option<u32>,
eos: Option<u32>,
pad: Option<u32>,
unk: Option<u32>,
) -> Self {
Self {
vocab_size: vocab_size.max(1),
bos,
eos,
pad,
unk,
}
}
fn parse_token(text: &str) -> Option<u32> {
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
trimmed.parse::<u32>().ok()
}
}
impl Tokenizer for PretokenizedTokenizer {
fn encode(&self, text: &str, add_bos: bool, add_eos: bool) -> Vec<u32> {
let mut tokens = Vec::new();
if add_bos && let Some(bos) = self.bos {
tokens.push(bos);
}
for chunk in text.split(|ch: char| ch.is_whitespace() || matches!(ch, ',' | '[' | ']')) {
match Self::parse_token(chunk) {
Some(token) => tokens.push(token),
None if chunk.trim().is_empty() => {}
None => {
if let Some(unk) = self.unk {
tokens.push(unk);
}
}
}
}
if add_eos && let Some(eos) = self.eos {
tokens.push(eos);
}
tokens
}
fn decode(&self, ids: &[u32]) -> String {
self.decode_with_options(ids, true)
}
fn decode_with_options(&self, ids: &[u32], stop_at_eos: bool) -> String {
let mut rendered = Vec::with_capacity(ids.len());
for &id in ids {
if Some(id) == self.pad || Some(id) == self.bos {
continue;
}
if Some(id) == self.eos {
if stop_at_eos {
break;
}
continue;
}
rendered.push(id.to_string());
}
rendered.join(" ")
}
fn len(&self) -> usize {
self.vocab_size
}
fn is_empty(&self) -> bool {
self.vocab_size == 0
}
fn bos_id(&self) -> Option<u32> {
self.bos
}
fn eos_id(&self) -> Option<u32> {
self.eos
}
fn pad_id(&self) -> Option<u32> {
self.pad
}
fn unk_id(&self) -> Option<u32> {
self.unk
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pretokenized_round_trip_preserves_numeric_ids() {
let tokenizer = PretokenizedTokenizer::new(16, Some(14), Some(15), Some(13), Some(12));
let ids = tokenizer.encode("[1, 2, 3]", true, true);
assert_eq!(ids, vec![14, 1, 2, 3, 15]);
assert_eq!(tokenizer.decode(&ids), "1 2 3");
}
#[test]
fn pretokenized_unknown_falls_back_to_unk_when_available() {
let tokenizer = PretokenizedTokenizer::new(16, None, None, None, Some(7));
let ids = tokenizer.encode("10 nope 11", false, false);
assert_eq!(ids, vec![10, 7, 11]);
}
}