use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::CacheError;
#[async_trait::async_trait]
pub trait WasiHttpFetch: Send + Sync + std::fmt::Debug {
async fn get(&self, url: &str) -> Result<(u16, Vec<u8>), CacheError>;
}
#[derive(Debug, Clone)]
pub struct WasiModelCache {
base_url: String,
http: Arc<dyn WasiHttpFetch>,
inner: Arc<RwLock<HashMap<String, Arc<[u8]>>>>,
}
impl WasiModelCache {
#[must_use]
pub fn new(http: Arc<dyn WasiHttpFetch>) -> Self {
Self {
base_url: "https://huggingface.co".to_owned(),
http,
inner: Arc::default(),
}
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into().trim_end_matches('/').to_owned();
self
}
#[must_use]
pub fn with_http_client(mut self, http: Arc<dyn WasiHttpFetch>) -> Self {
self.http = http;
self
}
fn cache_key(repo_id: &str, filename: &str) -> String {
format!("{repo_id}/{filename}")
}
#[must_use]
pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
self.inner
.read()
.is_ok_and(|m| m.contains_key(&Self::cache_key(repo_id, filename)))
}
pub async fn download(&self, repo_id: &str, filename: &str) -> Result<Arc<[u8]>, CacheError> {
let key = Self::cache_key(repo_id, filename);
if let Some(bytes) = self.inner.read().ok().and_then(|m| m.get(&key).cloned()) {
return Ok(bytes);
}
let url = format!("{}/{}/resolve/main/{}", self.base_url, repo_id, filename);
let (status, body) = self.http.get(&url).await?;
if status >= 400 {
return Err(CacheError::Download(format!(
"HTTP {status} fetching {key}"
)));
}
let bytes: Arc<[u8]> = Arc::from(body);
if let Ok(mut m) = self.inner.write() {
m.insert(key, Arc::clone(&bytes));
}
Ok(bytes)
}
}