1use crate::error::{OutrigError, Result};
15
16#[derive(Debug, Clone, PartialEq)]
21pub struct HfFile {
22 pub path: String,
23 pub size: Option<u64>,
24}
25
26#[allow(async_fn_in_trait)]
27pub trait HfTreeFetcher {
28 async fn list_files(&mut self, model_id: &str, revision: Option<&str>) -> Result<Vec<HfFile>>;
29}
30
31#[cfg(feature = "local-llm")]
32pub struct ApiHfTreeFetcher;
33
34#[cfg(feature = "local-llm")]
35#[derive(serde::Deserialize)]
36struct TreeEntry {
37 #[serde(rename = "type")]
38 kind: String,
39 path: String,
40 #[serde(default)]
41 size: Option<u64>,
42}
43
44#[cfg(feature = "local-llm")]
45impl HfTreeFetcher for ApiHfTreeFetcher {
46 async fn list_files(&mut self, model_id: &str, revision: Option<&str>) -> Result<Vec<HfFile>> {
47 let revision = revision.unwrap_or("main");
48 let url = format!("https://huggingface.co/api/models/{model_id}/tree/{revision}");
49 let client = reqwest::Client::builder()
50 .timeout(std::time::Duration::from_secs(10))
51 .build()
52 .map_err(|e| OutrigError::Configuration(format!("hf client: {e}")))?;
53 let resp = client
54 .get(&url)
55 .send()
56 .await
57 .map_err(|e| OutrigError::Configuration(format!("hf list {model_id:?}: {e}")))?;
58 if !resp.status().is_success() {
59 return Err(OutrigError::Configuration(format!(
60 "hf list {model_id:?}: HTTP {}",
61 resp.status()
62 ))
63 .into());
64 }
65 let entries: Vec<TreeEntry> = resp
66 .json()
67 .await
68 .map_err(|e| OutrigError::Configuration(format!("hf list {model_id:?}: {e}")))?;
69 Ok(entries
70 .into_iter()
71 .filter(|e| e.kind == "file")
72 .map(|e| HfFile {
73 path: e.path,
74 size: e.size,
75 })
76 .collect())
77 }
78}
79
80pub struct UnavailableHfTreeFetcher;
85
86impl HfTreeFetcher for UnavailableHfTreeFetcher {
87 async fn list_files(
88 &mut self,
89 _model_id: &str,
90 _revision: Option<&str>,
91 ) -> Result<Vec<HfFile>> {
92 Err(OutrigError::Configuration(
93 "HuggingFace tree-listing not available in this build".to_string(),
94 )
95 .into())
96 }
97}
98
99pub fn auto() -> AutoHfTreeFetcher {
102 #[cfg(feature = "local-llm")]
103 {
104 AutoHfTreeFetcher::Api(ApiHfTreeFetcher)
105 }
106 #[cfg(not(feature = "local-llm"))]
107 {
108 AutoHfTreeFetcher::Unavailable(UnavailableHfTreeFetcher)
109 }
110}
111
112pub enum AutoHfTreeFetcher {
113 #[cfg(feature = "local-llm")]
114 Api(ApiHfTreeFetcher),
115 Unavailable(UnavailableHfTreeFetcher),
116}
117
118impl HfTreeFetcher for AutoHfTreeFetcher {
119 async fn list_files(&mut self, model_id: &str, revision: Option<&str>) -> Result<Vec<HfFile>> {
120 match self {
121 #[cfg(feature = "local-llm")]
122 Self::Api(f) => f.list_files(model_id, revision).await,
123 Self::Unavailable(f) => f.list_files(model_id, revision).await,
124 }
125 }
126}
127
128pub fn filter_gguf(files: Vec<HfFile>) -> Vec<HfFile> {
130 let mut out: Vec<HfFile> = files
131 .into_iter()
132 .filter(|f| f.path.to_ascii_lowercase().ends_with(".gguf"))
133 .collect();
134 out.sort_by(|a, b| a.path.cmp(&b.path));
135 out
136}
137
138pub fn format_size(bytes: u64) -> String {
140 const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB"];
141 let mut size = bytes as f64;
142 let mut unit = 0;
143 while size >= 1024.0 && unit < UNITS.len() - 1 {
144 size /= 1024.0;
145 unit += 1;
146 }
147 if unit == 0 {
148 format!("{bytes} {}", UNITS[0])
149 } else {
150 format!("{size:.1} {}", UNITS[unit])
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 fn f(path: &str, size: Option<u64>) -> HfFile {
159 HfFile {
160 path: path.to_string(),
161 size,
162 }
163 }
164
165 #[test]
166 fn filter_gguf_keeps_only_gguf_and_sorts() {
167 let files = vec![
168 f("README.md", Some(1_024)),
169 f("config.json", Some(512)),
170 f(
171 "qwen2.5-coder-1.5b-instruct-q5_k_m.gguf",
172 Some(1_500_000_000),
173 ),
174 f(
175 "qwen2.5-coder-1.5b-instruct-q4_k_m.GGUF",
176 Some(1_000_000_000),
177 ),
178 f("qwen2.5-coder-1.5b-instruct-q8_0.gguf", Some(2_000_000_000)),
179 f("tokenizer.json", Some(512)),
180 ];
181 let out = filter_gguf(files);
182 let names: Vec<&str> = out.iter().map(|x| x.path.as_str()).collect();
183 assert_eq!(
184 names,
185 vec![
186 "qwen2.5-coder-1.5b-instruct-q4_k_m.GGUF",
187 "qwen2.5-coder-1.5b-instruct-q5_k_m.gguf",
188 "qwen2.5-coder-1.5b-instruct-q8_0.gguf",
189 ]
190 );
191 }
192
193 #[test]
194 fn format_size_renders_units() {
195 assert_eq!(format_size(0), "0 B");
196 assert_eq!(format_size(512), "512 B");
197 assert_eq!(format_size(2 * 1024), "2.0 KiB");
198 assert_eq!(format_size(3 * 1024 * 1024), "3.0 MiB");
199 assert_eq!(format_size(1_500_000_000), "1.4 GiB");
200 }
201
202 #[tokio::test]
203 async fn unavailable_fetcher_always_errors() {
204 let mut x = UnavailableHfTreeFetcher;
205 assert!(x.list_files("anything", None).await.is_err());
206 }
207}