use super::{cache::LoRACache, source::LoRASource};
use anyhow::Result;
use std::{path::PathBuf, sync::Arc};
pub struct LoRADownloader {
sources: Vec<Arc<dyn LoRASource>>,
cache: LoRACache,
}
impl LoRADownloader {
pub fn new(sources: Vec<Arc<dyn LoRASource>>, cache: LoRACache) -> Self {
Self { sources, cache }
}
pub async fn download_if_needed(&self, lora_uri: &str) -> Result<PathBuf> {
if lora_uri.starts_with("file://") {
for source in &self.sources {
if let Ok(exists) = source.exists(lora_uri).await
&& exists
{
return source.download(lora_uri, &PathBuf::new()).await;
}
}
anyhow::bail!("Local LoRA not found: {}", lora_uri);
}
let cache_key = self.uri_to_cache_key(lora_uri);
if self.cache.is_cached(&cache_key) && self.cache.validate_cached(&cache_key)? {
tracing::debug!("LoRA found in cache: {}", cache_key);
return Ok(self.cache.get_cache_path(&cache_key));
}
let dest_path = self.cache.get_cache_path(&cache_key);
for source in &self.sources {
if let Ok(exists) = source.exists(lora_uri).await
&& exists
{
let downloaded_path = source.download(lora_uri, &dest_path).await?;
if self.cache.validate_cached(&cache_key)? {
return Ok(downloaded_path);
} else {
tracing::warn!(
"Downloaded LoRA at {} failed validation",
downloaded_path.display()
);
}
}
}
anyhow::bail!("LoRA {} not found in any source", lora_uri)
}
fn uri_to_cache_key(&self, uri: &str) -> String {
LoRACache::uri_to_cache_key(uri)
}
}