Skip to main content

llm_tokenizer/
factory.rs

1use std::{fs::File, io::Read, path::Path, sync::Arc};
2
3use anyhow::{Error, Result};
4use tracing::debug;
5
6use crate::{
7    hub::download_tokenizer_from_hf,
8    huggingface::HuggingFaceTokenizer,
9    tiktoken::{has_tiktoken_file, is_tiktoken_file, TiktokenTokenizer},
10    traits,
11};
12
13/// Represents the type of tokenizer being used
14#[derive(Debug, Clone)]
15pub enum TokenizerType {
16    HuggingFace(String),
17    Mock,
18    Tiktoken(String),
19    // Future: SentencePiece, GGUF
20}
21
22/// Create a tokenizer from a file path to a tokenizer file.
23/// The file extension is used to determine the tokenizer type.
24/// Supported file types are:
25/// - json: HuggingFace tokenizer
26/// - For testing: can return mock tokenizer
27pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
28    create_tokenizer_with_chat_template(file_path, None)
29}
30
31/// Create a tokenizer from a file path with an optional chat template
32pub fn create_tokenizer_with_chat_template(
33    file_path: &str,
34    chat_template_path: Option<&str>,
35) -> Result<Arc<dyn traits::Tokenizer>> {
36    // Special case for testing
37    if file_path == "mock" || file_path == "test" {
38        return Ok(Arc::new(super::mock::MockTokenizer::new()));
39    }
40
41    let path = Path::new(file_path);
42
43    // Check if file exists
44    if !path.exists() {
45        return Err(Error::msg(format!("File not found: {file_path}")));
46    }
47
48    // If path is a directory, search for tokenizer files
49    if path.is_dir() {
50        let tokenizer_json = path.join("tokenizer.json");
51        if tokenizer_json.exists() {
52            // Resolve chat template: provided path takes precedence over auto-discovery
53            let final_chat_template =
54                resolve_and_log_chat_template(chat_template_path, path, file_path);
55            let tokenizer_path_str = tokenizer_json.to_str().ok_or_else(|| {
56                Error::msg(format!(
57                    "Tokenizer path is not valid UTF-8: {tokenizer_json:?}"
58                ))
59            })?;
60            return create_tokenizer_with_chat_template(
61                tokenizer_path_str,
62                final_chat_template.as_deref(),
63            );
64        }
65
66        // Priority 2: tiktoken.model / *.tiktoken
67        // Only forward the user's explicit chat_template_path — tiktoken handles
68        // its own config/discovery (tokenizer_config.json → directory discovery).
69        if has_tiktoken_file(path) {
70            return Ok(Arc::new(TiktokenTokenizer::from_dir_with_chat_template(
71                path,
72                chat_template_path,
73            )?));
74        }
75
76        return Err(Error::msg(format!(
77            "Directory '{file_path}' does not contain a valid tokenizer file (tokenizer.json, tiktoken.model, *.tiktoken, or vocab.json)"
78        )));
79    }
80
81    // Try to determine tokenizer type from extension
82    let extension = path
83        .extension()
84        .and_then(std::ffi::OsStr::to_str)
85        .map(|s| s.to_lowercase());
86
87    let result = match extension.as_deref() {
88        Some("json") => {
89            let tokenizer =
90                HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
91
92            Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
93        }
94        Some("model") | Some("tiktoken") => {
95            // Check if it's a tiktoken file (tiktoken.model / *.tiktoken) before assuming SentencePiece
96            if is_tiktoken_file(path) {
97                Ok(Arc::new(TiktokenTokenizer::from_file_with_chat_template(
98                    path,
99                    chat_template_path,
100                )?) as Arc<dyn traits::Tokenizer>)
101            } else {
102                Err(Error::msg("SentencePiece models not yet supported"))
103            }
104        }
105        Some("gguf") => {
106            // GGUF format
107            Err(Error::msg("GGUF format not yet supported"))
108        }
109        _ => {
110            // Try to auto-detect by reading file content
111            auto_detect_tokenizer(file_path)
112        }
113    };
114
115    result
116}
117
118/// Auto-detect tokenizer type by examining file content
119fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
120    let mut file = File::open(file_path)?;
121    let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
122    let bytes_read = file.read(&mut buffer)?;
123    buffer.truncate(bytes_read);
124
125    // Check for JSON (HuggingFace format)
126    if is_likely_json(&buffer) {
127        let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
128        return Ok(Arc::new(tokenizer));
129    }
130
131    // Check for GGUF magic number
132    if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
133        return Err(Error::msg("GGUF format detected but not yet supported"));
134    }
135
136    // Check for SentencePiece model
137    if is_likely_sentencepiece(&buffer) {
138        return Err(Error::msg(
139            "SentencePiece model detected but not yet supported",
140        ));
141    }
142
143    Err(Error::msg(format!(
144        "Unable to determine tokenizer type for file: {file_path}"
145    )))
146}
147
148/// Check if the buffer likely contains JSON data
149fn is_likely_json(buffer: &[u8]) -> bool {
150    // Skip UTF-8 BOM if present
151    let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
152        &buffer[3..]
153    } else {
154        buffer
155    };
156
157    // Find first non-whitespace character without allocation
158    if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
159        *first_byte == b'{' || *first_byte == b'['
160    } else {
161        false
162    }
163}
164
165/// Check if the buffer likely contains a SentencePiece model
166fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
167    // SentencePiece models often start with specific patterns
168    // This is a simplified check
169    if buffer.len() < 12 {
170        return false;
171    }
172
173    // Check header patterns first (cheap)
174    if buffer.starts_with(b"\x0a\x09") || buffer.starts_with(b"\x08\x00") {
175        return true;
176    }
177
178    // Single-pass scan for special token markers
179    // Instead of multiple windows() calls, scan once looking for all patterns
180    let patterns: &[&[u8]] = &[b"<unk", b"<s>", b"</s>"];
181    for window in buffer.windows(4) {
182        for pattern in patterns {
183            if window.starts_with(pattern) {
184                return true;
185            }
186        }
187    }
188    false
189}
190
191/// Helper function to discover chat template files in a directory
192pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
193    use std::fs;
194
195    // Priority 1: Look for chat_template.json (contains Jinja in JSON format)
196    let json_template_path = dir.join("chat_template.json");
197    if json_template_path.exists() {
198        return json_template_path.to_str().map(|s| s.to_string());
199    }
200
201    // Priority 2: Look for chat_template.jinja (standard Jinja file)
202    let jinja_path = dir.join("chat_template.jinja");
203    if jinja_path.exists() {
204        return jinja_path.to_str().map(|s| s.to_string());
205    }
206
207    // Priority 3: Look for any .jinja file (for models with non-standard naming)
208    if let Ok(entries) = fs::read_dir(dir) {
209        for entry in entries.flatten() {
210            if let Some(name) = entry.file_name().to_str() {
211                if name.ends_with(".jinja") && name != "chat_template.jinja" {
212                    return entry.path().to_str().map(|s| s.to_string());
213                }
214            }
215        }
216    }
217
218    None
219}
220
221/// Helper function to resolve and log chat template selection
222///
223/// Resolves the final chat template to use by prioritizing provided path over auto-discovery,
224/// and logs the source for debugging purposes.
225fn resolve_and_log_chat_template(
226    provided_path: Option<&str>,
227    discovery_dir: &Path,
228    model_name: &str,
229) -> Option<String> {
230    let final_chat_template = provided_path
231        .map(|s| s.to_string())
232        .or_else(|| discover_chat_template_in_dir(discovery_dir));
233
234    match (&provided_path, &final_chat_template) {
235        (Some(provided), _) => {
236            debug!("Using provided chat template: {}", provided);
237        }
238        (None, Some(discovered)) => {
239            debug!(
240                "Auto-discovered chat template in '{}': {}",
241                discovery_dir.display(),
242                discovered
243            );
244        }
245        (None, None) => {
246            debug!(
247                "No chat template provided or discovered for model: {}",
248                model_name
249            );
250        }
251    }
252
253    final_chat_template
254}
255
256/// Factory function to create tokenizer from a model name or path (async version)
257pub async fn create_tokenizer_async(
258    model_name_or_path: &str,
259) -> Result<Arc<dyn traits::Tokenizer>> {
260    create_tokenizer_async_with_chat_template(model_name_or_path, None).await
261}
262
263/// Check if a model name looks like an OpenAI model that should use tiktoken.
264///
265/// Uses targeted patterns to minimise false positives.  False negatives are
266/// acceptable — the caller falls back to HuggingFace Hub download, which
267/// handles non-OpenAI models correctly.  False positives waste time by
268/// trying tiktoken for a model that doesn't support it.
269///
270/// Matched model families:
271///   gpt-4, gpt-4o, gpt-4-turbo, gpt-4-32k, gpt-4o-mini, gpt-4.5-preview,
272///   gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-3.5-turbo-instruct,
273///   chatgpt-4o-latest,
274///   o1, o1-mini, o1-preview, o3, o3-mini, o3-pro, o4-mini,
275///   text-davinci-003, code-davinci-002, davinci,
276///   text-curie-001, curie, text-babbage-001, babbage,
277///   text-ada-001, text-embedding-ada-002, ada
278fn is_likely_openai_model(name: &str) -> bool {
279    let bare = name.rsplit('/').next().unwrap_or(name);
280
281    // GPT family: gpt-4*, gpt-3.5*, chatgpt-*
282    // Require "gpt-" followed by a digit to avoid false positives like "gpt-oss-20b"
283    if bare.starts_with("gpt-") && bare.as_bytes().get(4).is_some_and(|b| b.is_ascii_digit()) {
284        return true;
285    }
286    if bare.starts_with("chatgpt-") {
287        return true;
288    }
289
290    // Reasoning model family (o1, o1-mini, o1-preview, o3, o3-mini, o3-pro, o4-mini, …)
291    // Pattern: "o" + digit, optionally followed by "-suffix"
292    if bare.starts_with('o')
293        && bare.as_bytes().get(1).is_some_and(|b| b.is_ascii_digit())
294        && bare.as_bytes().get(2).is_none_or(|b| *b == b'-')
295    {
296        return true;
297    }
298
299    // Legacy completion / embedding / edit models.
300    // Use prefix-based checks ("text-", "code-") or exact-match for bare
301    // names to avoid matching unrelated models (e.g. "adapter-v2" for "ada",
302    // "turbo-llama" for "turbo").
303    matches!(bare, "davinci" | "curie" | "babbage" | "ada")
304        || bare.starts_with("text-davinci")
305        || bare.starts_with("code-davinci")
306        || bare.starts_with("text-curie")
307        || bare.starts_with("text-babbage")
308        || bare.starts_with("text-ada")
309        || bare.starts_with("text-embedding-ada")
310        || bare.starts_with("code-cushman")
311}
312
313/// Factory function to create tokenizer with optional chat template (async version)
314pub async fn create_tokenizer_async_with_chat_template(
315    model_name_or_path: &str,
316    chat_template_path: Option<&str>,
317) -> Result<Arc<dyn traits::Tokenizer>> {
318    // Check if it's a file path
319    let path = Path::new(model_name_or_path);
320    if path.exists() {
321        return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
322    }
323
324    // Check if it's a GPT model name that should use Tiktoken
325    if is_likely_openai_model(model_name_or_path) {
326        // Try tiktoken first, but fall back to HuggingFace if it fails
327        match TiktokenTokenizer::from_model_name(model_name_or_path) {
328            Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
329            Err(e) => {
330                debug!(
331                    "Tiktoken failed for '{}': {}, falling back to HuggingFace",
332                    model_name_or_path, e
333                );
334            }
335        }
336    }
337
338    // Try to download tokenizer files from HuggingFace
339    match download_tokenizer_from_hf(model_name_or_path).await {
340        Ok(cache_dir) => {
341            // Look for tokenizer.json in the cache directory
342            let tokenizer_path = cache_dir.join("tokenizer.json");
343            if tokenizer_path.exists() {
344                // Resolve chat template: provided path takes precedence over auto-discovery
345                let final_chat_template = resolve_and_log_chat_template(
346                    chat_template_path,
347                    &cache_dir,
348                    model_name_or_path,
349                );
350
351                let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
352                    Error::msg(format!(
353                        "Tokenizer path is not valid UTF-8: {tokenizer_path:?}"
354                    ))
355                })?;
356                create_tokenizer_with_chat_template(
357                    tokenizer_path_str,
358                    final_chat_template.as_deref(),
359                )
360            } else if has_tiktoken_file(&cache_dir) {
361                Ok(Arc::new(TiktokenTokenizer::from_dir_with_chat_template(
362                    &cache_dir,
363                    chat_template_path,
364                )?))
365            } else {
366                // Try other common tokenizer file names
367                let possible_files = ["tokenizer_config.json", "vocab.json"];
368                for file_name in &possible_files {
369                    let file_path = cache_dir.join(file_name);
370                    if file_path.exists() {
371                        // Resolve chat template: provided path takes precedence over auto-discovery
372                        let final_chat_template = resolve_and_log_chat_template(
373                            chat_template_path,
374                            &cache_dir,
375                            model_name_or_path,
376                        );
377
378                        let file_path_str = file_path.to_str().ok_or_else(|| {
379                            Error::msg(format!("File path is not valid UTF-8: {file_path:?}"))
380                        })?;
381                        return create_tokenizer_with_chat_template(
382                            file_path_str,
383                            final_chat_template.as_deref(),
384                        );
385                    }
386                }
387                Err(Error::msg(format!(
388                    "Downloaded model '{model_name_or_path}' but couldn't find a suitable tokenizer file"
389                )))
390            }
391        }
392        Err(e) => Err(Error::msg(format!(
393            "Failed to download tokenizer from HuggingFace: {e}"
394        ))),
395    }
396}
397
398/// Factory function to create tokenizer from a model name or path (blocking version)
399///
400/// This delegates to `create_tokenizer_with_chat_template_blocking` with no chat template,
401/// which handles both local files and HuggingFace Hub downloads uniformly.
402pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
403    create_tokenizer_with_chat_template_blocking(model_name_or_path, None)
404}
405
406/// Factory function to create tokenizer with optional chat template (blocking version)
407pub fn create_tokenizer_with_chat_template_blocking(
408    model_name_or_path: &str,
409    chat_template_path: Option<&str>,
410) -> Result<Arc<dyn traits::Tokenizer>> {
411    // Check if it's a file path
412    let path = Path::new(model_name_or_path);
413    if path.exists() {
414        return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
415    }
416
417    // Check if it's a GPT model name that should use Tiktoken
418    // Try tiktoken first, but fall back to HuggingFace if it fails
419    if is_likely_openai_model(model_name_or_path) {
420        match TiktokenTokenizer::from_model_name(model_name_or_path) {
421            Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
422            Err(e) => {
423                debug!(
424                    "Tiktoken failed for '{}': {}, falling back to HuggingFace",
425                    model_name_or_path, e
426                );
427            }
428        }
429    }
430
431    // Fall back to HuggingFace Hub download (requires tokio runtime)
432    if let Ok(handle) = tokio::runtime::Handle::try_current() {
433        tokio::task::block_in_place(|| {
434            handle.block_on(create_tokenizer_async_with_chat_template(
435                model_name_or_path,
436                chat_template_path,
437            ))
438        })
439    } else {
440        let rt = tokio::runtime::Runtime::new()?;
441        rt.block_on(create_tokenizer_async_with_chat_template(
442            model_name_or_path,
443            chat_template_path,
444        ))
445    }
446}
447
448/// Get information about a tokenizer file
449pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
450    let path = Path::new(file_path);
451
452    if !path.exists() {
453        return Err(Error::msg(format!("File not found: {file_path}")));
454    }
455
456    let extension = path
457        .extension()
458        .and_then(std::ffi::OsStr::to_str)
459        .map(|s| s.to_lowercase());
460
461    match extension.as_deref() {
462        Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
463        _ => {
464            // Try auto-detection
465            use std::{fs::File, io::Read};
466
467            let mut file = File::open(file_path)?;
468            let mut buffer = vec![0u8; 512];
469            let bytes_read = file.read(&mut buffer)?;
470            buffer.truncate(bytes_read);
471
472            if is_likely_json(&buffer) {
473                Ok(TokenizerType::HuggingFace(file_path.to_string()))
474            } else {
475                Err(Error::msg("Unknown tokenizer type"))
476            }
477        }
478    }
479}
480
481#[cfg(test)]
482#[expect(
483    clippy::print_stdout,
484    reason = "diagnostic output in tests for CI skip messages and download results"
485)]
486mod tests {
487    use super::{
488        create_tokenizer, create_tokenizer_async, create_tokenizer_from_file, is_likely_json,
489        is_likely_openai_model,
490    };
491
492    #[test]
493    fn test_json_detection() {
494        assert!(is_likely_json(b"{\"test\": \"value\"}"));
495        assert!(is_likely_json(b"  \n\t{\"test\": \"value\"}"));
496        assert!(is_likely_json(b"[1, 2, 3]"));
497        assert!(!is_likely_json(b"not json"));
498        assert!(!is_likely_json(b""));
499    }
500
501    #[test]
502    fn test_mock_tokenizer_creation() {
503        let tokenizer = create_tokenizer_from_file("mock").unwrap();
504        assert_eq!(tokenizer.vocab_size(), 14); // Mock tokenizer has 14 tokens
505    }
506
507    #[test]
508    fn test_file_not_found() {
509        let result = create_tokenizer_from_file("/nonexistent/file.json");
510        assert!(result.is_err());
511        if let Err(e) = result {
512            assert!(e.to_string().contains("File not found"));
513        }
514    }
515
516    #[test]
517    fn test_create_tiktoken_tokenizer() {
518        let tokenizer = create_tokenizer("gpt-4").unwrap();
519        assert!(tokenizer.vocab_size() > 0);
520
521        let text = "Hello, world!";
522        let encoding = tokenizer.encode(text, false).unwrap();
523        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
524        assert_eq!(decoded, text);
525    }
526
527    #[tokio::test]
528    async fn test_download_tokenizer_from_hf() {
529        // Skip this test if HF_TOKEN is not set and we're in CI
530        if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
531            println!("Skipping HF download test in CI without HF_TOKEN");
532            return;
533        }
534
535        // Try to create tokenizer for a known small model
536        let result = create_tokenizer_async("bert-base-uncased").await;
537
538        // The test might fail due to network issues or rate limiting
539        // so we just check that the function executes without panic
540        match result {
541            Ok(tokenizer) => {
542                assert!(tokenizer.vocab_size() > 0);
543                println!("Successfully downloaded and created tokenizer");
544            }
545            Err(e) => {
546                println!("Download failed (this might be expected): {e}");
547                // Don't fail the test - network issues shouldn't break CI
548            }
549        }
550    }
551
552    #[test]
553    fn test_is_likely_openai_model_positives() {
554        // GPT-4 family
555        assert!(is_likely_openai_model("gpt-4"));
556        assert!(is_likely_openai_model("gpt-4o"));
557        assert!(is_likely_openai_model("gpt-4o-mini"));
558        assert!(is_likely_openai_model("gpt-4-turbo"));
559        assert!(is_likely_openai_model("gpt-4-32k"));
560        assert!(is_likely_openai_model("gpt-4.5-preview"));
561
562        // GPT-3.5 family
563        assert!(is_likely_openai_model("gpt-3.5-turbo"));
564        assert!(is_likely_openai_model("gpt-3.5-turbo-16k"));
565        assert!(is_likely_openai_model("gpt-3.5-turbo-instruct"));
566
567        // ChatGPT
568        assert!(is_likely_openai_model("chatgpt-4o-latest"));
569
570        // Reasoning models
571        assert!(is_likely_openai_model("o1"));
572        assert!(is_likely_openai_model("o1-mini"));
573        assert!(is_likely_openai_model("o1-preview"));
574        assert!(is_likely_openai_model("o3"));
575        assert!(is_likely_openai_model("o3-mini"));
576        assert!(is_likely_openai_model("o3-pro"));
577        assert!(is_likely_openai_model("o4-mini"));
578
579        // Legacy models
580        assert!(is_likely_openai_model("davinci"));
581        assert!(is_likely_openai_model("text-davinci-003"));
582        assert!(is_likely_openai_model("code-davinci-002"));
583        assert!(is_likely_openai_model("curie"));
584        assert!(is_likely_openai_model("text-curie-001"));
585        assert!(is_likely_openai_model("babbage"));
586        assert!(is_likely_openai_model("text-babbage-001"));
587        assert!(is_likely_openai_model("ada"));
588        assert!(is_likely_openai_model("text-ada-001"));
589        assert!(is_likely_openai_model("text-embedding-ada-002"));
590        assert!(is_likely_openai_model("code-cushman-001"));
591
592        // With org prefix
593        assert!(is_likely_openai_model("openai/gpt-4"));
594        assert!(is_likely_openai_model("openai/o1-mini"));
595        assert!(is_likely_openai_model("openai/davinci"));
596    }
597
598    #[test]
599    fn test_is_likely_openai_model_negatives() {
600        // HuggingFace models that should NOT match
601        assert!(!is_likely_openai_model("openai/gpt-oss-20b"));
602        assert!(!is_likely_openai_model("meta-llama/Llama-3-8B"));
603        assert!(!is_likely_openai_model("mistralai/Mistral-7B"));
604        assert!(!is_likely_openai_model("bert-base-uncased"));
605
606        // Names that previously caused false positives
607        assert!(!is_likely_openai_model("turbo-llama"));
608        assert!(!is_likely_openai_model("adapter-v2"));
609        assert!(!is_likely_openai_model("oracle-7b"));
610        assert!(!is_likely_openai_model("open-llama"));
611    }
612}