use ollama_rs::Ollama;
use crate::config::RuntimeMode;
use crate::error::{map_ollama_error, Result, RuntimeError};
pub(crate) async fn ensure_local_model(
auto_pull: bool,
mode: RuntimeMode,
client: &Ollama,
model: &str,
) -> Result<()> {
let local = client
.list_local_models()
.await
.map_err(map_ollama_error)?;
if local
.iter()
.any(|m| model_identifier_matches(model, m.name.as_str()))
{
return Ok(());
}
if auto_pull && mode == RuntimeMode::Development {
client
.pull_model(model.to_string(), false)
.await
.map_err(map_ollama_error)?;
return Ok(());
}
Err(RuntimeError::ModelNotFound(model.to_string()))
}
pub(crate) fn model_identifier_matches(requested: &str, listed: &str) -> bool {
if requested == listed {
return true;
}
let Some(req_base) = requested.split(':').next() else {
return false;
};
let Some(list_base) = listed.split(':').next() else {
return false;
};
if req_base != list_base {
return false;
}
match requested.split(':').nth(1) {
None => true,
Some(rt) => listed.split(':').nth(1).is_some_and(|lt| rt == lt),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_name_matching() {
assert!(model_identifier_matches("mistral", "mistral:latest"));
assert!(model_identifier_matches("mistral", "mistral:7b"));
assert!(model_identifier_matches("mistral:latest", "mistral:latest"));
assert!(!model_identifier_matches("mistral:latest", "mistral:7b"));
assert!(!model_identifier_matches("a", "b"));
}
}