ollama-kit 0.1.0

Runtime control (lifecycle + execution guards) for ollama-rs without wrapping its API.
Documentation
use ollama_rs::Ollama;

use crate::config::RuntimeMode;
use crate::error::{map_ollama_error, Result, RuntimeError};

/// Ensures requested models exist locally according to [`RuntimeMode`] and `auto_pull` policy.
#[derive(Clone, Debug)]
pub struct ModelManager {
    pub client: Ollama,
    pub auto_pull: bool,
    pub mode: RuntimeMode,
}

impl ModelManager {
    pub fn new(client: Ollama, auto_pull: bool, mode: RuntimeMode) -> Self {
        Self {
            client,
            auto_pull,
            mode,
        }
    }

    /// Lists local models; pulls in Development when `auto_pull` is enabled; otherwise reports
    /// [`RuntimeError::ModelNotFound`].
    pub async fn ensure(&self, model: &str) -> Result<()> {
        let local = self
            .client
            .list_local_models()
            .await
            .map_err(map_ollama_error)?;

        if local
            .iter()
            .any(|m| model_identifier_matches(model, m.name.as_str()))
        {
            return Ok(());
        }

        if self.auto_pull && self.mode == RuntimeMode::Development {
            self.client
                .pull_model(model.to_string(), false)
                .await
                .map_err(map_ollama_error)?;
            return Ok(());
        }

        Err(RuntimeError::ModelNotFound(model.to_string()))
    }
}

fn model_identifier_matches(requested: &str, listed: &str) -> bool {
    if requested == listed {
        return true;
    }
    let Some(req_base) = requested.split(':').next() else {
        return false;
    };
    let Some(list_base) = listed.split(':').next() else {
        return false;
    };
    if req_base != list_base {
        return false;
    }
    match requested.split(':').nth(1) {
        None => true,
        Some(rt) => listed.split(':').nth(1).is_some_and(|lt| rt == lt),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn model_name_matching() {
        assert!(model_identifier_matches("mistral", "mistral:latest"));
        assert!(model_identifier_matches("mistral", "mistral:7b"));
        assert!(model_identifier_matches("mistral:latest", "mistral:latest"));
        assert!(!model_identifier_matches("mistral:latest", "mistral:7b"));
        assert!(!model_identifier_matches("a", "b"));
    }
}