Skip to main content

outrig_cli/
hf.rs

1//! Minimal HuggingFace tree-listing helper used by `outrig config init`.
2//!
3//! Returns each file's path *and* size so the picker can display
4//! human-readable sizes alongside filenames -- mistralrs users tell
5//! quantizations apart by file size as much as by name. Hits HF's
6//! `/api/models/{id}/tree/{revision}` directly via `reqwest` (rather
7//! than `hf-hub::Api::info()`, which only exposes filenames).
8//!
9//! The `local-llm` feature pulls in `reqwest` for the real implementation.
10//! Builds without the feature still get the trait plus an `Unavailable`
11//! impl that always errors -- so the init flow can prompt for `model-file`
12//! as free-form text without compiling against `reqwest`.
13
14use crate::error::{OutrigError, Result};
15
16/// One file in a HuggingFace repo, projected into the fields the picker
17/// needs. `size` is the file's byte count when known (HF's tree endpoint
18/// reports it for every regular file; the field stays `Option` so future
19/// API quirks don't break the picker).
20#[derive(Debug, Clone, PartialEq)]
21pub struct HfFile {
22    pub path: String,
23    pub size: Option<u64>,
24}
25
26#[allow(async_fn_in_trait)]
27pub trait HfTreeFetcher {
28    async fn list_files(&mut self, model_id: &str, revision: Option<&str>) -> Result<Vec<HfFile>>;
29}
30
31#[cfg(feature = "local-llm")]
32pub struct ApiHfTreeFetcher;
33
34#[cfg(feature = "local-llm")]
35#[derive(serde::Deserialize)]
36struct TreeEntry {
37    #[serde(rename = "type")]
38    kind: String,
39    path: String,
40    #[serde(default)]
41    size: Option<u64>,
42}
43
44#[cfg(feature = "local-llm")]
45impl HfTreeFetcher for ApiHfTreeFetcher {
46    async fn list_files(&mut self, model_id: &str, revision: Option<&str>) -> Result<Vec<HfFile>> {
47        let revision = revision.unwrap_or("main");
48        let url = format!("https://huggingface.co/api/models/{model_id}/tree/{revision}");
49        let client = reqwest::Client::builder()
50            .timeout(std::time::Duration::from_secs(10))
51            .build()
52            .map_err(|e| OutrigError::Configuration(format!("hf client: {e}")))?;
53        let resp = client
54            .get(&url)
55            .send()
56            .await
57            .map_err(|e| OutrigError::Configuration(format!("hf list {model_id:?}: {e}")))?;
58        if !resp.status().is_success() {
59            return Err(OutrigError::Configuration(format!(
60                "hf list {model_id:?}: HTTP {}",
61                resp.status()
62            ))
63            .into());
64        }
65        let entries: Vec<TreeEntry> = resp
66            .json()
67            .await
68            .map_err(|e| OutrigError::Configuration(format!("hf list {model_id:?}: {e}")))?;
69        Ok(entries
70            .into_iter()
71            .filter(|e| e.kind == "file")
72            .map(|e| HfFile {
73                path: e.path,
74                size: e.size,
75            })
76            .collect())
77    }
78}
79
80/// Always-fails fetcher used when the `local-llm` feature is off (or by
81/// callers that explicitly want to bypass the network). Returns a
82/// configuration error the prompt flow recognizes as "fall back to the
83/// free-form text prompt".
84pub struct UnavailableHfTreeFetcher;
85
86impl HfTreeFetcher for UnavailableHfTreeFetcher {
87    async fn list_files(
88        &mut self,
89        _model_id: &str,
90        _revision: Option<&str>,
91    ) -> Result<Vec<HfFile>> {
92        Err(OutrigError::Configuration(
93            "HuggingFace tree-listing not available in this build".to_string(),
94        )
95        .into())
96    }
97}
98
99/// Pick a fetcher appropriate for the current build. Mirrors the
100/// `init::prompt::auto` factory.
101pub fn auto() -> AutoHfTreeFetcher {
102    #[cfg(feature = "local-llm")]
103    {
104        AutoHfTreeFetcher::Api(ApiHfTreeFetcher)
105    }
106    #[cfg(not(feature = "local-llm"))]
107    {
108        AutoHfTreeFetcher::Unavailable(UnavailableHfTreeFetcher)
109    }
110}
111
112pub enum AutoHfTreeFetcher {
113    #[cfg(feature = "local-llm")]
114    Api(ApiHfTreeFetcher),
115    Unavailable(UnavailableHfTreeFetcher),
116}
117
118impl HfTreeFetcher for AutoHfTreeFetcher {
119    async fn list_files(&mut self, model_id: &str, revision: Option<&str>) -> Result<Vec<HfFile>> {
120        match self {
121            #[cfg(feature = "local-llm")]
122            Self::Api(f) => f.list_files(model_id, revision).await,
123            Self::Unavailable(f) => f.list_files(model_id, revision).await,
124        }
125    }
126}
127
128/// Filter to `.gguf` files only, sorted by path.
129pub fn filter_gguf(files: Vec<HfFile>) -> Vec<HfFile> {
130    let mut out: Vec<HfFile> = files
131        .into_iter()
132        .filter(|f| f.path.to_ascii_lowercase().ends_with(".gguf"))
133        .collect();
134    out.sort_by(|a, b| a.path.cmp(&b.path));
135    out
136}
137
138/// Render a byte count as a human-readable string (e.g. "1.4 GiB").
139pub fn format_size(bytes: u64) -> String {
140    const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB"];
141    let mut size = bytes as f64;
142    let mut unit = 0;
143    while size >= 1024.0 && unit < UNITS.len() - 1 {
144        size /= 1024.0;
145        unit += 1;
146    }
147    if unit == 0 {
148        format!("{bytes} {}", UNITS[0])
149    } else {
150        format!("{size:.1} {}", UNITS[unit])
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn f(path: &str, size: Option<u64>) -> HfFile {
159        HfFile {
160            path: path.to_string(),
161            size,
162        }
163    }
164
165    #[test]
166    fn filter_gguf_keeps_only_gguf_and_sorts() {
167        let files = vec![
168            f("README.md", Some(1_024)),
169            f("config.json", Some(512)),
170            f(
171                "qwen2.5-coder-1.5b-instruct-q5_k_m.gguf",
172                Some(1_500_000_000),
173            ),
174            f(
175                "qwen2.5-coder-1.5b-instruct-q4_k_m.GGUF",
176                Some(1_000_000_000),
177            ),
178            f("qwen2.5-coder-1.5b-instruct-q8_0.gguf", Some(2_000_000_000)),
179            f("tokenizer.json", Some(512)),
180        ];
181        let out = filter_gguf(files);
182        let names: Vec<&str> = out.iter().map(|x| x.path.as_str()).collect();
183        assert_eq!(
184            names,
185            vec![
186                "qwen2.5-coder-1.5b-instruct-q4_k_m.GGUF",
187                "qwen2.5-coder-1.5b-instruct-q5_k_m.gguf",
188                "qwen2.5-coder-1.5b-instruct-q8_0.gguf",
189            ]
190        );
191    }
192
193    #[test]
194    fn format_size_renders_units() {
195        assert_eq!(format_size(0), "0 B");
196        assert_eq!(format_size(512), "512 B");
197        assert_eq!(format_size(2 * 1024), "2.0 KiB");
198        assert_eq!(format_size(3 * 1024 * 1024), "3.0 MiB");
199        assert_eq!(format_size(1_500_000_000), "1.4 GiB");
200    }
201
202    #[tokio::test]
203    async fn unavailable_fetcher_always_errors() {
204        let mut x = UnavailableHfTreeFetcher;
205        assert!(x.list_files("anything", None).await.is_err());
206    }
207}