llm-tokenizer 1.0.0

LLM tokenizer library with caching and chat template support
Documentation
use std::path::{Path, PathBuf};

use hf_hub::api::tokio::ApiBuilder;

const IGNORED: [&str; 5] = [
    ".gitattributes",
    "LICENSE",
    "LICENSE.txt",
    "README.md",
    "USE_POLICY.md",
];

/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
    filename.ends_with(".bin")
        || filename.ends_with(".safetensors")
        || filename.ends_with(".h5")
        || filename.ends_with(".msgpack")
        || filename.ends_with(".ckpt.index")
}

/// Checks if a file is an image file
fn is_image(filename: &str) -> bool {
    filename.ends_with(".png")
        || filename.ends_with("PNG")
        || filename.ends_with(".jpg")
        || filename.ends_with("JPG")
        || filename.ends_with(".jpeg")
        || filename.ends_with("JPEG")
}

/// Checks if a file is a tokenizer file
fn is_tokenizer_file(filename: &str) -> bool {
    filename.ends_with("tokenizer.json")
        || filename.ends_with("tokenizer_config.json")
        || filename.ends_with("special_tokens_map.json")
        || filename.ends_with("vocab.json")
        || filename.ends_with("merges.txt")
        || filename.ends_with(".model")  // SentencePiece models
        || filename.ends_with(".tiktoken")
        || is_chat_template_file(filename) // Include chat template files
}

/// Checks if a file is a chat template file
fn is_chat_template_file(filename: &str) -> bool {
    filename.ends_with(".jinja")  // Direct Jinja files
        || filename == "chat_template.json" // JSON file containing Jinja template
}

/// Attempt to download tokenizer files from Hugging Face
/// Returns the directory containing the downloaded tokenizer files
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
    let model_id = model_id.as_ref();
    let api = ApiBuilder::from_env().with_progress(true).build()?;
    let model_name = model_id.display().to_string();

    let repo = api.model(model_name.clone());

    let info = match repo.info().await {
        Ok(info) => info,
        Err(e) => {
            return Err(anyhow::anyhow!(
                "Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
                model_name,
                e
            ));
        }
    };

    if info.siblings.is_empty() {
        return Err(anyhow::anyhow!(
            "Model '{}' exists but contains no downloadable files.",
            model_name
        ));
    }

    let mut cache_dir = None;
    let mut tokenizer_files_found = false;

    // First, identify all tokenizer files to download
    let tokenizer_files: Vec<_> = info
        .siblings
        .iter()
        .filter(|sib| {
            !IGNORED.contains(&sib.rfilename.as_str())
                && !is_image(&sib.rfilename)
                && !is_weight_file(&sib.rfilename)
                && is_tokenizer_file(&sib.rfilename)
        })
        .collect();

    if tokenizer_files.is_empty() {
        return Err(anyhow::anyhow!(
            "No tokenizer files found for model '{}'.",
            model_name
        ));
    }

    // Download all tokenizer files
    for sib in tokenizer_files {
        match repo.get(&sib.rfilename).await {
            Ok(path) => {
                if cache_dir.is_none() {
                    cache_dir = path.parent().map(|p| p.to_path_buf());
                }
                tokenizer_files_found = true;
            }
            Err(e) => {
                return Err(anyhow::anyhow!(
                    "Failed to download tokenizer file '{}' from model '{}': {}",
                    sib.rfilename,
                    model_name,
                    e
                ));
            }
        }
    }

    if !tokenizer_files_found {
        return Err(anyhow::anyhow!(
            "No tokenizer files could be downloaded for model '{}'.",
            model_name
        ));
    }

    match cache_dir {
        Some(dir) => {
            // Ensure we return the correct model directory, not a subfolder
            // Some models have an "original" subfolder for PyTorch weights
            // We want the main model directory that contains tokenizer files
            let final_dir = resolve_model_cache_dir(&dir, &model_name);
            Ok(final_dir)
        }
        None => Err(anyhow::anyhow!(
            "Invalid HF cache path for model '{}'",
            model_name
        )),
    }
}

/// Attempt to download a model from Hugging Face (including weights)
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
    let name = name.as_ref();
    let api = ApiBuilder::from_env().with_progress(true).build()?;
    let model_name = name.display().to_string();

    let repo = api.model(model_name.clone());

    let info = match repo.info().await {
        Ok(info) => info,
        Err(e) => {
            return Err(anyhow::anyhow!(
                "Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
                model_name,
                e
            ));
        }
    };

    if info.siblings.is_empty() {
        return Err(anyhow::anyhow!(
            "Model '{}' exists but contains no downloadable files.",
            model_name
        ));
    }

    let mut p = PathBuf::new();
    let mut files_downloaded = false;

    for sib in info.siblings {
        if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
            continue;
        }

        // If ignore_weights is true, skip weight files
        if ignore_weights && is_weight_file(&sib.rfilename) {
            continue;
        }

        match repo.get(&sib.rfilename).await {
            Ok(path) => {
                p = path;
                files_downloaded = true;
            }
            Err(e) => {
                return Err(anyhow::anyhow!(
                    "Failed to download file '{}' from model '{}': {}",
                    sib.rfilename,
                    model_name,
                    e
                ));
            }
        }
    }

    if !files_downloaded {
        let file_type = if ignore_weights {
            "non-weight"
        } else {
            "valid"
        };
        return Err(anyhow::anyhow!(
            "No {} files found for model '{}'.",
            file_type,
            model_name
        ));
    }

    match p.parent() {
        Some(p) => {
            let final_dir = resolve_model_cache_dir(p, &model_name);
            Ok(final_dir)
        }
        None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
    }
}

/// Resolve the correct model cache directory
/// Handles cases where files might be in subfolders (e.g., "original" folder)
fn resolve_model_cache_dir(path: &Path, model_name: &str) -> PathBuf {
    // Check if we're in a subfolder like "original"
    if let Some(parent) = path.parent() {
        if let Some(folder_name) = path.file_name() {
            if folder_name == "original" {
                // We're in the "original" subfolder, go up one level
                return parent.to_path_buf();
            }
        }
    }

    // Check if the current path contains the model name components
    // This helps ensure we're at the right directory level
    let model_parts: Vec<&str> = model_name.split('/').collect();
    if model_parts.len() >= 2 {
        let expected_pattern = format!(
            "models--{}--{}",
            model_parts[0].replace("-", "--"),
            model_parts[1].replace("-", "--")
        );

        if path.to_string_lossy().contains(&expected_pattern) {
            // We're already at the correct level
            return path.to_path_buf();
        }

        let mut current = path.to_path_buf();

        // First check if current path already contains tokenizer files
        if current.join("tokenizer.json").exists() || current.join("tokenizer_config.json").exists()
        {
            return current;
        }

        // If not, traverse up to find the model root, then look in snapshots
        while let Some(parent) = current.parent() {
            if parent.to_string_lossy().contains(&expected_pattern) {
                let snapshots_dir = parent.join("snapshots");
                if snapshots_dir.exists() && snapshots_dir.is_dir() {
                    if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
                        for entry in entries.flatten() {
                            let snapshot_path = entry.path();
                            if snapshot_path.is_dir()
                                && (snapshot_path.join("tokenizer.json").exists()
                                    || snapshot_path.join("tokenizer_config.json").exists())
                            {
                                return snapshot_path;
                            }
                        }
                    }
                }
                return parent.to_path_buf();
            }
            current = parent.to_path_buf();
        }
    }

    path.to_path_buf()
}

#[cfg(test)]
mod tests {
    use super::{is_chat_template_file, is_tokenizer_file, is_weight_file};

    #[test]
    fn test_is_tokenizer_file() {
        assert!(is_tokenizer_file("tokenizer.json"));
        assert!(is_tokenizer_file("tokenizer_config.json"));
        assert!(is_tokenizer_file("special_tokens_map.json"));
        assert!(is_tokenizer_file("vocab.json"));
        assert!(is_tokenizer_file("merges.txt"));
        assert!(is_tokenizer_file("spiece.model"));
        assert!(is_tokenizer_file("chat_template.jinja"));
        assert!(is_tokenizer_file("template.jinja"));
        assert!(!is_tokenizer_file("model.bin"));
        assert!(!is_tokenizer_file("README.md"));
    }

    #[test]
    fn test_is_chat_template_file() {
        assert!(is_chat_template_file("chat_template.jinja"));
        assert!(is_chat_template_file("template.jinja"));
        assert!(is_chat_template_file("any_file.jinja"));
        assert!(is_chat_template_file("chat_template.json"));
        assert!(!is_chat_template_file("tokenizer.json"));
        assert!(!is_chat_template_file("other_file.json"));
        assert!(!is_chat_template_file("chat_template"));
        assert!(!is_chat_template_file("README.md"));
    }

    #[test]
    fn test_is_weight_file() {
        assert!(is_weight_file("model.bin"));
        assert!(is_weight_file("model.safetensors"));
        assert!(is_weight_file("pytorch_model.bin"));
        assert!(!is_weight_file("tokenizer.json"));
        assert!(!is_weight_file("config.json"));
    }
}