use std::path::{Path, PathBuf};
use xybrid_core::cache_provider::CacheProvider;
use super::cache_manager::CacheManager;
use crate::model::SdkError;
pub struct SdkCacheProvider {
cache: CacheManager,
}
impl SdkCacheProvider {
pub fn with_dir(cache_dir: PathBuf) -> Result<Self, SdkError> {
let cache = CacheManager::with_dir(cache_dir)?;
Ok(Self {
cache,
})
}
fn find_matching_dir(&self, model_id: &str) -> Option<PathBuf> {
let model_id_lower = model_id.to_lowercase();
let model_id_normalized = normalize_name(&model_id_lower);
let cache_dir = self.cache.cache_dir();
if !cache_dir.exists() {
return None;
}
if let Ok(entries) = std::fs::read_dir(cache_dir) {
for entry in entries.flatten() {
let entry_path = entry.path();
if !entry_path.is_dir() {
continue;
}
let dir_name = entry.file_name().to_string_lossy().to_lowercase();
let dir_name_normalized = normalize_name(&dir_name);
let is_match = dir_name.contains(&model_id_lower)
|| dir_name_normalized.contains(&model_id_normalized);
if is_match && has_model_files(&entry_path) {
return Some(entry_path);
}
}
}
if let Some(fixtures_dir) = xybrid_core::testing::model_fixtures::models_dir() {
let fixtures_path = fixtures_dir.join(model_id);
if fixtures_path.exists() && has_model_files(&fixtures_path) {
return Some(fixtures_path);
}
}
None
}
}
impl CacheProvider for SdkCacheProvider {
fn is_model_cached(&self, model_id: &str) -> bool {
if self.cache.is_cached(model_id) {
return true;
}
self.find_matching_dir(model_id).is_some()
}
fn get_model_path(&self, model_id: &str) -> Option<PathBuf> {
if let Some(path) = self.cache.get_cached_path(model_id) {
return path.parent().map(|p| p.to_path_buf());
}
self.find_matching_dir(model_id)
}
fn cache_dir(&self) -> PathBuf {
self.cache.cache_dir().to_path_buf()
}
fn name(&self) -> &'static str {
"sdk"
}
}
fn normalize_name(name: &str) -> String {
name.replace(['-', '_', '.'], "")
}
fn has_model_files(path: &Path) -> bool {
path.join("universal.xyb").exists()
|| path.join("model_metadata.json").exists()
|| path.join("model.onnx").exists()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_normalize_name() {
assert_eq!(normalize_name("kokoro-82m"), "kokoro82m");
assert_eq!(normalize_name("kokoro_82m"), "kokoro82m");
assert_eq!(normalize_name("kokoro.82m.v1.0"), "kokoro82mv10");
}
#[test]
fn test_sdk_cache_provider_creation() {
let temp_dir = TempDir::new().unwrap();
let provider = SdkCacheProvider::with_dir(temp_dir.path().to_path_buf()).unwrap();
assert_eq!(provider.name(), "sdk");
assert_eq!(provider.cache_dir(), temp_dir.path());
}
#[test]
fn test_is_model_cached_empty() {
let temp_dir = TempDir::new().unwrap();
let provider = SdkCacheProvider::with_dir(temp_dir.path().to_path_buf()).unwrap();
assert!(!provider.is_model_cached("nonexistent-model"));
}
#[test]
fn test_get_model_path_empty() {
let temp_dir = TempDir::new().unwrap();
let provider = SdkCacheProvider::with_dir(temp_dir.path().to_path_buf()).unwrap();
assert!(provider.get_model_path("nonexistent-model").is_none());
}
#[test]
fn test_fuzzy_matching_with_directory() {
let temp_dir = TempDir::new().unwrap();
let model_dir = temp_dir.path().join("Kokoro-82M-v1.0-ONNX");
std::fs::create_dir_all(&model_dir).unwrap();
std::fs::write(model_dir.join("universal.xyb"), b"fake bundle").unwrap();
let provider = SdkCacheProvider::with_dir(temp_dir.path().to_path_buf()).unwrap();
assert!(provider.is_model_cached("kokoro-82m"));
let path = provider.get_model_path("kokoro-82m");
assert!(path.is_some());
assert!(path.unwrap().to_string_lossy().contains("Kokoro-82M"));
}
}