use crate::config::Settings;
use gobby_core::local_backend::Backend;
use std::fmt;
use std::time::Duration;
#[derive(Debug)]
pub enum ModelError {
NotFound(String),
PullFailed(String),
WarmupFailed(String),
NetworkError(String),
}
impl fmt::Display for ModelError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotFound(m) => write!(f, "model '{}' not found", m),
Self::PullFailed(msg) => write!(f, "pull failed: {}", msg),
Self::WarmupFailed(msg) => write!(f, "warmup failed: {}", msg),
Self::NetworkError(msg) => write!(f, "network error: {}", msg),
}
}
}
pub fn ensure_model_ready(
backend: &Backend,
model: &str,
settings: &Settings,
) -> Result<(), ModelError> {
if backend.name != "ollama" {
return Ok(());
}
match ollama_check_model(backend, model, settings.probe_timeout_ms) {
Ok(true) => Ok(()),
Ok(false) => {
if settings.auto_load {
eprintln!("gloc: loading {} into ollama...", model);
ollama_warmup_model(backend, model)?;
}
Ok(())
}
Err(ModelError::NotFound(m)) => {
if settings.auto_pull {
eprintln!("gloc: pulling {} from ollama registry...", m);
ollama_pull_model(backend, &m)?;
if settings.auto_load {
eprintln!("gloc: loading {} into ollama...", m);
ollama_warmup_model(backend, &m)?;
}
Ok(())
} else {
Err(ModelError::NotFound(m))
}
}
Err(e) => Err(e),
}
}
fn ollama_check_model(backend: &Backend, model: &str, timeout_ms: u64) -> Result<bool, ModelError> {
let timeout = Duration::from_millis(timeout_ms.max(5000));
let agent = ureq::AgentBuilder::new()
.timeout_connect(timeout)
.timeout_read(timeout)
.build();
let tags_url = format!("{}/api/tags", backend.url);
let resp: serde_json::Value = agent
.get(&tags_url)
.call()
.map_err(|e| ModelError::NetworkError(e.to_string()))?
.into_json()
.map_err(|e| ModelError::NetworkError(e.to_string()))?;
let models = resp
.get("models")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
let downloaded = models.iter().any(|m| model_name_matches(m, model));
if !downloaded {
return Err(ModelError::NotFound(model.to_string()));
}
let ps_url = format!("{}/api/ps", backend.url);
if let Ok(ps_resp) = agent.get(&ps_url).call()
&& let Ok(ps_json) = ps_resp.into_json::<serde_json::Value>()
{
let loaded = ps_json
.get("models")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().any(|m| model_name_matches(m, model)))
.unwrap_or(false);
return Ok(loaded);
}
Ok(false)
}
fn ollama_pull_model(backend: &Backend, model: &str) -> Result<(), ModelError> {
let url = format!("{}/api/pull", backend.url);
let agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_secs(10))
.timeout_read(Duration::from_secs(600))
.build();
let body = serde_json::json!({
"model": model,
"stream": false
});
agent
.post(&url)
.send_json(body)
.map_err(|e| ModelError::PullFailed(e.to_string()))?;
Ok(())
}
fn ollama_warmup_model(backend: &Backend, model: &str) -> Result<(), ModelError> {
let url = format!("{}/api/generate", backend.url);
let agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_secs(10))
.timeout_read(Duration::from_secs(120))
.build();
let body = serde_json::json!({
"model": model
});
agent
.post(&url)
.send_json(body)
.map_err(|e| ModelError::WarmupFailed(e.to_string()))?;
Ok(())
}
fn model_name_matches(entry: &serde_json::Value, model: &str) -> bool {
entry
.get("name")
.or_else(|| entry.get("model"))
.and_then(|n| n.as_str())
.map(|n| {
n == model || n == format!("{}:latest", model) || n.starts_with(&format!("{}:", model))
})
.unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
fn unreachable_backend() -> Backend {
Backend {
name: "fake".into(),
url: "http://127.0.0.1:19999".into(),
probe: "/".into(),
auth_token: "".into(),
}
}
#[test]
fn test_detect_backend_none_running() {
let backends = vec![unreachable_backend()];
assert!(gobby_core::local_backend::detect_backend(&backends, 100).is_none());
}
#[test]
fn test_validate_backend_unreachable() {
assert!(!gobby_core::local_backend::validate_backend(
&unreachable_backend(),
100
));
}
#[test]
fn test_ensure_model_ready_non_ollama_is_noop() {
let backend = Backend {
name: "lmstudio".into(),
url: "http://127.0.0.1:19999".into(),
probe: "/v1/models".into(),
auth_token: "lmstudio".into(),
};
let settings = Settings::default();
assert!(ensure_model_ready(&backend, "any-model", &settings).is_ok());
}
#[test]
fn test_model_name_matches_exact() {
let entry = serde_json::json!({"name": "qwen3-coder"});
assert!(model_name_matches(&entry, "qwen3-coder"));
}
#[test]
fn test_model_name_matches_with_latest() {
let entry = serde_json::json!({"name": "qwen3-coder:latest"});
assert!(model_name_matches(&entry, "qwen3-coder"));
}
#[test]
fn test_model_name_matches_with_tag() {
let entry = serde_json::json!({"name": "glm-4.7:cloud"});
assert!(model_name_matches(&entry, "glm-4.7"));
}
#[test]
fn test_model_name_no_match() {
let entry = serde_json::json!({"name": "llama3"});
assert!(!model_name_matches(&entry, "qwen3-coder"));
}
#[test]
fn test_model_name_matches_model_field() {
let entry = serde_json::json!({"model": "qwen3-coder:latest"});
assert!(model_name_matches(&entry, "qwen3-coder"));
}
#[test]
fn test_model_error_display() {
assert_eq!(
ModelError::NotFound("foo".into()).to_string(),
"model 'foo' not found"
);
assert_eq!(
ModelError::PullFailed("timeout".into()).to_string(),
"pull failed: timeout"
);
}
}