Skip to main content

batuta/hf/
hub_client.rs

1//! HuggingFace Hub API Client
2//!
3//! Implements HF-QUERY-002 (Hub Search) and HF-QUERY-003 (Asset Metadata)
4//!
5//! Provides live queries to HuggingFace Hub API:
6//! - Model search with filters
7//! - Dataset search with filters
8//! - Space search with filters
9//! - Asset metadata retrieval
10//!
11//! ## Observability (HF-OBS-001, HF-OBS-002)
12//!
13//! All Hub operations are instrumented with tracing spans:
14//! - `hf.search.models` - Model search operations
15//! - `hf.search.datasets` - Dataset search operations
16//! - `hf.search.spaces` - Space search operations
17//! - `hf.get.model` - Model metadata retrieval
18//! - `hf.get.dataset` - Dataset metadata retrieval
19//! - `hf.get.space` - Space metadata retrieval
20
21// Allow dead_code for now - these types are tested and will be used
22// once live Hub API integration is implemented (HUB-API milestone)
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::time::{Duration, Instant};
26use tracing::{debug, info, instrument, warn};
27
28// ============================================================================
29// HF-QUERY-002: Hub Asset Types
30// ============================================================================
31
32/// Type of Hub asset
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum HubAssetType {
36    Model,
37    Dataset,
38    Space,
39}
40
41impl std::fmt::Display for HubAssetType {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            Self::Model => write!(f, "model"),
45            Self::Dataset => write!(f, "dataset"),
46            Self::Space => write!(f, "space"),
47        }
48    }
49}
50
51/// Hub asset metadata (model, dataset, or space)
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct HubAsset {
54    /// Asset ID (e.g., "meta-llama/Llama-2-7b-hf")
55    pub id: String,
56    /// Asset type
57    pub asset_type: HubAssetType,
58    /// Author/organization
59    pub author: String,
60    /// Downloads count
61    pub downloads: u64,
62    /// Likes count
63    pub likes: u64,
64    /// Tags
65    pub tags: Vec<String>,
66    /// Pipeline tag (task) - for models
67    pub pipeline_tag: Option<String>,
68    /// Library (transformers, diffusers, etc.) - for models
69    pub library: Option<String>,
70    /// License
71    pub license: Option<String>,
72    /// Last modified timestamp
73    pub last_modified: String,
74    /// Model card/README content (optional, fetched separately)
75    pub card_content: Option<String>,
76}
77
78impl HubAsset {
79    pub fn new(id: impl Into<String>, asset_type: HubAssetType) -> Self {
80        let id_str = id.into();
81        let author = id_str.split('/').next().unwrap_or("unknown").to_string();
82        Self {
83            id: id_str,
84            asset_type,
85            author,
86            downloads: 0,
87            likes: 0,
88            tags: Vec::new(),
89            pipeline_tag: None,
90            library: None,
91            license: None,
92            last_modified: String::new(),
93            card_content: None,
94        }
95    }
96
97    pub fn with_downloads(mut self, downloads: u64) -> Self {
98        self.downloads = downloads;
99        self
100    }
101
102    pub fn with_likes(mut self, likes: u64) -> Self {
103        self.likes = likes;
104        self
105    }
106
107    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
108        self.tags = tags;
109        self
110    }
111
112    pub fn with_pipeline_tag(mut self, tag: impl Into<String>) -> Self {
113        self.pipeline_tag = Some(tag.into());
114        self
115    }
116
117    pub fn with_library(mut self, library: impl Into<String>) -> Self {
118        self.library = Some(library.into());
119        self
120    }
121
122    pub fn with_license(mut self, license: impl Into<String>) -> Self {
123        self.license = Some(license.into());
124        self
125    }
126}
127
128// ============================================================================
129// HF-QUERY-002: Search Filters
130// ============================================================================
131
132/// Search filters for Hub queries
133#[derive(Debug, Clone, Default)]
134pub struct SearchFilters {
135    /// Filter by task (pipeline_tag)
136    pub task: Option<String>,
137    /// Filter by library
138    pub library: Option<String>,
139    /// Filter by author/organization
140    pub author: Option<String>,
141    /// Filter by license
142    pub license: Option<String>,
143    /// Minimum downloads threshold
144    pub min_downloads: Option<u64>,
145    /// Minimum likes threshold
146    pub min_likes: Option<u64>,
147    /// Search query text
148    pub query: Option<String>,
149    /// Maximum results to return
150    pub limit: usize,
151    /// Sort field
152    pub sort: Option<String>,
153    /// Sort direction (asc/desc)
154    pub sort_direction: Option<String>,
155}
156
157impl SearchFilters {
158    pub fn new() -> Self {
159        Self { limit: 20, ..Default::default() }
160    }
161
162    pub fn with_task(mut self, task: impl Into<String>) -> Self {
163        self.task = Some(task.into());
164        self
165    }
166
167    pub fn with_library(mut self, library: impl Into<String>) -> Self {
168        self.library = Some(library.into());
169        self
170    }
171
172    pub fn with_author(mut self, author: impl Into<String>) -> Self {
173        self.author = Some(author.into());
174        self
175    }
176
177    pub fn with_license(mut self, license: impl Into<String>) -> Self {
178        self.license = Some(license.into());
179        self
180    }
181
182    pub fn with_min_downloads(mut self, min: u64) -> Self {
183        self.min_downloads = Some(min);
184        self
185    }
186
187    pub fn with_min_likes(mut self, min: u64) -> Self {
188        self.min_likes = Some(min);
189        self
190    }
191
192    pub fn with_query(mut self, query: impl Into<String>) -> Self {
193        self.query = Some(query.into());
194        self
195    }
196
197    pub fn with_limit(mut self, limit: usize) -> Self {
198        self.limit = limit;
199        self
200    }
201
202    pub fn with_sort(mut self, field: impl Into<String>, direction: impl Into<String>) -> Self {
203        self.sort = Some(field.into());
204        self.sort_direction = Some(direction.into());
205        self
206    }
207}
208
209// ============================================================================
210// HF-QUERY-002/003: Response Cache
211// ============================================================================
212
213/// Cache entry with TTL
214#[derive(Debug, Clone)]
215struct CacheEntry<T> {
216    data: T,
217    created: Instant,
218    ttl: Duration,
219}
220
221impl<T> CacheEntry<T> {
222    fn new(data: T, ttl: Duration) -> Self {
223        Self { data, created: Instant::now(), ttl }
224    }
225
226    fn is_expired(&self) -> bool {
227        self.created.elapsed() > self.ttl
228    }
229}
230
231/// Response cache for Hub queries
232#[derive(Debug, Default)]
233pub struct ResponseCache {
234    search_cache: HashMap<String, CacheEntry<Vec<HubAsset>>>,
235    asset_cache: HashMap<String, CacheEntry<HubAsset>>,
236    ttl: Duration,
237}
238
239impl ResponseCache {
240    pub fn new(ttl: Duration) -> Self {
241        Self { search_cache: HashMap::new(), asset_cache: HashMap::new(), ttl }
242    }
243
244    /// Default cache with 15 minute TTL
245    pub fn default_ttl() -> Self {
246        Self::new(Duration::from_secs(15 * 60))
247    }
248
249    /// Cache a search result
250    pub fn cache_search(&mut self, key: &str, results: Vec<HubAsset>) {
251        self.search_cache.insert(key.to_string(), CacheEntry::new(results, self.ttl));
252    }
253
254    /// Get cached search result
255    pub fn get_search(&self, key: &str) -> Option<&Vec<HubAsset>> {
256        self.search_cache.get(key).and_then(|entry| {
257            if entry.is_expired() {
258                None
259            } else {
260                Some(&entry.data)
261            }
262        })
263    }
264
265    /// Cache an asset
266    pub fn cache_asset(&mut self, id: &str, asset: HubAsset) {
267        self.asset_cache.insert(id.to_string(), CacheEntry::new(asset, self.ttl));
268    }
269
270    /// Get cached asset
271    pub fn get_asset(&self, id: &str) -> Option<&HubAsset> {
272        self.asset_cache.get(id).and_then(
273            |entry| {
274                if entry.is_expired() {
275                    None
276                } else {
277                    Some(&entry.data)
278                }
279            },
280        )
281    }
282
283    /// Clear expired entries
284    pub fn clear_expired(&mut self) {
285        self.search_cache.retain(|_, entry| !entry.is_expired());
286        self.asset_cache.retain(|_, entry| !entry.is_expired());
287    }
288
289    /// Clear all cache
290    pub fn clear(&mut self) {
291        self.search_cache.clear();
292        self.asset_cache.clear();
293    }
294
295    /// Get cache statistics
296    pub fn stats(&self) -> CacheStats {
297        CacheStats {
298            search_entries: self.search_cache.len(),
299            asset_entries: self.asset_cache.len(),
300            ttl_secs: self.ttl.as_secs(),
301        }
302    }
303}
304
305/// Cache statistics
306#[derive(Debug, Clone, Serialize)]
307pub struct CacheStats {
308    pub search_entries: usize,
309    pub asset_entries: usize,
310    pub ttl_secs: u64,
311}
312
313// ============================================================================
314// HF-QUERY-002/003: Hub Client
315// ============================================================================
316
317/// HuggingFace Hub API client
318#[derive(Debug)]
319pub struct HubClient {
320    base_url: String,
321    cache: ResponseCache,
322    offline_mode: bool,
323}
324
325impl HubClient {
326    /// Create new client with default settings
327    pub fn new() -> Self {
328        Self {
329            base_url: "https://huggingface.co/api".to_string(),
330            cache: ResponseCache::default_ttl(),
331            offline_mode: false,
332        }
333    }
334
335    /// Create client with custom base URL (for testing)
336    pub fn with_base_url(base_url: impl Into<String>) -> Self {
337        Self { base_url: base_url.into(), cache: ResponseCache::default_ttl(), offline_mode: false }
338    }
339
340    /// Enable offline mode (only return cached data)
341    pub fn offline(mut self) -> Self {
342        self.offline_mode = true;
343        self
344    }
345
346    /// Get cache statistics
347    pub fn cache_stats(&self) -> CacheStats {
348        self.cache.stats()
349    }
350
351    /// Clear cache
352    pub fn clear_cache(&mut self) {
353        self.cache.clear();
354    }
355
356    // ========================================================================
357    // HF-QUERY-002: Search Methods (HF-OBS-001: Instrumented with tracing)
358    // ========================================================================
359
360    /// Search models on HuggingFace Hub
361    #[instrument(name = "hf.search.models", skip(self), fields(
362        task = filters.task.as_deref(),
363        limit = filters.limit,
364        cache_hit = tracing::field::Empty,
365        result_count = tracing::field::Empty
366    ))]
367    pub fn search_models(&mut self, filters: &SearchFilters) -> Result<Vec<HubAsset>, HubError> {
368        let cache_key = format!("models:{:?}", filters);
369
370        // Check cache first
371        if let Some(cached) = self.cache.get_search(&cache_key) {
372            debug!(cache_hit = true, "Model search cache hit");
373            tracing::Span::current().record("cache_hit", true);
374            tracing::Span::current().record("result_count", cached.len());
375            return Ok(cached.clone());
376        }
377
378        if self.offline_mode {
379            warn!("Model search attempted in offline mode");
380            return Err(HubError::OfflineMode);
381        }
382
383        // In a real implementation, this would make an HTTP request
384        // For now, return mock data for testing
385        let results = self.mock_model_search(filters);
386        self.cache.cache_search(&cache_key, results.clone());
387        info!(result_count = results.len(), "Model search completed");
388        tracing::Span::current().record("cache_hit", false);
389        tracing::Span::current().record("result_count", results.len());
390        Ok(results)
391    }
392
393    /// Search datasets on HuggingFace Hub
394    #[instrument(name = "hf.search.datasets", skip(self), fields(
395        limit = filters.limit,
396        cache_hit = tracing::field::Empty,
397        result_count = tracing::field::Empty
398    ))]
399    pub fn search_datasets(&mut self, filters: &SearchFilters) -> Result<Vec<HubAsset>, HubError> {
400        let cache_key = format!("datasets:{:?}", filters);
401
402        if let Some(cached) = self.cache.get_search(&cache_key) {
403            debug!(cache_hit = true, "Dataset search cache hit");
404            tracing::Span::current().record("cache_hit", true);
405            tracing::Span::current().record("result_count", cached.len());
406            return Ok(cached.clone());
407        }
408
409        if self.offline_mode {
410            warn!("Dataset search attempted in offline mode");
411            return Err(HubError::OfflineMode);
412        }
413
414        let results = self.mock_dataset_search(filters);
415        self.cache.cache_search(&cache_key, results.clone());
416        info!(result_count = results.len(), "Dataset search completed");
417        tracing::Span::current().record("cache_hit", false);
418        tracing::Span::current().record("result_count", results.len());
419        Ok(results)
420    }
421
422    /// Search spaces on HuggingFace Hub
423    #[instrument(name = "hf.search.spaces", skip(self), fields(
424        limit = filters.limit,
425        cache_hit = tracing::field::Empty,
426        result_count = tracing::field::Empty
427    ))]
428    pub fn search_spaces(&mut self, filters: &SearchFilters) -> Result<Vec<HubAsset>, HubError> {
429        let cache_key = format!("spaces:{:?}", filters);
430
431        if let Some(cached) = self.cache.get_search(&cache_key) {
432            debug!(cache_hit = true, "Space search cache hit");
433            tracing::Span::current().record("cache_hit", true);
434            tracing::Span::current().record("result_count", cached.len());
435            return Ok(cached.clone());
436        }
437
438        if self.offline_mode {
439            warn!("Space search attempted in offline mode");
440            return Err(HubError::OfflineMode);
441        }
442
443        let results = self.mock_space_search(filters);
444        self.cache.cache_search(&cache_key, results.clone());
445        info!(result_count = results.len(), "Space search completed");
446        tracing::Span::current().record("cache_hit", false);
447        tracing::Span::current().record("result_count", results.len());
448        Ok(results)
449    }
450
451    // ========================================================================
452    // HF-QUERY-003: Asset Metadata Methods (HF-OBS-002: Instrumented with tracing)
453    // ========================================================================
454
455    /// Get model metadata
456    #[instrument(name = "hf.get.model", skip(self), fields(
457        asset_id = id,
458        cache_hit = tracing::field::Empty
459    ))]
460    pub fn get_model(&mut self, id: &str) -> Result<HubAsset, HubError> {
461        let cache_key = format!("model:{}", id);
462
463        if let Some(cached) = self.cache.get_asset(&cache_key) {
464            debug!(cache_hit = true, "Model metadata cache hit");
465            tracing::Span::current().record("cache_hit", true);
466            return Ok(cached.clone());
467        }
468
469        if self.offline_mode {
470            warn!(asset_id = id, "Model get attempted in offline mode");
471            return Err(HubError::OfflineMode);
472        }
473
474        let asset = self.mock_get_model(id)?;
475        self.cache.cache_asset(&cache_key, asset.clone());
476        info!(asset_id = id, "Model metadata retrieved");
477        tracing::Span::current().record("cache_hit", false);
478        Ok(asset)
479    }
480
481    /// Get dataset metadata
482    #[instrument(name = "hf.get.dataset", skip(self), fields(
483        asset_id = id,
484        cache_hit = tracing::field::Empty
485    ))]
486    pub fn get_dataset(&mut self, id: &str) -> Result<HubAsset, HubError> {
487        let cache_key = format!("dataset:{}", id);
488
489        if let Some(cached) = self.cache.get_asset(&cache_key) {
490            debug!(cache_hit = true, "Dataset metadata cache hit");
491            tracing::Span::current().record("cache_hit", true);
492            return Ok(cached.clone());
493        }
494
495        if self.offline_mode {
496            warn!(asset_id = id, "Dataset get attempted in offline mode");
497            return Err(HubError::OfflineMode);
498        }
499
500        let asset = self.mock_get_dataset(id)?;
501        self.cache.cache_asset(&cache_key, asset.clone());
502        info!(asset_id = id, "Dataset metadata retrieved");
503        tracing::Span::current().record("cache_hit", false);
504        Ok(asset)
505    }
506
507    /// Get space metadata
508    #[instrument(name = "hf.get.space", skip(self), fields(
509        asset_id = id,
510        cache_hit = tracing::field::Empty
511    ))]
512    pub fn get_space(&mut self, id: &str) -> Result<HubAsset, HubError> {
513        let cache_key = format!("space:{}", id);
514
515        if let Some(cached) = self.cache.get_asset(&cache_key) {
516            debug!(cache_hit = true, "Space metadata cache hit");
517            tracing::Span::current().record("cache_hit", true);
518            return Ok(cached.clone());
519        }
520
521        if self.offline_mode {
522            warn!(asset_id = id, "Space get attempted in offline mode");
523            return Err(HubError::OfflineMode);
524        }
525
526        let asset = self.mock_get_space(id)?;
527        self.cache.cache_asset(&cache_key, asset.clone());
528        info!(asset_id = id, "Space metadata retrieved");
529        tracing::Span::current().record("cache_hit", false);
530        Ok(asset)
531    }
532
533    // ========================================================================
534    // Mock implementations (replace with real API calls)
535    // ========================================================================
536
537    fn mock_model_search(&self, filters: &SearchFilters) -> Vec<HubAsset> {
538        let mut results = vec![
539            HubAsset::new("meta-llama/Llama-2-7b-hf", HubAssetType::Model)
540                .with_downloads(5_000_000)
541                .with_likes(10_000)
542                .with_pipeline_tag("text-generation")
543                .with_library("transformers")
544                .with_license("llama2"),
545            HubAsset::new("openai/whisper-large-v3", HubAssetType::Model)
546                .with_downloads(2_000_000)
547                .with_likes(5_000)
548                .with_pipeline_tag("automatic-speech-recognition")
549                .with_library("transformers")
550                .with_license("apache-2.0"),
551            HubAsset::new("stabilityai/stable-diffusion-xl-base-1.0", HubAssetType::Model)
552                .with_downloads(3_000_000)
553                .with_likes(8_000)
554                .with_pipeline_tag("text-to-image")
555                .with_library("diffusers")
556                .with_license("openrail++"),
557            HubAsset::new("sentence-transformers/all-MiniLM-L6-v2", HubAssetType::Model)
558                .with_downloads(10_000_000)
559                .with_likes(2_000)
560                .with_pipeline_tag("sentence-similarity")
561                .with_library("sentence-transformers")
562                .with_license("apache-2.0"),
563            HubAsset::new("bert-base-uncased", HubAssetType::Model)
564                .with_downloads(50_000_000)
565                .with_likes(15_000)
566                .with_pipeline_tag("fill-mask")
567                .with_library("transformers")
568                .with_license("apache-2.0"),
569        ];
570
571        // Apply filters
572        if let Some(ref task) = filters.task {
573            results.retain(|m| m.pipeline_tag.as_ref().is_some_and(|t| t == task));
574        }
575        if let Some(ref library) = filters.library {
576            results.retain(|m| m.library.as_ref().is_some_and(|l| l == library));
577        }
578        if let Some(min) = filters.min_downloads {
579            results.retain(|m| m.downloads >= min);
580        }
581        if let Some(min) = filters.min_likes {
582            results.retain(|m| m.likes >= min);
583        }
584
585        results.truncate(filters.limit);
586        results
587    }
588
589    fn mock_dataset_search(&self, filters: &SearchFilters) -> Vec<HubAsset> {
590        let mut results = vec![
591            HubAsset::new("squad", HubAssetType::Dataset)
592                .with_downloads(5_000_000)
593                .with_likes(1_000)
594                .with_tags(vec!["question-answering".into(), "english".into()]),
595            HubAsset::new("imdb", HubAssetType::Dataset)
596                .with_downloads(3_000_000)
597                .with_likes(500)
598                .with_tags(vec!["text-classification".into(), "sentiment".into()]),
599            HubAsset::new("wikipedia", HubAssetType::Dataset)
600                .with_downloads(10_000_000)
601                .with_likes(2_000)
602                .with_tags(vec!["text".into(), "multilingual".into()]),
603        ];
604
605        if let Some(min) = filters.min_downloads {
606            results.retain(|d| d.downloads >= min);
607        }
608
609        results.truncate(filters.limit);
610        results
611    }
612
613    fn mock_space_search(&self, filters: &SearchFilters) -> Vec<HubAsset> {
614        let mut results = vec![
615            HubAsset::new("gradio/chatbot", HubAssetType::Space)
616                .with_downloads(100_000)
617                .with_likes(500)
618                .with_tags(vec!["gradio".into(), "chat".into()]),
619            HubAsset::new("stabilityai/stable-diffusion", HubAssetType::Space)
620                .with_downloads(500_000)
621                .with_likes(2_000)
622                .with_tags(vec!["gradio".into(), "image-generation".into()]),
623        ];
624
625        if let Some(min) = filters.min_downloads {
626            results.retain(|s| s.downloads >= min);
627        }
628
629        results.truncate(filters.limit);
630        results
631    }
632
633    fn mock_get_model(&self, id: &str) -> Result<HubAsset, HubError> {
634        // Return mock data for known models
635        match id {
636            "meta-llama/Llama-2-7b-hf" => Ok(HubAsset::new(id, HubAssetType::Model)
637                .with_downloads(5_000_000)
638                .with_likes(10_000)
639                .with_pipeline_tag("text-generation")
640                .with_library("transformers")
641                .with_license("llama2")),
642            "bert-base-uncased" => Ok(HubAsset::new(id, HubAssetType::Model)
643                .with_downloads(50_000_000)
644                .with_likes(15_000)
645                .with_pipeline_tag("fill-mask")
646                .with_library("transformers")
647                .with_license("apache-2.0")),
648            _ => Err(HubError::NotFound(id.to_string())),
649        }
650    }
651
652    fn mock_get_dataset(&self, id: &str) -> Result<HubAsset, HubError> {
653        match id {
654            "squad" => Ok(HubAsset::new(id, HubAssetType::Dataset)
655                .with_downloads(5_000_000)
656                .with_likes(1_000)
657                .with_tags(vec!["question-answering".into()])),
658            _ => Err(HubError::NotFound(id.to_string())),
659        }
660    }
661
662    fn mock_get_space(&self, id: &str) -> Result<HubAsset, HubError> {
663        match id {
664            "gradio/chatbot" => Ok(HubAsset::new(id, HubAssetType::Space)
665                .with_downloads(100_000)
666                .with_likes(500)
667                .with_tags(vec!["gradio".into(), "chat".into()])),
668            _ => Err(HubError::NotFound(id.to_string())),
669        }
670    }
671}
672
673impl Default for HubClient {
674    fn default() -> Self {
675        Self::new()
676    }
677}
678
679// ============================================================================
680// Error Types
681// ============================================================================
682
683/// Hub API error
684#[derive(Debug, Clone, PartialEq, Eq)]
685pub enum HubError {
686    /// Asset not found
687    NotFound(String),
688    /// Rate limited
689    RateLimited { retry_after: Option<u64> },
690    /// Network error
691    NetworkError(String),
692    /// Offline mode - no cached data available
693    OfflineMode,
694    /// Invalid response from API
695    InvalidResponse(String),
696}
697
698impl std::fmt::Display for HubError {
699    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
700        match self {
701            Self::NotFound(id) => write!(f, "Asset not found: {}", id),
702            Self::RateLimited { retry_after } => {
703                if let Some(secs) = retry_after {
704                    write!(f, "Rate limited, retry after {} seconds", secs)
705                } else {
706                    write!(f, "Rate limited")
707                }
708            }
709            Self::NetworkError(msg) => write!(f, "Network error: {}", msg),
710            Self::OfflineMode => write!(f, "Offline mode: no cached data available"),
711            Self::InvalidResponse(msg) => write!(f, "Invalid response: {}", msg),
712        }
713    }
714}
715
716impl std::error::Error for HubError {}
717
718// ============================================================================
719// Tests - Extreme TDD
720// ============================================================================
721
722#[cfg(test)]
723#[allow(non_snake_case)]
724#[path = "hub_client_tests.rs"]
725mod tests;