use pyo3::prelude::*;
use pyo3::types::PyDict;
use oxillama_runtime::TokenizerBridge;
use crate::chat_template::apply_template;
use crate::error::runtime_to_py;
#[pyclass(name = "Tokenizer")]
pub struct PyTokenizer {
inner: TokenizerBridge,
}
#[pymethods]
impl PyTokenizer {
#[staticmethod]
pub fn from_file(path: &str) -> PyResult<Self> {
let bridge = TokenizerBridge::from_file(path).map_err(runtime_to_py)?;
Ok(Self { inner: bridge })
}
#[staticmethod]
pub fn from_json(json: &str) -> PyResult<Self> {
let bridge = TokenizerBridge::from_bytes(json.as_bytes()).map_err(runtime_to_py)?;
Ok(Self { inner: bridge })
}
pub fn encode(&self, text: &str) -> PyResult<Vec<u32>> {
self.inner.encode(text).map_err(runtime_to_py)
}
pub fn decode(&self, ids: Vec<u32>) -> PyResult<String> {
self.inner.decode(&ids).map_err(runtime_to_py)
}
#[getter]
pub fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.inner.id_to_token(id)
}
fn __repr__(&self) -> String {
format!("Tokenizer(vocab_size={})", self.inner.vocab_size())
}
pub fn encode_batch(&self, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
texts
.iter()
.map(|t| self.inner.encode(t).map_err(runtime_to_py))
.collect()
}
#[pyo3(signature = (messages, template = None, add_generation_prompt = None))]
pub fn apply_chat_template(
&self,
_py: Python<'_>,
messages: Vec<Bound<'_, PyDict>>,
template: Option<String>,
add_generation_prompt: Option<bool>,
) -> PyResult<String> {
let tpl = template.as_deref().unwrap_or("chatml");
let add_gen = add_generation_prompt.unwrap_or(true);
apply_template(tpl, &messages, add_gen)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_file_nonexistent() {
let path = std::env::temp_dir().join("oxillama_py_no_such_tokenizer_42.json");
let path_str = path.to_string_lossy();
let result = TokenizerBridge::from_file(&path_str);
assert!(result.is_err(), "nonexistent tokenizer file should error");
}
#[test]
fn test_from_json_invalid() {
let result = TokenizerBridge::from_bytes(b"not valid json at all");
assert!(result.is_err(), "invalid JSON should error");
}
#[test]
fn test_from_file_with_empty_file() {
let tmp = std::env::temp_dir().join("oxillama_py_empty_tok.json");
std::fs::write(&tmp, "{}").ok();
let _result = TokenizerBridge::from_bytes(b"{}");
let _ = std::fs::remove_file(&tmp);
}
}