1use std::collections::BTreeMap;
9use std::hash::BuildHasher;
10
11use serde::Deserialize;
12
13use crate::error::FetchError;
14
15#[derive(Debug, Clone)]
17pub struct SearchResult {
18 pub model_id: String,
20 pub downloads: u64,
22 pub library_name: Option<String>,
24 pub pipeline_tag: Option<String>,
26}
27
28#[derive(Debug, Clone)]
30pub struct DiscoveredFamily {
31 pub model_type: String,
33 pub top_model: String,
35 pub downloads: u64,
37}
38
39#[derive(Debug, Deserialize)]
41struct ApiModelEntry {
42 #[serde(rename = "modelId")]
43 model_id: String,
44 #[serde(default)]
45 downloads: u64,
46 #[serde(default)]
47 config: Option<ApiConfig>,
48 #[serde(default)]
49 library_name: Option<String>,
50 #[serde(default)]
51 pipeline_tag: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
56struct ApiConfig {
57 model_type: Option<String>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
65#[non_exhaustive]
66pub enum GateStatus {
67 Open,
69 Auto,
71 Manual,
73}
74
75impl GateStatus {
76 #[must_use]
78 pub const fn is_gated(&self) -> bool {
79 matches!(self, Self::Auto | Self::Manual)
80 }
81}
82
83impl std::fmt::Display for GateStatus {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Self::Open => write!(f, "open"),
87 Self::Auto => write!(f, "auto"),
88 Self::Manual => write!(f, "manual"),
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
99pub struct ModelCardMetadata {
100 pub license: Option<String>,
102 pub pipeline_tag: Option<String>,
104 pub tags: Vec<String>,
106 pub library_name: Option<String>,
108 pub languages: Vec<String>,
110 pub gated: GateStatus,
112}
113
114#[derive(Debug, Deserialize)]
116struct ApiModelDetail {
117 #[serde(default)]
118 pipeline_tag: Option<String>,
119 #[serde(default)]
120 tags: Vec<String>,
121 #[serde(default)]
122 library_name: Option<String>,
123 #[serde(default)]
124 gated: ApiGated,
125 #[serde(default, rename = "cardData")]
126 card_data: Option<ApiCardData>,
127}
128
129#[derive(Debug, Deserialize)]
131struct ApiCardData {
132 #[serde(default)]
133 license: Option<String>,
134 #[serde(default)]
135 language: Option<ApiLanguage>,
136}
137
138#[derive(Debug, Deserialize)]
140#[serde(untagged)]
141enum ApiLanguage {
142 Single(String),
143 Multiple(Vec<String>),
144}
145
146#[derive(Debug, Deserialize)]
148#[serde(untagged)]
149enum ApiGated {
150 Bool(bool),
151 Mode(String),
152}
153
154impl Default for ApiGated {
155 fn default() -> Self {
156 Self::Bool(false)
157 }
158}
159
160const PAGE_SIZE: usize = 100;
161const HF_API_BASE: &str = "https://huggingface.co/api/models";
162
163pub async fn discover_new_families<S: BuildHasher>(
175 local_families: &std::collections::HashSet<String, S>,
176 max_models: usize,
177) -> Result<Vec<DiscoveredFamily>, FetchError> {
178 let client = reqwest::Client::new();
179 let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
180 let mut offset: usize = 0;
181
182 while offset < max_models {
183 let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
184 let url = format!(
185 "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
186 );
187
188 let response = client
189 .get(url.as_str()) .send()
191 .await
192 .map_err(|e| FetchError::Http(e.to_string()))?;
193
194 if !response.status().is_success() {
195 return Err(FetchError::Http(format!(
196 "HF API returned status {}",
197 response.status()
198 )));
199 }
200
201 let models: Vec<ApiModelEntry> = response
202 .json()
203 .await
204 .map_err(|e| FetchError::Http(e.to_string()))?;
205
206 if models.is_empty() {
207 break;
208 }
209
210 for model in &models {
211 let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
213
214 if let Some(mt) = model_type {
215 remote_families
216 .entry(mt.to_owned())
217 .or_insert_with(|| (model.model_id.clone(), model.downloads));
218 }
219 }
220
221 offset = offset.saturating_add(models.len());
222 }
223
224 let discovered: Vec<DiscoveredFamily> = remote_families
227 .into_iter()
228 .filter(|(mt, _)| !local_families.contains(mt.as_str()))
229 .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
230 model_type,
231 top_model,
232 downloads,
233 })
234 .collect();
235
236 Ok(discovered)
237}
238
239#[must_use]
243fn normalize_quantization_terms(query: &str) -> String {
244 const SYNONYMS: &[(&[&str], &str)] = &[
246 (&["8bit", "8-bit", "int8"], "8-bit"),
247 (&["4bit", "4-bit", "int4"], "4-bit"),
248 (&["fp8", "float8"], "fp8"),
249 ];
250
251 query
252 .split_whitespace()
253 .map(|token| {
254 let lower = token.to_lowercase();
256 for &(variants, canonical) in SYNONYMS {
257 if variants.contains(&lower.as_str()) {
259 return (*canonical).to_owned();
261 }
262 }
263 token.to_owned()
265 })
266 .collect::<Vec<_>>()
267 .join(" ")
268}
269
270pub async fn search_models(
293 query: &str,
294 limit: usize,
295 library: Option<&str>,
296 pipeline: Option<&str>,
297) -> Result<Vec<SearchResult>, FetchError> {
298 let normalized = normalize_quantization_terms(query);
299 let client = reqwest::Client::new();
300
301 let mut query_params: Vec<(&str, &str)> = vec![
303 ("search", normalized.as_str()),
304 ("sort", "downloads"),
305 ("direction", "-1"),
306 ];
307 if let Some(lib) = library {
308 query_params.push(("library", lib));
309 }
310 if let Some(pipe) = pipeline {
311 query_params.push(("pipeline_tag", pipe));
312 }
313
314 let response = client
315 .get(HF_API_BASE)
316 .query(&query_params)
317 .query(&[("limit", limit)])
318 .send()
319 .await
320 .map_err(|e| FetchError::Http(e.to_string()))?;
321
322 if !response.status().is_success() {
323 return Err(FetchError::Http(format!(
324 "HF API returned status {}",
325 response.status()
326 )));
327 }
328
329 let models: Vec<ApiModelEntry> = response
330 .json()
331 .await
332 .map_err(|e| FetchError::Http(e.to_string()))?;
333
334 let results = models
338 .into_iter()
339 .filter(|m| {
340 if let Some(lib) = library {
341 match m.library_name {
342 Some(ref name) if name.as_str().eq_ignore_ascii_case(lib) => {}
344 _ => return false,
345 }
346 }
347 if let Some(pipe) = pipeline {
348 match m.pipeline_tag {
349 Some(ref tag) if tag.as_str().eq_ignore_ascii_case(pipe) => {}
351 _ => return false,
352 }
353 }
354 true
355 })
356 .map(|m| SearchResult {
357 model_id: m.model_id,
358 downloads: m.downloads,
359 library_name: m.library_name,
360 pipeline_tag: m.pipeline_tag,
361 })
362 .collect();
363
364 Ok(results)
365}
366
367pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
380 let client = reqwest::Client::new();
381 let url = format!("{HF_API_BASE}/{model_id}");
382
383 let response = client
384 .get(url.as_str()) .send()
386 .await
387 .map_err(|e| FetchError::Http(e.to_string()))?;
388
389 if !response.status().is_success() {
390 return Err(FetchError::Http(format!(
391 "HF API returned status {} for model {model_id}",
392 response.status()
393 )));
394 }
395
396 let detail: ApiModelDetail = response
397 .json()
398 .await
399 .map_err(|e| FetchError::Http(e.to_string()))?;
400
401 let (license, languages) = if let Some(card) = detail.card_data {
402 let langs = match card.language {
403 Some(ApiLanguage::Single(s)) => vec![s],
404 Some(ApiLanguage::Multiple(v)) => v,
405 None => Vec::new(),
406 };
407 (card.license, langs)
408 } else {
409 (None, Vec::new())
410 };
411
412 let gated = match detail.gated {
413 ApiGated::Bool(false) => GateStatus::Open,
414 ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
415 ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
416 };
417
418 Ok(ModelCardMetadata {
419 license,
420 pipeline_tag: detail.pipeline_tag,
421 tags: detail.tags,
422 library_name: detail.library_name,
423 languages,
424 gated,
425 })
426}
427
428pub async fn fetch_readme(
443 model_id: &str,
444 revision: Option<&str>,
445 token: Option<&str>,
446) -> Result<Option<String>, FetchError> {
447 let rev = revision.unwrap_or("main");
448 let url = crate::chunked::build_download_url(model_id, rev, "README.md");
449 let client = crate::chunked::build_client(token)?;
450
451 let response = client
452 .get(url.as_str()) .send()
454 .await
455 .map_err(|e| FetchError::Http(format!("failed to fetch README for {model_id}: {e}")))?;
456
457 if response.status() == reqwest::StatusCode::NOT_FOUND {
458 return Ok(None);
459 }
460
461 if !response.status().is_success() {
462 return Err(FetchError::Http(format!(
463 "README request for {model_id} returned status {}",
464 response.status()
465 )));
466 }
467
468 let text = response
469 .text()
470 .await
471 .map_err(|e| FetchError::Http(format!("failed to read README for {model_id}: {e}")))?;
472
473 Ok(Some(text))
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn normalize_8bit_variants() {
482 assert_eq!(normalize_quantization_terms("AWQ 8bit"), "AWQ 8-bit");
483 assert_eq!(normalize_quantization_terms("AWQ 8-bit"), "AWQ 8-bit");
484 assert_eq!(normalize_quantization_terms("AWQ int8"), "AWQ 8-bit");
485 assert_eq!(normalize_quantization_terms("AWQ INT8"), "AWQ 8-bit");
486 }
487
488 #[test]
489 fn normalize_4bit_variants() {
490 assert_eq!(normalize_quantization_terms("GPTQ 4bit"), "GPTQ 4-bit");
491 assert_eq!(normalize_quantization_terms("GPTQ INT4"), "GPTQ 4-bit");
492 assert_eq!(normalize_quantization_terms("GPTQ 4-bit"), "GPTQ 4-bit");
493 }
494
495 #[test]
496 fn normalize_fp8_variants() {
497 assert_eq!(normalize_quantization_terms("FP8"), "fp8");
498 assert_eq!(normalize_quantization_terms("float8"), "fp8");
499 assert_eq!(normalize_quantization_terms("fp8"), "fp8");
500 }
501
502 #[test]
503 fn normalize_passthrough() {
504 assert_eq!(normalize_quantization_terms("llama 3"), "llama 3");
505 assert_eq!(normalize_quantization_terms("RWKV-7"), "RWKV-7");
506 }
507}