next_plaid_cli/
model.rs

1use anyhow::Result;
2use hf_hub::api::sync::Api;
3use std::path::PathBuf;
4
5pub const DEFAULT_MODEL: &str = "lightonai/GTE-ModernColBERT-v1-onnx";
6
7/// Files required for ColBERT model
8const REQUIRED_FILES: &[&str] = &[
9    "model_int8.onnx",
10    "tokenizer.json",
11    "config_sentence_transformers.json",
12    "config.json",
13];
14
15/// Load model from cache or download from HuggingFace.
16/// Returns path to the model directory.
17pub fn ensure_model(model_id: Option<&str>) -> Result<PathBuf> {
18    let model_id = model_id.unwrap_or(DEFAULT_MODEL);
19
20    // Check if it's a local path
21    let local_path = PathBuf::from(model_id);
22    if local_path.exists() && local_path.is_dir() {
23        return Ok(local_path);
24    }
25
26    // Download from HuggingFace
27    eprintln!("🤖 Model: {}", model_id);
28    let api = Api::new()?;
29    let repo = api.model(model_id.to_string());
30
31    // Download all required files (cached if already present)
32    let mut model_dir = None;
33    for file in REQUIRED_FILES {
34        match repo.get(file) {
35            Ok(path) => {
36                if model_dir.is_none() {
37                    model_dir = path.parent().map(|p| p.to_path_buf());
38                }
39            }
40            Err(e) => {
41                // config.json may not exist in all models, that's ok
42                if *file != "config.json" {
43                    return Err(e.into());
44                }
45            }
46        }
47    }
48
49    model_dir.ok_or_else(|| anyhow::anyhow!("Failed to determine model directory"))
50}