use std::collections::HashSet;
use std::path::PathBuf;
pub trait ModelProvider {
fn name(&self) -> &str;
fn is_available(&self) -> bool;
fn installed_models(&self) -> HashSet<String>;
fn start_pull(&self, model_tag: &str) -> Result<PullHandle, String>;
}
pub struct PullHandle {
pub model_tag: String,
pub receiver: std::sync::mpsc::Receiver<PullEvent>,
}
#[derive(Debug, Clone)]
pub enum PullEvent {
Progress {
status: String,
percent: Option<f64>,
},
Done,
Error(String),
}
pub struct OllamaProvider {
base_url: String,
}
fn normalize_ollama_host(raw: &str) -> Option<String> {
let host = raw.trim();
if host.is_empty() {
return None;
}
if host.starts_with("http://") || host.starts_with("https://") {
return Some(host.to_string());
}
if host.contains("://") {
return None;
}
Some(format!("http://{host}"))
}
impl Default for OllamaProvider {
fn default() -> Self {
let base_url = std::env::var("OLLAMA_HOST")
.ok()
.and_then(|raw| {
let normalized = normalize_ollama_host(&raw);
if normalized.is_none() {
eprintln!(
"Warning: could not parse OLLAMA_HOST='{}'. Expected host:port or http(s)://host:port",
raw
);
}
normalized
})
.unwrap_or_else(|| "http://localhost:11434".to_string());
Self { base_url }
}
}
impl OllamaProvider {
pub fn new() -> Self {
Self::default()
}
fn api_url(&self, path: &str) -> String {
format!("{}/api/{}", self.base_url.trim_end_matches('/'), path)
}
pub fn detect_with_installed(&self) -> (bool, HashSet<String>, usize) {
let mut set = HashSet::new();
let Ok(resp) = ureq::get(&self.api_url("tags"))
.config()
.timeout_global(Some(std::time::Duration::from_millis(800)))
.build()
.call()
else {
return (false, set, 0);
};
let Ok(tags): Result<TagsResponse, _> = resp.into_body().read_json() else {
return (true, set, 0);
};
let count = tags.models.len();
for m in tags.models {
let lower = m.name.to_lowercase();
set.insert(lower.clone());
if let Some(family) = lower.split(':').next() {
set.insert(family.to_string());
}
}
(true, set, count)
}
pub fn installed_models_counted(&self) -> (HashSet<String>, usize) {
let mut set = HashSet::new();
let Ok(resp) = ureq::get(&self.api_url("tags"))
.config()
.timeout_global(Some(std::time::Duration::from_secs(5)))
.build()
.call()
else {
return (set, 0);
};
let Ok(tags): Result<TagsResponse, _> = resp.into_body().read_json() else {
return (set, 0);
};
let count = tags.models.len();
for m in tags.models {
let lower = m.name.to_lowercase();
set.insert(lower.clone());
if let Some(family) = lower.split(':').next() {
set.insert(family.to_string());
}
}
(set, count)
}
pub fn has_remote_tag(&self, model_tag: &str) -> bool {
let body = serde_json::json!({ "model": model_tag });
ureq::post(&self.api_url("show"))
.config()
.timeout_global(Some(std::time::Duration::from_millis(1200)))
.build()
.send_json(&body)
.is_ok()
}
}
#[derive(serde::Deserialize)]
struct TagsResponse {
models: Vec<OllamaModel>,
}
#[derive(serde::Deserialize)]
struct OllamaModel {
name: String,
}
#[derive(serde::Deserialize)]
struct PullStreamLine {
#[serde(default)]
status: String,
#[serde(default)]
total: Option<u64>,
#[serde(default)]
completed: Option<u64>,
#[serde(default)]
error: Option<String>,
}
impl ModelProvider for OllamaProvider {
fn name(&self) -> &str {
"Ollama"
}
fn is_available(&self) -> bool {
ureq::get(&self.api_url("tags"))
.config()
.timeout_global(Some(std::time::Duration::from_secs(2)))
.build()
.call()
.is_ok()
}
fn installed_models(&self) -> HashSet<String> {
let (set, _) = self.installed_models_counted();
set
}
fn start_pull(&self, model_tag: &str) -> Result<PullHandle, String> {
let url = self.api_url("pull");
let tag = model_tag.to_string();
let (tx, rx) = std::sync::mpsc::channel();
let body = serde_json::json!({
"model": tag,
"stream": true,
});
std::thread::spawn(move || {
let resp = ureq::post(&url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(3600)))
.build()
.send_json(&body);
match resp {
Ok(resp) => {
let reader = std::io::BufReader::new(resp.into_body().into_reader());
use std::io::BufRead;
for line in reader.lines() {
let Ok(line) = line else { break };
if line.is_empty() {
continue;
}
if let Ok(parsed) = serde_json::from_str::<PullStreamLine>(&line) {
if let Some(ref err) = parsed.error {
let _ = tx.send(PullEvent::Error(err.clone()));
return;
}
let percent = match (parsed.completed, parsed.total) {
(Some(c), Some(t)) if t > 0 => Some(c as f64 / t as f64 * 100.0),
_ => None,
};
let _ = tx.send(PullEvent::Progress {
status: parsed.status.clone(),
percent,
});
if parsed.status == "success" {
let _ = tx.send(PullEvent::Done);
return;
}
}
}
let _ = tx.send(PullEvent::Error(
"Pull ended without success (model may not exist in Ollama registry)"
.to_string(),
));
}
Err(e) => {
let _ = tx.send(PullEvent::Error(format!("{e}")));
}
}
});
Ok(PullHandle {
model_tag: model_tag.to_string(),
receiver: rx,
})
}
}
pub struct MlxProvider {
server_url: String,
}
impl Default for MlxProvider {
fn default() -> Self {
let server_url = std::env::var("MLX_LM_HOST")
.ok()
.and_then(|url| {
if url.starts_with("http://") || url.starts_with("https://") {
Some(url)
} else {
eprintln!(
"Warning: MLX_LM_HOST must start with http:// or https://, ignoring: {}",
url
);
None
}
})
.unwrap_or_else(|| "http://localhost:8080".to_string());
Self { server_url }
}
}
impl MlxProvider {
pub fn new() -> Self {
Self::default()
}
pub fn detect_with_installed(&self) -> (bool, HashSet<String>) {
let mut set = scan_hf_cache_for_mlx();
if !cfg!(target_os = "macos") {
return (false, set);
}
let url = format!("{}/v1/models", self.server_url.trim_end_matches('/'));
if let Ok(resp) = ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_millis(800)))
.build()
.call()
{
if let Ok(json) = resp.into_body().read_json::<serde_json::Value>()
&& let Some(data) = json.get("data").and_then(|d| d.as_array())
{
for model in data {
if let Some(id) = model.get("id").and_then(|i| i.as_str()) {
set.insert(id.to_lowercase());
}
}
}
return (true, set);
}
(check_mlx_python(), set)
}
}
static MLX_PYTHON_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
fn check_mlx_python() -> bool {
*MLX_PYTHON_AVAILABLE.get_or_init(|| {
std::process::Command::new("python3")
.args(["-c", "import mlx_lm"])
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false)
})
}
fn is_likely_mlx_repo(owner: &str, repo: &str) -> bool {
let owner_lower = owner.to_lowercase();
let repo_lower = repo.to_lowercase();
owner_lower == "mlx-community"
|| repo_lower.contains("-mlx-")
|| repo_lower.ends_with("-mlx")
|| repo_lower.contains("mlx-")
|| repo_lower.ends_with("mlx")
}
fn scan_hf_cache_for_mlx() -> HashSet<String> {
let mut set = HashSet::new();
let cache_dir = dirs_hf_cache();
let Ok(entries) = std::fs::read_dir(&cache_dir) else {
return set;
};
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
let Some(rest) = name_str.strip_prefix("models--") else {
continue;
};
let mut parts = rest.splitn(2, "--");
let Some(owner) = parts.next() else {
continue;
};
let Some(repo) = parts.next() else {
continue;
};
if !is_likely_mlx_repo(owner, repo) {
continue;
}
let owner_lower = owner.to_lowercase();
let repo_lower = repo.to_lowercase();
set.insert(format!("{}/{}", owner_lower, repo_lower));
set.insert(repo_lower);
}
set
}
fn dirs_hf_cache() -> std::path::PathBuf {
if let Ok(cache) = std::env::var("HF_HOME") {
std::path::PathBuf::from(cache).join("hub")
} else if let Ok(home) = std::env::var("HOME") {
std::path::PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub")
} else {
std::path::PathBuf::from("/tmp/.cache/huggingface/hub")
}
}
impl ModelProvider for MlxProvider {
fn name(&self) -> &str {
"MLX"
}
fn is_available(&self) -> bool {
if !cfg!(target_os = "macos") {
return false;
}
let url = format!("{}/v1/models", self.server_url.trim_end_matches('/'));
if ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(2)))
.build()
.call()
.is_ok()
{
return true;
}
check_mlx_python()
}
fn installed_models(&self) -> HashSet<String> {
let mut set = scan_hf_cache_for_mlx();
if !cfg!(target_os = "macos") {
return set;
}
let url = format!("{}/v1/models", self.server_url.trim_end_matches('/'));
if let Ok(resp) = ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(2)))
.build()
.call()
&& let Ok(json) = resp.into_body().read_json::<serde_json::Value>()
&& let Some(data) = json.get("data").and_then(|d| d.as_array())
{
for model in data {
if let Some(id) = model.get("id").and_then(|i| i.as_str()) {
set.insert(id.to_lowercase());
}
}
}
set
}
fn start_pull(&self, model_tag: &str) -> Result<PullHandle, String> {
let repo_id = if model_tag.contains('/') {
model_tag.to_string()
} else {
format!("mlx-community/{}", model_tag)
};
let repo_for_thread = repo_id.clone();
let (tx, rx) = std::sync::mpsc::channel();
let hf_bin = find_binary("hf").ok_or_else(|| {
"hf not found in PATH. Install it with: uv tool install 'huggingface_hub[cli]'"
.to_string()
})?;
std::thread::spawn(move || {
let _ = tx.send(PullEvent::Progress {
status: format!("Downloading {}...", repo_for_thread),
percent: None,
});
let result = std::process::Command::new(&hf_bin)
.args(["download", &repo_for_thread])
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.output();
match result {
Ok(output) if output.status.success() => {
let _ = tx.send(PullEvent::Done);
}
Ok(output) => {
let stderr = String::from_utf8_lossy(&output.stderr);
let _ = tx.send(PullEvent::Error(format!(
"hf download failed (exit {}): {}",
output.status.code().unwrap_or(-1),
stderr.trim()
)));
}
Err(e) => {
let _ = tx.send(PullEvent::Error(format!("failed to run hf: {e}")));
}
}
});
Ok(PullHandle {
model_tag: repo_id,
receiver: rx,
})
}
}
pub struct LlamaCppProvider {
models_dir: PathBuf,
llama_cli: Option<String>,
llama_server: Option<String>,
server_running: bool,
}
impl Default for LlamaCppProvider {
fn default() -> Self {
let models_dir = llamacpp_models_dir();
let llama_cli = find_binary("llama-cli");
let llama_server = find_binary("llama-server");
let server_running = if llama_cli.is_none() && llama_server.is_none() {
let port = std::env::var("LLAMA_SERVER_PORT").unwrap_or_else(|_| "8080".to_string());
probe_llama_server(&format!("http://localhost:{}", port))
} else {
false
};
Self {
models_dir,
llama_cli,
llama_server,
server_running,
}
}
}
impl LlamaCppProvider {
pub fn new() -> Self {
Self::default()
}
pub fn installed_models_counted(&self) -> (HashSet<String>, usize) {
let mut set = HashSet::new();
let mut count = 0usize;
for path in self.list_gguf_files() {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
count += 1;
let lower = stem.to_lowercase();
set.insert(lower.clone());
if let Some(base) = strip_gguf_quant_suffix(&lower) {
set.insert(base);
}
}
}
(set, count)
}
pub fn models_dir(&self) -> &std::path::Path {
&self.models_dir
}
pub fn llama_cli_path(&self) -> Option<&str> {
self.llama_cli.as_deref()
}
pub fn llama_server_path(&self) -> Option<&str> {
self.llama_server.as_deref()
}
pub fn server_running(&self) -> bool {
self.server_running
}
pub fn detection_hint(&self) -> &'static str {
if self.llama_cli.is_some() || self.llama_server.is_some() {
""
} else if self.server_running {
"server detected"
} else {
"not in PATH, set LLAMA_CPP_PATH"
}
}
pub fn list_gguf_files(&self) -> Vec<PathBuf> {
let mut files = Vec::new();
if let Ok(entries) = std::fs::read_dir(&self.models_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("gguf") {
files.push(path);
}
}
}
files
}
pub fn search_hf_gguf(query: &str) -> Vec<(String, String)> {
let url = format!(
"https://huggingface.co/api/models?library=gguf&search={}&sort=trending&limit=20",
urlencoding::encode(query)
);
let Ok(resp) = ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(15)))
.build()
.call()
else {
return Vec::new();
};
let Ok(models) = resp.into_body().read_json::<Vec<serde_json::Value>>() else {
return Vec::new();
};
models
.into_iter()
.filter_map(|m| {
let id = m.get("id")?.as_str()?.to_string();
let desc = m
.get("pipeline_tag")
.and_then(|v| v.as_str())
.unwrap_or("model")
.to_string();
Some((id, desc))
})
.collect()
}
pub fn list_repo_gguf_files(repo_id: &str) -> Vec<(String, u64)> {
let url = format!(
"https://huggingface.co/api/models/{}/tree/main?recursive=true",
repo_id
);
let Ok(resp) = ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(15)))
.build()
.call()
else {
return Vec::new();
};
let Ok(entries) = resp.into_body().read_json::<Vec<serde_json::Value>>() else {
return Vec::new();
};
parse_repo_gguf_entries(entries)
}
pub fn select_best_gguf(files: &[(String, u64)], budget_gb: f64) -> Option<(String, u64)> {
let quant_order = [
"Q8_0", "q8_0", "Q6_K", "q6_k", "Q6_K_L", "q6_k_l", "Q5_K_M", "q5_k_m", "Q5_K_S",
"q5_k_s", "Q4_K_M", "q4_k_m", "Q4_K_S", "q4_k_s", "Q4_0", "q4_0", "Q3_K_M", "q3_k_m",
"Q3_K_S", "q3_k_s", "Q2_K", "q2_k", "IQ4_XS", "iq4_xs", "IQ3_M", "iq3_m", "IQ2_M",
"iq2_m", "IQ1_M", "iq1_m",
];
let budget_bytes = (budget_gb * 1024.0 * 1024.0 * 1024.0) as u64;
let candidates = build_gguf_candidates(files);
for quant in &quant_order {
for (filename, size) in &candidates {
if *size > 0 && *size <= budget_bytes && filename.contains(quant) {
return Some((filename.clone(), *size));
}
}
}
let mut fitting: Vec<_> = candidates
.iter()
.filter(|(_, s)| *s > 0 && *s <= budget_bytes)
.collect();
fitting.sort_by_key(|(_, s)| *s);
fitting.last().map(|(f, s)| (f.clone(), *s))
}
pub fn download_gguf(&self, repo_id: &str, filename: &str) -> Result<PullHandle, String> {
validate_gguf_repo_path(filename)?;
let paths: Vec<String> = if parse_shard_info(filename).is_some() {
let listing = Self::list_repo_gguf_files(repo_id);
match collect_shard_set(&listing, filename) {
Some(shards) if !shards.is_empty() => shards.into_iter().map(|(f, _)| f).collect(),
_ => vec![filename.to_string()],
}
} else {
vec![filename.to_string()]
};
self.download_gguf_paths(repo_id, paths)
}
fn download_gguf_paths(&self, repo_id: &str, paths: Vec<String>) -> Result<PullHandle, String> {
if paths.is_empty() {
return Err("download_gguf_paths called with no paths".to_string());
}
let models_dir = self.models_dir.clone();
let mut jobs: Vec<(String, PathBuf)> = Vec::with_capacity(paths.len());
for path in &paths {
validate_gguf_repo_path(path)?;
let local_filename = std::path::Path::new(path)
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| format!("Invalid filename in path: {}", path))?;
validate_gguf_filename(local_filename)?;
let dest_path = models_dir.join(local_filename);
if let (Ok(canonical_dir), Ok(canonical_dest)) = (
std::fs::create_dir_all(&models_dir).and_then(|_| models_dir.canonicalize()),
dest_path
.parent()
.ok_or_else(|| std::io::Error::other("no parent"))
.and_then(|p| {
std::fs::create_dir_all(p)?;
p.canonicalize()
}),
) && !canonical_dest.starts_with(&canonical_dir)
{
return Err(format!(
"Security: download path escapes cache directory: {}",
dest_path.display()
));
}
let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, path);
jobs.push((url, dest_path));
}
let tag = format!("{}/{}", repo_id, paths[0]);
let total_parts = jobs.len();
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
for (idx, (url, dest_path)) in jobs.into_iter().enumerate() {
let part_num = idx + 1;
let part_label = if total_parts > 1 {
format!("[{}/{}] ", part_num, total_parts)
} else {
String::new()
};
let display_name = dest_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("")
.to_string();
let _ = tx.send(PullEvent::Progress {
status: format!("{}Connecting to {}...", part_label, display_name),
percent: Some(0.0),
});
let resp = ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(7200)))
.build()
.call();
let resp = match resp {
Ok(r) => r,
Err(e) => {
let _ = tx.send(PullEvent::Error(format!(
"{}Download failed: {}",
part_label, e
)));
return;
}
};
let total_size = resp
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let _ = tx.send(PullEvent::Progress {
status: format!(
"{}Downloading {} ({:.1} GB)...",
part_label,
display_name,
total_size as f64 / 1_073_741_824.0
),
percent: Some(0.0),
});
let tmp_path = dest_path.with_extension("gguf.part");
let file = match std::fs::File::create(&tmp_path) {
Ok(f) => f,
Err(e) => {
let _ = tx.send(PullEvent::Error(format!("Failed to create file: {}", e)));
return;
}
};
let mut writer = std::io::BufWriter::new(file);
let mut reader = resp.into_body().into_reader();
let mut downloaded: u64 = 0;
let mut buf = [0u8; 128 * 1024]; let mut last_report = std::time::Instant::now();
loop {
match std::io::Read::read(&mut reader, &mut buf) {
Ok(0) => break, Ok(n) => {
if let Err(e) = std::io::Write::write_all(&mut writer, &buf[..n]) {
let _ = tx.send(PullEvent::Error(format!("Write error: {}", e)));
let _ = std::fs::remove_file(&tmp_path);
return;
}
downloaded += n as u64;
if last_report.elapsed() >= std::time::Duration::from_millis(200) {
let pct = if total_size > 0 {
downloaded as f64 / total_size as f64 * 100.0
} else {
0.0
};
let dl_gb = downloaded as f64 / 1_073_741_824.0;
let total_gb = total_size as f64 / 1_073_741_824.0;
let _ = tx.send(PullEvent::Progress {
status: format!(
"{}Downloading {:.1}/{:.1} GB",
part_label, dl_gb, total_gb
),
percent: Some(pct),
});
last_report = std::time::Instant::now();
}
}
Err(e) => {
let _ = tx.send(PullEvent::Error(format!("Download error: {}", e)));
let _ = std::fs::remove_file(&tmp_path);
return;
}
}
}
if let Err(e) = std::io::Write::flush(&mut writer) {
let _ = tx.send(PullEvent::Error(format!("Flush error: {}", e)));
let _ = std::fs::remove_file(&tmp_path);
return;
}
drop(writer);
if total_size > 0 && downloaded < total_size {
let _ = std::fs::remove_file(&tmp_path);
let _ = tx.send(PullEvent::Error(format!(
"{}Truncated download: got {} bytes, expected {}",
part_label, downloaded, total_size
)));
return;
}
if let Err(e) = std::fs::rename(&tmp_path, &dest_path) {
let _ = tx.send(PullEvent::Error(format!(
"Failed to finalize download: {}",
e
)));
let _ = std::fs::remove_file(&tmp_path);
return;
}
let _ = tx.send(PullEvent::Progress {
status: format!("{}Saved {}", part_label, display_name),
percent: Some(100.0),
});
}
let _ = tx.send(PullEvent::Progress {
status: "Download complete!".to_string(),
percent: Some(100.0),
});
let _ = tx.send(PullEvent::Done);
});
Ok(PullHandle {
model_tag: tag,
receiver: rx,
})
}
}
fn validate_gguf_filename(filename: &str) -> Result<(), String> {
if filename.is_empty() {
return Err("GGUF filename must not be empty".to_string());
}
if filename.contains('/') || filename.contains('\\') {
return Err(format!(
"Security: path separators not allowed in GGUF filename: {}",
filename
));
}
let path = std::path::Path::new(filename);
if path.is_absolute() {
return Err(format!(
"Security: absolute paths not allowed in GGUF filename: {}",
filename
));
}
if !filename.ends_with(".gguf") {
return Err(format!(
"GGUF filename must end in .gguf, got: {}",
filename
));
}
if path.file_name().and_then(|n| n.to_str()) != Some(filename) {
return Err(format!(
"Security: GGUF filename must be a basename without path components: {}",
filename
));
}
Ok(())
}
fn parse_shard_info(filename: &str) -> Option<(u32, u32)> {
let stem = filename.strip_suffix(".gguf")?;
let of_pos = stem.rfind("-of-")?;
let total_str = &stem[of_pos + 4..];
if total_str.is_empty() || !total_str.chars().all(|c| c.is_ascii_digit()) {
return None;
}
let total: u32 = total_str.parse().ok()?;
let before = &stem[..of_pos];
let dash_pos = before.rfind('-')?;
let index_str = &before[dash_pos + 1..];
if index_str.is_empty() || !index_str.chars().all(|c| c.is_ascii_digit()) {
return None;
}
let index: u32 = index_str.parse().ok()?;
if index == 0 || index > total {
return None;
}
Some((index, total))
}
pub fn collect_shard_set(files: &[(String, u64)], path: &str) -> Option<Vec<(String, u64)>> {
let (_, total) = parse_shard_info(path)?;
let stem = path.strip_suffix(".gguf")?;
let of_pos = stem.rfind("-of-")?;
let before = &stem[..of_pos];
let dash_pos = before.rfind('-')?;
let prefix = &path[..=dash_pos];
let suffix = &path[of_pos..];
let mut matches: Vec<(u32, String, u64)> = files
.iter()
.filter_map(|(f, s)| {
if !f.starts_with(prefix) || !f.ends_with(suffix) {
return None;
}
let (idx, t) = parse_shard_info(f)?;
if t != total {
return None;
}
Some((idx, f.clone(), *s))
})
.collect();
matches.sort_by_key(|(i, _, _)| *i);
if matches.is_empty() {
return None;
}
Some(matches.into_iter().map(|(_, f, s)| (f, s)).collect())
}
fn build_gguf_candidates(files: &[(String, u64)]) -> Vec<(String, u64)> {
let mut seen_groups: HashSet<String> = HashSet::new();
let mut out: Vec<(String, u64)> = Vec::new();
for (f, s) in files {
if parse_shard_info(f).is_some() {
let Some(stem) = f.strip_suffix(".gguf") else {
continue;
};
let Some(of_pos) = stem.rfind("-of-") else {
continue;
};
let before = &stem[..of_pos];
let Some(dash_pos) = before.rfind('-') else {
continue;
};
let key = format!("{}|{}", &f[..=dash_pos], &f[of_pos..]);
if !seen_groups.insert(key) {
continue;
}
if let Some(shards) = collect_shard_set(files, f) {
let total: u64 = shards.iter().map(|(_, sz)| *sz).sum();
let rep = shards[0].0.clone();
out.push((rep, total));
}
} else {
out.push((f.clone(), *s));
}
}
out
}
fn validate_gguf_repo_path(path: &str) -> Result<(), String> {
if path.is_empty() {
return Err("GGUF path must not be empty".to_string());
}
for component in path.split('/') {
if component == ".." || component == "." {
return Err(format!(
"Security: path traversal not allowed in GGUF path: {}",
path
));
}
}
if path.contains('\\') {
return Err(format!(
"Security: backslash not allowed in GGUF path: {}",
path
));
}
if path.starts_with('/') {
return Err(format!(
"Security: absolute paths not allowed in GGUF path: {}",
path
));
}
if !path.ends_with(".gguf") {
return Err(format!("GGUF path must end in .gguf, got: {}", path));
}
Ok(())
}
fn parse_repo_gguf_entries(entries: Vec<serde_json::Value>) -> Vec<(String, u64)> {
entries
.into_iter()
.filter_map(|e| {
let path = e.get("path")?.as_str()?.to_string();
if validate_gguf_repo_path(&path).is_err() {
return None;
}
let size = e.get("size").and_then(|v| v.as_u64()).unwrap_or(0);
Some((path, size))
})
.collect()
}
fn llamacpp_models_dir() -> PathBuf {
if let Ok(dir) = std::env::var("LLMFIT_MODELS_DIR") {
PathBuf::from(dir)
} else if let Ok(home) = std::env::var("HOME") {
PathBuf::from(home)
.join(".cache")
.join("llmfit")
.join("models")
} else {
PathBuf::from("/tmp/.cache/llmfit/models")
}
}
fn find_binary(name: &str) -> Option<String> {
if let Ok(dir) = std::env::var("LLAMA_CPP_PATH") {
let candidate = PathBuf::from(&dir).join(name);
if candidate.is_file() {
return Some(candidate.to_string_lossy().to_string());
}
}
let mut common_dirs: Vec<PathBuf> = vec![
PathBuf::from("/usr/local/bin"),
PathBuf::from("/opt/llama.cpp/build/bin"),
];
if let Ok(home) = std::env::var("HOME") {
common_dirs.push(PathBuf::from(home).join(".local").join("bin"));
}
for dir in common_dirs {
let candidate = dir.join(name);
if candidate.is_file() {
return Some(candidate.to_string_lossy().to_string());
}
}
which::which(name)
.ok()
.map(|p| p.to_string_lossy().to_string())
}
fn probe_llama_server(base_url: &str) -> bool {
let url = format!("{}/health", base_url.trim_end_matches('/'));
std::process::Command::new("curl")
.args(["-sf", "--max-time", "2", &url])
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false)
}
mod urlencoding {
pub fn encode(s: &str) -> String {
let mut result = String::with_capacity(s.len() * 3);
for byte in s.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
result.push(byte as char);
}
_ => {
result.push('%');
result.push_str(&format!("{:02X}", byte));
}
}
}
result
}
}
impl ModelProvider for LlamaCppProvider {
fn name(&self) -> &str {
"llama.cpp"
}
fn is_available(&self) -> bool {
self.llama_cli.is_some() || self.llama_server.is_some() || self.server_running
}
fn installed_models(&self) -> HashSet<String> {
let (set, _) = self.installed_models_counted();
set
}
fn start_pull(&self, model_tag: &str) -> Result<PullHandle, String> {
if model_tag.matches('/').count() >= 2 && model_tag.ends_with(".gguf") {
let parts: Vec<&str> = model_tag.splitn(3, '/').collect();
if parts.len() == 3 {
let repo = format!("{}/{}", parts[0], parts[1]);
let filename = parts[2];
return self.download_gguf(&repo, filename);
}
}
if model_tag.contains('/') {
let files = Self::list_repo_gguf_files(model_tag);
if files.is_empty() {
return Err(format!("No GGUF files found in repository '{}'", model_tag));
}
if let Some((filename, _)) = Self::select_best_gguf(&files, 999.0) {
return self.download_gguf(model_tag, &filename);
}
let (filename, _) = &files[0];
return self.download_gguf(model_tag, filename);
}
let results = Self::search_hf_gguf(model_tag);
if results.is_empty() {
return Err(format!(
"No GGUF models found on HuggingFace for '{}'",
model_tag
));
}
let (repo_id, _) = &results[0];
let files = Self::list_repo_gguf_files(repo_id);
if files.is_empty() {
return Err(format!("No GGUF files found in repository '{}'", repo_id));
}
if let Some((filename, _)) = Self::select_best_gguf(&files, 999.0) {
return self.download_gguf(repo_id, &filename);
}
let (filename, _) = &files[0];
self.download_gguf(repo_id, filename)
}
}
pub struct DockerModelRunnerProvider {
base_url: String,
}
fn normalize_docker_mr_host(raw: &str) -> Option<String> {
let host = raw.trim();
if host.is_empty() {
return None;
}
if host.starts_with("http://") || host.starts_with("https://") {
return Some(host.to_string());
}
if host.contains("://") {
return None;
}
Some(format!("http://{host}"))
}
impl Default for DockerModelRunnerProvider {
fn default() -> Self {
let base_url = std::env::var("DOCKER_MODEL_RUNNER_HOST")
.ok()
.and_then(|raw| {
let normalized = normalize_docker_mr_host(&raw);
if normalized.is_none() {
eprintln!(
"Warning: could not parse DOCKER_MODEL_RUNNER_HOST='{}'. \
Expected host:port or http(s)://host:port",
raw
);
}
normalized
})
.unwrap_or_else(|| "http://localhost:12434".to_string());
Self { base_url }
}
}
impl DockerModelRunnerProvider {
pub fn new() -> Self {
Self::default()
}
fn models_url(&self) -> String {
format!("{}/v1/models", self.base_url.trim_end_matches('/'))
}
pub fn detect_with_installed(&self) -> (bool, HashSet<String>, usize) {
let mut set = HashSet::new();
let Ok(resp) = ureq::get(&self.models_url())
.config()
.timeout_global(Some(std::time::Duration::from_millis(800)))
.build()
.call()
else {
return (false, set, 0);
};
let Ok(list) = resp.into_body().read_json::<DockerModelList>() else {
return (true, set, 0);
};
let engines = list.data;
let count = engines.len();
for e in engines {
let lower = e.id.to_lowercase();
set.insert(lower.clone());
if let Some(name) = lower.split('/').next_back()
&& name != lower
{
set.insert(name.to_string());
}
if let Some(base) = lower.split(':').next() {
set.insert(base.to_string());
}
}
(true, set, count)
}
pub fn installed_models_counted(&self) -> (HashSet<String>, usize) {
let (_, set, count) = self.detect_with_installed();
(set, count)
}
}
#[derive(serde::Deserialize)]
struct DockerModelList {
data: Vec<DockerEngine>,
}
#[derive(serde::Deserialize)]
struct DockerEngine {
id: String,
}
impl ModelProvider for DockerModelRunnerProvider {
fn name(&self) -> &str {
"Docker Model Runner"
}
fn is_available(&self) -> bool {
ureq::get(&self.models_url())
.config()
.timeout_global(Some(std::time::Duration::from_secs(2)))
.build()
.call()
.is_ok()
}
fn installed_models(&self) -> HashSet<String> {
let (set, _) = self.installed_models_counted();
set
}
fn start_pull(&self, model_tag: &str) -> Result<PullHandle, String> {
let tag = model_tag.to_string();
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let _ = tx.send(PullEvent::Progress {
status: format!("Pulling {} via docker model pull...", tag),
percent: None,
});
let result = std::process::Command::new("docker")
.args(["model", "pull", &tag])
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.output();
match result {
Ok(output) if output.status.success() => {
let _ = tx.send(PullEvent::Done);
}
Ok(output) => {
let stderr = String::from_utf8_lossy(&output.stderr);
let _ = tx.send(PullEvent::Error(format!(
"docker model pull failed: {}",
stderr.trim()
)));
}
Err(e) => {
let _ = tx.send(PullEvent::Error(format!("Failed to run docker: {e}")));
}
}
});
Ok(PullHandle {
model_tag: model_tag.to_string(),
receiver: rx,
})
}
}
pub struct LmStudioProvider {
base_url: String,
}
fn normalize_lmstudio_host(raw: &str) -> Option<String> {
let host = raw.trim();
if host.is_empty() {
return None;
}
if host.starts_with("http://") || host.starts_with("https://") {
return Some(host.to_string());
}
if host.contains("://") {
return None;
}
Some(format!("http://{host}"))
}
impl Default for LmStudioProvider {
fn default() -> Self {
let base_url = std::env::var("LMSTUDIO_HOST")
.ok()
.and_then(|raw| {
let normalized = normalize_lmstudio_host(&raw);
if normalized.is_none() {
eprintln!(
"Warning: could not parse LMSTUDIO_HOST='{}'. \
Expected host:port or http(s)://host:port",
raw
);
}
normalized
})
.unwrap_or_else(|| "http://127.0.0.1:1234".to_string());
Self { base_url }
}
}
impl LmStudioProvider {
pub fn new() -> Self {
Self::default()
}
fn models_url(&self) -> String {
format!("{}/v1/models", self.base_url.trim_end_matches('/'))
}
fn download_url(&self) -> String {
format!(
"{}/api/v1/models/download",
self.base_url.trim_end_matches('/')
)
}
fn download_status_url(&self) -> String {
format!(
"{}/api/v1/models/download-status",
self.base_url.trim_end_matches('/')
)
}
pub fn detect_with_installed(&self) -> (bool, HashSet<String>, usize) {
let mut set = HashSet::new();
let Ok(resp) = ureq::get(&self.models_url())
.config()
.timeout_global(Some(std::time::Duration::from_millis(800)))
.build()
.call()
else {
return (false, set, 0);
};
let Ok(list) = resp.into_body().read_json::<LmStudioModelList>() else {
return (true, set, 0);
};
let models = list.data;
let count = models.len();
for m in models {
let lower = m.id.to_lowercase();
set.insert(lower.clone());
if let Some(name) = lower.split('/').next_back()
&& name != lower
{
set.insert(name.to_string());
}
}
(true, set, count)
}
pub fn installed_models_counted(&self) -> (HashSet<String>, usize) {
let (_, set, count) = self.detect_with_installed();
(set, count)
}
}
#[derive(serde::Deserialize)]
struct LmStudioModelList {
data: Vec<LmStudioModel>,
}
#[derive(serde::Deserialize)]
struct LmStudioModel {
id: String,
}
#[derive(serde::Deserialize)]
struct LmStudioDownloadResponse {
#[serde(default)]
#[allow(dead_code)]
job_id: Option<String>,
#[serde(default)]
status: String,
#[serde(default)]
#[allow(dead_code)]
total_size_bytes: Option<u64>,
}
#[derive(serde::Deserialize)]
struct LmStudioDownloadStatus {
#[serde(default)]
status: String,
#[serde(default)]
progress: Option<f64>,
#[serde(default)]
downloaded_bytes: Option<u64>,
#[serde(default)]
total_size_bytes: Option<u64>,
}
impl ModelProvider for LmStudioProvider {
fn name(&self) -> &str {
"LM Studio"
}
fn is_available(&self) -> bool {
ureq::get(&self.models_url())
.config()
.timeout_global(Some(std::time::Duration::from_secs(2)))
.build()
.call()
.is_ok()
}
fn installed_models(&self) -> HashSet<String> {
let (set, _) = self.installed_models_counted();
set
}
fn start_pull(&self, model_tag: &str) -> Result<PullHandle, String> {
let download_url = self.download_url();
let status_url = self.download_status_url();
let tag = model_tag.to_string();
let (tx, rx) = std::sync::mpsc::channel();
let body = serde_json::json!({
"model": tag,
});
std::thread::spawn(move || {
let resp = ureq::post(&download_url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(30)))
.build()
.send_json(&body);
match resp {
Ok(resp) => {
let Ok(dl_resp) = resp.into_body().read_json::<LmStudioDownloadResponse>()
else {
let _ = tx.send(PullEvent::Error(
"Failed to parse LM Studio download response".to_string(),
));
return;
};
if dl_resp.status == "already_downloaded" {
let _ = tx.send(PullEvent::Progress {
status: "Already downloaded".to_string(),
percent: Some(100.0),
});
let _ = tx.send(PullEvent::Done);
return;
}
if dl_resp.status == "failed" {
let _ = tx.send(PullEvent::Error("LM Studio download failed".to_string()));
return;
}
let _ = tx.send(PullEvent::Progress {
status: format!("Downloading via LM Studio ({})", dl_resp.status),
percent: Some(0.0),
});
loop {
std::thread::sleep(std::time::Duration::from_millis(500));
let poll = ureq::get(&status_url)
.config()
.timeout_global(Some(std::time::Duration::from_secs(10)))
.build()
.call();
match poll {
Ok(resp) => {
let body_str = match resp.into_body().read_to_string() {
Ok(s) => s,
Err(_) => continue,
};
let status_opt: Option<LmStudioDownloadStatus> =
if let Ok(statuses) =
serde_json::from_str::<Vec<LmStudioDownloadStatus>>(
&body_str,
)
{
statuses.into_iter().find(|s| {
s.status == "downloading"
|| s.status == "completed"
|| s.status == "failed"
})
} else {
serde_json::from_str(&body_str).ok()
};
let Some(st) = status_opt else {
continue;
};
let percent = st.progress.map(|p| p * 100.0).or_else(|| {
match (st.downloaded_bytes, st.total_size_bytes) {
(Some(dl), Some(total)) if total > 0 => {
Some(dl as f64 / total as f64 * 100.0)
}
_ => None,
}
});
if st.status == "completed" {
let _ = tx.send(PullEvent::Progress {
status: "Download complete".to_string(),
percent: Some(100.0),
});
let _ = tx.send(PullEvent::Done);
return;
}
if st.status == "failed" {
let _ = tx.send(PullEvent::Error(
"LM Studio download failed".to_string(),
));
return;
}
let _ = tx.send(PullEvent::Progress {
status: "Downloading via LM Studio...".to_string(),
percent,
});
}
Err(_) => {
continue;
}
}
}
}
Err(e) => {
let _ = tx.send(PullEvent::Error(format!("LM Studio download error: {e}")));
}
}
});
Ok(PullHandle {
model_tag: model_tag.to_string(),
receiver: rx,
})
}
}
pub fn hf_name_to_lmstudio_candidates(hf_name: &str) -> Vec<String> {
let repo = hf_name
.split('/')
.next_back()
.unwrap_or(hf_name)
.to_lowercase();
let mut candidates = vec![hf_name.to_lowercase()];
if repo != hf_name.to_lowercase() {
candidates.push(repo.clone());
}
let stripped = repo
.replace("-instruct", "")
.replace("-chat", "")
.replace("-hf", "")
.replace("-it", "");
if stripped != repo {
candidates.push(stripped);
}
candidates
}
pub fn is_model_installed_lmstudio(hf_name: &str, installed: &HashSet<String>) -> bool {
let candidates = hf_name_to_lmstudio_candidates(hf_name);
candidates.iter().any(|candidate| {
installed
.iter()
.any(|installed_name| installed_name.contains(candidate))
})
}
pub fn has_lmstudio_mapping(hf_name: &str) -> bool {
!hf_name.is_empty()
}
pub fn lmstudio_pull_tag(hf_name: &str) -> Option<String> {
if hf_name.is_empty() {
return None;
}
Some(hf_name.to_string())
}
const DOCKER_MODELS_JSON: &str = include_str!("../data/docker_models.json");
#[derive(serde::Deserialize)]
struct DockerModelCatalog {
models: Vec<DockerModelEntry>,
}
#[derive(serde::Deserialize)]
struct DockerModelEntry {
hf_name: String,
docker_tag: String,
}
fn docker_mr_catalog() -> &'static [(String, String)] {
use std::sync::OnceLock;
static CATALOG: OnceLock<Vec<(String, String)>> = OnceLock::new();
CATALOG.get_or_init(|| {
let Ok(catalog) = serde_json::from_str::<DockerModelCatalog>(DOCKER_MODELS_JSON) else {
return Vec::new();
};
catalog
.models
.into_iter()
.map(|e| (e.hf_name.to_lowercase(), e.docker_tag))
.collect()
})
}
pub fn has_docker_mr_mapping(hf_name: &str) -> bool {
docker_mr_pull_tag(hf_name).is_some()
}
pub fn docker_mr_pull_tag(hf_name: &str) -> Option<String> {
let lower = hf_name.to_lowercase();
docker_mr_catalog()
.iter()
.find(|(name, _)| *name == lower)
.map(|(_, tag)| tag.clone())
}
pub fn hf_name_to_docker_mr_candidates(hf_name: &str) -> Vec<String> {
let Some(tag) = docker_mr_pull_tag(hf_name) else {
return Vec::new();
};
let mut candidates = vec![tag.clone()];
if let Some(stripped) = tag.strip_prefix("ai/") {
candidates.push(stripped.to_string());
}
if let Some(base) = tag.split(':').next() {
candidates.push(base.to_string());
}
candidates
}
pub fn is_model_installed_docker_mr(hf_name: &str, installed: &HashSet<String>) -> bool {
let candidates = hf_name_to_docker_mr_candidates(hf_name);
candidates.iter().any(|candidate| {
installed
.iter()
.any(|installed_name| docker_mr_installed_matches(installed_name, candidate))
})
}
fn docker_mr_installed_matches(installed_name: &str, candidate: &str) -> bool {
if installed_name == candidate {
return true;
}
if candidate.contains(':') {
return installed_name.starts_with(&format!("{candidate}-"));
}
false
}
fn strip_gguf_quant_suffix(stem: &str) -> Option<String> {
let quant_patterns = [
"-q8_0", "-q6_k", "-q6_k_l", "-q5_k_m", "-q5_k_s", "-q4_k_m", "-q4_k_s", "-q4_0",
"-q3_k_m", "-q3_k_s", "-q2_k", "-iq4_xs", "-iq3_m", "-iq2_m", "-iq1_m", "-f16", "-f32",
"-bf16", ".q8_0", ".q6_k", ".q5_k_m", ".q4_k_m", ".q4_0", ".q3_k_m", ".q2_k",
];
for pat in &quant_patterns {
if let Some(pos) = stem.rfind(pat) {
return Some(stem[..pos].to_string());
}
}
None
}
const LLAMACPP_GGUF_MAPPINGS: &[(&str, &str)] = &[
(
"llama-3.3-70b-instruct",
"bartowski/Llama-3.3-70B-Instruct-GGUF",
),
(
"llama-3.2-3b-instruct",
"bartowski/Llama-3.2-3B-Instruct-GGUF",
),
(
"llama-3.2-1b-instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
),
(
"llama-3.1-8b-instruct",
"bartowski/Llama-3.1-8B-Instruct-GGUF",
),
(
"llama-3.1-70b-instruct",
"bartowski/Llama-3.1-70B-Instruct-GGUF",
),
(
"llama-3.1-405b-instruct",
"bartowski/Meta-Llama-3.1-405B-Instruct-GGUF",
),
(
"meta-llama-3-8b-instruct",
"bartowski/Meta-Llama-3-8B-Instruct-GGUF",
),
(
"qwen2.5-72b-instruct",
"bartowski/Qwen2.5-72B-Instruct-GGUF",
),
(
"qwen2.5-32b-instruct",
"bartowski/Qwen2.5-32B-Instruct-GGUF",
),
(
"qwen2.5-14b-instruct",
"bartowski/Qwen2.5-14B-Instruct-GGUF",
),
("qwen2.5-7b-instruct", "bartowski/Qwen2.5-7B-Instruct-GGUF"),
("qwen2.5-3b-instruct", "bartowski/Qwen2.5-3B-Instruct-GGUF"),
(
"qwen2.5-1.5b-instruct",
"bartowski/Qwen2.5-1.5B-Instruct-GGUF",
),
(
"qwen2.5-0.5b-instruct",
"bartowski/Qwen2.5-0.5B-Instruct-GGUF",
),
(
"qwen2.5-coder-32b-instruct",
"bartowski/Qwen2.5-Coder-32B-Instruct-GGUF",
),
(
"qwen2.5-coder-14b-instruct",
"bartowski/Qwen2.5-Coder-14B-Instruct-GGUF",
),
(
"qwen2.5-coder-7b-instruct",
"bartowski/Qwen2.5-Coder-7B-Instruct-GGUF",
),
("qwen3-32b", "bartowski/Qwen3-32B-GGUF"),
("qwen3-14b", "bartowski/Qwen3-14B-GGUF"),
("qwen3-8b", "bartowski/Qwen3-8B-GGUF"),
("qwen3-4b", "bartowski/Qwen3-4B-GGUF"),
("qwen3-0.6b", "bartowski/Qwen3-0.6B-GGUF"),
(
"mistral-7b-instruct-v0.3",
"bartowski/Mistral-7B-Instruct-v0.3-GGUF",
),
(
"mistral-small-24b-instruct-2501",
"bartowski/Mistral-Small-24B-Instruct-2501-GGUF",
),
(
"mixtral-8x7b-instruct-v0.1",
"bartowski/Mixtral-8x7B-Instruct-v0.1-GGUF",
),
("gemma-3-12b-it", "bartowski/gemma-3-12b-it-GGUF"),
("gemma-2-27b-it", "bartowski/gemma-2-27b-it-GGUF"),
("gemma-2-9b-it", "bartowski/gemma-2-9b-it-GGUF"),
("gemma-2-2b-it", "bartowski/gemma-2-2b-it-GGUF"),
("phi-4", "bartowski/phi-4-GGUF"),
("phi-4-mini-instruct", "bartowski/phi-4-mini-instruct-GGUF"),
(
"phi-3.5-mini-instruct",
"bartowski/Phi-3.5-mini-instruct-GGUF",
),
(
"phi-3-mini-4k-instruct",
"bartowski/Phi-3-mini-4k-instruct-GGUF",
),
("deepseek-r1", "bartowski/DeepSeek-R1-GGUF"),
(
"deepseek-r1-distill-qwen-32b",
"bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF",
),
(
"deepseek-r1-distill-qwen-14b",
"bartowski/DeepSeek-R1-Distill-Qwen-14B-GGUF",
),
(
"deepseek-r1-distill-qwen-7b",
"bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF",
),
("deepseek-v3", "bartowski/DeepSeek-V3-GGUF"),
(
"tinyllama-1.1b-chat-v1.0",
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
),
("falcon-7b-instruct", "TheBloke/falcon-7b-instruct-GGUF"),
(
"smollm2-135m-instruct",
"bartowski/SmolLM2-135M-Instruct-GGUF",
),
];
fn lookup_gguf_repo(hf_name: &str) -> Option<&'static str> {
let repo = hf_name
.split('/')
.next_back()
.unwrap_or(hf_name)
.to_lowercase();
LLAMACPP_GGUF_MAPPINGS
.iter()
.find(|&&(hf_suffix, _)| repo == hf_suffix)
.map(|&(_, gguf_repo)| gguf_repo)
}
pub fn hf_name_to_gguf_candidates(hf_name: &str) -> Vec<String> {
if let Some(repo) = lookup_gguf_repo(hf_name) {
return vec![repo.to_string()];
}
let base = hf_name.split('/').next_back().unwrap_or(hf_name);
vec![
format!("bartowski/{}-GGUF", base),
format!("ggml-org/{}-GGUF", base),
format!("TheBloke/{}-GGUF", base),
]
}
pub fn has_gguf_mapping(hf_name: &str) -> bool {
lookup_gguf_repo(hf_name).is_some()
}
pub fn is_model_installed_llamacpp(hf_name: &str, installed: &HashSet<String>) -> bool {
let repo = hf_name
.split('/')
.next_back()
.unwrap_or(hf_name)
.to_lowercase();
if installed.contains(&repo) {
return true;
}
let stripped = repo
.replace("-instruct", "")
.replace("-chat", "")
.replace("-hf", "")
.replace("-it", "");
installed.iter().any(|name| {
name.contains(&repo) || name.contains(&stripped) || repo.contains(name.as_str())
})
}
pub fn gguf_pull_tag(hf_name: &str) -> Option<String> {
lookup_gguf_repo(hf_name).map(|s| s.to_string())
}
pub fn hf_repo_exists(repo_id: &str) -> bool {
let url = format!("https://huggingface.co/api/models/{}", repo_id);
ureq::get(&url)
.config()
.timeout_global(Some(std::time::Duration::from_millis(1200)))
.build()
.call()
.is_ok()
}
pub fn first_existing_gguf_repo(hf_name: &str) -> Option<String> {
if let Some(repo) = gguf_pull_tag(hf_name)
&& hf_repo_exists(&repo)
{
return Some(repo);
}
let candidates = hf_name_to_gguf_candidates(hf_name);
candidates.into_iter().find(|repo| hf_repo_exists(repo))
}
fn push_unique_candidate(candidates: &mut Vec<String>, candidate: String) {
if !candidate.is_empty() && !candidates.iter().any(|c| c == &candidate) {
candidates.push(candidate);
}
}
fn strip_trailing_quant_suffix(name: &str) -> String {
for suffix in ["-4bit", "-6bit", "-8bit"] {
if let Some(stripped) = name.strip_suffix(suffix) {
return stripped.to_string();
}
}
name.to_string()
}
fn normalize_mlx_repo_base(repo_lower: &str) -> String {
let without_quant = strip_trailing_quant_suffix(repo_lower);
without_quant
.strip_suffix("-mlx")
.unwrap_or(&without_quant)
.trim_matches('-')
.to_string()
}
fn strip_trailing_common_model_suffixes(name: &str) -> String {
let mut out = name.to_string();
loop {
let mut changed = false;
for suffix in ["-instruct", "-chat", "-hf", "-it", "-base"] {
if let Some(stripped) = out.strip_suffix(suffix) {
out = stripped.trim_end_matches('-').to_string();
changed = true;
break;
}
}
if !changed {
break;
}
}
out
}
fn explicit_mlx_repo_id(hf_name: &str) -> Option<String> {
if hf_name.matches('/').count() != 1 {
return None;
}
let mut parts = hf_name.splitn(2, '/');
let owner = parts.next()?.trim();
let repo = parts.next()?.trim();
if owner.is_empty() || repo.is_empty() || !is_likely_mlx_repo(owner, repo) {
return None;
}
Some(format!("{}/{}", owner.to_lowercase(), repo.to_lowercase()))
}
pub fn hf_name_to_mlx_candidates(hf_name: &str) -> Vec<String> {
let mut candidates = Vec::new();
if let Some(repo_id) = explicit_mlx_repo_id(hf_name) {
push_unique_candidate(&mut candidates, repo_id.clone());
if let Some(repo_name) = repo_id.split('/').next_back() {
push_unique_candidate(&mut candidates, repo_name.to_string());
}
}
let repo = hf_name.split('/').next_back().unwrap_or(hf_name);
let repo_lower = repo.to_lowercase();
push_unique_candidate(&mut candidates, repo_lower.clone());
let normalized_repo = normalize_mlx_repo_base(&repo_lower);
let mappings: &[(&str, &str)] = &[
("Llama-3.3-70B-Instruct", "Llama-3.3-70B-Instruct"),
("Llama-3.2-3B-Instruct", "Llama-3.2-3B-Instruct"),
("Llama-3.2-1B-Instruct", "Llama-3.2-1B-Instruct"),
("Llama-3.1-8B-Instruct", "Llama-3.1-8B-Instruct"),
("Llama-3.1-70B-Instruct", "Llama-3.1-70B-Instruct"),
("Qwen2.5-72B-Instruct", "Qwen2.5-72B-Instruct"),
("Qwen2.5-32B-Instruct", "Qwen2.5-32B-Instruct"),
("Qwen2.5-14B-Instruct", "Qwen2.5-14B-Instruct"),
("Qwen2.5-7B-Instruct", "Qwen2.5-7B-Instruct"),
("Qwen2.5-Coder-32B-Instruct", "Qwen2.5-Coder-32B-Instruct"),
("Qwen2.5-Coder-14B-Instruct", "Qwen2.5-Coder-14B-Instruct"),
("Qwen2.5-Coder-7B-Instruct", "Qwen2.5-Coder-7B-Instruct"),
("Qwen3-32B", "Qwen3-32B"),
("Qwen3-14B", "Qwen3-14B"),
("Qwen3-8B", "Qwen3-8B"),
("Qwen3-4B", "Qwen3-4B"),
("Qwen3-1.7B", "Qwen3-1.7B"),
("Qwen3-0.6B", "Qwen3-0.6B"),
("Qwen3-30B-A3B", "Qwen3-30B-A3B"),
("Qwen3-235B-A22B", "Qwen3-235B-A22B"),
("Qwen3.5-0.6B", "Qwen3.5-0.6B"),
("Qwen3.5-1.7B", "Qwen3.5-1.7B"),
("Qwen3.5-4B", "Qwen3.5-4B"),
("Qwen3.5-8B", "Qwen3.5-8B"),
("Qwen3.5-9B", "Qwen3.5-9B"),
("Qwen3.5-14B", "Qwen3.5-14B"),
("Qwen3.5-27B", "Qwen3.5-27B"),
("Qwen3.5-32B", "Qwen3.5-32B"),
("Qwen3.5-35B-A3B", "Qwen3.5-35B-A3B"),
("Qwen3.5-72B", "Qwen3.5-72B"),
("Qwen3.5-122B-A10B", "Qwen3.5-122B-A10B"),
("Qwen3.5-397B-A17B", "Qwen3.5-397B-A17B"),
("Mistral-7B-Instruct-v0.3", "Mistral-7B-Instruct-v0.3"),
(
"Mistral-Small-24B-Instruct-2501",
"Mistral-Small-24B-Instruct-2501",
),
("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x7B-Instruct-v0.1"),
(
"Mistral-Small-3.1-24B-Instruct-2503",
"Mistral-Small-3.1-24B-Instruct-2503",
),
("Ministral-8B-Instruct-2410", "Ministral-8B-Instruct-2410"),
("Mistral-Nemo-Instruct-2407", "Mistral-Nemo-Instruct-2407"),
(
"DeepSeek-R1-Distill-Qwen-32B",
"DeepSeek-R1-Distill-Qwen-32B",
),
("DeepSeek-R1-Distill-Qwen-7B", "DeepSeek-R1-Distill-Qwen-7B"),
(
"DeepSeek-R1-Distill-Qwen-14B",
"DeepSeek-R1-Distill-Qwen-14B",
),
(
"DeepSeek-R1-Distill-Llama-8B",
"DeepSeek-R1-Distill-Llama-8B",
),
(
"DeepSeek-R1-Distill-Llama-70B",
"DeepSeek-R1-Distill-Llama-70B",
),
("gemma-3-12b-it", "gemma-3-12b-it"),
("gemma-2-27b-it", "gemma-2-27b-it"),
("gemma-2-9b-it", "gemma-2-9b-it"),
("gemma-2-2b-it", "gemma-2-2b-it"),
("gemma-3-1b-it", "gemma-3-1b-it"),
("gemma-3-4b-it", "gemma-3-4b-it"),
("gemma-3-27b-it", "gemma-3-27b-it"),
("gemma-3n-E4B-it", "gemma-3n-E4B-it"),
("gemma-3n-E2B-it", "gemma-3n-E2B-it"),
("Phi-4", "Phi-4"),
("Phi-3.5-mini-instruct", "Phi-3.5-mini-instruct"),
("Phi-3-mini-4k-instruct", "Phi-3-mini-4k-instruct"),
("Phi-4-mini-instruct", "Phi-4-mini-instruct"),
("Phi-4-reasoning", "Phi-4-reasoning"),
("Phi-4-mini-reasoning", "Phi-4-mini-reasoning"),
(
"Llama-4-Scout-17B-16E-Instruct",
"Llama-4-Scout-17B-16E-Instruct",
),
(
"Llama-4-Maverick-17B-128E-Instruct",
"Llama-4-Maverick-17B-128E-Instruct",
),
];
for &(hf_suffix, mlx_base) in mappings {
let mapped_suffix = hf_suffix.to_lowercase();
if repo_lower == mapped_suffix || normalized_repo == mapped_suffix {
let base_lower = mlx_base.to_lowercase();
push_unique_candidate(&mut candidates, format!("{}-4bit", base_lower));
push_unique_candidate(&mut candidates, format!("{}-8bit", base_lower));
push_unique_candidate(&mut candidates, base_lower);
return candidates;
}
}
if !normalized_repo.is_empty() {
push_unique_candidate(&mut candidates, format!("{}-4bit", normalized_repo));
push_unique_candidate(&mut candidates, format!("{}-8bit", normalized_repo));
push_unique_candidate(&mut candidates, format!("{}-mlx-4bit", normalized_repo));
push_unique_candidate(&mut candidates, format!("{}-mlx-8bit", normalized_repo));
push_unique_candidate(&mut candidates, normalized_repo.clone());
}
let stripped = strip_trailing_common_model_suffixes(&normalized_repo);
if !stripped.is_empty() && stripped != normalized_repo {
push_unique_candidate(&mut candidates, format!("{}-4bit", stripped));
push_unique_candidate(&mut candidates, format!("{}-8bit", stripped));
push_unique_candidate(&mut candidates, format!("{}-mlx-4bit", stripped));
push_unique_candidate(&mut candidates, format!("{}-mlx-8bit", stripped));
push_unique_candidate(&mut candidates, stripped);
}
candidates
}
pub fn is_model_installed_mlx(hf_name: &str, installed: &HashSet<String>) -> bool {
let candidates = hf_name_to_mlx_candidates(hf_name);
candidates.iter().any(|c| installed.contains(c))
}
pub fn mlx_pull_tag(hf_name: &str) -> String {
if let Some(repo_id) = explicit_mlx_repo_id(hf_name) {
return repo_id;
}
let candidates = hf_name_to_mlx_candidates(hf_name);
candidates
.iter()
.find(|c| c.ends_with("-4bit"))
.cloned()
.unwrap_or_else(|| {
candidates.into_iter().next().unwrap_or_else(|| {
hf_name
.split('/')
.next_back()
.unwrap_or(hf_name)
.to_lowercase()
})
})
}
const OLLAMA_MAPPINGS: &[(&str, &str)] = &[
("llama-3.3-70b-instruct", "llama3.3:70b"),
("llama-3.2-11b-vision-instruct", "llama3.2-vision:11b"),
("llama-3.2-3b-instruct", "llama3.2:3b"),
("llama-3.2-3b", "llama3.2:3b"),
("llama-3.2-1b-instruct", "llama3.2:1b"),
("llama-3.2-1b", "llama3.2:1b"),
("llama-3.1-405b-instruct", "llama3.1:405b"),
("llama-3.1-405b", "llama3.1:405b"),
("llama-3.1-70b-instruct", "llama3.1:70b"),
("llama-3.1-8b-instruct", "llama3.1:8b"),
("llama-3.1-8b", "llama3.1:8b"),
("meta-llama-3-8b-instruct", "llama3:8b"),
("meta-llama-3-8b", "llama3:8b"),
("llama-2-7b-hf", "llama2:7b"),
("codellama-34b-instruct-hf", "codellama:34b"),
("codellama-13b-instruct-hf", "codellama:13b"),
("codellama-7b-instruct-hf", "codellama:7b"),
("gemma-3-12b-it", "gemma3:12b"),
("gemma-2-27b-it", "gemma2:27b"),
("gemma-2-9b-it", "gemma2:9b"),
("gemma-2-2b-it", "gemma2:2b"),
("phi-4", "phi4"),
("phi-4-mini-instruct", "phi4-mini"),
("phi-3.5-mini-instruct", "phi3.5"),
("phi-3-mini-4k-instruct", "phi3"),
("phi-3-medium-14b-instruct", "phi3:14b"),
("phi-2", "phi"),
("orca-2-7b", "orca2:7b"),
("orca-2-13b", "orca2:13b"),
("mistral-7b-instruct-v0.3", "mistral:7b"),
("mistral-7b-instruct-v0.2", "mistral:7b"),
("mistral-nemo-instruct-2407", "mistral-nemo"),
("mistral-small-24b-instruct-2501", "mistral-small:24b"),
("mistral-large-instruct-2407", "mistral-large"),
("mixtral-8x7b-instruct-v0.1", "mixtral:8x7b"),
("mixtral-8x22b-instruct-v0.1", "mixtral:8x22b"),
("qwen2-1.5b-instruct", "qwen2:1.5b"),
("qwen2.5-72b-instruct", "qwen2.5:72b"),
("qwen2.5-32b-instruct", "qwen2.5:32b"),
("qwen2.5-14b-instruct", "qwen2.5:14b"),
("qwen2.5-7b-instruct", "qwen2.5:7b"),
("qwen2.5-7b", "qwen2.5:7b"),
("qwen2.5-3b-instruct", "qwen2.5:3b"),
("qwen2.5-1.5b-instruct", "qwen2.5:1.5b"),
("qwen2.5-1.5b", "qwen2.5:1.5b"),
("qwen2.5-0.5b-instruct", "qwen2.5:0.5b"),
("qwen2.5-0.5b", "qwen2.5:0.5b"),
("qwen2.5-coder-32b-instruct", "qwen2.5-coder:32b"),
("qwen2.5-coder-14b-instruct", "qwen2.5-coder:14b"),
("qwen2.5-coder-7b-instruct", "qwen2.5-coder:7b"),
("qwen2.5-coder-1.5b-instruct", "qwen2.5-coder:1.5b"),
("qwen2.5-coder-0.5b-instruct", "qwen2.5-coder:0.5b"),
("qwen2.5-vl-7b-instruct", "qwen2.5vl:7b"),
("qwen2.5-vl-3b-instruct", "qwen2.5vl:3b"),
("qwen3-235b-a22b", "qwen3:235b"),
("qwen3-32b", "qwen3:32b"),
("qwen3-30b-a3b", "qwen3:30b-a3b"),
("qwen3-30b-a3b-instruct-2507", "qwen3:30b-a3b"),
("qwen3-14b", "qwen3:14b"),
("qwen3-8b", "qwen3:8b"),
("qwen3-4b", "qwen3:4b"),
("qwen3-4b-instruct-2507", "qwen3:4b"),
("qwen3-1.7b-base", "qwen3:1.7b"),
("qwen3-0.6b", "qwen3:0.6b"),
("qwen3-coder-30b-a3b-instruct", "qwen3-coder"),
("qwen3.5-27b", "qwen3.5"),
("qwen3.5-35b-a3b", "qwen3.5:35b"),
("qwen3.5-122b-a10b", "qwen3.5:122b"),
("qwen3-coder-next", "qwen3-coder-next"),
("deepseek-v3", "deepseek-v3"),
("deepseek-v3.2", "deepseek-v3"),
("deepseek-r1", "deepseek-r1"),
("deepseek-r1-0528", "deepseek-r1"),
("deepseek-r1-distill-qwen-32b", "deepseek-r1:32b"),
("deepseek-r1-distill-qwen-14b", "deepseek-r1:14b"),
("deepseek-r1-distill-qwen-7b", "deepseek-r1:7b"),
("deepseek-coder-v2-lite-instruct", "deepseek-coder-v2:16b"),
("tinyllama-1.1b-chat-v1.0", "tinyllama"),
("stablelm-2-1_6b-chat", "stablelm2:1.6b"),
("yi-6b-chat", "yi:6b"),
("yi-34b-chat", "yi:34b"),
("starcoder2-7b", "starcoder2:7b"),
("starcoder2-15b", "starcoder2:15b"),
("falcon-7b-instruct", "falcon:7b"),
("falcon-40b-instruct", "falcon:40b"),
("falcon-180b-chat", "falcon:180b"),
("falcon3-7b-instruct", "falcon3:7b"),
("openchat-3.5-0106", "openchat:7b"),
("vicuna-7b-v1.5", "vicuna:7b"),
("vicuna-13b-v1.5", "vicuna:13b"),
("glm-4-9b-chat", "glm4:9b"),
("solar-10.7b-instruct-v1.0", "solar:10.7b"),
("zephyr-7b-beta", "zephyr:7b"),
("c4ai-command-r-v01", "command-r"),
(
"nous-hermes-2-mixtral-8x7b-dpo",
"nous-hermes2-mixtral:8x7b",
),
("hermes-3-llama-3.1-8b", "hermes3:8b"),
("nomic-embed-text-v1.5", "nomic-embed-text"),
("bge-large-en-v1.5", "bge-large"),
("smollm2-135m-instruct", "smollm2:135m"),
("smollm2-135m", "smollm2:135m"),
("gemma-3n-e4b-it", "gemma3n:e4b"),
("gemma-3n-e2b-it", "gemma3n:e2b"),
("phi-4-reasoning", "phi4-reasoning"),
("phi-4-mini-reasoning", "phi4-mini-reasoning"),
("deepseek-v3.2-speciale", "deepseek-v3"),
("lfm2-350m", "lfm2:350m"),
("lfm2-700m", "lfm2:700m"),
("lfm2-1.2b", "lfm2:1.2b"),
("lfm2-2.6b", "lfm2:2.6b"),
("lfm2-2.6b-exp", "lfm2:2.6b"),
("lfm2-8b-a1b", "lfm2:8b-a1b"),
("lfm2-24b-a2b", "lfm2:24b"),
("lfm2.5-1.2b-instruct", "lfm2.5:1.2b"),
("lfm2.5-1.2b-thinking", "lfm2.5-thinking:1.2b"),
];
fn lookup_ollama_tag(hf_name: &str) -> Option<&'static str> {
let repo = hf_name
.split('/')
.next_back()
.unwrap_or(hf_name)
.to_lowercase();
OLLAMA_MAPPINGS
.iter()
.find(|&&(hf_suffix, _)| repo == hf_suffix)
.map(|&(_, tag)| tag)
}
pub fn hf_name_to_ollama_candidates(hf_name: &str) -> Vec<String> {
match lookup_ollama_tag(hf_name) {
Some(tag) => vec![tag.to_string()],
None => vec![],
}
}
pub fn has_ollama_mapping(hf_name: &str) -> bool {
lookup_ollama_tag(hf_name).is_some()
}
fn ollama_installed_matches_candidate(installed_name: &str, candidate: &str) -> bool {
if installed_name == candidate {
return true;
}
if candidate.contains(':') {
return installed_name.starts_with(&format!("{candidate}-"));
}
false
}
pub fn is_model_installed(hf_name: &str, installed: &HashSet<String>) -> bool {
let candidates = hf_name_to_ollama_candidates(hf_name);
candidates.iter().any(|candidate| {
installed
.iter()
.any(|installed_name| ollama_installed_matches_candidate(installed_name, candidate))
})
}
pub fn ollama_pull_tag(hf_name: &str) -> Option<String> {
lookup_ollama_tag(hf_name).map(|s| s.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hf_name_to_mlx_candidates() {
let candidates = hf_name_to_mlx_candidates("meta-llama/Llama-3.1-8B-Instruct");
assert!(
candidates
.iter()
.any(|c| c.contains("llama-3.1-8b-instruct"))
);
assert!(candidates.iter().any(|c| c.ends_with("-4bit")));
assert!(candidates.iter().any(|c| c.ends_with("-8bit")));
let qwen = hf_name_to_mlx_candidates("Qwen/Qwen2.5-Coder-14B-Instruct");
assert!(
qwen.iter()
.any(|c| c.contains("qwen2.5-coder-14b-instruct"))
);
}
#[test]
fn test_hf_name_to_mlx_candidates_qwen35() {
let candidates = hf_name_to_mlx_candidates("Qwen/Qwen3.5-9B");
assert!(candidates.iter().any(|c| c == "qwen3.5-9b-4bit"));
assert!(candidates.iter().any(|c| c == "qwen3.5-9b-8bit"));
}
#[test]
fn test_hf_name_to_mlx_candidates_llama4() {
let candidates = hf_name_to_mlx_candidates("meta-llama/Llama-4-Scout-17B-16E-Instruct");
assert!(candidates.iter().any(|c| c.contains("llama-4-scout")));
assert!(candidates.iter().any(|c| c.ends_with("-4bit")));
}
#[test]
fn test_hf_name_to_mlx_candidates_gemma3() {
let candidates = hf_name_to_mlx_candidates("google/gemma-3-27b-it");
assert!(candidates.iter().any(|c| c == "gemma-3-27b-it-4bit"));
assert!(candidates.iter().any(|c| c == "gemma-3-27b-it-8bit"));
}
#[test]
fn test_hf_name_to_mlx_fallback_generates_mlx_infix_candidates() {
let candidates = hf_name_to_mlx_candidates("SomeOrg/SomeNewModel-7B");
assert!(candidates.iter().any(|c| c == "somenewmodel-7b-mlx-4bit"));
assert!(candidates.iter().any(|c| c == "somenewmodel-7b-mlx-8bit"));
}
#[test]
fn test_hf_name_to_mlx_candidates_normalizes_explicit_mlx_repo() {
let candidates =
hf_name_to_mlx_candidates("lmstudio-community/Qwen3-Coder-30B-A3B-Instruct-MLX-8bit");
assert!(
candidates
.contains(&"lmstudio-community/qwen3-coder-30b-a3b-instruct-mlx-8bit".to_string())
);
assert!(candidates.contains(&"qwen3-coder-30b-a3b-instruct-4bit".to_string()));
assert!(candidates.contains(&"qwen3-coder-30b-a3b-instruct-8bit".to_string()));
assert!(!candidates.iter().any(|c| c.contains("-8bit-4bit")));
assert!(!candidates.iter().any(|c| c.contains("-8bit-8bit")));
}
#[test]
fn test_mlx_pull_tag_prefers_explicit_repo_id() {
let tag = mlx_pull_tag("lmstudio-community/Qwen3-Coder-30B-A3B-Instruct-MLX-8bit");
assert_eq!(
tag,
"lmstudio-community/qwen3-coder-30b-a3b-instruct-mlx-8bit"
);
}
#[test]
fn test_mlx_cache_scan_parsing() {
let mut installed = HashSet::new();
installed.insert("llama-3.1-8b-instruct-4bit".to_string());
assert!(is_model_installed_mlx(
"meta-llama/Llama-3.1-8B-Instruct",
&installed
));
assert!(!is_model_installed_mlx(
"Qwen/Qwen2.5-7B-Instruct",
&installed
));
}
#[test]
fn test_is_model_installed_mlx() {
let mut installed = HashSet::new();
installed.insert("qwen2.5-coder-14b-instruct-8bit".to_string());
assert!(is_model_installed_mlx(
"Qwen/Qwen2.5-Coder-14B-Instruct",
&installed
));
assert!(!is_model_installed_mlx(
"Qwen/Qwen2.5-14B-Instruct",
&installed
));
}
#[test]
fn test_is_model_installed_mlx_with_owner_prefixed_repo_id() {
let mut installed = HashSet::new();
installed.insert("lmstudio-community/qwen3-coder-30b-a3b-instruct-mlx-8bit".to_string());
assert!(is_model_installed_mlx(
"lmstudio-community/Qwen3-Coder-30B-A3B-Instruct-MLX-8bit",
&installed
));
}
#[test]
fn test_qwen_coder_14b_matches_coder_entry() {
let mut installed = HashSet::new();
installed.insert("qwen2.5-coder:14b".to_string());
installed.insert("qwen2.5-coder".to_string());
assert!(is_model_installed(
"Qwen/Qwen2.5-Coder-14B-Instruct",
&installed
));
assert!(!is_model_installed("Qwen/Qwen2.5-14B-Instruct", &installed));
}
#[test]
fn test_qwen_base_does_not_match_coder() {
let mut installed = HashSet::new();
installed.insert("qwen2.5:14b".to_string());
installed.insert("qwen2.5".to_string());
assert!(is_model_installed("Qwen/Qwen2.5-14B-Instruct", &installed));
assert!(!is_model_installed(
"Qwen/Qwen2.5-Coder-14B-Instruct",
&installed
));
}
#[test]
fn test_installed_variant_suffix_matches_ollama_candidate() {
let mut installed = HashSet::new();
installed.insert("qwen2.5-coder:7b-instruct".to_string());
assert!(is_model_installed(
"Qwen/Qwen2.5-Coder-7B-Instruct",
&installed
));
}
#[test]
fn test_candidates_for_coder_model() {
let candidates = hf_name_to_ollama_candidates("Qwen/Qwen2.5-Coder-14B-Instruct");
assert!(candidates.contains(&"qwen2.5-coder:14b".to_string()));
}
#[test]
fn test_candidates_for_base_model() {
let candidates = hf_name_to_ollama_candidates("Qwen/Qwen2.5-14B-Instruct");
assert!(candidates.contains(&"qwen2.5:14b".to_string()));
}
#[test]
fn test_llama_mapping() {
let candidates = hf_name_to_ollama_candidates("meta-llama/Llama-3.1-8B-Instruct");
assert!(candidates.contains(&"llama3.1:8b".to_string()));
}
#[test]
fn test_deepseek_coder_mapping() {
let candidates =
hf_name_to_ollama_candidates("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct");
assert!(candidates.contains(&"deepseek-coder-v2:16b".to_string()));
}
#[test]
fn test_normalize_ollama_host_with_scheme() {
assert_eq!(
normalize_ollama_host("https://ollama.example.com:11434"),
Some("https://ollama.example.com:11434".to_string())
);
}
#[test]
fn test_normalize_ollama_host_without_scheme() {
assert_eq!(
normalize_ollama_host("ollama.example.com:11434"),
Some("http://ollama.example.com:11434".to_string())
);
}
#[test]
fn test_normalize_ollama_host_rejects_unsupported_scheme() {
assert_eq!(
normalize_ollama_host("ftp://ollama.example.com:11434"),
None
);
}
#[test]
fn test_validate_gguf_filename_valid() {
assert!(validate_gguf_filename("Llama-3.1-8B-Q4_K_M.gguf").is_ok());
assert!(validate_gguf_filename("model.gguf").is_ok());
}
#[test]
fn test_validate_gguf_filename_traversal() {
assert!(validate_gguf_filename("../../outside.gguf").is_err());
assert!(validate_gguf_filename("../evil.gguf").is_err());
assert!(validate_gguf_filename("foo/../bar.gguf").is_err());
}
#[test]
fn test_validate_gguf_filename_absolute() {
assert!(validate_gguf_filename("/etc/passwd").is_err());
assert!(validate_gguf_filename("/tmp/model.gguf").is_err());
}
#[test]
fn test_validate_gguf_filename_bad_extension() {
assert!(validate_gguf_filename("malware.exe").is_err());
assert!(validate_gguf_filename("script.sh").is_err());
assert!(validate_gguf_filename("./model.guuf").is_err());
}
#[test]
fn test_validate_gguf_filename_empty() {
assert!(validate_gguf_filename("").is_err());
}
#[test]
fn test_validate_gguf_filename_subdirectory() {
assert!(validate_gguf_filename("subdir/model.gguf").is_err());
}
#[test]
fn test_validate_gguf_filename_rejects_non_basename_forms() {
assert!(validate_gguf_filename("./model.gguf").is_err());
assert!(validate_gguf_filename("model.gguf/").is_err());
assert!(validate_gguf_filename(".\\model.gguf").is_err());
assert!(validate_gguf_filename("C:/models/model.gguf").is_err());
assert!(validate_gguf_filename("C:\\models\\model.gguf").is_err());
}
#[test]
fn test_validate_gguf_repo_path_valid() {
assert!(validate_gguf_repo_path("model.gguf").is_ok());
assert!(validate_gguf_repo_path("Q4_K_M/model.gguf").is_ok());
assert!(validate_gguf_repo_path("deep/nested/model.gguf").is_ok());
}
#[test]
fn test_validate_gguf_repo_path_rejects_traversal() {
assert!(validate_gguf_repo_path("../escape.gguf").is_err());
assert!(validate_gguf_repo_path("foo/../bar.gguf").is_err());
assert!(validate_gguf_repo_path("./model.gguf").is_err());
}
#[test]
fn test_validate_gguf_repo_path_rejects_absolute() {
assert!(validate_gguf_repo_path("/etc/passwd").is_err());
assert!(validate_gguf_repo_path("/tmp/model.gguf").is_err());
}
#[test]
fn test_validate_gguf_repo_path_rejects_backslash() {
assert!(validate_gguf_repo_path("dir\\model.gguf").is_err());
assert!(validate_gguf_repo_path("C:\\models\\model.gguf").is_err());
}
#[test]
fn test_validate_gguf_repo_path_rejects_non_gguf() {
assert!(validate_gguf_repo_path("malware.exe").is_err());
assert!(validate_gguf_repo_path("subdir/readme.md").is_err());
}
#[test]
fn test_validate_gguf_repo_path_rejects_empty() {
assert!(validate_gguf_repo_path("").is_err());
}
#[test]
fn test_parse_repo_gguf_entries_filters_unsafe_paths() {
let entries = vec![
serde_json::json!({"path": "good.gguf", "size": 123u64}),
serde_json::json!({"path": "../escape.gguf", "size": 456u64}),
serde_json::json!({"path": "nested/model.gguf", "size": 789u64}),
serde_json::json!({"path": "./model.gguf", "size": 99u64}),
serde_json::json!({"path": "readme.md", "size": 12u64}),
];
let files = parse_repo_gguf_entries(entries);
assert_eq!(
files,
vec![
("good.gguf".to_string(), 123u64),
("nested/model.gguf".to_string(), 789u64),
]
);
}
#[test]
fn test_hf_name_to_gguf_candidates_generates_common_patterns() {
let candidates = hf_name_to_gguf_candidates("SomeOrg/Cool-Model-7B");
assert!(
candidates
.iter()
.any(|c| c == "bartowski/Cool-Model-7B-GGUF"),
"Should generate bartowski candidate, got: {:?}",
candidates
);
assert!(
candidates
.iter()
.any(|c| c == "ggml-org/Cool-Model-7B-GGUF"),
"Should generate ggml-org candidate, got: {:?}",
candidates
);
assert!(
candidates
.iter()
.any(|c| c == "TheBloke/Cool-Model-7B-GGUF"),
"Should generate TheBloke candidate, got: {:?}",
candidates
);
}
#[test]
fn test_hf_name_to_gguf_candidates_strips_owner() {
let candidates = hf_name_to_gguf_candidates("Qwen/Qwen2.5-7B-Instruct");
for c in &candidates {
assert!(
!c.contains("Qwen/Qwen"),
"Candidate should not contain original owner prefix: {}",
c
);
}
}
#[test]
fn test_lookup_gguf_repo_known_mappings() {
assert!(lookup_gguf_repo("meta-llama/Llama-3.1-8B-Instruct").is_some());
assert!(lookup_gguf_repo("deepseek-r1").is_some());
}
#[test]
fn test_lookup_gguf_repo_unknown_returns_none() {
assert!(lookup_gguf_repo("totally-unknown/model-xyz").is_none());
}
#[test]
fn test_has_gguf_mapping_matches_known_models() {
assert!(has_gguf_mapping("meta-llama/Llama-3.1-8B-Instruct"));
assert!(!has_gguf_mapping("some-random/UnknownModel"));
}
#[test]
fn test_gguf_candidates_fallback_covers_major_providers() {
let candidates = hf_name_to_gguf_candidates("SomeOrg/NewModel-7B");
assert!(candidates.iter().any(|c| c.starts_with("bartowski/")));
assert!(candidates.iter().any(|c| c.starts_with("ggml-org/")));
assert!(candidates.iter().any(|c| c.starts_with("TheBloke/")));
assert!(candidates.iter().all(|c| c.ends_with("-GGUF")));
}
#[test]
fn test_gguf_candidates_known_mapping_returns_single() {
let candidates = hf_name_to_gguf_candidates("meta-llama/Llama-3.1-8B-Instruct");
assert_eq!(candidates.len(), 1);
assert!(candidates[0].contains("GGUF"));
}
#[test]
fn test_select_best_gguf_prefers_higher_quality() {
let files = vec![
("model-Q2_K.gguf".to_string(), 2_000_000_000u64),
("model-Q4_K_M.gguf".to_string(), 4_000_000_000u64),
("model-Q8_0.gguf".to_string(), 8_000_000_000u64),
];
let result = LlamaCppProvider::select_best_gguf(&files, 10.0);
assert!(result.is_some());
let (name, _) = result.unwrap();
assert!(name.contains("Q8_0"), "should prefer Q8, got: {}", name);
}
#[test]
fn test_select_best_gguf_respects_budget() {
let files = vec![
("model-Q2_K.gguf".to_string(), 2_000_000_000u64),
("model-Q4_K_M.gguf".to_string(), 4_000_000_000u64),
("model-Q8_0.gguf".to_string(), 8_000_000_000u64),
];
let result = LlamaCppProvider::select_best_gguf(&files, 3.7);
assert!(result.is_some());
let (name, _) = result.unwrap();
assert!(
name.contains("Q2_K"),
"should select Q2_K for 3.7GB budget, got: {}",
name
);
}
#[test]
fn test_select_best_gguf_nothing_fits() {
let files = vec![("model-Q2_K.gguf".to_string(), 8_000_000_000u64)];
let result = LlamaCppProvider::select_best_gguf(&files, 1.0);
assert!(result.is_none());
}
#[test]
fn test_select_best_gguf_prefers_shard_group_over_lower_quant() {
let files = vec![
(
"model-Q4_K_M-00001-of-00003.gguf".to_string(),
4_000_000_000u64,
),
(
"model-Q4_K_M-00002-of-00003.gguf".to_string(),
4_000_000_000u64,
),
(
"model-Q4_K_M-00003-of-00003.gguf".to_string(),
4_000_000_000u64,
),
("model-Q2_K.gguf".to_string(), 2_000_000_000u64),
];
let (name, size) = LlamaCppProvider::select_best_gguf(&files, 16.0).unwrap();
assert!(name.contains("Q4_K_M-00001-of-00003"), "got: {}", name);
assert_eq!(size, 12_000_000_000u64);
}
#[test]
fn test_select_best_gguf_empty_list() {
let result = LlamaCppProvider::select_best_gguf(&[], 10.0);
assert!(result.is_none());
}
#[test]
fn test_parse_shard_info_smoke() {
assert!(parse_shard_info("model-00001-of-00003.gguf").is_some());
assert!(parse_shard_info("model-Q4_K_M.gguf").is_none());
assert!(parse_shard_info("model.gguf").is_none());
}
#[test]
fn test_parse_shard_info_basic() {
assert_eq!(
parse_shard_info("Qwen3-Coder-Next-Q5_K_M-00001-of-00003.gguf"),
Some((1, 3))
);
assert_eq!(
parse_shard_info("Q5_K_M/Qwen3-Coder-Next-Q5_K_M-00003-of-00003.gguf"),
Some((3, 3))
);
}
#[test]
fn test_parse_shard_info_rejects_non_shards() {
assert_eq!(parse_shard_info("model.gguf"), None);
assert_eq!(parse_shard_info("model-Q4_K_M.gguf"), None);
assert_eq!(parse_shard_info("model-of-tea.gguf"), None);
assert_eq!(parse_shard_info("model-00001-of-00003.bin"), None);
assert_eq!(parse_shard_info("model-00004-of-00003.gguf"), None);
assert_eq!(parse_shard_info("model-00000-of-00003.gguf"), None);
}
#[test]
fn test_collect_shard_set_returns_all_shards_sorted() {
let files = vec![
(
"Q5_K_M/Qwen3-Coder-Next-Q5_K_M-00002-of-00003.gguf".to_string(),
3_000_000_000u64,
),
(
"Q5_K_M/Qwen3-Coder-Next-Q5_K_M-00001-of-00003.gguf".to_string(),
3_000_000_000u64,
),
(
"Q5_K_M/Qwen3-Coder-Next-Q5_K_M-00003-of-00003.gguf".to_string(),
2_500_000_000u64,
),
(
"Q4_K_M/Qwen3-Coder-Next-Q4_K_M.gguf".to_string(),
4_000_000_000u64,
),
];
let shards =
collect_shard_set(&files, "Q5_K_M/Qwen3-Coder-Next-Q5_K_M-00001-of-00003.gguf")
.expect("should detect shard set");
assert_eq!(shards.len(), 3);
assert!(shards[0].0.contains("00001-of-00003"));
assert!(shards[1].0.contains("00002-of-00003"));
assert!(shards[2].0.contains("00003-of-00003"));
}
#[test]
fn test_collect_shard_set_returns_none_for_non_shard() {
let files = vec![("model-Q4_K_M.gguf".to_string(), 4_000_000_000u64)];
assert!(collect_shard_set(&files, "model-Q4_K_M.gguf").is_none());
}
#[test]
fn test_collect_shard_set_does_not_mix_groups() {
let files = vec![
("Q4_K_M/m-Q4_K_M-00001-of-00002.gguf".to_string(), 1_000),
("Q4_K_M/m-Q4_K_M-00002-of-00002.gguf".to_string(), 1_000),
("Q5_K_M/m-Q5_K_M-00001-of-00003.gguf".to_string(), 2_000),
("Q5_K_M/m-Q5_K_M-00002-of-00003.gguf".to_string(), 2_000),
("Q5_K_M/m-Q5_K_M-00003-of-00003.gguf".to_string(), 2_000),
];
let q4 = collect_shard_set(&files, "Q4_K_M/m-Q4_K_M-00001-of-00002.gguf").unwrap();
assert_eq!(q4.len(), 2);
let q5 = collect_shard_set(&files, "Q5_K_M/m-Q5_K_M-00002-of-00003.gguf").unwrap();
assert_eq!(q5.len(), 3);
}
#[test]
fn test_select_best_gguf_picks_shard_group() {
let files = vec![
(
"Q5_K_M/m-Q5_K_M-00001-of-00003.gguf".to_string(),
3_000_000_000u64,
),
(
"Q5_K_M/m-Q5_K_M-00002-of-00003.gguf".to_string(),
3_000_000_000u64,
),
(
"Q5_K_M/m-Q5_K_M-00003-of-00003.gguf".to_string(),
2_000_000_000u64,
),
];
let (path, size) = LlamaCppProvider::select_best_gguf(&files, 16.0)
.expect("shard group should be selectable");
assert!(path.contains("00001-of-00003"), "got: {}", path);
assert_eq!(size, 8_000_000_000u64);
}
#[test]
fn test_select_best_gguf_shard_group_respects_budget() {
let files = vec![
(
"Q5_K_M/m-Q5_K_M-00001-of-00003.gguf".to_string(),
3_000_000_000u64,
),
(
"Q5_K_M/m-Q5_K_M-00002-of-00003.gguf".to_string(),
3_000_000_000u64,
),
(
"Q5_K_M/m-Q5_K_M-00003-of-00003.gguf".to_string(),
2_000_000_000u64,
),
("Q2_K/m-Q2_K.gguf".to_string(), 1_500_000_000u64),
];
let (path, _) = LlamaCppProvider::select_best_gguf(&files, 4.0).unwrap();
assert!(path.contains("Q2_K") && !path.contains("-of-"));
}
#[test]
fn test_urlencoding_ascii() {
assert_eq!(urlencoding::encode("hello"), "hello");
assert_eq!(urlencoding::encode("test-model_v1.0"), "test-model_v1.0");
}
#[test]
fn test_urlencoding_special_chars() {
assert_eq!(urlencoding::encode("hello world"), "hello%20world");
assert_eq!(urlencoding::encode("a+b"), "a%2Bb");
assert_eq!(urlencoding::encode("foo/bar"), "foo%2Fbar");
}
#[test]
fn test_urlencoding_empty() {
assert_eq!(urlencoding::encode(""), "");
}
#[test]
fn test_is_model_installed_llamacpp_exact() {
let mut installed = HashSet::new();
installed.insert("llama-3.1-8b-instruct".to_string());
assert!(is_model_installed_llamacpp(
"meta-llama/Llama-3.1-8B-Instruct",
&installed
));
}
#[test]
fn test_is_model_installed_llamacpp_stripped_suffixes() {
let mut installed = HashSet::new();
installed.insert("llama-3.1-8b".to_string());
assert!(is_model_installed_llamacpp(
"meta-llama/Llama-3.1-8B-Instruct",
&installed
));
}
#[test]
fn test_is_model_installed_llamacpp_not_installed() {
let installed = HashSet::new();
assert!(!is_model_installed_llamacpp(
"meta-llama/Llama-3.1-8B-Instruct",
&installed
));
}
#[test]
fn test_gguf_pull_tag_known() {
let tag = gguf_pull_tag("meta-llama/Llama-3.1-8B-Instruct");
assert!(tag.is_some());
assert!(tag.unwrap().contains("GGUF"));
}
#[test]
fn test_gguf_pull_tag_unknown() {
assert!(gguf_pull_tag("totally-unknown/model-xyz").is_none());
}
#[test]
fn test_has_ollama_mapping_known() {
assert!(has_ollama_mapping("meta-llama/Llama-3.1-8B-Instruct"));
assert!(has_ollama_mapping("Qwen/Qwen2.5-7B-Instruct"));
}
#[test]
fn test_has_ollama_mapping_unknown() {
assert!(!has_ollama_mapping("totally-unknown/model-xyz"));
}
#[test]
fn test_ollama_pull_tag_known() {
let tag = ollama_pull_tag("meta-llama/Llama-3.1-8B-Instruct");
assert_eq!(tag, Some("llama3.1:8b".to_string()));
}
#[test]
fn test_ollama_pull_tag_unknown() {
assert!(ollama_pull_tag("totally-unknown/model-xyz").is_none());
}
#[test]
fn test_mlx_pull_tag_prefers_4bit() {
let tag = mlx_pull_tag("meta-llama/Llama-3.1-8B-Instruct");
assert!(tag.ends_with("-4bit"), "should prefer 4bit, got: {}", tag);
}
#[test]
fn test_mlx_pull_tag_fallback() {
let tag = mlx_pull_tag("SomeUnknown/Model-7B");
assert!(!tag.is_empty());
}
#[test]
fn test_ollama_installed_matches_exact() {
assert!(ollama_installed_matches_candidate(
"llama3.1:8b",
"llama3.1:8b"
));
}
#[test]
fn test_ollama_installed_matches_variant_suffix() {
assert!(ollama_installed_matches_candidate(
"llama3.1:8b-instruct-q4_K_M",
"llama3.1:8b"
));
}
#[test]
fn test_ollama_installed_no_match() {
assert!(!ollama_installed_matches_candidate(
"qwen2.5:7b",
"llama3.1:8b"
));
}
#[test]
fn test_parse_repo_gguf_entries_valid() {
let entries = vec![
serde_json::json!({"path": "model-Q4_K_M.gguf", "size": 4_000_000_000u64}),
serde_json::json!({"path": "model-Q8_0.gguf", "size": 8_000_000_000u64}),
];
let files = parse_repo_gguf_entries(entries);
assert_eq!(files.len(), 2);
assert_eq!(files[0].0, "model-Q4_K_M.gguf");
assert_eq!(files[1].0, "model-Q8_0.gguf");
}
#[test]
fn test_parse_repo_gguf_entries_missing_size_defaults_to_zero() {
let entries = vec![serde_json::json!({"path": "model.gguf"})];
let files = parse_repo_gguf_entries(entries);
assert_eq!(files.len(), 1);
assert_eq!(files[0].1, 0);
}
#[test]
fn test_parse_repo_gguf_entries_skips_non_gguf() {
let entries = vec![
serde_json::json!({"path": "README.md", "size": 1000u64}),
serde_json::json!({"path": "config.json", "size": 500u64}),
serde_json::json!({"path": "model.gguf", "size": 4_000_000_000u64}),
];
let files = parse_repo_gguf_entries(entries);
assert_eq!(files.len(), 1);
assert_eq!(files[0].0, "model.gguf");
}
#[test]
fn test_hf_name_to_mlx_candidates_bare_model_name() {
let candidates = hf_name_to_mlx_candidates("Phi-4");
assert!(candidates.iter().any(|c| c.contains("phi-4")));
assert!(candidates.iter().any(|c| c.ends_with("-4bit")));
}
#[test]
fn test_hf_name_to_mlx_candidates_no_duplicates() {
let candidates = hf_name_to_mlx_candidates("meta-llama/Llama-3.1-8B-Instruct");
let unique: HashSet<_> = candidates.iter().collect();
assert_eq!(
unique.len(),
candidates.len(),
"candidates should have no duplicates: {:?}",
candidates
);
}
#[test]
fn test_hf_name_to_ollama_candidates_unknown_returns_empty() {
let candidates = hf_name_to_ollama_candidates("totally-unknown/model-xyz");
assert!(candidates.is_empty());
}
#[test]
fn test_hf_name_to_ollama_candidates_multiple_models() {
assert!(!hf_name_to_ollama_candidates("meta-llama/Llama-3.1-8B-Instruct").is_empty());
assert!(!hf_name_to_ollama_candidates("Qwen/Qwen2.5-Coder-7B-Instruct").is_empty());
assert!(!hf_name_to_ollama_candidates("google/gemma-2-9b-it").is_empty());
}
#[test]
fn test_docker_mr_catalog_parses() {
let catalog = docker_mr_catalog();
assert!(!catalog.is_empty(), "Docker MR catalog should not be empty");
}
#[test]
fn test_has_docker_mr_mapping_known() {
assert!(has_docker_mr_mapping("meta-llama/Llama-3.1-70B-Instruct"));
}
#[test]
fn test_has_docker_mr_mapping_unknown() {
assert!(!has_docker_mr_mapping("totally-unknown/model-xyz"));
}
#[test]
fn test_docker_mr_pull_tag_returns_ai_prefixed() {
let tag = docker_mr_pull_tag("meta-llama/Llama-3.1-70B-Instruct");
assert!(tag.is_some());
assert!(tag.unwrap().starts_with("ai/"));
}
#[test]
fn test_docker_mr_candidates_includes_ai_prefix() {
let candidates = hf_name_to_docker_mr_candidates("meta-llama/Llama-3.1-70B-Instruct");
assert!(candidates.iter().any(|c| c.starts_with("ai/")));
}
#[test]
fn test_docker_mr_candidates_unknown_returns_empty() {
let candidates = hf_name_to_docker_mr_candidates("totally-unknown/model-xyz");
assert!(candidates.is_empty());
}
#[test]
fn test_is_model_installed_docker_mr_exact() {
let mut installed = HashSet::new();
installed.insert("ai/llama3.1:70b".to_string());
installed.insert("llama3.1:70b".to_string());
installed.insert("llama3.1".to_string());
assert!(is_model_installed_docker_mr(
"meta-llama/Llama-3.1-70B-Instruct",
&installed
));
}
#[test]
fn test_is_model_installed_docker_mr_variant_suffix() {
let mut installed = HashSet::new();
installed.insert("ai/llama3.1:70b-q4_k_m".to_string());
assert!(is_model_installed_docker_mr(
"meta-llama/Llama-3.1-70B-Instruct",
&installed
));
}
#[test]
fn test_is_model_installed_docker_mr_not_installed() {
let installed = HashSet::new();
assert!(!is_model_installed_docker_mr(
"meta-llama/Llama-3.1-70B-Instruct",
&installed
));
}
#[test]
fn test_normalize_docker_mr_host_with_scheme() {
assert_eq!(
normalize_docker_mr_host("https://docker.example.com:12434"),
Some("https://docker.example.com:12434".to_string())
);
}
#[test]
fn test_normalize_docker_mr_host_without_scheme() {
assert_eq!(
normalize_docker_mr_host("docker.example.com:12434"),
Some("http://docker.example.com:12434".to_string())
);
}
#[test]
fn test_normalize_docker_mr_host_rejects_unsupported_scheme() {
assert_eq!(
normalize_docker_mr_host("ftp://docker.example.com:12434"),
None
);
}
}