use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyDict;
use rustc_hash::FxHashMap;
use crate::core::SentencePieceTokenizer;
use crate::core::pretrained::{
cl100k_base_special_tokens, deepseek_v3_special_tokens, llama3_special_tokens,
mistral_v1_special_tokens, mistral_v2_special_tokens, mistral_v3_special_tokens,
o200k_base_special_tokens, CL100K_BASE_VOCAB, DEEPSEEK_V3_VOCAB, LLAMA3_VOCAB,
MISTRAL_V2_VOCAB, MISTRAL_V3_VOCAB, MISTRAL_VOCAB, O200K_BASE_VOCAB,
};
use crate::core::{
byte_level_decode_bytes, Tokenizer, CL100K_BASE_PATTERN, LLAMA3_PATTERN, MISTRAL_V3_PATTERN,
O200K_BASE_PATTERN, SENTENCEPIECE_PATTERN,
};
#[pyclass(name = "Tokenizer")]
pub struct PyTokenizer {
inner: Tokenizer,
}
#[pymethods]
impl PyTokenizer {
#[new]
#[pyo3(signature = (vocab_path, pattern, special_tokens=None))]
fn new(
vocab_path: &str,
pattern: &str,
special_tokens: Option<&Bound<'_, PyDict>>,
) -> PyResult<Self> {
let special = parse_special_tokens(special_tokens)?;
let inner = Tokenizer::from_file(vocab_path, pattern, special)
.map_err(|e| PyIOError::new_err(e.to_string()))?;
Ok(Self { inner })
}
#[staticmethod]
fn from_pretrained(name: &str) -> PyResult<Self> {
match name {
"cl100k_base" => {
let special = cl100k_base_special_tokens();
let inner = Tokenizer::from_bytes(CL100K_BASE_VOCAB, CL100K_BASE_PATTERN, special)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
"o200k_base" => {
let special = o200k_base_special_tokens();
let inner = Tokenizer::from_bytes(O200K_BASE_VOCAB, O200K_BASE_PATTERN, special)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
"llama3" | "llama3.1" | "llama3.2" | "llama3.3" => {
let special = llama3_special_tokens();
let inner = Tokenizer::from_bytes(LLAMA3_VOCAB, LLAMA3_PATTERN, special)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
"deepseek_v3" | "deepseek-v3" => {
let special = deepseek_v3_special_tokens();
let inner =
Tokenizer::from_bytes_byte_level(DEEPSEEK_V3_VOCAB, LLAMA3_PATTERN, special)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
"mistral" | "mistral_v1" => {
let special = mistral_v1_special_tokens();
let inner = Tokenizer::from_bytes_sentencepiece(
MISTRAL_VOCAB,
SENTENCEPIECE_PATTERN,
special,
)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
"mistral_v2" => {
let special = mistral_v2_special_tokens();
let inner = Tokenizer::from_bytes_sentencepiece(
MISTRAL_V2_VOCAB,
SENTENCEPIECE_PATTERN,
special,
)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
"mistral_v3" => {
let special = mistral_v3_special_tokens();
let inner =
Tokenizer::from_bytes_byte_level(MISTRAL_V3_VOCAB, MISTRAL_V3_PATTERN, special)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
_ => Err(PyValueError::new_err(format!(
"Unknown pretrained model: {}. See from_pretrained docstring for supported models.",
name
))),
}
}
#[staticmethod]
#[pyo3(signature = (vocab_data, pattern, special_tokens=None))]
fn from_bytes(
vocab_data: &[u8],
pattern: &str,
special_tokens: Option<&Bound<'_, PyDict>>,
) -> PyResult<Self> {
let special = parse_special_tokens(special_tokens)?;
let inner = Tokenizer::from_bytes(vocab_data, pattern, special)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
#[pyo3(signature = (use_pcre2=true))]
fn pcre2(&self, use_pcre2: bool) -> PyResult<Self> {
let new_inner = self.inner.clone();
let result = new_inner
.pcre2(use_pcre2)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner: result })
}
#[pyo3(signature = (use_jit=true))]
fn jit(&self, use_jit: bool) -> PyResult<Self> {
let new_inner = self.inner.clone();
let result = new_inner
.jit(use_jit)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner: result })
}
fn encode(&self, text: &str) -> Vec<u32> {
self.inner.encode(text)
}
fn encode_rayon(&self, text: &str) -> Vec<u32> {
self.inner.encode_rayon(text)
}
fn encode_with_special(&self, text: &str) -> Vec<u32> {
self.inner.encode_with_special(text)
}
fn decode(&self, tokens: Vec<u32>) -> PyResult<String> {
self.inner
.decode(&tokens)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn decode_bytes(&self, tokens: Vec<u32>) -> Vec<u8> {
self.inner.decode_bytes(&tokens)
}
fn decode_lossy(&self, tokens: Vec<u32>) -> String {
self.inner.decode_lossy(&tokens)
}
fn encode_batch(&self, texts: Vec<String>) -> Vec<Vec<u32>> {
self.inner.encode_batch(&texts)
}
fn encode_batch_with_special(&self, texts: Vec<String>) -> Vec<Vec<u32>> {
self.inner.encode_batch_with_special(&texts)
}
fn decode_batch(&self, token_lists: Vec<Vec<u32>>) -> PyResult<Vec<String>> {
self.inner
.decode_batch(&token_lists)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn decode_batch_lossy(&self, token_lists: Vec<Vec<u32>>) -> Vec<String> {
self.inner.decode_batch_lossy(&token_lists)
}
#[getter]
fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
fn streaming_decoder(&self) -> PyStreamingDecoder {
PyStreamingDecoder::new(
self.inner.decoder().clone(),
self.inner.special_tokens_decoder().clone(),
)
}
fn byte_level_streaming_decoder(&self) -> PyByteLevelStreamingDecoder {
PyByteLevelStreamingDecoder::new(
self.inner.decoder().clone(),
self.inner.special_tokens_decoder().clone(),
)
}
fn clear_cache(&self) {
self.inner.clear_cache();
}
#[getter]
fn cache_len(&self) -> usize {
self.inner.cache_len()
}
fn __repr__(&self) -> String {
format!("Tokenizer(vocab_size={})", self.inner.vocab_size())
}
}
#[pyclass(name = "SentencePieceTokenizer")]
pub struct PySentencePieceTokenizer {
inner: SentencePieceTokenizer,
}
#[pymethods]
impl PySentencePieceTokenizer {
#[new]
#[pyo3(signature = (tokens, scores, eos_token_id, bos_token_id=None))]
fn new(
tokens: Vec<String>,
scores: Vec<f32>,
eos_token_id: u32,
bos_token_id: Option<u32>,
) -> PyResult<Self> {
let inner = SentencePieceTokenizer::new(tokens, scores, bos_token_id, eos_token_id)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { inner })
}
fn encode(&self, text: &str) -> Vec<u32> {
self.inner.encode(text)
}
fn decode(&self, ids: Vec<u32>) -> PyResult<String> {
self.inner
.decode(&ids)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn decode_lossy(&self, ids: Vec<u32>) -> String {
self.inner.decode_lossy(&ids)
}
#[getter]
fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
fn is_eos(&self, token_id: u32) -> bool {
self.inner.is_eos(token_id)
}
#[getter]
fn eos_token_id(&self) -> u32 {
self.inner.eos_token_id()
}
#[getter]
fn bos_token_id(&self) -> Option<u32> {
self.inner.bos_token_id()
}
fn __repr__(&self) -> String {
format!(
"SentencePieceTokenizer(vocab_size={})",
self.inner.vocab_size()
)
}
}
fn parse_special_tokens(
special_tokens: Option<&Bound<'_, PyDict>>,
) -> PyResult<FxHashMap<String, u32>> {
let mut result = FxHashMap::default();
if let Some(dict) = special_tokens {
for (key, value) in dict.iter() {
let k: String = key.extract()?;
let v: u32 = value.extract()?;
result.insert(k, v);
}
}
Ok(result)
}
#[pyclass(name = "StreamingDecoder")]
pub struct PyStreamingDecoder {
decoder: FxHashMap<u32, Vec<u8>>,
special_decoder: FxHashMap<u32, String>,
buffer: Vec<u8>,
}
#[pymethods]
impl PyStreamingDecoder {
fn add_token(&mut self, token_id: u32) -> Option<String> {
let bytes = if let Some(b) = self.decoder.get(&token_id) {
b.as_slice()
} else if let Some(s) = self.special_decoder.get(&token_id) {
s.as_bytes()
} else {
return None;
};
self.buffer.extend_from_slice(bytes);
self.extract_complete_utf8()
}
fn add_tokens(&mut self, token_ids: Vec<u32>) -> Option<String> {
for token_id in token_ids {
let bytes = if let Some(b) = self.decoder.get(&token_id) {
b.as_slice()
} else if let Some(s) = self.special_decoder.get(&token_id) {
s.as_bytes()
} else {
continue;
};
self.buffer.extend_from_slice(bytes);
}
self.extract_complete_utf8()
}
fn flush(&mut self) -> String {
if self.buffer.is_empty() {
return String::new();
}
let result = String::from_utf8_lossy(&self.buffer).into_owned();
self.buffer.clear();
result
}
fn reset(&mut self) {
self.buffer.clear();
}
#[getter]
fn has_pending(&self) -> bool {
!self.buffer.is_empty()
}
#[getter]
fn pending_bytes(&self) -> usize {
self.buffer.len()
}
fn __repr__(&self) -> String {
format!("StreamingDecoder(pending_bytes={})", self.buffer.len())
}
}
impl PyStreamingDecoder {
fn new(decoder: FxHashMap<u32, Vec<u8>>, special_decoder: FxHashMap<u32, String>) -> Self {
Self {
decoder,
special_decoder,
buffer: Vec::with_capacity(16),
}
}
fn extract_complete_utf8(&mut self) -> Option<String> {
if self.buffer.is_empty() {
return None;
}
let valid_len = self.find_valid_utf8_len();
if valid_len == 0 {
return None;
}
let valid_bytes: Vec<u8> = self.buffer.drain(..valid_len).collect();
let result = unsafe { String::from_utf8_unchecked(valid_bytes) };
Some(result)
}
fn find_valid_utf8_len(&self) -> usize {
let bytes = &self.buffer;
let len = bytes.len();
if len == 0 {
return 0;
}
if std::str::from_utf8(bytes).is_ok() {
return len;
}
for incomplete_len in 1..=3.min(len) {
let check_len = len - incomplete_len;
if check_len == 0 {
continue;
}
if std::str::from_utf8(&bytes[..check_len]).is_ok()
&& Self::could_be_incomplete_sequence(&bytes[check_len..])
{
return check_len;
}
}
for i in (0..len).rev() {
if std::str::from_utf8(&bytes[..=i]).is_ok() {
return i + 1;
}
}
0
}
fn could_be_incomplete_sequence(bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
let first = bytes[0];
match first {
0xC0..=0xDF => bytes.len() < 2,
0xE0..=0xEF => bytes.len() < 3,
0xF0..=0xF7 => bytes.len() < 4,
_ => false,
}
}
}
#[pyclass(name = "ByteLevelStreamingDecoder")]
pub struct PyByteLevelStreamingDecoder {
decoder: FxHashMap<u32, Vec<u8>>,
special_decoder: FxHashMap<u32, String>,
buffer: Vec<u8>,
}
#[pymethods]
impl PyByteLevelStreamingDecoder {
fn add_token(&mut self, token_id: u32) -> Option<String> {
if let Some(encoded_bytes) = self.decoder.get(&token_id) {
if let Some(raw_bytes) = byte_level_decode_bytes(encoded_bytes) {
self.buffer.extend_from_slice(&raw_bytes);
} else {
self.buffer.extend_from_slice(encoded_bytes);
}
} else if let Some(special) = self.special_decoder.get(&token_id) {
self.buffer.extend_from_slice(special.as_bytes());
} else {
return None;
}
self.extract_complete_utf8()
}
fn add_tokens(&mut self, token_ids: Vec<u32>) -> Option<String> {
for token_id in token_ids {
if let Some(encoded_bytes) = self.decoder.get(&token_id) {
if let Some(raw_bytes) = byte_level_decode_bytes(encoded_bytes) {
self.buffer.extend_from_slice(&raw_bytes);
} else {
self.buffer.extend_from_slice(encoded_bytes);
}
} else if let Some(special) = self.special_decoder.get(&token_id) {
self.buffer.extend_from_slice(special.as_bytes());
}
}
self.extract_complete_utf8()
}
fn flush(&mut self) -> String {
if self.buffer.is_empty() {
return String::new();
}
let result = String::from_utf8_lossy(&self.buffer).into_owned();
self.buffer.clear();
result
}
fn reset(&mut self) {
self.buffer.clear();
}
#[getter]
fn has_pending(&self) -> bool {
!self.buffer.is_empty()
}
#[getter]
fn pending_bytes(&self) -> usize {
self.buffer.len()
}
fn __repr__(&self) -> String {
format!(
"ByteLevelStreamingDecoder(pending_bytes={})",
self.buffer.len()
)
}
}
impl PyByteLevelStreamingDecoder {
fn new(decoder: FxHashMap<u32, Vec<u8>>, special_decoder: FxHashMap<u32, String>) -> Self {
Self {
decoder,
special_decoder,
buffer: Vec::with_capacity(16),
}
}
fn extract_complete_utf8(&mut self) -> Option<String> {
if self.buffer.is_empty() {
return None;
}
let valid_len = self.find_valid_utf8_len();
if valid_len == 0 {
return None;
}
let valid_bytes: Vec<u8> = self.buffer.drain(..valid_len).collect();
let result = unsafe { String::from_utf8_unchecked(valid_bytes) };
Some(result)
}
fn find_valid_utf8_len(&self) -> usize {
let bytes = &self.buffer;
let len = bytes.len();
if len == 0 {
return 0;
}
if std::str::from_utf8(bytes).is_ok() {
return len;
}
for incomplete_len in 1..=3.min(len) {
let check_len = len - incomplete_len;
if check_len == 0 {
continue;
}
if std::str::from_utf8(&bytes[..check_len]).is_ok()
&& Self::could_be_incomplete_sequence(&bytes[check_len..])
{
return check_len;
}
}
for i in (0..len).rev() {
if std::str::from_utf8(&bytes[..=i]).is_ok() {
return i + 1;
}
}
0
}
fn could_be_incomplete_sequence(bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
let first = bytes[0];
match first {
0xC0..=0xDF => bytes.len() < 2,
0xE0..=0xEF => bytes.len() < 3,
0xF0..=0xF7 => bytes.len() < 4,
_ => false,
}
}
}
include!("agent_tokens_generated.rs");