1use async_trait::async_trait;
8use ferrum_types::{Result, SpecialTokens, TokenId};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub trait Tokenizer: Send + Sync {
14 fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>>;
16
17 fn decode(&self, tokens: &[TokenId], skip_special: bool) -> Result<String>;
19
20 fn decode_incremental(&self, prev: &[TokenId], next: TokenId) -> Result<String>;
23
24 fn vocab_size(&self) -> usize;
26
27 fn special_tokens(&self) -> &SpecialTokens;
29
30 fn token_id(&self, text: &str) -> Option<TokenId>;
32
33 fn token_text(&self, token_id: TokenId) -> Option<&str>;
35
36 fn is_special_token(&self, token_id: TokenId) -> bool {
38 let special = self.special_tokens();
39 let fallback = TokenId::MAX;
40 token_id == special.bos_token.unwrap_or(fallback)
41 || token_id == special.eos_token.unwrap_or(fallback)
42 || token_id == special.unk_token.unwrap_or(fallback)
43 || token_id == special.pad_token.unwrap_or(fallback)
44 }
45
46 fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
48 let mut result = String::new();
50 for msg in messages {
51 result.push_str(&format!("{}: {}\n", msg.role, msg.content));
52 }
53 Ok(result.trim_end().to_string())
54 }
55
56 fn info(&self) -> TokenizerInfo;
58}
59
60#[async_trait]
62pub trait AsyncTokenizer: Tokenizer {
63 async fn encode_async(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>>;
65
66 async fn decode_async(&self, tokens: &[TokenId], skip_special: bool) -> Result<String>;
68
69 async fn encode_batch(&self, texts: &[&str], add_special: bool) -> Result<Vec<Vec<TokenId>>>;
71
72 async fn decode_batch(
74 &self,
75 token_sequences: &[&[TokenId]],
76 skip_special: bool,
77 ) -> Result<Vec<String>>;
78}
79
80pub trait TokenizerCapabilities: Tokenizer {
82 fn token_probability(&self, text: &str, token_id: TokenId) -> Option<f32>;
84
85 fn get_prefix_tokens(&self, prefix: &str) -> Result<Vec<TokenId>>;
87
88 fn can_extend(&self, tokens: &[TokenId], next_token: TokenId) -> bool;
90
91 fn token_type(&self, token_id: TokenId) -> TokenType;
93
94 fn normalize_text(&self, text: &str) -> String;
96
97 fn pre_tokenize(&self, text: &str) -> Vec<String>;
99}
100
101#[async_trait]
103pub trait TokenizerFactory: Send + Sync {
104 async fn load_from_file(&self, path: &str) -> Result<Box<dyn Tokenizer>>;
106
107 async fn load_from_bytes(&self, data: &[u8]) -> Result<Box<dyn Tokenizer>>;
109
110 async fn load_from_hub(
112 &self,
113 repo_id: &str,
114 revision: Option<&str>,
115 ) -> Result<Box<dyn Tokenizer>>;
116
117 async fn create_from_config(&self, config: &TokenizerConfig) -> Result<Box<dyn Tokenizer>>;
119
120 fn supported_types(&self) -> Vec<TokenizerType>;
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct TokenizerInfo {
127 pub tokenizer_type: TokenizerType,
129 pub vocab_size: usize,
131 pub special_tokens: SpecialTokens,
133 pub supports_incremental: bool,
135 pub supports_chat_template: bool,
137 pub max_token_length: Option<usize>,
139 pub model_name: Option<String>,
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145pub enum TokenizerType {
146 BPE,
148 WordPiece,
150 SentencePiece,
152 Tiktoken,
154 Custom,
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
160pub enum TokenType {
161 Word,
163 Subword,
165 Punctuation,
167 Number,
169 Special,
171 Unknown,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ChatMessage {
178 pub role: String,
180 pub content: String,
182 pub metadata: HashMap<String, serde_json::Value>,
184}
185
186impl ChatMessage {
187 pub fn user(content: impl Into<String>) -> Self {
189 Self {
190 role: "user".to_string(),
191 content: content.into(),
192 metadata: HashMap::new(),
193 }
194 }
195
196 pub fn assistant(content: impl Into<String>) -> Self {
198 Self {
199 role: "assistant".to_string(),
200 content: content.into(),
201 metadata: HashMap::new(),
202 }
203 }
204
205 pub fn system(content: impl Into<String>) -> Self {
207 Self {
208 role: "system".to_string(),
209 content: content.into(),
210 metadata: HashMap::new(),
211 }
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct TokenizerConfig {
218 pub tokenizer_type: TokenizerType,
220 pub path: String,
222 pub add_special_tokens: bool,
224 pub use_fast: bool,
226 pub truncation: Option<TruncationConfig>,
228 pub padding: Option<PaddingConfig>,
230 pub chat_template: Option<String>,
232 pub extra_options: HashMap<String, serde_json::Value>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct TruncationConfig {
239 pub max_length: usize,
241 pub strategy: TruncationStrategy,
243 pub stride: Option<usize>,
245}
246
247#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
249pub enum TruncationStrategy {
250 TruncateEnd,
252 TruncateStart,
254 TruncateBoth,
256 SlidingWindow,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct PaddingConfig {
263 pub strategy: PaddingStrategy,
265 pub token_id: TokenId,
267 pub length: Option<usize>,
269 pub direction: PaddingDirection,
271}
272
273#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
275pub enum PaddingStrategy {
276 None,
278 Longest,
280 MultipleOf(usize),
282 Fixed,
284}
285
286#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
288pub enum PaddingDirection {
289 Right,
291 Left,
293}
294
295pub trait IncrementalTokenizer: Tokenizer {
297 type State: Send + Sync;
299
300 fn create_state(&self) -> Self::State;
302
303 fn decode_incremental_with_state(
305 &self,
306 state: &mut Self::State,
307 token: TokenId,
308 ) -> Result<String>;
309
310 fn reset_state(&self, state: &mut Self::State);
312
313 fn get_decoded_text(&self, state: &Self::State) -> String;
315}
316
317pub trait TextProcessor: Send + Sync {
319 fn preprocess(&self, text: &str) -> String;
321
322 fn postprocess(&self, text: &str) -> String;
324
325 fn detect_language(&self, text: &str) -> Option<String>;
327
328 fn sentence_split(&self, text: &str) -> Vec<String>;
330
331 fn estimate_token_count(&self, text: &str) -> usize;
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct TokenizerStats {
338 pub encode_operations: u64,
340 pub decode_operations: u64,
342 pub tokens_processed: u64,
344 pub avg_encode_time_per_char_us: f64,
346 pub avg_decode_time_per_token_us: f64,
348 pub incremental_cache_hit_rate: f32,
350}
351
352pub trait TokenizerRegistry: Send + Sync {
354 fn register(&mut self, name: &str, tokenizer: Box<dyn Tokenizer>) -> Result<()>;
356
357 fn get(&self, name: &str) -> Option<&dyn Tokenizer>;
359
360 fn remove(&mut self, name: &str) -> Option<Box<dyn Tokenizer>>;
362
363 fn list_names(&self) -> Vec<String>;
365
366 fn contains(&self, name: &str) -> bool;
368}