use crate::invariant_ppt::shimmy_invariants;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveredModel {
pub name: String,
pub path: PathBuf,
pub lora_path: Option<PathBuf>,
pub size_bytes: u64,
pub model_type: String,
pub parameter_count: Option<String>,
pub quantization: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OllamaManifest {
#[serde(rename = "schemaVersion")]
#[allow(dead_code)]
schema_version: i32,
#[serde(rename = "mediaType")]
#[allow(dead_code)]
media_type: String,
#[allow(dead_code)]
config: OllamaConfig,
layers: Vec<OllamaLayer>,
}
#[derive(Debug, Deserialize)]
struct OllamaConfig {
#[serde(rename = "mediaType")]
#[allow(dead_code)]
media_type: String,
#[allow(dead_code)]
digest: String,
#[allow(dead_code)]
size: i64,
}
#[derive(Debug, Deserialize)]
struct OllamaLayer {
#[serde(rename = "mediaType")]
media_type: String,
digest: String,
size: i64,
}
pub struct ModelAutoDiscovery {
pub search_paths: Vec<PathBuf>,
}
impl ModelAutoDiscovery {
pub fn new() -> Self {
let mut search_paths = vec![PathBuf::from("./models")];
if let Ok(shimmy_base) = std::env::var("SHIMMY_BASE_GGUF") {
let path = PathBuf::from(shimmy_base);
if let Some(parent) = path.parent() {
search_paths.push(parent.to_path_buf());
}
}
if let Ok(custom_dirs) = std::env::var("SHIMMY_MODEL_PATHS") {
for dir in custom_dirs.split(';').filter(|s| !s.is_empty()) {
search_paths.push(PathBuf::from(dir));
}
}
if let Ok(ollama_models) = std::env::var("OLLAMA_MODELS") {
search_paths.push(PathBuf::from(ollama_models));
}
if let Some(home) = std::env::var_os("HOME") {
search_paths.push(PathBuf::from(home.clone()).join(".cache/huggingface/hub"));
search_paths.push(PathBuf::from(home.clone()).join(".ollama/models"));
search_paths.push(PathBuf::from(home.clone()).join(".lmstudio/models"));
search_paths.push(PathBuf::from(home.clone()).join("models"));
search_paths.push(PathBuf::from(home).join(".local/share/shimmy/models"));
}
if let Some(user_profile) = std::env::var_os("USERPROFILE") {
search_paths.push(PathBuf::from(user_profile.clone()).join(".cache\\huggingface\\hub"));
search_paths.push(PathBuf::from(user_profile.clone()).join(".ollama\\models"));
search_paths.push(PathBuf::from(user_profile.clone()).join(".lmstudio\\models"));
search_paths.push(PathBuf::from(user_profile.clone()).join("models"));
search_paths
.push(PathBuf::from(user_profile.clone()).join("AppData\\Local\\shimmy\\models"));
search_paths.push(PathBuf::from(user_profile).join("Downloads"));
}
#[cfg(windows)]
{
if let Ok(username) = std::env::var("USERNAME") {
for drive in &["C:", "D:", "E:", "F:"] {
let ollama_path = PathBuf::from(format!(
"{}\\Users\\{}\\AppData\\Local\\Ollama\\models",
drive, username
));
search_paths.push(ollama_path);
let alt_ollama = PathBuf::from(format!("{}\\Ollama\\models", drive));
search_paths.push(alt_ollama);
let models_path = PathBuf::from(format!("{}\\models", drive));
search_paths.push(models_path);
}
}
}
Self { search_paths }
}
#[allow(dead_code)]
pub fn add_search_path(&mut self, path: PathBuf) {
self.search_paths.push(path);
}
pub fn discover_models(&self) -> Result<Vec<DiscoveredModel>> {
let mut discovered = Vec::new();
for search_path in &self.search_paths {
if search_path.exists() && search_path.is_dir() {
match self.scan_directory(search_path) {
Ok(models) => discovered.extend(models),
Err(e) => {
eprintln!("Warning: Failed to scan {}: {}", search_path.display(), e);
continue; }
}
}
}
match self.discover_ollama_models() {
Ok(ollama_models) => discovered.extend(ollama_models),
Err(e) => eprintln!("Warning: Failed to discover Ollama models: {}", e),
}
discovered.sort_by(|a, b| a.path.cmp(&b.path));
discovered.dedup_by(|a, b| a.path == b.path);
shimmy_invariants::assert_discovery_valid(discovered.len());
for model in &discovered {
let path_str = if cfg!(target_os = "windows") {
model.path.to_string_lossy().replace('\\', "/")
} else {
model.path.to_string_lossy().to_string()
};
shimmy_invariants::assert_backend_selection_valid(&path_str, &model.model_type);
}
Ok(discovered)
}
fn scan_directory(&self, dir: &Path) -> Result<Vec<DiscoveredModel>> {
self.scan_directory_with_depth(dir, 0)
}
fn scan_directory_with_depth(&self, dir: &Path, depth: usize) -> Result<Vec<DiscoveredModel>> {
if depth >= 4 {
return Ok(Vec::new());
}
if let Some(dir_name) = dir.file_name().and_then(|n| n.to_str()) {
if dir_name.starts_with('.')
&& dir_name != ".cache"
&& dir_name != ".ollama"
&& dir_name != ".local"
{
return Ok(Vec::new());
}
match dir_name {
"Library" | "Applications" | "System" | "Developer" | "usr" | "var" | "tmp"
| "private" | "Volumes" | "cores" | "dev" | "etc" | "home" | "net" | "proc"
| "opt" | "sbin" | "bin" => {
return Ok(Vec::new());
}
_ => {}
}
#[cfg(windows)]
match dir_name.to_lowercase().as_str() {
"windows"
| "program files"
| "program files (x86)"
| "programdata"
| "users"
| "system volume information"
| "$recycle.bin"
| "recovery" => {
return Ok(Vec::new());
}
_ => {}
}
}
let mut models = Vec::new();
let mut model_files = Vec::new();
let entries = match fs::read_dir(dir) {
Ok(entries) => entries,
Err(_) => return Ok(Vec::new()), };
for entry in entries {
let entry = match entry {
Ok(entry) => entry,
Err(_) => continue, };
let path = entry.path();
if path.is_dir() {
let dir_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("")
.to_lowercase();
if dir_name == "target"
|| dir_name == "cmake"
|| dir_name == "incremental"
|| dir_name.starts_with(".git")
|| dir_name.contains("whisper")
|| dir_name.contains("wav2vec")
|| dir_name.contains("bert")
|| dir_name.contains("clip")
{
continue;
}
if path.to_string_lossy().contains("huggingface") {
let path_str = path.to_string_lossy().to_lowercase();
if !(path_str.contains("llama")
|| path_str.contains("phi")
|| path_str.contains("mistral")
|| path_str.contains("qwen")
|| path_str.contains("gemma")
|| path_str.contains("gguf"))
{
continue;
}
}
models.extend(self.scan_directory_with_depth(&path, depth + 1)?);
} else if self.is_model_file(&path) {
model_files.push(path);
}
}
let grouped_models = self.group_sharded_models(dir, &model_files)?;
models.extend(grouped_models);
Ok(models)
}
fn group_sharded_models(
&self,
dir: &Path,
model_files: &[PathBuf],
) -> Result<Vec<DiscoveredModel>> {
use regex::Regex;
use std::collections::HashMap;
let mut grouped_models = Vec::new();
let mut processed_files = std::collections::HashSet::new();
let shard_pattern = Regex::new(r"^(.+)-\d{5}-of-\d{5}(\..+)$").unwrap();
let mut shard_groups: HashMap<String, Vec<PathBuf>> = HashMap::new();
for file_path in model_files {
if let Some(filename) = file_path.file_name().and_then(|f| f.to_str()) {
if let Some(captures) = shard_pattern.captures(filename) {
let base_name = captures.get(1).unwrap().as_str();
let extension = captures.get(2).unwrap().as_str();
let group_key = format!("{}{}", base_name, extension);
shard_groups
.entry(group_key)
.or_default()
.push(file_path.clone());
processed_files.insert(file_path.clone());
}
}
}
for (group_key, files) in shard_groups {
if files.len() > 1 {
let total_size: u64 = files
.iter()
.filter_map(|path| fs::metadata(path).ok().map(|m| m.len()))
.sum();
let model_name = dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(&group_key)
.to_string();
let first_file = &files[0];
let filename = first_file
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown");
let descriptive_path = if files.len() == 1 {
first_file.clone()
} else {
PathBuf::from(format!(
"{} (+{} more files)",
first_file.display(),
files.len() - 1
))
};
let (model_type, parameter_count, quantization) = self.parse_filename(filename);
let backend_type =
if first_file.extension().and_then(|s| s.to_str()) == Some("gguf") {
"Llama".to_string()
} else {
model_type
};
let lora_path = files.iter().find_map(|path| self.find_lora_for_model(path));
grouped_models.push(DiscoveredModel {
name: model_name,
path: descriptive_path,
lora_path,
size_bytes: total_size,
model_type: backend_type,
parameter_count,
quantization,
});
}
}
for file_path in model_files {
if !processed_files.contains(file_path) {
if let Ok(model) = self.analyze_model_file(file_path) {
grouped_models.push(model);
}
}
}
Ok(grouped_models)
}
fn is_model_file(&self, path: &Path) -> bool {
if let Some(extension) = path.extension() {
let ext = extension.to_string_lossy().to_lowercase();
if ext == "gguf" {
return true;
}
if ext == "safetensors" {
let path_str = path.to_string_lossy().to_lowercase();
return !path_str.contains("tokenizer") && !path_str.contains("config");
}
if ext == "bin" {
let path_str = path.to_string_lossy().to_lowercase();
if path_str.contains("target\\")
|| path_str.contains("target/")
|| path_str.contains("cmake")
|| path_str.contains("incremental")
|| path_str.contains("work-products")
|| path_str.contains("dep-graph")
|| path_str.contains("query-cache")
|| path_str.contains("ompver")
|| path_str.contains("whisper")
|| path_str.contains("wav2vec")
|| path_str.contains("pytorch_model")
{
return false;
}
return (path_str.contains("model")
|| path_str.contains("llama")
|| path_str.contains("phi")
|| path_str.contains("mistral")
|| path_str.contains("qwen")
|| path_str.contains("gemma"))
&& !path_str.contains("config")
&& !path_str.contains("tokenizer");
}
}
false
}
fn is_lora_file(&self, path: &Path) -> bool {
if let Some(extension) = path.extension() {
let ext = extension.to_string_lossy().to_lowercase();
if ext == "gguf" || ext == "ggml" {
let filename = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("")
.to_lowercase();
return filename.contains("lora") || filename.contains("adapter");
}
}
false
}
pub fn find_lora_for_model(&self, model_path: &Path) -> Option<PathBuf> {
let model_dir = model_path.parent()?;
let model_stem = model_path.file_stem()?.to_str()?;
if let Ok(entries) = fs::read_dir(model_dir) {
for entry in entries.flatten() {
let path = entry.path();
if self.is_lora_file(&path) {
let lora_stem = path.file_stem()?.to_str()?;
if lora_stem.contains(model_stem) || model_stem.contains(lora_stem) {
return Some(path);
}
}
}
}
None
}
fn analyze_model_file(&self, path: &Path) -> Result<DiscoveredModel> {
let metadata = fs::metadata(path)?;
let filename = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
let (model_type, parameter_count, quantization) = self.parse_filename(&filename);
let backend_type = if path.extension().and_then(|s| s.to_str()) == Some("gguf") {
"Llama".to_string()
} else {
model_type
};
let name = self.generate_model_name(&filename);
let lora_path = self.find_lora_for_model(path);
Ok(DiscoveredModel {
name,
path: path.to_path_buf(),
lora_path,
size_bytes: metadata.len(),
model_type: backend_type,
parameter_count,
quantization,
})
}
fn parse_filename(&self, filename: &str) -> (String, Option<String>, Option<String>) {
let lower = filename.to_lowercase();
let model_type = if lower.contains("llama") {
"Llama"
} else if lower.contains("phi") {
"Phi"
} else if lower.contains("gemma") {
"Gemma"
} else if lower.contains("mistral") {
"Mistral"
} else if lower.contains("qwen") {
"Qwen"
} else {
"Unknown"
}
.to_string();
let parameter_count = if lower.contains("3b") || lower.contains("3.0b") {
Some("3B".to_string())
} else if lower.contains("7b") || lower.contains("7.0b") {
Some("7B".to_string())
} else if lower.contains("13b") || lower.contains("13.0b") {
Some("13B".to_string())
} else if lower.contains("70b") || lower.contains("70.0b") {
Some("70B".to_string())
} else {
None
};
let quantization = if lower.contains("q4_k_m") {
Some("Q4_K_M".to_string())
} else if lower.contains("q4_0") {
Some("Q4_0".to_string())
} else if lower.contains("q8_0") {
Some("Q8_0".to_string())
} else if lower.contains("f16") {
Some("F16".to_string())
} else if lower.contains("f32") {
Some("F32".to_string())
} else {
None
};
(model_type, parameter_count, quantization)
}
fn generate_model_name(&self, filename: &str) -> String {
let name = if let Some(pos) = filename.rfind('.') {
&filename[..pos]
} else {
filename
};
name.replace("_", "-").replace(" ", "-").to_lowercase()
}
fn discover_ollama_models(&self) -> Result<Vec<DiscoveredModel>> {
let mut models = Vec::new();
let mut ollama_dirs = Vec::new();
if let Ok(ollama_models) = std::env::var("OLLAMA_MODELS") {
ollama_dirs.push(PathBuf::from(ollama_models));
}
if let Ok(shimmy_base) = std::env::var("SHIMMY_BASE_GGUF") {
let path = PathBuf::from(shimmy_base);
if let Some(parent) = path.parent() {
ollama_dirs.push(parent.to_path_buf());
if parent.file_name().and_then(|n| n.to_str()) == Some("blobs") {
if let Some(grandparent) = parent.parent() {
ollama_dirs.push(grandparent.to_path_buf());
}
}
}
}
if let Some(home) = std::env::var_os("HOME") {
ollama_dirs.push(PathBuf::from(home).join(".ollama/models"));
}
if let Some(user_profile) = std::env::var_os("USERPROFILE") {
ollama_dirs.push(PathBuf::from(user_profile).join(".ollama").join("models"));
}
for ollama_dir in ollama_dirs {
if !ollama_dir.exists() {
continue;
}
let manifests_dir = ollama_dir.join("manifests");
let blobs_dir = ollama_dir.join("blobs");
if manifests_dir.exists() && blobs_dir.exists() {
models.extend(self.discover_ollama_manifest_models(&manifests_dir, &blobs_dir)?);
}
models.extend(self.discover_ollama_direct_models(&ollama_dir)?);
}
Ok(models)
}
fn discover_ollama_manifest_models(
&self,
manifests_dir: &Path,
blobs_dir: &Path,
) -> Result<Vec<DiscoveredModel>> {
let mut models = Vec::new();
self.scan_manifest_directory(manifests_dir, blobs_dir, &mut models, Vec::new())?;
Ok(models)
}
fn scan_manifest_directory(
&self,
dir: &Path,
blobs_dir: &Path,
models: &mut Vec<DiscoveredModel>,
path_components: Vec<String>,
) -> Result<()> {
for entry in
fs::read_dir(dir).map_err(|_| anyhow::anyhow!("Cannot read directory: {:?}", dir))?
{
let entry = entry?;
let entry_name = entry.file_name().to_string_lossy().to_string();
let mut new_path_components = path_components.clone();
new_path_components.push(entry_name.clone());
if entry.path().is_dir() {
self.scan_manifest_directory(
&entry.path(),
blobs_dir,
models,
new_path_components,
)?;
} else if entry.path().is_file() {
if let Ok(manifest_content) = fs::read_to_string(entry.path()) {
if let Ok(manifest) = serde_json::from_str::<OllamaManifest>(&manifest_content)
{
for layer in &manifest.layers {
if layer.media_type == "application/vnd.ollama.image.model" {
if let Some(hash) = layer.digest.strip_prefix("sha256:") {
let blob_path = blobs_dir.join(format!("sha256-{}", hash));
if blob_path.exists()
&& self.is_gguf_blob(&blob_path).unwrap_or(false)
{
let display_name = if path_components.len() >= 2 {
let mut name_parts = path_components.clone();
name_parts.push(entry_name.clone());
name_parts.join("/")
} else {
format!("{}:{}", path_components.join("/"), entry_name)
};
let discovered = DiscoveredModel {
name: display_name,
path: blob_path,
lora_path: None,
size_bytes: layer.size as u64,
model_type: "Ollama".to_string(),
parameter_count: None,
quantization: None,
};
models.push(discovered);
}
}
}
}
}
}
}
}
Ok(())
}
fn discover_ollama_direct_models(&self, ollama_dir: &Path) -> Result<Vec<DiscoveredModel>> {
let mut models = Vec::new();
let skip_dirs = ["manifests", "blobs"];
if let Ok(entries) = fs::read_dir(ollama_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
let dir_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if skip_dirs.contains(&dir_name) || dir_name.starts_with('.') {
continue;
}
models.extend(self.discover_ollama_direct_models_recursive(&path, 0)?);
} else if self.is_model_file(&path) {
if let Ok(mut model) = self.analyze_model_file(&path) {
model.name = format!("ollama:{}", model.name);
model.model_type = "Ollama".to_string();
models.push(model);
}
}
}
}
Ok(models)
}
fn discover_ollama_direct_models_recursive(
&self,
dir: &Path,
depth: usize,
) -> Result<Vec<DiscoveredModel>> {
let mut models = Vec::new();
if depth >= 5 {
return Ok(models);
}
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
let dir_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if dir_name.starts_with('.') || dir_name == "tmp" || dir_name == "cache" {
continue;
}
models.extend(self.discover_ollama_direct_models_recursive(&path, depth + 1)?);
} else if self.is_model_file(&path) {
if let Ok(mut model) = self.analyze_model_file(&path) {
let relative_path = path
.strip_prefix(dir.ancestors().nth(depth).unwrap_or(dir))
.unwrap_or(&path);
let parent_name = relative_path
.parent()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.unwrap_or("unknown");
model.name = format!("ollama:{}", parent_name);
model.model_type = "Ollama".to_string();
models.push(model);
}
}
}
}
Ok(models)
}
fn is_gguf_blob(&self, path: &Path) -> Result<bool> {
let mut file = std::fs::File::open(path)?;
let mut buffer = [0u8; 4];
use std::io::Read;
file.read_exact(&mut buffer)?;
Ok(&buffer == b"GGUF")
}
}
impl Default for ModelAutoDiscovery {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discovered_model_creation() {
let model = DiscoveredModel {
name: "test".to_string(),
path: PathBuf::from("/test"),
lora_path: None,
size_bytes: 1024,
model_type: "Llama".to_string(),
parameter_count: Some("7B".to_string()),
quantization: Some("Q4_K_M".to_string()),
};
assert_eq!(model.name, "test");
assert_eq!(model.size_bytes, 1024);
}
#[test]
fn test_model_auto_discovery_new() {
let discovery = ModelAutoDiscovery::new();
assert!(!discovery.search_paths.is_empty());
}
#[test]
fn test_filename_parsing() {
let discovery = ModelAutoDiscovery::new();
let (model_type, params, quant) = discovery.parse_filename("llama-7b-q4_k_m.gguf");
assert_eq!(model_type, "Llama");
assert_eq!(params, Some("7B".to_string()));
assert_eq!(quant, Some("Q4_K_M".to_string()));
}
}