llm_tokenizer/
factory.rs

1use std::{fs::File, io::Read, path::Path, sync::Arc};
2
3use anyhow::{Error, Result};
4use tracing::debug;
5
6use crate::{
7    hub::download_tokenizer_from_hf, huggingface::HuggingFaceTokenizer,
8    tiktoken::TiktokenTokenizer, traits,
9};
10
11/// Represents the type of tokenizer being used
12#[derive(Debug, Clone)]
13pub enum TokenizerType {
14    HuggingFace(String),
15    Mock,
16    Tiktoken(String),
17    // Future: SentencePiece, GGUF
18}
19
20/// Create a tokenizer from a file path to a tokenizer file.
21/// The file extension is used to determine the tokenizer type.
22/// Supported file types are:
23/// - json: HuggingFace tokenizer
24/// - For testing: can return mock tokenizer
25pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
26    create_tokenizer_with_chat_template(file_path, None)
27}
28
29/// Create a tokenizer from a file path with an optional chat template
30pub fn create_tokenizer_with_chat_template(
31    file_path: &str,
32    chat_template_path: Option<&str>,
33) -> Result<Arc<dyn traits::Tokenizer>> {
34    // Special case for testing
35    if file_path == "mock" || file_path == "test" {
36        return Ok(Arc::new(super::mock::MockTokenizer::new()));
37    }
38
39    let path = Path::new(file_path);
40
41    // Check if file exists
42    if !path.exists() {
43        return Err(Error::msg(format!("File not found: {}", file_path)));
44    }
45
46    // If path is a directory, search for tokenizer files
47    if path.is_dir() {
48        let tokenizer_json = path.join("tokenizer.json");
49        if tokenizer_json.exists() {
50            // Resolve chat template: provided path takes precedence over auto-discovery
51            let final_chat_template =
52                resolve_and_log_chat_template(chat_template_path, path, file_path);
53            let tokenizer_path_str = tokenizer_json.to_str().ok_or_else(|| {
54                Error::msg(format!(
55                    "Tokenizer path is not valid UTF-8: {:?}",
56                    tokenizer_json
57                ))
58            })?;
59            return create_tokenizer_with_chat_template(
60                tokenizer_path_str,
61                final_chat_template.as_deref(),
62            );
63        }
64
65        return Err(Error::msg(format!(
66            "Directory '{}' does not contain a valid tokenizer file (tokenizer.json, tokenizer_config.json, or vocab.json)",
67            file_path
68        )));
69    }
70
71    // Try to determine tokenizer type from extension
72    let extension = path
73        .extension()
74        .and_then(std::ffi::OsStr::to_str)
75        .map(|s| s.to_lowercase());
76
77    let result = match extension.as_deref() {
78        Some("json") => {
79            let tokenizer =
80                HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
81
82            Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
83        }
84        Some("model") => {
85            // SentencePiece model file
86            Err(Error::msg("SentencePiece models not yet supported"))
87        }
88        Some("gguf") => {
89            // GGUF format
90            Err(Error::msg("GGUF format not yet supported"))
91        }
92        _ => {
93            // Try to auto-detect by reading file content
94            auto_detect_tokenizer(file_path)
95        }
96    };
97
98    result
99}
100
101/// Auto-detect tokenizer type by examining file content
102fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
103    let mut file = File::open(file_path)?;
104    let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
105    let bytes_read = file.read(&mut buffer)?;
106    buffer.truncate(bytes_read);
107
108    // Check for JSON (HuggingFace format)
109    if is_likely_json(&buffer) {
110        let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
111        return Ok(Arc::new(tokenizer));
112    }
113
114    // Check for GGUF magic number
115    if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
116        return Err(Error::msg("GGUF format detected but not yet supported"));
117    }
118
119    // Check for SentencePiece model
120    if is_likely_sentencepiece(&buffer) {
121        return Err(Error::msg(
122            "SentencePiece model detected but not yet supported",
123        ));
124    }
125
126    Err(Error::msg(format!(
127        "Unable to determine tokenizer type for file: {}",
128        file_path
129    )))
130}
131
132/// Check if the buffer likely contains JSON data
133fn is_likely_json(buffer: &[u8]) -> bool {
134    // Skip UTF-8 BOM if present
135    let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
136        &buffer[3..]
137    } else {
138        buffer
139    };
140
141    // Find first non-whitespace character without allocation
142    if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
143        *first_byte == b'{' || *first_byte == b'['
144    } else {
145        false
146    }
147}
148
149/// Check if the buffer likely contains a SentencePiece model
150fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
151    // SentencePiece models often start with specific patterns
152    // This is a simplified check
153    if buffer.len() < 12 {
154        return false;
155    }
156
157    // Check header patterns first (cheap)
158    if buffer.starts_with(b"\x0a\x09") || buffer.starts_with(b"\x08\x00") {
159        return true;
160    }
161
162    // Single-pass scan for special token markers
163    // Instead of multiple windows() calls, scan once looking for all patterns
164    let patterns: &[&[u8]] = &[b"<unk", b"<s>", b"</s>"];
165    for window in buffer.windows(4) {
166        for pattern in patterns {
167            if window.starts_with(pattern) {
168                return true;
169            }
170        }
171    }
172    false
173}
174
175/// Helper function to discover chat template files in a directory
176pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
177    use std::fs;
178
179    // Priority 1: Look for chat_template.json (contains Jinja in JSON format)
180    let json_template_path = dir.join("chat_template.json");
181    if json_template_path.exists() {
182        return json_template_path.to_str().map(|s| s.to_string());
183    }
184
185    // Priority 2: Look for chat_template.jinja (standard Jinja file)
186    let jinja_path = dir.join("chat_template.jinja");
187    if jinja_path.exists() {
188        return jinja_path.to_str().map(|s| s.to_string());
189    }
190
191    // Priority 3: Look for any .jinja file (for models with non-standard naming)
192    if let Ok(entries) = fs::read_dir(dir) {
193        for entry in entries.flatten() {
194            if let Some(name) = entry.file_name().to_str() {
195                if name.ends_with(".jinja") && name != "chat_template.jinja" {
196                    return entry.path().to_str().map(|s| s.to_string());
197                }
198            }
199        }
200    }
201
202    None
203}
204
205/// Helper function to resolve and log chat template selection
206///
207/// Resolves the final chat template to use by prioritizing provided path over auto-discovery,
208/// and logs the source for debugging purposes.
209fn resolve_and_log_chat_template(
210    provided_path: Option<&str>,
211    discovery_dir: &Path,
212    model_name: &str,
213) -> Option<String> {
214    let final_chat_template = provided_path
215        .map(|s| s.to_string())
216        .or_else(|| discover_chat_template_in_dir(discovery_dir));
217
218    match (&provided_path, &final_chat_template) {
219        (Some(provided), _) => {
220            debug!("Using provided chat template: {}", provided);
221        }
222        (None, Some(discovered)) => {
223            debug!(
224                "Auto-discovered chat template in '{}': {}",
225                discovery_dir.display(),
226                discovered
227            );
228        }
229        (None, None) => {
230            debug!(
231                "No chat template provided or discovered for model: {}",
232                model_name
233            );
234        }
235    }
236
237    final_chat_template
238}
239
240/// Factory function to create tokenizer from a model name or path (async version)
241pub async fn create_tokenizer_async(
242    model_name_or_path: &str,
243) -> Result<Arc<dyn traits::Tokenizer>> {
244    create_tokenizer_async_with_chat_template(model_name_or_path, None).await
245}
246
247/// Factory function to create tokenizer with optional chat template (async version)
248pub async fn create_tokenizer_async_with_chat_template(
249    model_name_or_path: &str,
250    chat_template_path: Option<&str>,
251) -> Result<Arc<dyn traits::Tokenizer>> {
252    // Check if it's a file path
253    let path = Path::new(model_name_or_path);
254    if path.exists() {
255        return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
256    }
257
258    // Check if it's a GPT model name that should use Tiktoken
259    // Only match specific OpenAI model patterns to avoid catching HuggingFace models like "openai/gpt-oss-20b"
260    if model_name_or_path.contains("gpt-4")
261        || model_name_or_path.contains("gpt-3.5")
262        || model_name_or_path.contains("gpt-3")
263        || model_name_or_path.contains("turbo")
264        || model_name_or_path.contains("davinci")
265        || model_name_or_path.contains("curie")
266        || model_name_or_path.contains("babbage")
267        || model_name_or_path.contains("ada")
268        || model_name_or_path.contains("codex")
269    {
270        // Try tiktoken first, but fall back to HuggingFace if it fails
271        match TiktokenTokenizer::from_model_name(model_name_or_path) {
272            Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
273            Err(e) => {
274                debug!(
275                    "Tiktoken failed for '{}': {}, falling back to HuggingFace",
276                    model_name_or_path, e
277                );
278            }
279        }
280    }
281
282    // Try to download tokenizer files from HuggingFace
283    match download_tokenizer_from_hf(model_name_or_path).await {
284        Ok(cache_dir) => {
285            // Look for tokenizer.json in the cache directory
286            let tokenizer_path = cache_dir.join("tokenizer.json");
287            if tokenizer_path.exists() {
288                // Resolve chat template: provided path takes precedence over auto-discovery
289                let final_chat_template = resolve_and_log_chat_template(
290                    chat_template_path,
291                    &cache_dir,
292                    model_name_or_path,
293                );
294
295                let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
296                    Error::msg(format!(
297                        "Tokenizer path is not valid UTF-8: {:?}",
298                        tokenizer_path
299                    ))
300                })?;
301                create_tokenizer_with_chat_template(
302                    tokenizer_path_str,
303                    final_chat_template.as_deref(),
304                )
305            } else {
306                // Try other common tokenizer file names
307                let possible_files = ["tokenizer_config.json", "vocab.json"];
308                for file_name in &possible_files {
309                    let file_path = cache_dir.join(file_name);
310                    if file_path.exists() {
311                        // Resolve chat template: provided path takes precedence over auto-discovery
312                        let final_chat_template = resolve_and_log_chat_template(
313                            chat_template_path,
314                            &cache_dir,
315                            model_name_or_path,
316                        );
317
318                        let file_path_str = file_path.to_str().ok_or_else(|| {
319                            Error::msg(format!("File path is not valid UTF-8: {:?}", file_path))
320                        })?;
321                        return create_tokenizer_with_chat_template(
322                            file_path_str,
323                            final_chat_template.as_deref(),
324                        );
325                    }
326                }
327                Err(Error::msg(format!(
328                    "Downloaded model '{}' but couldn't find a suitable tokenizer file",
329                    model_name_or_path
330                )))
331            }
332        }
333        Err(e) => Err(Error::msg(format!(
334            "Failed to download tokenizer from HuggingFace: {}",
335            e
336        ))),
337    }
338}
339
340/// Factory function to create tokenizer from a model name or path (blocking version)
341///
342/// This delegates to `create_tokenizer_with_chat_template_blocking` with no chat template,
343/// which handles both local files and HuggingFace Hub downloads uniformly.
344pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
345    create_tokenizer_with_chat_template_blocking(model_name_or_path, None)
346}
347
348/// Factory function to create tokenizer with optional chat template (blocking version)
349pub fn create_tokenizer_with_chat_template_blocking(
350    model_name_or_path: &str,
351    chat_template_path: Option<&str>,
352) -> Result<Arc<dyn traits::Tokenizer>> {
353    // Check if it's a file path
354    let path = Path::new(model_name_or_path);
355    if path.exists() {
356        return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
357    }
358
359    // Check if it's a GPT model name that should use Tiktoken
360    if model_name_or_path.contains("gpt-")
361        || model_name_or_path.contains("davinci")
362        || model_name_or_path.contains("curie")
363        || model_name_or_path.contains("babbage")
364        || model_name_or_path.contains("ada")
365    {
366        let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
367        return Ok(Arc::new(tokenizer));
368    }
369
370    // Only use tokio for HuggingFace downloads
371    // Check if we're already in a tokio runtime
372    if let Ok(handle) = tokio::runtime::Handle::try_current() {
373        // We're in a runtime, use block_in_place
374        tokio::task::block_in_place(|| {
375            handle.block_on(create_tokenizer_async_with_chat_template(
376                model_name_or_path,
377                chat_template_path,
378            ))
379        })
380    } else {
381        // No runtime, create a temporary one
382        let rt = tokio::runtime::Runtime::new()?;
383        rt.block_on(create_tokenizer_async_with_chat_template(
384            model_name_or_path,
385            chat_template_path,
386        ))
387    }
388}
389
390/// Get information about a tokenizer file
391pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
392    let path = Path::new(file_path);
393
394    if !path.exists() {
395        return Err(Error::msg(format!("File not found: {}", file_path)));
396    }
397
398    let extension = path
399        .extension()
400        .and_then(std::ffi::OsStr::to_str)
401        .map(|s| s.to_lowercase());
402
403    match extension.as_deref() {
404        Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
405        _ => {
406            // Try auto-detection
407            use std::{fs::File, io::Read};
408
409            let mut file = File::open(file_path)?;
410            let mut buffer = vec![0u8; 512];
411            let bytes_read = file.read(&mut buffer)?;
412            buffer.truncate(bytes_read);
413
414            if is_likely_json(&buffer) {
415                Ok(TokenizerType::HuggingFace(file_path.to_string()))
416            } else {
417                Err(Error::msg("Unknown tokenizer type"))
418            }
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::{
426        create_tokenizer, create_tokenizer_async, create_tokenizer_from_file, is_likely_json,
427    };
428
429    #[test]
430    fn test_json_detection() {
431        assert!(is_likely_json(b"{\"test\": \"value\"}"));
432        assert!(is_likely_json(b"  \n\t{\"test\": \"value\"}"));
433        assert!(is_likely_json(b"[1, 2, 3]"));
434        assert!(!is_likely_json(b"not json"));
435        assert!(!is_likely_json(b""));
436    }
437
438    #[test]
439    fn test_mock_tokenizer_creation() {
440        let tokenizer = create_tokenizer_from_file("mock").unwrap();
441        assert_eq!(tokenizer.vocab_size(), 14); // Mock tokenizer has 14 tokens
442    }
443
444    #[test]
445    fn test_file_not_found() {
446        let result = create_tokenizer_from_file("/nonexistent/file.json");
447        assert!(result.is_err());
448        if let Err(e) = result {
449            assert!(e.to_string().contains("File not found"));
450        }
451    }
452
453    #[test]
454    fn test_create_tiktoken_tokenizer() {
455        let tokenizer = create_tokenizer("gpt-4").unwrap();
456        assert!(tokenizer.vocab_size() > 0);
457
458        let text = "Hello, world!";
459        let encoding = tokenizer.encode(text, false).unwrap();
460        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
461        assert_eq!(decoded, text);
462    }
463
464    #[tokio::test]
465    async fn test_download_tokenizer_from_hf() {
466        // Skip this test if HF_TOKEN is not set and we're in CI
467        if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
468            println!("Skipping HF download test in CI without HF_TOKEN");
469            return;
470        }
471
472        // Try to create tokenizer for a known small model
473        let result = create_tokenizer_async("bert-base-uncased").await;
474
475        // The test might fail due to network issues or rate limiting
476        // so we just check that the function executes without panic
477        match result {
478            Ok(tokenizer) => {
479                assert!(tokenizer.vocab_size() > 0);
480                println!("Successfully downloaded and created tokenizer");
481            }
482            Err(e) => {
483                println!("Download failed (this might be expected): {}", e);
484                // Don't fail the test - network issues shouldn't break CI
485            }
486        }
487    }
488}