use async_trait::async_trait;
use ferrum_types::{Result, SpecialTokens, TokenId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait Tokenizer: Send + Sync {
fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>>;
fn decode(&self, tokens: &[TokenId], skip_special: bool) -> Result<String>;
fn decode_incremental(&self, prev: &[TokenId], next: TokenId) -> Result<String>;
fn vocab_size(&self) -> usize;
fn special_tokens(&self) -> &SpecialTokens;
fn token_id(&self, text: &str) -> Option<TokenId>;
fn token_text(&self, token_id: TokenId) -> Option<&str>;
fn is_special_token(&self, token_id: TokenId) -> bool {
let special = self.special_tokens();
let fallback = TokenId::MAX;
token_id == special.bos_token.unwrap_or(fallback)
|| token_id == special.eos_token.unwrap_or(fallback)
|| token_id == special.unk_token.unwrap_or(fallback)
|| token_id == special.pad_token.unwrap_or(fallback)
}
fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
Ok(result.trim_end().to_string())
}
fn info(&self) -> TokenizerInfo;
}
#[async_trait]
pub trait AsyncTokenizer: Tokenizer {
async fn encode_async(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>>;
async fn decode_async(&self, tokens: &[TokenId], skip_special: bool) -> Result<String>;
async fn encode_batch(&self, texts: &[&str], add_special: bool) -> Result<Vec<Vec<TokenId>>>;
async fn decode_batch(
&self,
token_sequences: &[&[TokenId]],
skip_special: bool,
) -> Result<Vec<String>>;
}
pub trait TokenizerCapabilities: Tokenizer {
fn token_probability(&self, text: &str, token_id: TokenId) -> Option<f32>;
fn get_prefix_tokens(&self, prefix: &str) -> Result<Vec<TokenId>>;
fn can_extend(&self, tokens: &[TokenId], next_token: TokenId) -> bool;
fn token_type(&self, token_id: TokenId) -> TokenType;
fn normalize_text(&self, text: &str) -> String;
fn pre_tokenize(&self, text: &str) -> Vec<String>;
}
#[async_trait]
pub trait TokenizerFactory: Send + Sync {
async fn load_from_file(&self, path: &str) -> Result<Box<dyn Tokenizer>>;
async fn load_from_bytes(&self, data: &[u8]) -> Result<Box<dyn Tokenizer>>;
async fn load_from_hub(
&self,
repo_id: &str,
revision: Option<&str>,
) -> Result<Box<dyn Tokenizer>>;
async fn create_from_config(&self, config: &TokenizerConfig) -> Result<Box<dyn Tokenizer>>;
fn supported_types(&self) -> Vec<TokenizerType>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerInfo {
pub tokenizer_type: TokenizerType,
pub vocab_size: usize,
pub special_tokens: SpecialTokens,
pub supports_incremental: bool,
pub supports_chat_template: bool,
pub max_token_length: Option<usize>,
pub model_name: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenizerType {
BPE,
WordPiece,
SentencePiece,
Tiktoken,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenType {
Word,
Subword,
Punctuation,
Number,
Special,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
metadata: HashMap::new(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
metadata: HashMap::new(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
metadata: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerConfig {
pub tokenizer_type: TokenizerType,
pub path: String,
pub add_special_tokens: bool,
pub use_fast: bool,
pub truncation: Option<TruncationConfig>,
pub padding: Option<PaddingConfig>,
pub chat_template: Option<String>,
pub extra_options: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationConfig {
pub max_length: usize,
pub strategy: TruncationStrategy,
pub stride: Option<usize>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TruncationStrategy {
TruncateEnd,
TruncateStart,
TruncateBoth,
SlidingWindow,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaddingConfig {
pub strategy: PaddingStrategy,
pub token_id: TokenId,
pub length: Option<usize>,
pub direction: PaddingDirection,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum PaddingStrategy {
None,
Longest,
MultipleOf(usize),
Fixed,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum PaddingDirection {
Right,
Left,
}
pub trait IncrementalTokenizer: Tokenizer {
type State: Send + Sync;
fn create_state(&self) -> Self::State;
fn decode_incremental_with_state(
&self,
state: &mut Self::State,
token: TokenId,
) -> Result<String>;
fn reset_state(&self, state: &mut Self::State);
fn get_decoded_text(&self, state: &Self::State) -> String;
}
pub trait TextProcessor: Send + Sync {
fn preprocess(&self, text: &str) -> String;
fn postprocess(&self, text: &str) -> String;
fn detect_language(&self, text: &str) -> Option<String>;
fn sentence_split(&self, text: &str) -> Vec<String>;
fn estimate_token_count(&self, text: &str) -> usize;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerStats {
pub encode_operations: u64,
pub decode_operations: u64,
pub tokens_processed: u64,
pub avg_encode_time_per_char_us: f64,
pub avg_decode_time_per_token_us: f64,
pub incremental_cache_hit_rate: f32,
}
pub trait TokenizerRegistry: Send + Sync {
fn register(&mut self, name: &str, tokenizer: Box<dyn Tokenizer>) -> Result<()>;
fn get(&self, name: &str) -> Option<&dyn Tokenizer>;
fn remove(&mut self, name: &str) -> Option<Box<dyn Tokenizer>>;
fn list_names(&self) -> Vec<String>;
fn contains(&self, name: &str) -> bool;
}