use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::env;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveredModel {
pub name: String,
pub path: PathBuf,
pub format: ModelFormat,
pub size_bytes: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelFormat {
Gguf,
SafeTensors,
}
#[derive(Debug)]
pub struct ModelDiscovery {
search_paths: Vec<PathBuf>,
}
impl Default for ModelDiscovery {
fn default() -> Self {
Self::new()
}
}
impl ModelDiscovery {
pub fn new() -> Self {
Self {
search_paths: Vec::new(),
}
}
pub fn from_env() -> Self {
let mut discovery = Self::new();
if let Ok(base_path) = env::var("SHIMMY_BASE_GGUF") {
if let Some(parent) = Path::new(&base_path).parent() {
discovery.add_search_path(parent.to_path_buf());
}
}
if let Ok(custom_dirs) = env::var("SHIMMY_MODEL_PATHS") {
for dir in custom_dirs.split(';').filter(|s| !s.is_empty()) {
discovery.add_search_path(PathBuf::from(dir));
}
}
if let Ok(ollama_models) = env::var("OLLAMA_MODELS") {
discovery.add_search_path(PathBuf::from(ollama_models));
}
if let Ok(home) = env::var("HOME").or_else(|_| env::var("USERPROFILE")) {
let home_path = PathBuf::from(home);
discovery.add_search_path(home_path.join(".cache/huggingface"));
discovery.add_search_path(home_path.join(".ollama/models"));
discovery.add_search_path(home_path.join(".cache/lm-studio/models"));
discovery.add_search_path(home_path.join("models"));
}
#[cfg(windows)]
{
for drive in &["C:", "D:", "E:", "F:"] {
let ollama_path = PathBuf::from(format!(
"{}\\Users\\{}\\AppData\\Local\\Ollama\\models",
drive,
env::var("USERNAME").unwrap_or_default()
));
discovery.add_search_path(ollama_path);
let alt_ollama = PathBuf::from(format!("{}\\Ollama\\models", drive));
discovery.add_search_path(alt_ollama);
}
}
discovery
}
pub fn add_search_path(&mut self, path: PathBuf) {
self.search_paths.push(path);
}
pub fn search_paths(&self) -> &[PathBuf] {
&self.search_paths
}
pub fn discover_models(&self) -> Result<Vec<DiscoveredModel>> {
println!(
"DEBUG: discover_models called, search_paths: {:?}",
self.search_paths
);
let mut models = Vec::new();
for path in &self.search_paths {
if path.exists() {
self.scan_directory(path, &mut models)?;
}
}
Ok(models)
}
fn scan_directory(&self, dir: &Path, models: &mut Vec<DiscoveredModel>) -> Result<()> {
println!("DEBUG: Scanning directory: {:?}", dir);
let mut model_files = Vec::new();
let mut subdirs = Vec::new();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
subdirs.push(path);
} else if self.is_model_file(&path) {
model_files.push(path);
}
}
let grouped_models = self.group_sharded_models(dir, &model_files)?;
for model in grouped_models {
models.push(model);
}
for subdir in subdirs {
self.scan_directory(&subdir, models)?;
}
Ok(())
}
fn group_sharded_models(
&self,
dir: &Path,
model_files: &[PathBuf],
) -> Result<Vec<DiscoveredModel>> {
println!(
"DEBUG: group_sharded_models called for dir: {:?}, files: {}",
dir,
model_files.len()
);
use regex::Regex;
use std::collections::HashMap;
let mut grouped_models = Vec::new();
let mut processed_files = std::collections::HashSet::new();
let shard_pattern = Regex::new(r"^(.+)-\d{5}-of-\d{5}(\..+)$").unwrap();
let mut shard_groups: HashMap<String, Vec<PathBuf>> = HashMap::new();
for file_path in model_files {
if let Some(filename) = file_path.file_name().and_then(|f| f.to_str()) {
println!("DEBUG: Checking file: {}", filename);
if let Some(captures) = shard_pattern.captures(filename) {
let base_name = captures.get(1).unwrap().as_str();
let extension = captures.get(2).unwrap().as_str();
let group_key = format!("{}{}", base_name, extension);
println!(
"DEBUG: Matched sharded file - base: {}, ext: {}, key: {}",
base_name, extension, group_key
);
shard_groups
.entry(group_key)
.or_default()
.push(file_path.clone());
processed_files.insert(file_path.clone());
} else {
println!("DEBUG: No match for: {}", filename);
}
}
}
for (group_key, files) in shard_groups {
if files.len() > 1 {
let total_size: u64 = files
.iter()
.filter_map(|path| fs::metadata(path).ok().map(|m| m.len()))
.sum();
let model_name = dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(&group_key)
.to_string();
let primary_path = files[0].clone();
let format = if group_key.ends_with(".safetensors") {
ModelFormat::SafeTensors
} else {
ModelFormat::Gguf
};
grouped_models.push(DiscoveredModel {
name: model_name,
path: primary_path,
format,
size_bytes: Some(total_size),
});
}
}
for file_path in model_files {
if !processed_files.contains(file_path) {
if let Ok(model) = self.analyze_model_file(file_path) {
grouped_models.push(model);
}
}
}
Ok(grouped_models)
}
fn is_model_file(&self, path: &Path) -> bool {
if let Some(ext) = path.extension() {
if matches!(ext.to_str(), Some("gguf") | Some("safetensors")) {
return self.is_llm_model(path);
}
}
false
}
fn is_llm_model(&self, path: &Path) -> bool {
let filename = path
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("")
.to_lowercase();
let non_llm_patterns = [
"flux",
"sd",
"stable-diffusion",
"sdxl",
"dalle",
"midjourney",
"video",
"vid",
"animate",
"motion",
"whisper",
"audio",
"speech",
"tts",
"voice",
"clip",
"embed",
"encoder",
"vision",
"vae",
"unet",
"controlnet",
"lora",
"adapter",
];
if non_llm_patterns
.iter()
.any(|pattern| filename.contains(pattern))
{
return false;
}
if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
return true;
}
let llm_patterns = [
"llama",
"mistral",
"qwen",
"phi",
"gemma",
"codellama",
"vicuna",
"alpaca",
"orca",
"falcon",
"mpt",
"gpt",
"claude",
"chatglm",
"baichuan",
"yi",
"deepseek",
"mixtral",
"solar",
"openchat",
"starling",
"wizardlm",
"dolphin",
"nous",
"hermes",
"airoboros",
];
if llm_patterns
.iter()
.any(|pattern| filename.contains(pattern))
{
return true;
}
true
}
fn analyze_model_file(&self, path: &Path) -> Result<DiscoveredModel> {
let format = match path.extension().and_then(|s| s.to_str()) {
Some("gguf") => ModelFormat::Gguf,
Some("safetensors") => ModelFormat::SafeTensors,
_ => return Err(anyhow::anyhow!("Unknown model format")),
};
let size_bytes = fs::metadata(path).ok().map(|m| m.len());
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
Ok(DiscoveredModel {
name,
path: path.to_path_buf(),
format,
size_bytes,
})
}
}
pub fn discover_models_from_directory(path: &Path) -> Result<Vec<DiscoveredModel>> {
let mut discovery = ModelDiscovery::new();
discovery.add_search_path(path.to_path_buf());
discovery.discover_models()
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_model_discovery_new() {
let discovery = ModelDiscovery::new();
assert_eq!(discovery.search_paths.len(), 0);
}
#[test]
fn test_add_search_path() {
let mut discovery = ModelDiscovery::new();
let test_path = PathBuf::from("/test/path");
discovery.add_search_path(test_path.clone());
assert_eq!(discovery.search_paths.len(), 1);
assert_eq!(discovery.search_paths[0], test_path);
}
#[test]
fn test_from_env_with_shimmy_base_gguf() {
env::set_var("SHIMMY_BASE_GGUF", "/models/test.gguf");
let discovery = ModelDiscovery::from_env();
assert!(!discovery.search_paths.is_empty());
assert!(discovery
.search_paths
.iter()
.any(|p| p.to_string_lossy().contains("models")));
env::remove_var("SHIMMY_BASE_GGUF");
}
#[test]
fn test_from_env_with_home_directories() {
let original_home = env::var("HOME").or_else(|_| env::var("USERPROFILE"));
env::set_var("HOME", "/test/home");
let discovery = ModelDiscovery::from_env();
assert!(discovery
.search_paths
.iter()
.any(|p| p.to_string_lossy().contains(".cache/huggingface")));
assert!(discovery
.search_paths
.iter()
.any(|p| p.to_string_lossy().contains("models")));
env::remove_var("HOME");
if let Ok(home) = original_home {
env::set_var("HOME", home);
}
}
#[test]
fn test_is_model_file() {
let discovery = ModelDiscovery::new();
assert!(discovery.is_model_file(&PathBuf::from("test.gguf")));
assert!(discovery.is_model_file(&PathBuf::from("/path/to/model.gguf")));
assert!(discovery.is_model_file(&PathBuf::from("test.safetensors")));
assert!(discovery.is_model_file(&PathBuf::from("/path/to/model.safetensors")));
assert!(!discovery.is_model_file(&PathBuf::from("test.txt")));
assert!(!discovery.is_model_file(&PathBuf::from("test.bin")));
assert!(!discovery.is_model_file(&PathBuf::from("test")));
}
#[test]
fn test_analyze_model_file_gguf() -> Result<()> {
let temp_dir = TempDir::new()?;
let model_path = temp_dir.path().join("test-model.gguf");
fs::write(&model_path, "dummy gguf content")?;
let discovery = ModelDiscovery::new();
let model = discovery.analyze_model_file(&model_path)?;
assert_eq!(model.name, "test-model");
assert_eq!(model.path, model_path);
assert!(matches!(model.format, ModelFormat::Gguf));
assert!(model.size_bytes.is_some());
assert_eq!(model.size_bytes.unwrap(), "dummy gguf content".len() as u64);
Ok(())
}
#[test]
fn test_analyze_model_file_safetensors() -> Result<()> {
let temp_dir = TempDir::new()?;
let model_path = temp_dir.path().join("test-model.safetensors");
fs::write(&model_path, "dummy safetensors content")?;
let discovery = ModelDiscovery::new();
let model = discovery.analyze_model_file(&model_path)?;
assert_eq!(model.name, "test-model");
assert_eq!(model.path, model_path);
assert!(matches!(model.format, ModelFormat::SafeTensors));
assert!(model.size_bytes.is_some());
assert_eq!(
model.size_bytes.unwrap(),
"dummy safetensors content".len() as u64
);
Ok(())
}
#[test]
fn test_analyze_model_file_unknown_format() {
let temp_dir = TempDir::new().unwrap();
let model_path = temp_dir.path().join("test-model.unknown");
fs::write(&model_path, "dummy content").unwrap();
let discovery = ModelDiscovery::new();
let result = discovery.analyze_model_file(&model_path);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unknown model format"));
}
#[test]
fn test_analyze_model_file_no_metadata() {
let discovery = ModelDiscovery::new();
let nonexistent_path = PathBuf::from("/nonexistent/model.gguf");
let result = discovery.analyze_model_file(&nonexistent_path);
if let Ok(model) = result {
assert_eq!(model.name, "model");
assert!(matches!(model.format, ModelFormat::Gguf));
assert!(model.size_bytes.is_none());
}
}
#[test]
fn test_discover_models_empty_paths() {
let discovery = ModelDiscovery::new();
let models = discovery.discover_models().unwrap();
assert_eq!(models.len(), 0);
}
#[test]
fn test_discover_models_nonexistent_paths() {
let mut discovery = ModelDiscovery::new();
discovery.add_search_path(PathBuf::from("/nonexistent/path"));
let models = discovery.discover_models().unwrap();
assert_eq!(models.len(), 0);
}
#[test]
fn test_discover_models_with_files() -> Result<()> {
let temp_dir = TempDir::new()?;
fs::write(temp_dir.path().join("model1.gguf"), "content1")?;
fs::write(temp_dir.path().join("model2.safetensors"), "content2")?;
fs::write(temp_dir.path().join("not_model.txt"), "not a model")?;
let subdir = temp_dir.path().join("subdir");
fs::create_dir(&subdir)?;
fs::write(subdir.join("model3.gguf"), "content3")?;
let mut discovery = ModelDiscovery::new();
discovery.add_search_path(temp_dir.path().to_path_buf());
let models = discovery.discover_models()?;
assert_eq!(models.len(), 3);
let names: Vec<String> = models.iter().map(|m| m.name.clone()).collect();
assert!(names.contains(&"model1".to_string()));
assert!(names.contains(&"model2".to_string()));
assert!(names.contains(&"model3".to_string()));
Ok(())
}
#[test]
fn test_scan_directory_recursive() -> Result<()> {
let temp_dir = TempDir::new()?;
let level1 = temp_dir.path().join("level1");
let level2 = level1.join("level2");
fs::create_dir_all(&level2)?;
fs::write(temp_dir.path().join("root.gguf"), "root content")?;
fs::write(level1.join("level1.gguf"), "level1 content")?;
fs::write(level2.join("level2.safetensors"), "level2 content")?;
let discovery = ModelDiscovery::new();
let mut models = Vec::new();
discovery.scan_directory(temp_dir.path(), &mut models)?;
assert_eq!(models.len(), 3);
let names: Vec<String> = models.iter().map(|m| m.name.clone()).collect();
assert!(names.contains(&"root".to_string()));
assert!(names.contains(&"level1".to_string()));
assert!(names.contains(&"level2".to_string()));
Ok(())
}
#[test]
fn test_scan_directory_error_handling() {
let discovery = ModelDiscovery::new();
let mut models = Vec::new();
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("not_a_dir.txt");
fs::write(&file_path, "content").unwrap();
let result = discovery.scan_directory(&file_path, &mut models);
assert!(result.is_err());
}
#[test]
fn test_model_format_serialization() {
let gguf = ModelFormat::Gguf;
let safetensors = ModelFormat::SafeTensors;
let gguf_json = serde_json::to_string(&gguf).unwrap();
let safetensors_json = serde_json::to_string(&safetensors).unwrap();
assert!(gguf_json.contains("Gguf"));
assert!(safetensors_json.contains("SafeTensors"));
let gguf_parsed: ModelFormat = serde_json::from_str(&gguf_json).unwrap();
let safetensors_parsed: ModelFormat = serde_json::from_str(&safetensors_json).unwrap();
assert!(matches!(gguf_parsed, ModelFormat::Gguf));
assert!(matches!(safetensors_parsed, ModelFormat::SafeTensors));
}
#[test]
fn test_discovered_model_serialization() {
let model = DiscoveredModel {
name: "test-model".to_string(),
path: PathBuf::from("/path/to/model.gguf"),
format: ModelFormat::Gguf,
size_bytes: Some(1024),
};
let json = serde_json::to_string(&model).unwrap();
let parsed: DiscoveredModel = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "test-model");
assert_eq!(parsed.path, PathBuf::from("/path/to/model.gguf"));
assert!(matches!(parsed.format, ModelFormat::Gguf));
assert_eq!(parsed.size_bytes, Some(1024));
}
#[test]
fn test_discovered_model_debug_format() {
let model = DiscoveredModel {
name: "test".to_string(),
path: PathBuf::from("/test.gguf"),
format: ModelFormat::Gguf,
size_bytes: Some(512),
};
let debug_str = format!("{:?}", model);
assert!(debug_str.contains("test"));
assert!(debug_str.contains("test.gguf"));
assert!(debug_str.contains("Gguf"));
assert!(debug_str.contains("512"));
}
#[test]
fn test_model_discovery_debug_format() {
let mut discovery = ModelDiscovery::new();
discovery.add_search_path(PathBuf::from("/test"));
let debug_str = format!("{:?}", discovery);
assert!(debug_str.contains("ModelDiscovery"));
assert!(debug_str.contains("/test"));
}
#[test]
fn test_file_stem_edge_cases() {
let discovery = ModelDiscovery::new();
let temp_dir = TempDir::new().unwrap();
let complex_name = temp_dir.path().join("model.v1.0.final.gguf");
fs::write(&complex_name, "content").unwrap();
let model = discovery.analyze_model_file(&complex_name).unwrap();
assert_eq!(model.name, "model.v1.0.final");
let no_stem = PathBuf::from(".gguf");
if let Ok(model) = discovery.analyze_model_file(&no_stem) {
assert_eq!(model.name, "unknown");
}
}
#[test]
fn test_environment_variable_edge_cases() {
env::set_var("SHIMMY_BASE_GGUF", "model.gguf");
let discovery = ModelDiscovery::from_env();
assert!(!discovery.search_paths.is_empty());
env::remove_var("SHIMMY_BASE_GGUF");
}
#[test]
fn test_from_env_no_environment_variables() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_MODEL_PATHS");
env::remove_var("OLLAMA_MODELS");
env::remove_var("HOME");
env::remove_var("USERPROFILE");
let discovery = ModelDiscovery::from_env();
assert!(
discovery.search_paths.len() < 100,
"Unexpected explosion in default search paths: {}",
discovery.search_paths.len()
);
}
#[test]
fn test_multiple_search_paths() -> Result<()> {
let temp_dir1 = TempDir::new()?;
let temp_dir2 = TempDir::new()?;
fs::write(temp_dir1.path().join("model1.gguf"), "content1")?;
fs::write(temp_dir2.path().join("model2.safetensors"), "content2")?;
let mut discovery = ModelDiscovery::new();
discovery.add_search_path(temp_dir1.path().to_path_buf());
discovery.add_search_path(temp_dir2.path().to_path_buf());
let models = discovery.discover_models()?;
assert_eq!(models.len(), 2);
let names: Vec<String> = models.iter().map(|m| m.name.clone()).collect();
assert!(names.contains(&"model1".to_string()));
assert!(names.contains(&"model2".to_string()));
Ok(())
}
#[test]
fn test_is_llm_model_excludes_non_llm_models() {
let discovery = ModelDiscovery::new();
assert!(!discovery.is_llm_model(&PathBuf::from("flux-dev.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("stable-diffusion-xl.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("sdxl-base.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("dalle-mini.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("video-generator.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("animate-diff.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("motion-model.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("whisper-large.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("speech-t5.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("tts-model.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("voice-clone.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("clip-vit-base.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("text-embeddings.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("vision-encoder.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("vae-encoder.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("unet-model.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("controlnet-canny.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("lora-adapter.gguf")));
}
#[test]
fn test_is_llm_model_includes_llm_models() {
let discovery = ModelDiscovery::new();
assert!(discovery.is_llm_model(&PathBuf::from("llama-2-7b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("mistral-7b-instruct.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("qwen-14b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("phi-3-mini.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("gemma-2b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("codellama-34b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("vicuna-13b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("alpaca-7b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("orca-2-7b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("falcon-40b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("mpt-7b-chat.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("gpt4all-falcon.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("chatglm-6b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("baichuan-13b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("yi-34b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("deepseek-coder.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("mixtral-8x7b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("solar-10.7b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("openchat-3.5.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("starling-lm-7b.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("wizardlm-13b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("dolphin-mixtral.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("nous-hermes-2.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("airoboros-34b.gguf")));
}
#[test]
fn test_is_llm_model_safetensors_permissive() {
let discovery = ModelDiscovery::new();
assert!(discovery.is_llm_model(&PathBuf::from("unknown-model.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("custom-transformer.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("experimental-llm.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("stable-diffusion.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("whisper-base.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("clip-model.safetensors")));
}
#[test]
fn test_is_llm_model_gguf_default_inclusion() {
let discovery = ModelDiscovery::new();
assert!(discovery.is_llm_model(&PathBuf::from("unknown-model.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("custom-7b.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("experimental.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("language-model.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("flux-unknown.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("sd-custom.gguf")));
assert!(!discovery.is_llm_model(&PathBuf::from("whisper-custom.gguf")));
}
#[test]
fn test_is_llm_model_case_insensitive() {
let discovery = ModelDiscovery::new();
assert!(discovery.is_llm_model(&PathBuf::from("LLAMA-2-7B.GGUF")));
assert!(discovery.is_llm_model(&PathBuf::from("Mistral-7B.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("PHI-3-mini.SAFETENSORS")));
assert!(!discovery.is_llm_model(&PathBuf::from("FLUX-DEV.GGUF")));
assert!(!discovery.is_llm_model(&PathBuf::from("Stable-Diffusion.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from("WHISPER-LARGE.gguf")));
}
#[test]
fn test_is_llm_model_edge_cases() {
let discovery = ModelDiscovery::new();
assert!(discovery.is_llm_model(&PathBuf::from("llama-mistral-hybrid.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("phi-gemma-merged.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("gpt-4-turbo.gguf"))); assert!(!discovery.is_llm_model(&PathBuf::from("vision-gpt-clip.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("model.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("test.safetensors")));
assert!(discovery.is_llm_model(&PathBuf::from("llama-2-7b-chat-hf-q4_0.gguf")));
assert!(discovery.is_llm_model(&PathBuf::from("mistral-7b-instruct-v0.1-fp16.safetensors")));
assert!(!discovery.is_llm_model(&PathBuf::from(
"stable-diffusion-xl-base-1.0-fp16.safetensors"
)));
}
#[test]
fn test_model_filtering_integration() -> Result<()> {
let temp_dir = TempDir::new()?;
fs::write(temp_dir.path().join("llama-2-7b.gguf"), "llm content")?;
fs::write(
temp_dir.path().join("mistral-instruct.safetensors"),
"llm content",
)?;
fs::write(temp_dir.path().join("flux-dev.gguf"), "image model content")?;
fs::write(
temp_dir.path().join("whisper-large.gguf"),
"audio model content",
)?;
fs::write(
temp_dir.path().join("clip-vit.safetensors"),
"vision model content",
)?;
fs::write(
temp_dir.path().join("unknown-model.gguf"),
"unknown content",
)?;
let mut discovery = ModelDiscovery::new();
discovery.add_search_path(temp_dir.path().to_path_buf());
let models = discovery.discover_models()?;
assert_eq!(models.len(), 3);
let names: Vec<String> = models.iter().map(|m| m.name.clone()).collect();
assert!(names.contains(&"llama-2-7b".to_string()));
assert!(names.contains(&"mistral-instruct".to_string()));
assert!(names.contains(&"unknown-model".to_string()));
assert!(!names.contains(&"flux-dev".to_string()));
assert!(!names.contains(&"whisper-large".to_string()));
assert!(!names.contains(&"clip-vit".to_string()));
Ok(())
}
}