1use std::collections::BTreeMap;
9use std::hash::BuildHasher;
10
11use serde::Deserialize;
12
13use crate::error::FetchError;
14
15#[derive(Debug, Clone)]
17pub struct SearchResult {
18 pub model_id: String,
20 pub downloads: u64,
22}
23
24#[derive(Debug, Clone)]
26pub struct DiscoveredFamily {
27 pub model_type: String,
29 pub top_model: String,
31 pub downloads: u64,
33}
34
35#[derive(Debug, Deserialize)]
37struct ApiModelEntry {
38 #[serde(rename = "modelId")]
39 model_id: String,
40 #[serde(default)]
41 downloads: u64,
42 #[serde(default)]
43 config: Option<ApiConfig>,
44}
45
46#[derive(Debug, Deserialize)]
48struct ApiConfig {
49 model_type: Option<String>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
57#[non_exhaustive]
58pub enum GateStatus {
59 Open,
61 Auto,
63 Manual,
65}
66
67impl GateStatus {
68 #[must_use]
70 pub const fn is_gated(&self) -> bool {
71 matches!(self, Self::Auto | Self::Manual)
72 }
73}
74
75impl std::fmt::Display for GateStatus {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Self::Open => write!(f, "open"),
79 Self::Auto => write!(f, "auto"),
80 Self::Manual => write!(f, "manual"),
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
91pub struct ModelCardMetadata {
92 pub license: Option<String>,
94 pub pipeline_tag: Option<String>,
96 pub tags: Vec<String>,
98 pub library_name: Option<String>,
100 pub languages: Vec<String>,
102 pub gated: GateStatus,
104}
105
106#[derive(Debug, Deserialize)]
108struct ApiModelDetail {
109 #[serde(default)]
110 pipeline_tag: Option<String>,
111 #[serde(default)]
112 tags: Vec<String>,
113 #[serde(default)]
114 library_name: Option<String>,
115 #[serde(default)]
116 gated: ApiGated,
117 #[serde(default, rename = "cardData")]
118 card_data: Option<ApiCardData>,
119}
120
121#[derive(Debug, Deserialize)]
123struct ApiCardData {
124 #[serde(default)]
125 license: Option<String>,
126 #[serde(default)]
127 language: Option<ApiLanguage>,
128}
129
130#[derive(Debug, Deserialize)]
132#[serde(untagged)]
133enum ApiLanguage {
134 Single(String),
135 Multiple(Vec<String>),
136}
137
138#[derive(Debug, Deserialize)]
140#[serde(untagged)]
141enum ApiGated {
142 Bool(bool),
143 Mode(String),
144}
145
146impl Default for ApiGated {
147 fn default() -> Self {
148 Self::Bool(false)
149 }
150}
151
152const PAGE_SIZE: usize = 100;
153const HF_API_BASE: &str = "https://huggingface.co/api/models";
154
155pub async fn discover_new_families<S: BuildHasher>(
167 local_families: &std::collections::HashSet<String, S>,
168 max_models: usize,
169) -> Result<Vec<DiscoveredFamily>, FetchError> {
170 let client = reqwest::Client::new();
171 let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
172 let mut offset: usize = 0;
173
174 while offset < max_models {
175 let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
176 let url = format!(
177 "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
178 );
179
180 let response = client
181 .get(url.as_str()) .send()
183 .await
184 .map_err(|e| FetchError::Http(e.to_string()))?;
185
186 if !response.status().is_success() {
187 return Err(FetchError::Http(format!(
188 "HF API returned status {}",
189 response.status()
190 )));
191 }
192
193 let models: Vec<ApiModelEntry> = response
194 .json()
195 .await
196 .map_err(|e| FetchError::Http(e.to_string()))?;
197
198 if models.is_empty() {
199 break;
200 }
201
202 for model in &models {
203 let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
205
206 if let Some(mt) = model_type {
207 remote_families
208 .entry(mt.to_owned())
209 .or_insert_with(|| (model.model_id.clone(), model.downloads));
210 }
211 }
212
213 offset = offset.saturating_add(models.len());
214 }
215
216 let discovered: Vec<DiscoveredFamily> = remote_families
219 .into_iter()
220 .filter(|(mt, _)| !local_families.contains(mt.as_str()))
221 .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
222 model_type,
223 top_model,
224 downloads,
225 })
226 .collect();
227
228 Ok(discovered)
229}
230
231pub async fn search_models(query: &str, limit: usize) -> Result<Vec<SearchResult>, FetchError> {
244 let client = reqwest::Client::new();
245
246 let response = client
247 .get(HF_API_BASE)
248 .query(&[
249 ("search", query),
250 ("sort", "downloads"),
251 ("direction", "-1"),
252 ])
253 .query(&[("limit", limit)])
254 .send()
255 .await
256 .map_err(|e| FetchError::Http(e.to_string()))?;
257
258 if !response.status().is_success() {
259 return Err(FetchError::Http(format!(
260 "HF API returned status {}",
261 response.status()
262 )));
263 }
264
265 let models: Vec<ApiModelEntry> = response
266 .json()
267 .await
268 .map_err(|e| FetchError::Http(e.to_string()))?;
269
270 let results = models
271 .into_iter()
272 .map(|m| SearchResult {
273 model_id: m.model_id,
274 downloads: m.downloads,
275 })
276 .collect();
277
278 Ok(results)
279}
280
281pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
294 let client = reqwest::Client::new();
295 let url = format!("{HF_API_BASE}/{model_id}");
296
297 let response = client
298 .get(url.as_str()) .send()
300 .await
301 .map_err(|e| FetchError::Http(e.to_string()))?;
302
303 if !response.status().is_success() {
304 return Err(FetchError::Http(format!(
305 "HF API returned status {} for model {model_id}",
306 response.status()
307 )));
308 }
309
310 let detail: ApiModelDetail = response
311 .json()
312 .await
313 .map_err(|e| FetchError::Http(e.to_string()))?;
314
315 let (license, languages) = if let Some(card) = detail.card_data {
316 let langs = match card.language {
317 Some(ApiLanguage::Single(s)) => vec![s],
318 Some(ApiLanguage::Multiple(v)) => v,
319 None => Vec::new(),
320 };
321 (card.license, langs)
322 } else {
323 (None, Vec::new())
324 };
325
326 let gated = match detail.gated {
327 ApiGated::Bool(false) => GateStatus::Open,
328 ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
329 ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
330 };
331
332 Ok(ModelCardMetadata {
333 license,
334 pipeline_tag: detail.pipeline_tag,
335 tags: detail.tags,
336 library_name: detail.library_name,
337 languages,
338 gated,
339 })
340}