Skip to main content

hanzo_engine/speculative/
config.rs

1use std::path::{Path, PathBuf};
2
3use hf_hub::{
4    api::sync::{ApiBuilder, ApiRepo},
5    Cache, Repo, RepoType,
6};
7
8use crate::{
9    pipeline::{
10        hf::{get_file, hf_hub_cache_dir, list_repo_files, try_get_file},
11        TokenSource,
12    },
13    utils::tokens::get_token,
14    GLOBAL_HF_CACHE,
15};
16
17#[derive(Clone, Debug)]
18pub enum SpeculativeConfig {
19    Off,
20    Mtp(MtpConfig),
21}
22
23#[derive(Clone, Debug)]
24pub struct MtpConfig {
25    pub model: String,
26    pub n_predict: Option<usize>,
27}
28
29impl MtpConfig {
30    pub fn new(model: impl Into<String>, n_predict: Option<usize>) -> Self {
31        Self {
32            model: model.into(),
33            n_predict,
34        }
35    }
36
37    pub fn resolve_path(&self) -> hanzo_ml::Result<PathBuf> {
38        let path = PathBuf::from(&self.model);
39        if path.exists() || self.model.starts_with('.') || self.model.starts_with('/') {
40            Ok(path)
41        } else {
42            resolve_hf_mtp_path(&self.model)
43        }
44    }
45}
46
47fn build_hf_api(id: &str, revision: &str) -> hanzo_ml::Result<ApiRepo> {
48    let cache = GLOBAL_HF_CACHE
49        .get()
50        .cloned()
51        .unwrap_or_else(|| hf_hub_cache_dir().map(Cache::new).unwrap_or_default());
52    let mut api = ApiBuilder::from_cache(cache)
53        .with_progress(true)
54        .with_token(get_token(&TokenSource::CacheToken).map_err(hanzo_ml::Error::msg)?);
55    if let Some(cache_dir) = hf_hub_cache_dir() {
56        api = api.with_cache_dir(cache_dir);
57    }
58    Ok(api
59        .build()
60        .map_err(hanzo_ml::Error::msg)?
61        .repo(Repo::with_revision(
62            id.to_string(),
63            RepoType::Model,
64            revision.to_string(),
65        )))
66}
67
68fn resolve_hf_mtp_path(id: &str) -> hanzo_ml::Result<PathBuf> {
69    let revision = "main";
70    let api = build_hf_api(id, revision)?;
71    let model_id = Path::new(id);
72
73    let config_path =
74        get_file(&api, model_id, "config.json", revision).map_err(hanzo_ml::Error::msg)?;
75    let files = list_repo_files(&api, model_id, true, revision).map_err(hanzo_ml::Error::msg)?;
76    let mut weight_files = files
77        .iter()
78        .filter(|file| file.ends_with(".safetensors"))
79        .cloned()
80        .collect::<Vec<_>>();
81    weight_files.sort();
82    if weight_files.is_empty() {
83        hanzo_ml::bail!("MTP model `{id}` does not contain safetensors weights");
84    }
85    for file in weight_files {
86        get_file(&api, model_id, &file, revision).map_err(hanzo_ml::Error::msg)?;
87    }
88
89    try_get_file(&api, model_id, "generation_config.json", revision)
90        .map_err(|err| hanzo_ml::Error::Msg(err.to_string()))?;
91
92    config_path
93        .parent()
94        .map(Path::to_path_buf)
95        .ok_or_else(|| hanzo_ml::Error::Msg(format!("config path has no parent: {config_path:?}")))
96}