1use std::collections::BTreeMap;
9use std::hash::BuildHasher;
10
11use serde::Deserialize;
12
13use crate::error::FetchError;
14
15#[derive(Debug, Clone)]
17pub struct SearchResult {
18 pub model_id: String,
20 pub downloads: u64,
22 pub library_name: Option<String>,
24 pub pipeline_tag: Option<String>,
26 pub tags: Vec<String>,
28}
29
30#[derive(Debug, Clone)]
32pub struct DiscoveredFamily {
33 pub model_type: String,
35 pub top_model: String,
37 pub downloads: u64,
39}
40
41#[derive(Debug, Deserialize)]
43struct ApiModelEntry {
44 #[serde(rename = "modelId")]
45 model_id: String,
46 #[serde(default)]
47 downloads: u64,
48 #[serde(default)]
49 config: Option<ApiConfig>,
50 #[serde(default)]
51 library_name: Option<String>,
52 #[serde(default)]
53 pipeline_tag: Option<String>,
54 #[serde(default)]
55 tags: Vec<String>,
56}
57
58#[derive(Debug, Deserialize)]
60struct ApiConfig {
61 model_type: Option<String>,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
69#[non_exhaustive]
70pub enum GateStatus {
71 Open,
73 Auto,
75 Manual,
77}
78
79impl GateStatus {
80 #[must_use]
82 pub const fn is_gated(&self) -> bool {
83 matches!(self, Self::Auto | Self::Manual)
84 }
85}
86
87impl std::fmt::Display for GateStatus {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::Open => write!(f, "open"),
91 Self::Auto => write!(f, "auto"),
92 Self::Manual => write!(f, "manual"),
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
103pub struct ModelCardMetadata {
104 pub license: Option<String>,
106 pub pipeline_tag: Option<String>,
108 pub tags: Vec<String>,
110 pub library_name: Option<String>,
112 pub languages: Vec<String>,
114 pub gated: GateStatus,
116}
117
118#[derive(Debug, Deserialize)]
120struct ApiModelDetail {
121 #[serde(default)]
122 pipeline_tag: Option<String>,
123 #[serde(default)]
124 tags: Vec<String>,
125 #[serde(default)]
126 library_name: Option<String>,
127 #[serde(default)]
128 gated: ApiGated,
129 #[serde(default, rename = "cardData")]
130 card_data: Option<ApiCardData>,
131}
132
133#[derive(Debug, Deserialize)]
135struct ApiCardData {
136 #[serde(default)]
137 license: Option<String>,
138 #[serde(default)]
139 language: Option<ApiLanguage>,
140}
141
142#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum ApiLanguage {
146 Single(String),
147 Multiple(Vec<String>),
148}
149
150#[derive(Debug, Deserialize)]
152#[serde(untagged)]
153enum ApiGated {
154 Bool(bool),
155 Mode(String),
156}
157
158impl Default for ApiGated {
159 fn default() -> Self {
160 Self::Bool(false)
161 }
162}
163
164const PAGE_SIZE: usize = 100;
165const HF_API_BASE: &str = "https://huggingface.co/api/models";
166
167pub async fn discover_new_families<S: BuildHasher>(
179 local_families: &std::collections::HashSet<String, S>,
180 max_models: usize,
181) -> Result<Vec<DiscoveredFamily>, FetchError> {
182 let client = reqwest::Client::new();
183 let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
184 let mut offset: usize = 0;
185
186 while offset < max_models {
187 let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
188 let url = format!(
189 "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
190 );
191
192 let response = client
193 .get(url.as_str()) .send()
195 .await
196 .map_err(|e| FetchError::Http(e.to_string()))?;
197
198 if !response.status().is_success() {
199 return Err(FetchError::Http(format!(
200 "HF API returned status {}",
201 response.status()
202 )));
203 }
204
205 let models: Vec<ApiModelEntry> = response
206 .json()
207 .await
208 .map_err(|e| FetchError::Http(e.to_string()))?;
209
210 if models.is_empty() {
211 break;
212 }
213
214 for model in &models {
215 let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
217
218 if let Some(mt) = model_type {
219 remote_families
220 .entry(mt.to_owned())
221 .or_insert_with(|| (model.model_id.clone(), model.downloads));
222 }
223 }
224
225 offset = offset.saturating_add(models.len());
226 }
227
228 let discovered: Vec<DiscoveredFamily> = remote_families
231 .into_iter()
232 .filter(|(mt, _)| !local_families.contains(mt.as_str()))
233 .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
234 model_type,
235 top_model,
236 downloads,
237 })
238 .collect();
239
240 Ok(discovered)
241}
242
243#[must_use]
247fn normalize_quantization_terms(query: &str) -> String {
248 const SYNONYMS: &[(&[&str], &str)] = &[
250 (&["8bit", "8-bit", "int8"], "8-bit"),
251 (&["4bit", "4-bit", "int4"], "4-bit"),
252 (&["fp8", "float8"], "fp8"),
253 ];
254
255 query
256 .split_whitespace()
257 .map(|token| {
258 let lower = token.to_lowercase();
260 for &(variants, canonical) in SYNONYMS {
261 if variants.contains(&lower.as_str()) {
263 return (*canonical).to_owned();
265 }
266 }
267 token.to_owned()
269 })
270 .collect::<Vec<_>>()
271 .join(" ")
272}
273
274pub async fn search_models(
299 query: &str,
300 limit: usize,
301 library: Option<&str>,
302 pipeline: Option<&str>,
303 tag: Option<&str>,
304) -> Result<Vec<SearchResult>, FetchError> {
305 let normalized = normalize_quantization_terms(query);
306 let client = reqwest::Client::new();
307
308 let mut query_params: Vec<(&str, &str)> = vec![
310 ("search", normalized.as_str()),
311 ("sort", "downloads"),
312 ("direction", "-1"),
313 ];
314 if let Some(lib) = library {
315 query_params.push(("library", lib));
316 }
317 if let Some(pipe) = pipeline {
318 query_params.push(("pipeline_tag", pipe));
319 }
320 if let Some(t) = tag {
321 query_params.push(("filter", t));
322 }
323
324 let response = client
325 .get(HF_API_BASE)
326 .query(&query_params)
327 .query(&[("limit", limit)])
328 .send()
329 .await
330 .map_err(|e| FetchError::Http(e.to_string()))?;
331
332 if !response.status().is_success() {
333 return Err(FetchError::Http(format!(
334 "HF API returned status {}",
335 response.status()
336 )));
337 }
338
339 let models: Vec<ApiModelEntry> = response
340 .json()
341 .await
342 .map_err(|e| FetchError::Http(e.to_string()))?;
343
344 let results = models
348 .into_iter()
349 .filter(|m| {
350 if let Some(lib) = library {
351 match m.library_name {
352 Some(ref name) if name.as_str().eq_ignore_ascii_case(lib) => {}
354 _ => return false,
355 }
356 }
357 if let Some(pipe) = pipeline {
358 match m.pipeline_tag {
359 Some(ref t) if t.as_str().eq_ignore_ascii_case(pipe) => {}
361 _ => return false,
362 }
363 }
364 if let Some(t) = tag {
365 if !m.tags.iter().any(|model_tag| {
366 model_tag.as_str().eq_ignore_ascii_case(t)
368 }) {
369 return false;
370 }
371 }
372 true
373 })
374 .map(|m| SearchResult {
375 model_id: m.model_id,
376 downloads: m.downloads,
377 library_name: m.library_name,
378 pipeline_tag: m.pipeline_tag,
379 tags: m.tags,
380 })
381 .collect();
382
383 Ok(results)
384}
385
386pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
399 let client = reqwest::Client::new();
400 let url = format!("{HF_API_BASE}/{model_id}");
401
402 let response = client
403 .get(url.as_str()) .send()
405 .await
406 .map_err(|e| FetchError::Http(e.to_string()))?;
407
408 if !response.status().is_success() {
409 return Err(FetchError::Http(format!(
410 "HF API returned status {} for model {model_id}",
411 response.status()
412 )));
413 }
414
415 let detail: ApiModelDetail = response
416 .json()
417 .await
418 .map_err(|e| FetchError::Http(e.to_string()))?;
419
420 let (license, languages) = if let Some(card) = detail.card_data {
421 let langs = match card.language {
422 Some(ApiLanguage::Single(s)) => vec![s],
423 Some(ApiLanguage::Multiple(v)) => v,
424 None => Vec::new(),
425 };
426 (card.license, langs)
427 } else {
428 (None, Vec::new())
429 };
430
431 let gated = match detail.gated {
432 ApiGated::Bool(false) => GateStatus::Open,
433 ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
434 ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
435 };
436
437 Ok(ModelCardMetadata {
438 license,
439 pipeline_tag: detail.pipeline_tag,
440 tags: detail.tags,
441 library_name: detail.library_name,
442 languages,
443 gated,
444 })
445}
446
447pub async fn fetch_readme(
462 model_id: &str,
463 revision: Option<&str>,
464 token: Option<&str>,
465) -> Result<Option<String>, FetchError> {
466 let rev = revision.unwrap_or("main");
467 let url = crate::chunked::build_download_url(model_id, rev, "README.md");
468 let client = crate::chunked::build_client(token)?;
469
470 let response = client
471 .get(url.as_str()) .send()
473 .await
474 .map_err(|e| FetchError::Http(format!("failed to fetch README for {model_id}: {e}")))?;
475
476 if response.status() == reqwest::StatusCode::NOT_FOUND {
477 return Ok(None);
478 }
479
480 if !response.status().is_success() {
481 return Err(FetchError::Http(format!(
482 "README request for {model_id} returned status {}",
483 response.status()
484 )));
485 }
486
487 let text = response
488 .text()
489 .await
490 .map_err(|e| FetchError::Http(format!("failed to read README for {model_id}: {e}")))?;
491
492 Ok(Some(text))
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn normalize_8bit_variants() {
501 assert_eq!(normalize_quantization_terms("AWQ 8bit"), "AWQ 8-bit");
502 assert_eq!(normalize_quantization_terms("AWQ 8-bit"), "AWQ 8-bit");
503 assert_eq!(normalize_quantization_terms("AWQ int8"), "AWQ 8-bit");
504 assert_eq!(normalize_quantization_terms("AWQ INT8"), "AWQ 8-bit");
505 }
506
507 #[test]
508 fn normalize_4bit_variants() {
509 assert_eq!(normalize_quantization_terms("GPTQ 4bit"), "GPTQ 4-bit");
510 assert_eq!(normalize_quantization_terms("GPTQ INT4"), "GPTQ 4-bit");
511 assert_eq!(normalize_quantization_terms("GPTQ 4-bit"), "GPTQ 4-bit");
512 }
513
514 #[test]
515 fn normalize_fp8_variants() {
516 assert_eq!(normalize_quantization_terms("FP8"), "fp8");
517 assert_eq!(normalize_quantization_terms("float8"), "fp8");
518 assert_eq!(normalize_quantization_terms("fp8"), "fp8");
519 }
520
521 #[test]
522 fn normalize_passthrough() {
523 assert_eq!(normalize_quantization_terms("llama 3"), "llama 3");
524 assert_eq!(normalize_quantization_terms("RWKV-7"), "RWKV-7");
525 }
526}