use ollama_rs::Ollama;
use crate::config::RuntimeMode;
use crate::error::{map_ollama_error, Result, RuntimeError};
#[derive(Clone, Debug)]
pub struct ModelManager {
pub client: Ollama,
pub auto_pull: bool,
pub mode: RuntimeMode,
}
impl ModelManager {
pub fn new(client: Ollama, auto_pull: bool, mode: RuntimeMode) -> Self {
Self {
client,
auto_pull,
mode,
}
}
pub async fn ensure(&self, model: &str) -> Result<()> {
let local = self
.client
.list_local_models()
.await
.map_err(map_ollama_error)?;
if local
.iter()
.any(|m| model_identifier_matches(model, m.name.as_str()))
{
return Ok(());
}
if self.auto_pull && self.mode == RuntimeMode::Development {
self.client
.pull_model(model.to_string(), false)
.await
.map_err(map_ollama_error)?;
return Ok(());
}
Err(RuntimeError::ModelNotFound(model.to_string()))
}
}
fn model_identifier_matches(requested: &str, listed: &str) -> bool {
if requested == listed {
return true;
}
let Some(req_base) = requested.split(':').next() else {
return false;
};
let Some(list_base) = listed.split(':').next() else {
return false;
};
if req_base != list_base {
return false;
}
match requested.split(':').nth(1) {
None => true,
Some(rt) => listed.split(':').nth(1).is_some_and(|lt| rt == lt),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_name_matching() {
assert!(model_identifier_matches("mistral", "mistral:latest"));
assert!(model_identifier_matches("mistral", "mistral:7b"));
assert!(model_identifier_matches("mistral:latest", "mistral:latest"));
assert!(!model_identifier_matches("mistral:latest", "mistral:7b"));
assert!(!model_identifier_matches("a", "b"));
}
}