#[cfg(feature = "tokenizer-config")]
use serde_json::Value as JsonValue;
use std::{collections::HashMap, fs, path::Path};
use smol_str::format_smolstr;
use crate::error::{
ArithmeticOverflowPayload, EmptyInputPayload, Error, FileIoPayload, FileOp, MalformedDataPayload,
MissingFieldPayload, ParsePayload, Result, UnknownEnumValuePayload,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
#[non_exhaustive]
pub enum SentencePiecePieceType {
Normal,
Unknown,
Control,
UserDefined,
Unused,
Byte,
UnknownOrdinal(i32),
}
impl SentencePiecePieceType {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Normal => "normal",
Self::Unknown => "unknown",
Self::Control => "control",
Self::UserDefined => "user_defined",
Self::Unused => "unused",
Self::Byte => "byte",
Self::UnknownOrdinal(_) => "unknown",
}
}
pub fn as_raw(self) -> i32 {
match self {
Self::Normal => 1,
Self::Unknown => 2,
Self::Control => 3,
Self::UserDefined => 4,
Self::Unused => 5,
Self::Byte => 6,
Self::UnknownOrdinal(n) => n,
}
}
fn from_raw(raw: u64) -> Self {
match raw {
1 => SentencePiecePieceType::Normal,
2 => SentencePiecePieceType::Unknown,
3 => SentencePiecePieceType::Control,
4 => SentencePiecePieceType::UserDefined,
5 => SentencePiecePieceType::Unused,
6 => SentencePiecePieceType::Byte,
n => SentencePiecePieceType::UnknownOrdinal(n as i32),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum SentencePieceModelType {
Unigram,
Bpe,
}
impl SentencePieceModelType {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Unigram => "unigram",
Self::Bpe => "bpe",
}
}
fn from_raw(raw: u64) -> Option<Self> {
match raw {
1 => Some(SentencePieceModelType::Unigram),
2 => Some(SentencePieceModelType::Bpe),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SentencePieceToken {
token: String,
score: f32,
piece_type: SentencePiecePieceType,
}
impl SentencePieceToken {
pub fn new(token: impl Into<String>, score: f32, piece_type: SentencePiecePieceType) -> Self {
Self {
token: token.into(),
score,
piece_type,
}
}
#[inline(always)]
pub fn token(&self) -> &str {
&self.token
}
#[inline(always)]
pub fn score(&self) -> f32 {
self.score
}
#[inline(always)]
pub fn piece_type(&self) -> SentencePiecePieceType {
self.piece_type
}
}
struct SentencePieceProtobufReader<'a> {
data: &'a [u8],
index: usize,
}
impl<'a> SentencePieceProtobufReader<'a> {
fn new(data: &'a [u8]) -> Self {
Self { data, index: 0 }
}
fn is_at_end(&self) -> bool {
self.index >= self.data.len()
}
fn read_varint(&mut self) -> Result<u64> {
let mut value: u64 = 0;
let mut shift: u32 = 0;
while self.index < self.data.len() && shift < 64 {
let byte = self.data[self.index];
self.index += 1;
value |= u64::from(byte & 0x7f) << shift;
if byte & 0x80 == 0 {
return Ok(value);
}
shift += 7;
}
Err(Error::MalformedData(MalformedDataPayload::new(
"SentencePiece protobuf",
"malformed varint",
)))
}
fn read_length_delimited(&mut self) -> Result<&'a [u8]> {
let length = self.read_varint()? as usize;
let end = self
.index
.checked_add(length)
.ok_or(Error::ArithmeticOverflow(ArithmeticOverflowPayload::new(
"SentencePiece protobuf: length-delimited field",
"usize",
)))?;
if end > self.data.len() {
return Err(Error::MalformedData(MalformedDataPayload::new(
"SentencePiece protobuf",
"truncated length-delimited field",
)));
}
let slice = &self.data[self.index..end];
self.index = end;
Ok(slice)
}
fn read_fixed32(&mut self) -> Result<u32> {
let end = self.index.checked_add(4).ok_or(Error::ArithmeticOverflow(
ArithmeticOverflowPayload::new("SentencePiece protobuf: fixed32 offset", "usize"),
))?;
if end > self.data.len() {
return Err(Error::MalformedData(MalformedDataPayload::new(
"SentencePiece protobuf",
"truncated fixed32 field",
)));
}
let slice = &self.data[self.index..end];
self.index = end;
let mut value: u32 = 0;
for (i, &b) in slice.iter().enumerate() {
value |= u32::from(b) << (i * 8);
}
Ok(value)
}
fn skip_field(&mut self, wire_type: u64) -> Result<()> {
match wire_type {
0 => {
let _ = self.read_varint()?;
}
1 => {
let end = self.index.checked_add(8).ok_or(Error::ArithmeticOverflow(
ArithmeticOverflowPayload::new("SentencePiece protobuf: fixed64 offset", "usize"),
))?;
if end > self.data.len() {
return Err(Error::MalformedData(MalformedDataPayload::new(
"SentencePiece protobuf",
"truncated fixed64 field",
)));
}
self.index = end;
}
2 => {
let _ = self.read_length_delimited()?;
}
5 => {
let _ = self.read_fixed32()?;
}
other => {
return Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"SentencePiece protobuf: wire type",
format_smolstr!("{other}"),
&[
"0 (varint)",
"1 (fixed64)",
"2 (length-delimited)",
"5 (fixed32)",
],
)));
}
}
Ok(())
}
}
struct ParsedModel {
pieces: Vec<SentencePieceToken>,
unknown_token_id: usize,
model_type: SentencePieceModelType,
}
fn parse_pieces(data: &[u8]) -> Result<ParsedModel> {
let mut reader = SentencePieceProtobufReader::new(data);
let mut pieces: Vec<SentencePieceToken> = Vec::new();
let mut unknown_token_id: Option<usize> = None;
let mut model_type: SentencePieceModelType = SentencePieceModelType::Unigram;
while !reader.is_at_end() {
let key = reader.read_varint()?;
let field_number = key >> 3;
let wire_type = key & 0x7;
if field_number == 1 && wire_type == 2 {
let piece_data = reader.read_length_delimited()?;
if let Some(piece) = parse_piece(piece_data)? {
if piece.piece_type() == SentencePiecePieceType::Unknown && unknown_token_id.is_none() {
unknown_token_id = Some(pieces.len());
}
pieces.push(piece);
}
} else if field_number == 2 && wire_type == 2 {
let trainer_spec_data = reader.read_length_delimited()?;
if let Some(t) = parse_trainer_spec_model_type(trainer_spec_data)? {
model_type = t;
}
} else {
reader.skip_field(wire_type)?;
}
}
if pieces.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"SentencePiece model: vocabulary pieces",
)));
}
let resolved_unknown_id = unknown_token_id
.or_else(|| pieces.iter().position(|p| p.token() == "<unk>"))
.unwrap_or(0);
Ok(ParsedModel {
pieces,
unknown_token_id: resolved_unknown_id,
model_type,
})
}
fn parse_piece(data: &[u8]) -> Result<Option<SentencePieceToken>> {
let mut reader = SentencePieceProtobufReader::new(data);
let mut token: Option<String> = None;
let mut score: f32 = 0.0;
let mut r#type: SentencePiecePieceType = SentencePiecePieceType::Normal;
while !reader.is_at_end() {
let key = reader.read_varint()?;
let field_number = key >> 3;
let wire_type = key & 0x7;
match (field_number, wire_type) {
(1, 2) => {
let token_data = reader.read_length_delimited()?;
token = Some(String::from_utf8_lossy(token_data).into_owned());
}
(2, 5) => {
score = f32::from_bits(reader.read_fixed32()?);
}
(3, 0) => {
r#type = SentencePiecePieceType::from_raw(reader.read_varint()?);
}
_ => reader.skip_field(wire_type)?,
}
}
Ok(token.map(|token| SentencePieceToken::new(token, score, r#type)))
}
fn parse_trainer_spec_model_type(data: &[u8]) -> Result<Option<SentencePieceModelType>> {
let mut reader = SentencePieceProtobufReader::new(data);
while !reader.is_at_end() {
let key = reader.read_varint()?;
let field_number = key >> 3;
let wire_type = key & 0x7;
if field_number == 3 && wire_type == 0 {
return Ok(SentencePieceModelType::from_raw(reader.read_varint()?));
}
reader.skip_field(wire_type)?;
}
Ok(None)
}
#[derive(Debug, Clone)]
struct TokenLatticeNode {
token_id: usize,
char_start: usize,
char_len: usize,
score: f32,
prev: Option<usize>,
backtrace_score: f32,
}
struct TokenLattice {
chars: Vec<char>,
#[allow(dead_code)]
bos_token_id: usize,
#[allow(dead_code)]
eos_token_id: usize,
nodes: Vec<TokenLatticeNode>,
begin_nodes: Vec<Vec<usize>>,
end_nodes: Vec<Vec<usize>>,
}
impl TokenLattice {
fn new(sentence: &str, bos_token_id: usize, eos_token_id: usize) -> Self {
let chars: Vec<char> = sentence.chars().collect();
let n = chars.len();
let bos = TokenLatticeNode {
token_id: bos_token_id,
char_start: 0,
char_len: 0,
score: 0.0,
prev: None,
backtrace_score: 0.0,
};
let eos = TokenLatticeNode {
token_id: eos_token_id,
char_start: n,
char_len: 0,
score: 0.0,
prev: None,
backtrace_score: 0.0,
};
let mut nodes = Vec::with_capacity(n + 2);
nodes.push(bos);
nodes.push(eos);
let mut begin_nodes = vec![Vec::<usize>::new(); n + 1];
let mut end_nodes = vec![Vec::<usize>::new(); n + 1];
end_nodes[0].push(0); begin_nodes[n].push(1);
Self {
chars,
bos_token_id,
eos_token_id,
nodes,
begin_nodes,
end_nodes,
}
}
fn char_count(&self) -> usize {
self.chars.len()
}
fn insert(&mut self, char_start: usize, char_len: usize, score: f32, token_id: usize) {
let idx = self.nodes.len();
self.nodes.push(TokenLatticeNode {
token_id,
char_start,
char_len,
score,
prev: None,
backtrace_score: 0.0,
});
self.begin_nodes[char_start].push(idx);
self.end_nodes[char_start + char_len].push(idx);
}
fn viterbi(&mut self) -> Vec<TokenLatticeNode> {
let count = self.char_count();
for offset in 0..=count {
if self.begin_nodes[offset].is_empty() {
return Vec::new();
}
let rnode_indices = self.begin_nodes[offset].clone();
let lnode_indices = self.end_nodes[offset].clone();
for &rnode_idx in &rnode_indices {
let rnode_score = self.nodes[rnode_idx].score;
self.nodes[rnode_idx].prev = None;
let mut best_score: f32 = 0.0;
let mut best_lnode_idx: Option<usize> = None;
for &lnode_idx in &lnode_indices {
let lnode_backtrace = self.nodes[lnode_idx].backtrace_score;
let candidate = lnode_backtrace + rnode_score;
if best_lnode_idx.is_none() || candidate > best_score {
best_lnode_idx = Some(lnode_idx);
best_score = candidate;
}
}
if best_lnode_idx.is_some() {
self.nodes[rnode_idx].prev = best_lnode_idx;
self.nodes[rnode_idx].backtrace_score = best_score;
}
}
}
let root_idx = self.begin_nodes[count][0];
let mut prev = match self.nodes[root_idx].prev {
Some(i) => i,
None => return Vec::new(),
};
let mut result: Vec<TokenLatticeNode> = Vec::new();
loop {
let node = self.nodes[prev].clone();
let next = node.prev;
result.push(node);
match next {
Some(i) => prev = i,
None => break,
}
}
result.reverse();
result
}
fn piece(&self, node: &TokenLatticeNode) -> String {
let end = node.char_start + node.char_len;
self.chars[node.char_start..end].iter().collect()
}
}
#[derive(Debug, Default)]
struct TrieNode {
children: HashMap<char, TrieNode>,
is_end: bool,
}
#[derive(Debug, Default)]
struct Trie {
root: TrieNode,
}
impl Trie {
fn append_all<I, S>(&mut self, tokens: I)
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for token in tokens {
self.insert(token.as_ref());
}
}
fn insert(&mut self, token: &str) {
let mut node = &mut self.root;
for ch in token.chars() {
node = node.children.entry(ch).or_default();
}
node.is_end = true;
}
fn common_prefix_search(&self, chars: &[char]) -> Vec<String> {
let mut results: Vec<String> = Vec::new();
let mut node = &self.root;
let mut current = String::new();
for &ch in chars {
match node.children.get(&ch) {
Some(next) => {
current.push(ch);
node = next;
if node.is_end {
results.push(current.clone());
}
}
None => break,
}
}
results
}
}
#[derive(Debug)]
pub struct SentencePieceTokenizer {
vocab: Vec<SentencePieceToken>,
unknown_token_id: usize,
unknown_token_score: f32,
model_type: SentencePieceModelType,
tokens_to_ids: HashMap<String, usize>,
trie: Trie,
byte_map: [Option<usize>; 256],
bpe_atomic_pieces: Vec<String>,
}
impl SentencePieceTokenizer {
fn new(
vocab: Vec<SentencePieceToken>,
unknown_token_id: usize,
model_type: SentencePieceModelType,
) -> Self {
let min_score = vocab
.iter()
.map(|t| t.score())
.fold(f32::INFINITY, f32::min);
let unknown_token_score = min_score - 10.0;
let mut tokens_to_ids: HashMap<String, usize> = HashMap::with_capacity(vocab.len());
for (i, tok) in vocab.iter().enumerate() {
tokens_to_ids.insert(tok.token().to_owned(), i);
}
let mut trie = Trie::default();
trie.append_all(vocab.iter().map(|t| t.token()));
let mut byte_map: [Option<usize>; 256] = [None; 256];
for (i, tok) in vocab.iter().enumerate() {
let s = tok.token();
if let Some(byte) = parse_byte_fallback_piece(s) {
byte_map[byte as usize] = Some(i);
}
}
let mut bpe_atomic_pieces: Vec<String> = vocab
.iter()
.filter(|t| t.piece_type() == SentencePiecePieceType::UserDefined)
.map(|t| t.token().to_owned())
.collect();
bpe_atomic_pieces.sort_by_key(|piece| std::cmp::Reverse(piece.chars().count()));
Self {
vocab,
unknown_token_id,
unknown_token_score,
model_type,
tokens_to_ids,
trie,
byte_map,
bpe_atomic_pieces,
}
}
pub fn from_model_bytes(data: &[u8]) -> Result<Self> {
let parsed = parse_pieces(data)?;
Ok(Self::new(
parsed.pieces,
parsed.unknown_token_id,
parsed.model_type,
))
}
pub fn from_model_file(path: &Path) -> Result<Self> {
let bytes = fs::read(path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"SentencePieceTokenizer: failed to read model file",
FileOp::Read,
path.to_path_buf(),
e,
))
})?;
Self::from_model_bytes(&bytes)
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn from_tokenizer_json(tokenizer_json: &JsonValue) -> Result<Self> {
let model =
tokenizer_json
.get("model")
.ok_or(Error::MissingField(MissingFieldPayload::new(
"SentencePieceTokenizer",
"model",
)))?;
let unk_id = model
.get("unk_id")
.and_then(|v| v.as_u64())
.ok_or(Error::MissingField(MissingFieldPayload::new(
"SentencePieceTokenizer",
"model.unk_id",
)))? as usize;
let vocab_list = model
.get("vocab")
.and_then(|v| v.as_array())
.ok_or(Error::MissingField(MissingFieldPayload::new(
"SentencePieceTokenizer",
"model.vocab",
)))?;
let mut pieces: Vec<SentencePieceToken> = Vec::with_capacity(vocab_list.len());
for entry in vocab_list {
let arr = entry.as_array().ok_or_else(|| {
Error::MalformedData(MalformedDataPayload::new(
"SentencePieceTokenizer: `model.vocab`",
"entry is not an array",
))
})?;
if arr.len() != 2 {
return Err(Error::MalformedData(MalformedDataPayload::new(
"SentencePieceTokenizer: `model.vocab`",
"entry must be a [token, score] pair",
)));
}
let token = arr[0].as_str().ok_or_else(|| {
Error::MalformedData(MalformedDataPayload::new(
"SentencePieceTokenizer: `model.vocab`",
"entry[0] is not a string",
))
})?;
let score = arr[1].as_f64().ok_or_else(|| {
Error::MalformedData(MalformedDataPayload::new(
"SentencePieceTokenizer: `model.vocab`",
"entry[1] is not a number",
))
})? as f32;
let initial_type = if is_byte_fallback_piece(token) {
SentencePiecePieceType::Byte
} else {
SentencePiecePieceType::Normal
};
pieces.push(SentencePieceToken::new(
token.to_string(),
score,
initial_type,
));
}
if let Some(added) = tokenizer_json
.get("added_tokens")
.and_then(|v| v.as_array())
{
for at in added {
let Some(content) = at.get("content").and_then(|v| v.as_str()) else {
continue;
};
let special = at.get("special").and_then(|v| v.as_bool()).unwrap_or(false);
let target_type = if special {
SentencePiecePieceType::Control
} else {
SentencePiecePieceType::UserDefined
};
for p in &mut pieces {
if p.token() == content && p.piece_type() != SentencePiecePieceType::Byte {
*p = SentencePieceToken::new(p.token().to_string(), p.score(), target_type);
}
}
}
}
if let Some(unk_piece) = pieces.get_mut(unk_id) {
*unk_piece = SentencePieceToken::new(
unk_piece.token().to_string(),
unk_piece.score(),
SentencePiecePieceType::Unknown,
);
}
let model_type = match model.get("type").and_then(|v| v.as_str()) {
Some(t) if t.eq_ignore_ascii_case("BPE") => SentencePieceModelType::Bpe,
_ => SentencePieceModelType::Unigram,
};
Ok(Self::new(pieces, unk_id, model_type))
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn from_tokenizer_json_bytes(data: &[u8]) -> Result<Self> {
let json: JsonValue = serde_json::from_slice(data).map_err(|e| {
Error::Parse(ParsePayload::new(
"SentencePieceTokenizer::from_tokenizer_json_bytes",
"tokenizer.json",
e,
))
})?;
Self::from_tokenizer_json(&json)
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn unknown_token_id(&self) -> usize {
self.unknown_token_id
}
pub fn model_type(&self) -> SentencePieceModelType {
self.model_type
}
pub fn piece(&self, id: usize) -> Option<&SentencePieceToken> {
self.vocab.get(id)
}
pub fn encode_with_byte_fallback(&self, text: &str) -> Vec<usize> {
if self.model_type == SentencePieceModelType::Bpe {
return self.encode_bpe_with_byte_fallback(text);
}
self.encode_unigram_with_byte_fallback(text)
}
fn encode_unigram_with_byte_fallback(&self, text: &str) -> Vec<usize> {
let pre = apply_metaspace(text);
let mut lattice = TokenLattice::new(&pre, self.unknown_token_id, self.unknown_token_id);
let chars: Vec<char> = pre.chars().collect();
let mut begin_pos = 0;
while begin_pos < chars.len() {
let mblen = 1;
let mut has_single_node = false;
for token in self.trie.common_prefix_search(&chars[begin_pos..]) {
let Some(&token_id) = self.tokens_to_ids.get(&token) else {
continue;
};
let token_char_count = token.chars().count();
let token_score = self.vocab[token_id].score();
lattice.insert(begin_pos, token_char_count, token_score, token_id);
if !has_single_node && token_char_count == mblen {
has_single_node = true;
}
}
if !has_single_node {
lattice.insert(
begin_pos,
mblen,
self.unknown_token_score,
self.unknown_token_id,
);
}
begin_pos += mblen;
}
let path = lattice.viterbi();
let mut ids: Vec<usize> = Vec::with_capacity(path.len());
for node in &path {
if node.token_id == self.unknown_token_id {
let piece = lattice.piece(node);
for &b in piece.as_bytes() {
ids.push(self.byte_map[b as usize].unwrap_or(self.unknown_token_id));
}
} else {
ids.push(node.token_id);
}
}
ids
}
fn encode_bpe_with_byte_fallback(&self, text: &str) -> Vec<usize> {
let pre = apply_metaspace(text);
let mut symbols = self.initial_bpe_symbols(&pre);
while symbols.len() > 1 {
let mut best_index: Option<usize> = None;
let mut best_piece = String::new();
let mut best_score = f32::NEG_INFINITY;
for index in 0..symbols.len() - 1 {
let mut candidate = String::with_capacity(symbols[index].len() + symbols[index + 1].len());
candidate.push_str(&symbols[index]);
candidate.push_str(&symbols[index + 1]);
let Some(&token_id) = self.tokens_to_ids.get(&candidate) else {
continue;
};
let tok = &self.vocab[token_id];
if !matches!(
tok.piece_type(),
SentencePiecePieceType::Normal | SentencePiecePieceType::UserDefined
) {
continue;
}
if best_index.is_none() || tok.score() > best_score {
best_index = Some(index);
best_piece = candidate;
best_score = tok.score();
}
}
let Some(index) = best_index else { break };
symbols.splice(index..=index + 1, std::iter::once(best_piece));
}
let mut ids: Vec<usize> = Vec::new();
for symbol in &symbols {
if let Some(&token_id) = self.tokens_to_ids.get(symbol) {
ids.push(token_id);
} else {
for &b in symbol.as_bytes() {
ids.push(self.byte_map[b as usize].unwrap_or(self.unknown_token_id));
}
}
}
ids
}
fn initial_bpe_symbols(&self, text: &str) -> Vec<String> {
let mut symbols: Vec<String> = Vec::new();
let mut tail = text;
while !tail.is_empty() {
if let Some(atomic) = self
.bpe_atomic_pieces
.iter()
.find(|piece| tail.starts_with(piece.as_str()))
{
symbols.push(atomic.clone());
tail = &tail[atomic.len()..];
} else {
let mut iter = tail.char_indices();
let _ = iter.next();
let next_byte = iter.next().map(|(i, _)| i).unwrap_or(tail.len());
symbols.push(tail[..next_byte].to_string());
tail = &tail[next_byte..];
}
}
symbols
}
pub fn decode(&self, ids: &[usize]) -> String {
let mut bytes: Vec<u8> = Vec::new();
let mut pieces: Vec<String> = Vec::new();
for &id in ids {
let Some(token) = self.vocab.get(id) else {
continue;
};
if matches!(
token.piece_type(),
SentencePiecePieceType::Control | SentencePiecePieceType::Unused
) {
continue;
}
let tok = token.token();
if let Some(byte) = parse_byte_fallback_piece(tok) {
bytes.push(byte);
continue;
}
if !bytes.is_empty() {
if let Ok(s) = std::str::from_utf8(&bytes) {
pieces.push(s.to_string());
}
bytes.clear();
}
pieces.push(tok.to_owned());
}
if !bytes.is_empty()
&& let Ok(s) = std::str::from_utf8(&bytes)
{
pieces.push(s.to_string());
}
let joined: String = pieces.concat();
let restored = joined.replace('\u{2581}', " ");
restored.trim().to_string()
}
}
fn parse_byte_fallback_piece(piece: &str) -> Option<u8> {
let bytes = piece.as_bytes();
if bytes.len() != 6 || !bytes.starts_with(b"<0x") || bytes[5] != b'>' {
return None;
}
let hex = &piece[3..5];
u8::from_str_radix(hex, 16).ok()
}
#[cfg(feature = "tokenizer-config")]
fn is_byte_fallback_piece(piece: &str) -> bool {
parse_byte_fallback_piece(piece).is_some()
}
fn apply_metaspace(text: &str) -> String {
let replaced = text.replace(' ', "\u{2581}");
let mut out = String::with_capacity(replaced.len() + 3);
out.push('\u{2581}');
out.push_str(&replaced);
out
}
#[cfg(test)]
#[allow(clippy::identity_op, clippy::vec_init_then_push)]
mod tests {
use super::*;
fn write_varint(out: &mut Vec<u8>, mut value: u64) {
while value > 0x7f {
out.push((value & 0x7f) as u8 | 0x80);
value >>= 7;
}
out.push(value as u8);
}
fn build_piece(token: &str, score: f32, piece_type: u8) -> Vec<u8> {
let mut piece = Vec::new();
piece.push((1 << 3) | 2);
write_varint(&mut piece, token.len() as u64);
piece.extend_from_slice(token.as_bytes());
piece.push((2 << 3) | 5);
piece.extend_from_slice(&score.to_bits().to_le_bytes());
piece.push((3 << 3) | 0);
write_varint(&mut piece, u64::from(piece_type));
piece
}
fn build_model_with_pieces(pieces: &[(&str, f32, u8)], model_type: u64) -> Vec<u8> {
let mut out = Vec::new();
for (token, score, piece_type) in pieces {
let piece_bytes = build_piece(token, *score, *piece_type);
out.push((1 << 3) | 2);
write_varint(&mut out, piece_bytes.len() as u64);
out.extend_from_slice(&piece_bytes);
}
let mut trainer = Vec::new();
trainer.push((3 << 3) | 0);
write_varint(&mut trainer, model_type);
out.push((2 << 3) | 2);
write_varint(&mut out, trainer.len() as u64);
out.extend_from_slice(&trainer);
out
}
#[test]
fn parse_minimal_unigram_protobuf_yields_vocab_and_model_type() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}hello",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
(
"\u{2581}world",
-2.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
],
1, );
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 3);
assert_eq!(tok.unknown_token_id(), 0);
assert_eq!(tok.model_type(), SentencePieceModelType::Unigram);
assert_eq!(tok.piece(1).map(|p| p.token()), Some("\u{2581}hello"));
}
#[test]
fn malformed_protobuf_errors_with_actionable_message() {
let mut bad = Vec::new();
bad.push((1 << 3) | 2);
bad.push(50); let err = SentencePieceTokenizer::from_model_bytes(&bad).unwrap_err();
let message = err.to_string();
assert!(message.contains("SentencePiece"), "message: {message}");
assert!(message.contains("truncated"), "message: {message}");
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.context(), "SentencePiece protobuf");
assert!(p.detail().contains("truncated"), "detail: {}", p.detail());
}
#[test]
fn empty_vocab_protobuf_is_rejected() {
let mut data = Vec::new();
let mut trainer = Vec::new();
trainer.push((3 << 3) | 0);
write_varint(&mut trainer, 1);
data.push((2 << 3) | 2);
write_varint(&mut data, trainer.len() as u64);
data.extend_from_slice(&trainer);
let err = SentencePieceTokenizer::from_model_bytes(&data).unwrap_err();
let Error::EmptyInput(p) = err else {
panic!("expected Error::EmptyInput, got {err:?}");
};
assert!(
p.context().contains("vocabulary"),
"context: {}",
p.context()
);
}
fn toy_tokenizer() -> SentencePieceTokenizer {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}hello",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
(
"\u{2581}world",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
(
"\u{2581}",
-3.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
("<0x21>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
("<0x3F>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
],
1,
);
SentencePieceTokenizer::from_model_bytes(&data).unwrap()
}
#[test]
fn encode_unigram_known_input_yields_expected_piece_sequence() {
let tok = toy_tokenizer();
let ids = tok.encode_with_byte_fallback("hello world");
assert_eq!(ids, vec![1, 2], "ids={:?}", ids);
}
#[test]
fn encode_unigram_byte_fallback_for_out_of_vocab_chars() {
let tok = toy_tokenizer();
let ids = tok.encode_with_byte_fallback("hello?");
assert!(
ids.contains(&5),
"byte-fallback for `?` (id=5) missing in ids={ids:?}"
);
}
#[test]
fn encode_then_decode_is_lossless_round_trip_on_known_input() {
let tok = toy_tokenizer();
let original = "hello world";
let ids = tok.encode_with_byte_fallback(original);
let decoded = tok.decode(&ids);
assert_eq!(decoded, original, "round-trip mismatch: ids={ids:?}");
}
#[test]
fn decode_skips_control_and_unused_pieces() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
("<s>", 0.0, SentencePiecePieceType::Control.as_raw() as u8),
("<pad>", 0.0, SentencePiecePieceType::Unused.as_raw() as u8),
(
"\u{2581}hi",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
],
1,
);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
let decoded = tok.decode(&[1, 2, 3]); assert_eq!(decoded, "hi");
}
#[test]
fn decode_reassembles_byte_fallback_pieces_into_valid_utf8() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
("<0xC3>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
("<0xA9>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
(
"\u{2581}",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
],
1,
);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
let decoded = tok.decode(&[3, 1, 2]); assert_eq!(decoded, "é");
}
#[test]
fn from_model_file_propagates_io_error_for_missing_path() {
let err =
SentencePieceTokenizer::from_model_file(Path::new("/nonexistent/path.model")).unwrap_err();
match err {
Error::FileIo(p) => {
assert_eq!(p.op(), FileOp::Read);
assert_eq!(p.path(), Path::new("/nonexistent/path.model"));
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
assert!(p.context().contains("failed to read"));
}
other => panic!("expected Error::FileIo, got {other:?}"),
}
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_parses_unigram_vocab() {
let json: serde_json::Value = serde_json::json!({
"model": {
"type": "Unigram",
"unk_id": 0,
"vocab": [
["<unk>", 0.0],
["\u{2581}hello", -1.0],
["\u{2581}world", -1.0],
],
}
});
let tok = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap();
assert_eq!(tok.vocab_size(), 3);
assert_eq!(tok.unknown_token_id(), 0);
let ids = tok.encode_with_byte_fallback("hello world");
assert_eq!(ids, vec![1, 2]);
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_bytes_rejects_invalid_json() {
let err = SentencePieceTokenizer::from_tokenizer_json_bytes(b"not json").unwrap_err();
let Error::Parse(p) = err else {
panic!("expected Error::Parse, got {err:?}");
};
assert_eq!(p.input_kind(), "tokenizer.json");
assert_eq!(
p.context(),
"SentencePieceTokenizer::from_tokenizer_json_bytes"
);
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_missing_model_field() {
let json: serde_json::Value = serde_json::json!({"other": 1});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
match err {
Error::MissingField(p) => {
assert_eq!(p.type_name(), "SentencePieceTokenizer");
assert_eq!(p.field(), "model");
}
other => panic!("expected Error::MissingField, got {other:?}"),
}
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_infers_piece_types_from_unk_byte_and_added_tokens() {
let json: serde_json::Value = serde_json::json!({
"model": {
"type": "Unigram",
"unk_id": 0,
"vocab": [
["<unk>", 0.0], ["\u{2581}hello", -1.0], ["<0x41>", -2.0], ["<s>", -3.0], ["<custom>", -4.0], ["\u{2581}world", -5.0], ],
},
"added_tokens": [
{ "id": 3, "content": "<s>", "special": true },
{ "id": 4, "content": "<custom>", "special": false },
],
});
let tok = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap();
assert_eq!(tok.vocab_size(), 6);
assert_eq!(tok.unknown_token_id(), 0);
assert_eq!(
tok.piece(0).unwrap().piece_type(),
SentencePiecePieceType::Unknown
);
assert_eq!(
tok.piece(1).unwrap().piece_type(),
SentencePiecePieceType::Normal
);
assert_eq!(
tok.piece(2).unwrap().piece_type(),
SentencePiecePieceType::Byte
);
assert_eq!(
tok.piece(3).unwrap().piece_type(),
SentencePiecePieceType::Control
);
assert_eq!(
tok.piece(4).unwrap().piece_type(),
SentencePiecePieceType::UserDefined
);
assert_eq!(
tok.piece(5).unwrap().piece_type(),
SentencePiecePieceType::Normal
);
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_piece_type_precedence() {
let json: serde_json::Value = serde_json::json!({
"model": {
"type": "Unigram",
"unk_id": 1,
"vocab": [
["<0xFF>", 0.0], ["<0x00>", 0.0], ["\u{2581}x", 0.0],
],
},
"added_tokens": [
{ "id": 0, "content": "<0xFF>", "special": true },
],
});
let tok = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap();
assert_eq!(
tok.piece(0).unwrap().piece_type(),
SentencePiecePieceType::Byte
);
assert_eq!(
tok.piece(1).unwrap().piece_type(),
SentencePiecePieceType::Unknown
);
}
#[test]
fn piece_type_as_str_covers_every_variant() {
assert_eq!(SentencePiecePieceType::Normal.as_str(), "normal");
assert_eq!(SentencePiecePieceType::Unknown.as_str(), "unknown");
assert_eq!(SentencePiecePieceType::Control.as_str(), "control");
assert_eq!(SentencePiecePieceType::UserDefined.as_str(), "user_defined");
assert_eq!(SentencePiecePieceType::Unused.as_str(), "unused");
assert_eq!(SentencePiecePieceType::Byte.as_str(), "byte");
assert_eq!(
SentencePiecePieceType::UnknownOrdinal(99).as_str(),
"unknown"
);
}
#[test]
fn piece_type_display_delegates_to_as_str() {
assert_eq!(
SentencePiecePieceType::UserDefined.to_string(),
"user_defined"
);
assert_eq!(SentencePiecePieceType::Byte.to_string(), "byte");
assert_eq!(
SentencePiecePieceType::UnknownOrdinal(7).to_string(),
"unknown"
);
}
#[test]
fn piece_type_as_raw_covers_every_variant() {
assert_eq!(SentencePiecePieceType::Normal.as_raw(), 1);
assert_eq!(SentencePiecePieceType::Unknown.as_raw(), 2);
assert_eq!(SentencePiecePieceType::Control.as_raw(), 3);
assert_eq!(SentencePiecePieceType::UserDefined.as_raw(), 4);
assert_eq!(SentencePiecePieceType::Unused.as_raw(), 5);
assert_eq!(SentencePiecePieceType::Byte.as_raw(), 6);
assert_eq!(SentencePiecePieceType::UnknownOrdinal(42).as_raw(), 42);
assert_eq!(SentencePiecePieceType::UnknownOrdinal(-1).as_raw(), -1);
}
#[test]
fn piece_type_from_raw_round_trips_known_ordinals() {
for v in [
SentencePiecePieceType::Normal,
SentencePiecePieceType::Unknown,
SentencePiecePieceType::Control,
SentencePiecePieceType::UserDefined,
SentencePiecePieceType::Unused,
SentencePiecePieceType::Byte,
] {
let raw = v.as_raw() as u64;
assert_eq!(SentencePiecePieceType::from_raw(raw), v, "raw={raw}");
}
}
#[test]
fn piece_type_from_raw_maps_each_ordinal_explicitly() {
assert_eq!(
SentencePiecePieceType::from_raw(1),
SentencePiecePieceType::Normal
);
assert_eq!(
SentencePiecePieceType::from_raw(2),
SentencePiecePieceType::Unknown
);
assert_eq!(
SentencePiecePieceType::from_raw(3),
SentencePiecePieceType::Control
);
assert_eq!(
SentencePiecePieceType::from_raw(4),
SentencePiecePieceType::UserDefined
);
assert_eq!(
SentencePiecePieceType::from_raw(5),
SentencePiecePieceType::Unused
);
assert_eq!(
SentencePiecePieceType::from_raw(6),
SentencePiecePieceType::Byte
);
}
#[test]
fn piece_type_from_raw_unknown_ordinal_is_captured() {
assert_eq!(
SentencePiecePieceType::from_raw(7),
SentencePiecePieceType::UnknownOrdinal(7)
);
assert_eq!(
SentencePiecePieceType::from_raw(255),
SentencePiecePieceType::UnknownOrdinal(255)
);
assert!(SentencePiecePieceType::from_raw(7).is_unknown_ordinal());
assert!(SentencePiecePieceType::Normal.is_normal());
}
#[test]
fn model_type_as_str_and_display() {
assert_eq!(SentencePieceModelType::Unigram.as_str(), "unigram");
assert_eq!(SentencePieceModelType::Bpe.as_str(), "bpe");
assert_eq!(SentencePieceModelType::Unigram.to_string(), "unigram");
assert_eq!(SentencePieceModelType::Bpe.to_string(), "bpe");
assert!(SentencePieceModelType::Unigram.is_unigram());
assert!(SentencePieceModelType::Bpe.is_bpe());
}
#[test]
fn model_type_from_raw_maps_1_and_2_and_rejects_others() {
let unigram = build_model_with_pieces(
&[("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8)],
1,
);
assert_eq!(
SentencePieceTokenizer::from_model_bytes(&unigram)
.unwrap()
.model_type(),
SentencePieceModelType::Unigram
);
let bpe = build_model_with_pieces(
&[("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8)],
2,
);
assert_eq!(
SentencePieceTokenizer::from_model_bytes(&bpe)
.unwrap()
.model_type(),
SentencePieceModelType::Bpe
);
let unknown_kind = build_model_with_pieces(
&[("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8)],
3,
);
assert_eq!(
SentencePieceTokenizer::from_model_bytes(&unknown_kind)
.unwrap()
.model_type(),
SentencePieceModelType::Unigram
);
}
#[test]
fn read_varint_unterminated_is_malformed() {
let bad = vec![0x80u8];
let err = SentencePieceTokenizer::from_model_bytes(&bad).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.context(), "SentencePiece protobuf");
assert_eq!(p.detail(), "malformed varint");
}
#[test]
fn read_varint_multi_byte_continuation_then_eof_is_malformed() {
let bad = vec![0x80u8, 0x80, 0x80];
let err = SentencePieceTokenizer::from_model_bytes(&bad).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.detail(), "malformed varint");
}
#[test]
fn read_fixed32_truncated_is_malformed() {
let mut bad = Vec::new();
bad.push((9 << 3) | 5); bad.push(0x01);
bad.push(0x02); let err = SentencePieceTokenizer::from_model_bytes(&bad).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.context(), "SentencePiece protobuf");
assert_eq!(p.detail(), "truncated fixed32 field");
}
#[test]
fn skip_field_fixed32_full_then_pieces_parse() {
let mut data = Vec::new();
data.push((9 << 3) | 5); data.extend_from_slice(&0xDEAD_BEEFu32.to_le_bytes());
let piece = build_piece("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8);
data.push((1 << 3) | 2);
write_varint(&mut data, piece.len() as u64);
data.extend_from_slice(&piece);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 1);
}
#[test]
fn skip_field_varint_unknown_field_number() {
let mut data = Vec::new();
data.push((9 << 3) | 0); write_varint(&mut data, 123_456);
let piece = build_piece("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8);
data.push((1 << 3) | 2);
write_varint(&mut data, piece.len() as u64);
data.extend_from_slice(&piece);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 1);
}
#[test]
fn skip_field_fixed64_full_then_pieces_parse() {
let mut data = Vec::new();
data.push((9 << 3) | 1); data.extend_from_slice(&0x0102_0304_0506_0708u64.to_le_bytes());
let piece = build_piece("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8);
data.push((1 << 3) | 2);
write_varint(&mut data, piece.len() as u64);
data.extend_from_slice(&piece);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 1);
}
#[test]
fn skip_field_fixed64_truncated_is_malformed() {
let mut bad = Vec::new();
bad.push((9 << 3) | 1);
bad.extend_from_slice(&[0x01, 0x02, 0x03]); let err = SentencePieceTokenizer::from_model_bytes(&bad).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.detail(), "truncated fixed64 field");
}
#[test]
fn skip_field_length_delimited_unknown_field_number() {
let mut data = Vec::new();
data.push((9 << 3) | 2); let blob = b"ignored-bytes";
write_varint(&mut data, blob.len() as u64);
data.extend_from_slice(blob);
let piece = build_piece("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8);
data.push((1 << 3) | 2);
write_varint(&mut data, piece.len() as u64);
data.extend_from_slice(&piece);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 1);
}
#[test]
fn skip_field_unsupported_wire_type_errors() {
let mut bad = Vec::new();
bad.push((1 << 3) | 3); let err = SentencePieceTokenizer::from_model_bytes(&bad).unwrap_err();
let Error::UnknownEnumValue(p) = err else {
panic!("expected Error::UnknownEnumValue, got {err:?}");
};
assert_eq!(p.type_name(), "SentencePiece protobuf: wire type");
assert_eq!(p.value(), "3");
assert!(
p.supported().contains(&"0 (varint)"),
"supported: {:?}",
p.supported()
);
assert!(p.to_string().contains("wire type"), "{p}");
}
#[test]
fn parse_piece_skips_unknown_subfields() {
let mut piece = Vec::new();
piece.push((1 << 3) | 2);
write_varint(&mut piece, "\u{2581}hi".len() as u64);
piece.extend_from_slice("\u{2581}hi".as_bytes());
piece.push((7 << 3) | 0);
write_varint(&mut piece, 999);
piece.push((3 << 3) | 0);
write_varint(
&mut piece,
u64::from(SentencePiecePieceType::Normal.as_raw() as u8),
);
let mut data = Vec::new();
data.push((1 << 3) | 2);
write_varint(&mut data, piece.len() as u64);
data.extend_from_slice(&piece);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 1);
assert_eq!(tok.piece(0).map(|p| p.token()), Some("\u{2581}hi"));
assert_eq!(
tok.piece(0).unwrap().piece_type(),
SentencePiecePieceType::Normal
);
}
#[test]
fn parse_piece_with_no_token_field_is_dropped() {
let mut tokenless = Vec::new();
tokenless.push((2 << 3) | 5); tokenless.extend_from_slice(&(-1.0f32).to_bits().to_le_bytes());
let mut data = Vec::new();
data.push((1 << 3) | 2);
write_varint(&mut data, tokenless.len() as u64);
data.extend_from_slice(&tokenless);
let real = build_piece("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8);
data.push((1 << 3) | 2);
write_varint(&mut data, real.len() as u64);
data.extend_from_slice(&real);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.vocab_size(), 1);
assert_eq!(tok.piece(0).map(|p| p.token()), Some("<unk>"));
}
#[test]
fn trainer_spec_skips_unknown_subfield_and_returns_none() {
let mut trainer = Vec::new();
trainer.push((5 << 3) | 0); write_varint(&mut trainer, 7);
let mut data = Vec::new();
let piece = build_piece("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8);
data.push((1 << 3) | 2);
write_varint(&mut data, piece.len() as u64);
data.extend_from_slice(&piece);
data.push((2 << 3) | 2);
write_varint(&mut data, trainer.len() as u64);
data.extend_from_slice(&trainer);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.model_type(), SentencePieceModelType::Unigram);
}
#[test]
fn lattice_viterbi_returns_empty_on_gap_with_no_begin_node() {
let mut lattice = TokenLattice::new("ab", 0, 0);
assert_eq!(lattice.char_count(), 2);
let path = lattice.viterbi();
assert!(
path.is_empty(),
"expected empty path, got {} nodes",
path.len()
);
}
#[test]
fn lattice_viterbi_returns_empty_when_eos_has_no_predecessor() {
let mut lattice = TokenLattice::new("a", 5, 7);
lattice.insert(0, 0, -1.0, 9);
let path = lattice.viterbi();
assert!(
path.is_empty(),
"expected empty path (EOS unreachable), got {} nodes",
path.len()
);
}
#[test]
fn lattice_piece_extracts_char_range_substring() {
let mut lattice = TokenLattice::new("\u{2581}\u{00e9}!", 0, 0);
assert_eq!(lattice.char_count(), 3);
lattice.insert(1, 2, -1.0, 0); let node = lattice.nodes.last().unwrap().clone();
assert_eq!(lattice.piece(&node), "\u{00e9}!");
}
#[test]
fn from_model_file_reads_and_parses_a_real_file() {
let mut path = std::env::temp_dir();
path.push(format!(
"mlxrs_spm_model_{}_{:?}.model",
std::process::id(),
std::thread::current().id()
));
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}hi",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
],
1,
);
fs::write(&path, &data).expect("write temp model");
let tok = SentencePieceTokenizer::from_model_file(&path).expect("parse temp model");
let _ = fs::remove_file(&path);
assert_eq!(tok.vocab_size(), 2);
assert_eq!(tok.unknown_token_id(), 0);
assert_eq!(tok.model_type(), SentencePieceModelType::Unigram);
assert!(tok.piece(99).is_none());
}
#[test]
fn decode_skips_out_of_range_ids() {
let tok = toy_tokenizer(); let decoded = tok.decode(&[1, 999, 2]);
assert_eq!(decoded, "hello world");
}
#[test]
fn decode_flushes_byte_buffer_before_a_following_normal_piece() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
("<0xC3>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
("<0xA9>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
(
"\u{2581}hi",
-1.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
],
1,
);
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
let decoded = tok.decode(&[1, 2, 3]);
assert_eq!(decoded, "\u{00e9} hi");
}
#[test]
fn bpe_skips_merge_into_non_normal_piece_type() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}",
-3.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
("x", -2.0, SentencePiecePieceType::Normal.as_raw() as u8),
("y", -2.0, SentencePiecePieceType::Normal.as_raw() as u8),
("xy", -0.1, SentencePiecePieceType::Control.as_raw() as u8),
],
2, );
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
assert_eq!(tok.model_type(), SentencePieceModelType::Bpe);
let ids = tok.encode_with_byte_fallback("xy");
assert_eq!(ids, vec![1, 2, 3], "ids={ids:?}");
}
#[test]
fn bpe_leftover_symbol_falls_back_to_bytes() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}",
-3.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
("<0x7A>", -5.0, SentencePiecePieceType::Byte.as_raw() as u8),
],
2, );
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
let ids = tok.encode_with_byte_fallback("z");
assert_eq!(ids, vec![1, 2], "ids={ids:?}");
assert_eq!(tok.decode(&[2]), "z");
}
#[test]
fn bpe_leftover_symbol_without_byte_piece_uses_unknown_id() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}",
-3.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
],
2, );
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
let ids = tok.encode_with_byte_fallback("z");
assert_eq!(ids, vec![1, 0], "ids={ids:?}");
}
#[test]
fn bpe_initial_symbols_consume_user_defined_atomic_piece() {
let data = build_model_with_pieces(
&[
("<unk>", 0.0, SentencePiecePieceType::Unknown.as_raw() as u8),
(
"\u{2581}",
-3.0,
SentencePiecePieceType::Normal.as_raw() as u8,
),
(
"\u{2581}tag",
-0.5,
SentencePiecePieceType::UserDefined.as_raw() as u8,
),
],
2, );
let tok = SentencePieceTokenizer::from_model_bytes(&data).unwrap();
let ids = tok.encode_with_byte_fallback("tag");
assert_eq!(ids, vec![2], "ids={ids:?}");
assert_eq!(tok.decode(&ids), "tag");
}
#[test]
fn encode_unigram_empty_input_yields_no_or_metaspace_only_ids() {
let tok = toy_tokenizer();
let ids = tok.encode_with_byte_fallback("");
assert_eq!(ids, vec![3], "ids={ids:?}");
assert_eq!(tok.decode(&ids), "");
}
#[test]
fn token_accessors_round_trip_constructor_args() {
let t = SentencePieceToken::new("\u{2581}hi", -1.25, SentencePiecePieceType::Normal);
assert_eq!(t.token(), "\u{2581}hi");
assert_eq!(t.score(), -1.25);
assert_eq!(t.piece_type(), SentencePiecePieceType::Normal);
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_bytes_parses_valid_json() {
let json: &str = "{\"model\":{\"type\":\"Unigram\",\"unk_id\":0,\
\"vocab\":[[\"<unk>\",0.0],[\"\u{2581}hi\",-1.0]]}}";
let tok = SentencePieceTokenizer::from_tokenizer_json_bytes(json.as_bytes()).unwrap();
assert_eq!(tok.vocab_size(), 2);
assert_eq!(tok.unknown_token_id(), 0);
assert_eq!(tok.piece(1).map(|p| p.token()), Some("\u{2581}hi"));
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_missing_unk_id() {
let json: serde_json::Value = serde_json::json!({
"model": { "type": "Unigram", "vocab": [["<unk>", 0.0]] }
});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
let Error::MissingField(p) = err else {
panic!("expected Error::MissingField, got {err:?}");
};
assert_eq!(p.type_name(), "SentencePieceTokenizer");
assert_eq!(p.field(), "model.unk_id");
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_missing_vocab() {
let json: serde_json::Value = serde_json::json!({
"model": { "type": "Unigram", "unk_id": 0 }
});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
let Error::MissingField(p) = err else {
panic!("expected Error::MissingField, got {err:?}");
};
assert_eq!(p.field(), "model.vocab");
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_non_array_vocab_entry() {
let json: serde_json::Value = serde_json::json!({
"model": { "type": "Unigram", "unk_id": 0, "vocab": [42] }
});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.context(), "SentencePieceTokenizer: `model.vocab`");
assert_eq!(p.detail(), "entry is not an array");
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_wrong_arity_vocab_entry() {
let json: serde_json::Value = serde_json::json!({
"model": { "type": "Unigram", "unk_id": 0, "vocab": [["<unk>", 0.0, 99]] }
});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.detail(), "entry must be a [token, score] pair");
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_non_string_token() {
let json: serde_json::Value = serde_json::json!({
"model": { "type": "Unigram", "unk_id": 0, "vocab": [[7, 0.0]] }
});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.detail(), "entry[0] is not a string");
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_errors_on_non_numeric_score() {
let json: serde_json::Value = serde_json::json!({
"model": { "type": "Unigram", "unk_id": 0, "vocab": [["<unk>", "bad"]] }
});
let err = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap_err();
let Error::MalformedData(p) = err else {
panic!("expected Error::MalformedData, got {err:?}");
};
assert_eq!(p.detail(), "entry[1] is not a number");
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_added_token_without_content_is_skipped() {
let json: serde_json::Value = serde_json::json!({
"model": {
"type": "Unigram",
"unk_id": 0,
"vocab": [["<unk>", 0.0], ["<s>", -1.0], ["\u{2581}hi", -2.0]],
},
"added_tokens": [
{ "id": 9, "special": true }, { "id": 1, "content": "<s>", "special": true } ],
});
let tok = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap();
assert_eq!(tok.vocab_size(), 3);
assert_eq!(
tok.piece(1).unwrap().piece_type(),
SentencePiecePieceType::Control
);
assert_eq!(
tok.piece(2).unwrap().piece_type(),
SentencePiecePieceType::Normal
);
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn from_tokenizer_json_defaults_to_unigram_without_bpe_type() {
let json: serde_json::Value = serde_json::json!({
"model": { "unk_id": 0, "vocab": [["<unk>", 0.0]] }
});
let tok = SentencePieceTokenizer::from_tokenizer_json(&json).unwrap();
assert_eq!(tok.model_type(), SentencePieceModelType::Unigram);
let json_bpe: serde_json::Value = serde_json::json!({
"model": { "type": "bpe", "unk_id": 0, "vocab": [["<unk>", 0.0]] }
});
let tok_bpe = SentencePieceTokenizer::from_tokenizer_json(&json_bpe).unwrap();
assert_eq!(tok_bpe.model_type(), SentencePieceModelType::Bpe);
}
}