Skip to main content

ollama_kit/model/
manager.rs

1use ollama_rs::Ollama;
2
3use crate::config::RuntimeMode;
4use crate::error::{map_ollama_error, Result, RuntimeError};
5
6/// Ensures requested models exist locally according to [`RuntimeMode`] and `auto_pull` policy.
7#[derive(Clone, Debug)]
8pub struct ModelManager {
9    pub client: Ollama,
10    pub auto_pull: bool,
11    pub mode: RuntimeMode,
12}
13
14impl ModelManager {
15    pub fn new(client: Ollama, auto_pull: bool, mode: RuntimeMode) -> Self {
16        Self {
17            client,
18            auto_pull,
19            mode,
20        }
21    }
22
23    /// Lists local models; pulls in Development when `auto_pull` is enabled; otherwise reports
24    /// [`RuntimeError::ModelNotFound`].
25    pub async fn ensure(&self, model: &str) -> Result<()> {
26        let local = self
27            .client
28            .list_local_models()
29            .await
30            .map_err(map_ollama_error)?;
31
32        if local
33            .iter()
34            .any(|m| model_identifier_matches(model, m.name.as_str()))
35        {
36            return Ok(());
37        }
38
39        if self.auto_pull && self.mode == RuntimeMode::Development {
40            self.client
41                .pull_model(model.to_string(), false)
42                .await
43                .map_err(map_ollama_error)?;
44            return Ok(());
45        }
46
47        Err(RuntimeError::ModelNotFound(model.to_string()))
48    }
49}
50
51fn model_identifier_matches(requested: &str, listed: &str) -> bool {
52    if requested == listed {
53        return true;
54    }
55    let Some(req_base) = requested.split(':').next() else {
56        return false;
57    };
58    let Some(list_base) = listed.split(':').next() else {
59        return false;
60    };
61    if req_base != list_base {
62        return false;
63    }
64    match requested.split(':').nth(1) {
65        None => true,
66        Some(rt) => listed.split(':').nth(1).is_some_and(|lt| rt == lt),
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn model_name_matching() {
76        assert!(model_identifier_matches("mistral", "mistral:latest"));
77        assert!(model_identifier_matches("mistral", "mistral:7b"));
78        assert!(model_identifier_matches("mistral:latest", "mistral:latest"));
79        assert!(!model_identifier_matches("mistral:latest", "mistral:7b"));
80        assert!(!model_identifier_matches("a", "b"));
81    }
82}