1use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::time::{Duration, Instant};
26use tracing::{debug, info, instrument, warn};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum HubAssetType {
36 Model,
37 Dataset,
38 Space,
39}
40
41impl std::fmt::Display for HubAssetType {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::Model => write!(f, "model"),
45 Self::Dataset => write!(f, "dataset"),
46 Self::Space => write!(f, "space"),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct HubAsset {
54 pub id: String,
56 pub asset_type: HubAssetType,
58 pub author: String,
60 pub downloads: u64,
62 pub likes: u64,
64 pub tags: Vec<String>,
66 pub pipeline_tag: Option<String>,
68 pub library: Option<String>,
70 pub license: Option<String>,
72 pub last_modified: String,
74 pub card_content: Option<String>,
76}
77
78impl HubAsset {
79 pub fn new(id: impl Into<String>, asset_type: HubAssetType) -> Self {
80 let id_str = id.into();
81 let author = id_str.split('/').next().unwrap_or("unknown").to_string();
82 Self {
83 id: id_str,
84 asset_type,
85 author,
86 downloads: 0,
87 likes: 0,
88 tags: Vec::new(),
89 pipeline_tag: None,
90 library: None,
91 license: None,
92 last_modified: String::new(),
93 card_content: None,
94 }
95 }
96
97 pub fn with_downloads(mut self, downloads: u64) -> Self {
98 self.downloads = downloads;
99 self
100 }
101
102 pub fn with_likes(mut self, likes: u64) -> Self {
103 self.likes = likes;
104 self
105 }
106
107 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
108 self.tags = tags;
109 self
110 }
111
112 pub fn with_pipeline_tag(mut self, tag: impl Into<String>) -> Self {
113 self.pipeline_tag = Some(tag.into());
114 self
115 }
116
117 pub fn with_library(mut self, library: impl Into<String>) -> Self {
118 self.library = Some(library.into());
119 self
120 }
121
122 pub fn with_license(mut self, license: impl Into<String>) -> Self {
123 self.license = Some(license.into());
124 self
125 }
126}
127
128#[derive(Debug, Clone, Default)]
134pub struct SearchFilters {
135 pub task: Option<String>,
137 pub library: Option<String>,
139 pub author: Option<String>,
141 pub license: Option<String>,
143 pub min_downloads: Option<u64>,
145 pub min_likes: Option<u64>,
147 pub query: Option<String>,
149 pub limit: usize,
151 pub sort: Option<String>,
153 pub sort_direction: Option<String>,
155}
156
157impl SearchFilters {
158 pub fn new() -> Self {
159 Self { limit: 20, ..Default::default() }
160 }
161
162 pub fn with_task(mut self, task: impl Into<String>) -> Self {
163 self.task = Some(task.into());
164 self
165 }
166
167 pub fn with_library(mut self, library: impl Into<String>) -> Self {
168 self.library = Some(library.into());
169 self
170 }
171
172 pub fn with_author(mut self, author: impl Into<String>) -> Self {
173 self.author = Some(author.into());
174 self
175 }
176
177 pub fn with_license(mut self, license: impl Into<String>) -> Self {
178 self.license = Some(license.into());
179 self
180 }
181
182 pub fn with_min_downloads(mut self, min: u64) -> Self {
183 self.min_downloads = Some(min);
184 self
185 }
186
187 pub fn with_min_likes(mut self, min: u64) -> Self {
188 self.min_likes = Some(min);
189 self
190 }
191
192 pub fn with_query(mut self, query: impl Into<String>) -> Self {
193 self.query = Some(query.into());
194 self
195 }
196
197 pub fn with_limit(mut self, limit: usize) -> Self {
198 self.limit = limit;
199 self
200 }
201
202 pub fn with_sort(mut self, field: impl Into<String>, direction: impl Into<String>) -> Self {
203 self.sort = Some(field.into());
204 self.sort_direction = Some(direction.into());
205 self
206 }
207}
208
209#[derive(Debug, Clone)]
215struct CacheEntry<T> {
216 data: T,
217 created: Instant,
218 ttl: Duration,
219}
220
221impl<T> CacheEntry<T> {
222 fn new(data: T, ttl: Duration) -> Self {
223 Self { data, created: Instant::now(), ttl }
224 }
225
226 fn is_expired(&self) -> bool {
227 self.created.elapsed() > self.ttl
228 }
229}
230
231#[derive(Debug, Default)]
233pub struct ResponseCache {
234 search_cache: HashMap<String, CacheEntry<Vec<HubAsset>>>,
235 asset_cache: HashMap<String, CacheEntry<HubAsset>>,
236 ttl: Duration,
237}
238
239impl ResponseCache {
240 pub fn new(ttl: Duration) -> Self {
241 Self { search_cache: HashMap::new(), asset_cache: HashMap::new(), ttl }
242 }
243
244 pub fn default_ttl() -> Self {
246 Self::new(Duration::from_secs(15 * 60))
247 }
248
249 pub fn cache_search(&mut self, key: &str, results: Vec<HubAsset>) {
251 self.search_cache.insert(key.to_string(), CacheEntry::new(results, self.ttl));
252 }
253
254 pub fn get_search(&self, key: &str) -> Option<&Vec<HubAsset>> {
256 self.search_cache.get(key).and_then(|entry| {
257 if entry.is_expired() {
258 None
259 } else {
260 Some(&entry.data)
261 }
262 })
263 }
264
265 pub fn cache_asset(&mut self, id: &str, asset: HubAsset) {
267 self.asset_cache.insert(id.to_string(), CacheEntry::new(asset, self.ttl));
268 }
269
270 pub fn get_asset(&self, id: &str) -> Option<&HubAsset> {
272 self.asset_cache.get(id).and_then(
273 |entry| {
274 if entry.is_expired() {
275 None
276 } else {
277 Some(&entry.data)
278 }
279 },
280 )
281 }
282
283 pub fn clear_expired(&mut self) {
285 self.search_cache.retain(|_, entry| !entry.is_expired());
286 self.asset_cache.retain(|_, entry| !entry.is_expired());
287 }
288
289 pub fn clear(&mut self) {
291 self.search_cache.clear();
292 self.asset_cache.clear();
293 }
294
295 pub fn stats(&self) -> CacheStats {
297 CacheStats {
298 search_entries: self.search_cache.len(),
299 asset_entries: self.asset_cache.len(),
300 ttl_secs: self.ttl.as_secs(),
301 }
302 }
303}
304
305#[derive(Debug, Clone, Serialize)]
307pub struct CacheStats {
308 pub search_entries: usize,
309 pub asset_entries: usize,
310 pub ttl_secs: u64,
311}
312
313#[derive(Debug)]
319pub struct HubClient {
320 base_url: String,
321 cache: ResponseCache,
322 offline_mode: bool,
323}
324
325impl HubClient {
326 pub fn new() -> Self {
328 Self {
329 base_url: "https://huggingface.co/api".to_string(),
330 cache: ResponseCache::default_ttl(),
331 offline_mode: false,
332 }
333 }
334
335 pub fn with_base_url(base_url: impl Into<String>) -> Self {
337 Self { base_url: base_url.into(), cache: ResponseCache::default_ttl(), offline_mode: false }
338 }
339
340 pub fn offline(mut self) -> Self {
342 self.offline_mode = true;
343 self
344 }
345
346 pub fn cache_stats(&self) -> CacheStats {
348 self.cache.stats()
349 }
350
351 pub fn clear_cache(&mut self) {
353 self.cache.clear();
354 }
355
356 #[instrument(name = "hf.search.models", skip(self), fields(
362 task = filters.task.as_deref(),
363 limit = filters.limit,
364 cache_hit = tracing::field::Empty,
365 result_count = tracing::field::Empty
366 ))]
367 pub fn search_models(&mut self, filters: &SearchFilters) -> Result<Vec<HubAsset>, HubError> {
368 let cache_key = format!("models:{:?}", filters);
369
370 if let Some(cached) = self.cache.get_search(&cache_key) {
372 debug!(cache_hit = true, "Model search cache hit");
373 tracing::Span::current().record("cache_hit", true);
374 tracing::Span::current().record("result_count", cached.len());
375 return Ok(cached.clone());
376 }
377
378 if self.offline_mode {
379 warn!("Model search attempted in offline mode");
380 return Err(HubError::OfflineMode);
381 }
382
383 let results = self.mock_model_search(filters);
386 self.cache.cache_search(&cache_key, results.clone());
387 info!(result_count = results.len(), "Model search completed");
388 tracing::Span::current().record("cache_hit", false);
389 tracing::Span::current().record("result_count", results.len());
390 Ok(results)
391 }
392
393 #[instrument(name = "hf.search.datasets", skip(self), fields(
395 limit = filters.limit,
396 cache_hit = tracing::field::Empty,
397 result_count = tracing::field::Empty
398 ))]
399 pub fn search_datasets(&mut self, filters: &SearchFilters) -> Result<Vec<HubAsset>, HubError> {
400 let cache_key = format!("datasets:{:?}", filters);
401
402 if let Some(cached) = self.cache.get_search(&cache_key) {
403 debug!(cache_hit = true, "Dataset search cache hit");
404 tracing::Span::current().record("cache_hit", true);
405 tracing::Span::current().record("result_count", cached.len());
406 return Ok(cached.clone());
407 }
408
409 if self.offline_mode {
410 warn!("Dataset search attempted in offline mode");
411 return Err(HubError::OfflineMode);
412 }
413
414 let results = self.mock_dataset_search(filters);
415 self.cache.cache_search(&cache_key, results.clone());
416 info!(result_count = results.len(), "Dataset search completed");
417 tracing::Span::current().record("cache_hit", false);
418 tracing::Span::current().record("result_count", results.len());
419 Ok(results)
420 }
421
422 #[instrument(name = "hf.search.spaces", skip(self), fields(
424 limit = filters.limit,
425 cache_hit = tracing::field::Empty,
426 result_count = tracing::field::Empty
427 ))]
428 pub fn search_spaces(&mut self, filters: &SearchFilters) -> Result<Vec<HubAsset>, HubError> {
429 let cache_key = format!("spaces:{:?}", filters);
430
431 if let Some(cached) = self.cache.get_search(&cache_key) {
432 debug!(cache_hit = true, "Space search cache hit");
433 tracing::Span::current().record("cache_hit", true);
434 tracing::Span::current().record("result_count", cached.len());
435 return Ok(cached.clone());
436 }
437
438 if self.offline_mode {
439 warn!("Space search attempted in offline mode");
440 return Err(HubError::OfflineMode);
441 }
442
443 let results = self.mock_space_search(filters);
444 self.cache.cache_search(&cache_key, results.clone());
445 info!(result_count = results.len(), "Space search completed");
446 tracing::Span::current().record("cache_hit", false);
447 tracing::Span::current().record("result_count", results.len());
448 Ok(results)
449 }
450
451 #[instrument(name = "hf.get.model", skip(self), fields(
457 asset_id = id,
458 cache_hit = tracing::field::Empty
459 ))]
460 pub fn get_model(&mut self, id: &str) -> Result<HubAsset, HubError> {
461 let cache_key = format!("model:{}", id);
462
463 if let Some(cached) = self.cache.get_asset(&cache_key) {
464 debug!(cache_hit = true, "Model metadata cache hit");
465 tracing::Span::current().record("cache_hit", true);
466 return Ok(cached.clone());
467 }
468
469 if self.offline_mode {
470 warn!(asset_id = id, "Model get attempted in offline mode");
471 return Err(HubError::OfflineMode);
472 }
473
474 let asset = self.mock_get_model(id)?;
475 self.cache.cache_asset(&cache_key, asset.clone());
476 info!(asset_id = id, "Model metadata retrieved");
477 tracing::Span::current().record("cache_hit", false);
478 Ok(asset)
479 }
480
481 #[instrument(name = "hf.get.dataset", skip(self), fields(
483 asset_id = id,
484 cache_hit = tracing::field::Empty
485 ))]
486 pub fn get_dataset(&mut self, id: &str) -> Result<HubAsset, HubError> {
487 let cache_key = format!("dataset:{}", id);
488
489 if let Some(cached) = self.cache.get_asset(&cache_key) {
490 debug!(cache_hit = true, "Dataset metadata cache hit");
491 tracing::Span::current().record("cache_hit", true);
492 return Ok(cached.clone());
493 }
494
495 if self.offline_mode {
496 warn!(asset_id = id, "Dataset get attempted in offline mode");
497 return Err(HubError::OfflineMode);
498 }
499
500 let asset = self.mock_get_dataset(id)?;
501 self.cache.cache_asset(&cache_key, asset.clone());
502 info!(asset_id = id, "Dataset metadata retrieved");
503 tracing::Span::current().record("cache_hit", false);
504 Ok(asset)
505 }
506
507 #[instrument(name = "hf.get.space", skip(self), fields(
509 asset_id = id,
510 cache_hit = tracing::field::Empty
511 ))]
512 pub fn get_space(&mut self, id: &str) -> Result<HubAsset, HubError> {
513 let cache_key = format!("space:{}", id);
514
515 if let Some(cached) = self.cache.get_asset(&cache_key) {
516 debug!(cache_hit = true, "Space metadata cache hit");
517 tracing::Span::current().record("cache_hit", true);
518 return Ok(cached.clone());
519 }
520
521 if self.offline_mode {
522 warn!(asset_id = id, "Space get attempted in offline mode");
523 return Err(HubError::OfflineMode);
524 }
525
526 let asset = self.mock_get_space(id)?;
527 self.cache.cache_asset(&cache_key, asset.clone());
528 info!(asset_id = id, "Space metadata retrieved");
529 tracing::Span::current().record("cache_hit", false);
530 Ok(asset)
531 }
532
533 fn mock_model_search(&self, filters: &SearchFilters) -> Vec<HubAsset> {
538 let mut results = vec![
539 HubAsset::new("meta-llama/Llama-2-7b-hf", HubAssetType::Model)
540 .with_downloads(5_000_000)
541 .with_likes(10_000)
542 .with_pipeline_tag("text-generation")
543 .with_library("transformers")
544 .with_license("llama2"),
545 HubAsset::new("openai/whisper-large-v3", HubAssetType::Model)
546 .with_downloads(2_000_000)
547 .with_likes(5_000)
548 .with_pipeline_tag("automatic-speech-recognition")
549 .with_library("transformers")
550 .with_license("apache-2.0"),
551 HubAsset::new("stabilityai/stable-diffusion-xl-base-1.0", HubAssetType::Model)
552 .with_downloads(3_000_000)
553 .with_likes(8_000)
554 .with_pipeline_tag("text-to-image")
555 .with_library("diffusers")
556 .with_license("openrail++"),
557 HubAsset::new("sentence-transformers/all-MiniLM-L6-v2", HubAssetType::Model)
558 .with_downloads(10_000_000)
559 .with_likes(2_000)
560 .with_pipeline_tag("sentence-similarity")
561 .with_library("sentence-transformers")
562 .with_license("apache-2.0"),
563 HubAsset::new("bert-base-uncased", HubAssetType::Model)
564 .with_downloads(50_000_000)
565 .with_likes(15_000)
566 .with_pipeline_tag("fill-mask")
567 .with_library("transformers")
568 .with_license("apache-2.0"),
569 ];
570
571 if let Some(ref task) = filters.task {
573 results.retain(|m| m.pipeline_tag.as_ref().is_some_and(|t| t == task));
574 }
575 if let Some(ref library) = filters.library {
576 results.retain(|m| m.library.as_ref().is_some_and(|l| l == library));
577 }
578 if let Some(min) = filters.min_downloads {
579 results.retain(|m| m.downloads >= min);
580 }
581 if let Some(min) = filters.min_likes {
582 results.retain(|m| m.likes >= min);
583 }
584
585 results.truncate(filters.limit);
586 results
587 }
588
589 fn mock_dataset_search(&self, filters: &SearchFilters) -> Vec<HubAsset> {
590 let mut results = vec![
591 HubAsset::new("squad", HubAssetType::Dataset)
592 .with_downloads(5_000_000)
593 .with_likes(1_000)
594 .with_tags(vec!["question-answering".into(), "english".into()]),
595 HubAsset::new("imdb", HubAssetType::Dataset)
596 .with_downloads(3_000_000)
597 .with_likes(500)
598 .with_tags(vec!["text-classification".into(), "sentiment".into()]),
599 HubAsset::new("wikipedia", HubAssetType::Dataset)
600 .with_downloads(10_000_000)
601 .with_likes(2_000)
602 .with_tags(vec!["text".into(), "multilingual".into()]),
603 ];
604
605 if let Some(min) = filters.min_downloads {
606 results.retain(|d| d.downloads >= min);
607 }
608
609 results.truncate(filters.limit);
610 results
611 }
612
613 fn mock_space_search(&self, filters: &SearchFilters) -> Vec<HubAsset> {
614 let mut results = vec![
615 HubAsset::new("gradio/chatbot", HubAssetType::Space)
616 .with_downloads(100_000)
617 .with_likes(500)
618 .with_tags(vec!["gradio".into(), "chat".into()]),
619 HubAsset::new("stabilityai/stable-diffusion", HubAssetType::Space)
620 .with_downloads(500_000)
621 .with_likes(2_000)
622 .with_tags(vec!["gradio".into(), "image-generation".into()]),
623 ];
624
625 if let Some(min) = filters.min_downloads {
626 results.retain(|s| s.downloads >= min);
627 }
628
629 results.truncate(filters.limit);
630 results
631 }
632
633 fn mock_get_model(&self, id: &str) -> Result<HubAsset, HubError> {
634 match id {
636 "meta-llama/Llama-2-7b-hf" => Ok(HubAsset::new(id, HubAssetType::Model)
637 .with_downloads(5_000_000)
638 .with_likes(10_000)
639 .with_pipeline_tag("text-generation")
640 .with_library("transformers")
641 .with_license("llama2")),
642 "bert-base-uncased" => Ok(HubAsset::new(id, HubAssetType::Model)
643 .with_downloads(50_000_000)
644 .with_likes(15_000)
645 .with_pipeline_tag("fill-mask")
646 .with_library("transformers")
647 .with_license("apache-2.0")),
648 _ => Err(HubError::NotFound(id.to_string())),
649 }
650 }
651
652 fn mock_get_dataset(&self, id: &str) -> Result<HubAsset, HubError> {
653 match id {
654 "squad" => Ok(HubAsset::new(id, HubAssetType::Dataset)
655 .with_downloads(5_000_000)
656 .with_likes(1_000)
657 .with_tags(vec!["question-answering".into()])),
658 _ => Err(HubError::NotFound(id.to_string())),
659 }
660 }
661
662 fn mock_get_space(&self, id: &str) -> Result<HubAsset, HubError> {
663 match id {
664 "gradio/chatbot" => Ok(HubAsset::new(id, HubAssetType::Space)
665 .with_downloads(100_000)
666 .with_likes(500)
667 .with_tags(vec!["gradio".into(), "chat".into()])),
668 _ => Err(HubError::NotFound(id.to_string())),
669 }
670 }
671}
672
673impl Default for HubClient {
674 fn default() -> Self {
675 Self::new()
676 }
677}
678
679#[derive(Debug, Clone, PartialEq, Eq)]
685pub enum HubError {
686 NotFound(String),
688 RateLimited { retry_after: Option<u64> },
690 NetworkError(String),
692 OfflineMode,
694 InvalidResponse(String),
696}
697
698impl std::fmt::Display for HubError {
699 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
700 match self {
701 Self::NotFound(id) => write!(f, "Asset not found: {}", id),
702 Self::RateLimited { retry_after } => {
703 if let Some(secs) = retry_after {
704 write!(f, "Rate limited, retry after {} seconds", secs)
705 } else {
706 write!(f, "Rate limited")
707 }
708 }
709 Self::NetworkError(msg) => write!(f, "Network error: {}", msg),
710 Self::OfflineMode => write!(f, "Offline mode: no cached data available"),
711 Self::InvalidResponse(msg) => write!(f, "Invalid response: {}", msg),
712 }
713 }
714}
715
716impl std::error::Error for HubError {}
717
718#[cfg(test)]
723#[allow(non_snake_case)]
724#[path = "hub_client_tests.rs"]
725mod tests;