ollama_kit/model/
manager.rs1use ollama_rs::Ollama;
2
3use crate::config::RuntimeMode;
4use crate::error::{map_ollama_error, Result, RuntimeError};
5
6#[derive(Clone, Debug)]
8pub struct ModelManager {
9 pub client: Ollama,
10 pub auto_pull: bool,
11 pub mode: RuntimeMode,
12}
13
14impl ModelManager {
15 pub fn new(client: Ollama, auto_pull: bool, mode: RuntimeMode) -> Self {
16 Self {
17 client,
18 auto_pull,
19 mode,
20 }
21 }
22
23 pub async fn ensure(&self, model: &str) -> Result<()> {
26 let local = self
27 .client
28 .list_local_models()
29 .await
30 .map_err(map_ollama_error)?;
31
32 if local
33 .iter()
34 .any(|m| model_identifier_matches(model, m.name.as_str()))
35 {
36 return Ok(());
37 }
38
39 if self.auto_pull && self.mode == RuntimeMode::Development {
40 self.client
41 .pull_model(model.to_string(), false)
42 .await
43 .map_err(map_ollama_error)?;
44 return Ok(());
45 }
46
47 Err(RuntimeError::ModelNotFound(model.to_string()))
48 }
49}
50
51fn model_identifier_matches(requested: &str, listed: &str) -> bool {
52 if requested == listed {
53 return true;
54 }
55 let Some(req_base) = requested.split(':').next() else {
56 return false;
57 };
58 let Some(list_base) = listed.split(':').next() else {
59 return false;
60 };
61 if req_base != list_base {
62 return false;
63 }
64 match requested.split(':').nth(1) {
65 None => true,
66 Some(rt) => listed.split(':').nth(1).is_some_and(|lt| rt == lt),
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn model_name_matching() {
76 assert!(model_identifier_matches("mistral", "mistral:latest"));
77 assert!(model_identifier_matches("mistral", "mistral:7b"));
78 assert!(model_identifier_matches("mistral:latest", "mistral:latest"));
79 assert!(!model_identifier_matches("mistral:latest", "mistral:7b"));
80 assert!(!model_identifier_matches("a", "b"));
81 }
82}