hanzo_engine/speculative/
config.rs1use std::path::{Path, PathBuf};
2
3use hf_hub::{
4 api::sync::{ApiBuilder, ApiRepo},
5 Cache, Repo, RepoType,
6};
7
8use crate::{
9 pipeline::{
10 hf::{get_file, hf_hub_cache_dir, list_repo_files, try_get_file},
11 TokenSource,
12 },
13 utils::tokens::get_token,
14 GLOBAL_HF_CACHE,
15};
16
17#[derive(Clone, Debug)]
18pub enum SpeculativeConfig {
19 Off,
20 Mtp(MtpConfig),
21}
22
23#[derive(Clone, Debug)]
24pub struct MtpConfig {
25 pub model: String,
26 pub n_predict: Option<usize>,
27}
28
29impl MtpConfig {
30 pub fn new(model: impl Into<String>, n_predict: Option<usize>) -> Self {
31 Self {
32 model: model.into(),
33 n_predict,
34 }
35 }
36
37 pub fn resolve_path(&self) -> hanzo_ml::Result<PathBuf> {
38 let path = PathBuf::from(&self.model);
39 if path.exists() || self.model.starts_with('.') || self.model.starts_with('/') {
40 Ok(path)
41 } else {
42 resolve_hf_mtp_path(&self.model)
43 }
44 }
45}
46
47fn build_hf_api(id: &str, revision: &str) -> hanzo_ml::Result<ApiRepo> {
48 let cache = GLOBAL_HF_CACHE
49 .get()
50 .cloned()
51 .unwrap_or_else(|| hf_hub_cache_dir().map(Cache::new).unwrap_or_default());
52 let mut api = ApiBuilder::from_cache(cache)
53 .with_progress(true)
54 .with_token(get_token(&TokenSource::CacheToken).map_err(hanzo_ml::Error::msg)?);
55 if let Some(cache_dir) = hf_hub_cache_dir() {
56 api = api.with_cache_dir(cache_dir);
57 }
58 Ok(api
59 .build()
60 .map_err(hanzo_ml::Error::msg)?
61 .repo(Repo::with_revision(
62 id.to_string(),
63 RepoType::Model,
64 revision.to_string(),
65 )))
66}
67
68fn resolve_hf_mtp_path(id: &str) -> hanzo_ml::Result<PathBuf> {
69 let revision = "main";
70 let api = build_hf_api(id, revision)?;
71 let model_id = Path::new(id);
72
73 let config_path =
74 get_file(&api, model_id, "config.json", revision).map_err(hanzo_ml::Error::msg)?;
75 let files = list_repo_files(&api, model_id, true, revision).map_err(hanzo_ml::Error::msg)?;
76 let mut weight_files = files
77 .iter()
78 .filter(|file| file.ends_with(".safetensors"))
79 .cloned()
80 .collect::<Vec<_>>();
81 weight_files.sort();
82 if weight_files.is_empty() {
83 hanzo_ml::bail!("MTP model `{id}` does not contain safetensors weights");
84 }
85 for file in weight_files {
86 get_file(&api, model_id, &file, revision).map_err(hanzo_ml::Error::msg)?;
87 }
88
89 try_get_file(&api, model_id, "generation_config.json", revision)
90 .map_err(|err| hanzo_ml::Error::Msg(err.to_string()))?;
91
92 config_path
93 .parent()
94 .map(Path::to_path_buf)
95 .ok_or_else(|| hanzo_ml::Error::Msg(format!("config path has no parent: {config_path:?}")))
96}