ricecoder_local_models/
manager.rs

1//! Local model manager for handling model lifecycle operations
2
3use crate::error::LocalModelError;
4use crate::models::{LocalModel, ModelMetadata, PullProgress};
5use crate::Result;
6use reqwest::Client;
7use serde::Deserialize;
8use std::sync::Arc;
9use tracing::{debug, error, info, warn};
10
11/// Ollama API response for model pull
12#[derive(Debug, Deserialize)]
13struct OllamaPullResponse {
14    status: String,
15    digest: String,
16    total: Option<u64>,
17    completed: Option<u64>,
18}
19
20/// Ollama API response for model deletion
21/// Reserved for future use when model deletion is implemented
22#[allow(dead_code)]
23#[derive(Debug, Deserialize)]
24struct OllamaDeleteResponse {
25    status: String,
26}
27
28/// Local model manager for Ollama
29pub struct LocalModelManager {
30    client: Arc<Client>,
31    base_url: String,
32}
33
34impl LocalModelManager {
35    /// Create a new local model manager
36    pub fn new(base_url: String) -> Result<Self> {
37        if base_url.is_empty() {
38            return Err(LocalModelError::ConfigError(
39                "Ollama base URL is required".to_string(),
40            ));
41        }
42
43        Ok(Self {
44            client: Arc::new(Client::new()),
45            base_url,
46        })
47    }
48
49    /// Get the base URL
50    pub fn base_url(&self) -> &str {
51        &self.base_url
52    }
53
54    /// Create a new local model manager with default localhost endpoint
55    pub fn with_default_endpoint() -> Result<Self> {
56        Self::new("http://localhost:11434".to_string())
57    }
58
59    /// Pull a model from Ollama registry
60    /// Returns a stream of progress updates
61    pub async fn pull_model(&self, model_name: &str) -> Result<Vec<PullProgress>> {
62        if model_name.is_empty() {
63            return Err(LocalModelError::InvalidModelName(
64                "Model name cannot be empty".to_string(),
65            ));
66        }
67
68        debug!("Pulling model: {}", model_name);
69
70        let url = format!("{}/api/pull", self.base_url);
71        let request_body = serde_json::json!({
72            "name": model_name,
73            "stream": true
74        });
75
76        let response = self
77            .client
78            .post(&url)
79            .json(&request_body)
80            .send()
81            .await
82            .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
83
84        if !response.status().is_success() {
85            let status = response.status();
86            let error_text = response.text().await.unwrap_or_default();
87            error!("Failed to pull model {}: {}", model_name, error_text);
88            return Err(LocalModelError::PullFailed(format!(
89                "HTTP {}: {}",
90                status, error_text
91            )));
92        }
93
94        let body = response.text().await.map_err(|e| {
95            error!("Failed to read pull response: {}", e);
96            LocalModelError::NetworkError(e.to_string())
97        })?;
98
99        // Parse streaming responses
100        let mut progress_updates = Vec::new();
101        for line in body.lines() {
102            if line.is_empty() {
103                continue;
104            }
105
106            match serde_json::from_str::<OllamaPullResponse>(line) {
107                Ok(resp) => {
108                    let progress = PullProgress {
109                        model: model_name.to_string(),
110                        status: resp.status,
111                        digest: resp.digest,
112                        total: resp.total.unwrap_or(0),
113                        completed: resp.completed.unwrap_or(0),
114                    };
115                    progress_updates.push(progress);
116                }
117                Err(e) => {
118                    warn!("Failed to parse pull response line: {}", e);
119                }
120            }
121        }
122
123        info!("Successfully pulled model: {}", model_name);
124        Ok(progress_updates)
125    }
126
127    /// Remove a model from local storage
128    pub async fn remove_model(&self, model_name: &str) -> Result<()> {
129        if model_name.is_empty() {
130            return Err(LocalModelError::InvalidModelName(
131                "Model name cannot be empty".to_string(),
132            ));
133        }
134
135        debug!("Removing model: {}", model_name);
136
137        let url = format!("{}/api/delete", self.base_url);
138        let request_body = serde_json::json!({
139            "name": model_name
140        });
141
142        let response = self
143            .client
144            .delete(&url)
145            .json(&request_body)
146            .send()
147            .await
148            .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
149
150        if !response.status().is_success() {
151            let status = response.status();
152            let error_text = response.text().await.unwrap_or_default();
153            error!("Failed to remove model {}: {}", model_name, error_text);
154            return Err(LocalModelError::RemovalFailed(format!(
155                "HTTP {}: {}",
156                status, error_text
157            )));
158        }
159
160        info!("Successfully removed model: {}", model_name);
161        Ok(())
162    }
163
164    /// Update a model to the latest version
165    pub async fn update_model(&self, model_name: &str) -> Result<Vec<PullProgress>> {
166        if model_name.is_empty() {
167            return Err(LocalModelError::InvalidModelName(
168                "Model name cannot be empty".to_string(),
169            ));
170        }
171
172        debug!("Updating model: {}", model_name);
173
174        // Update is essentially a pull with the latest tag
175        let model_with_tag = if model_name.contains(':') {
176            model_name.to_string()
177        } else {
178            format!("{}:latest", model_name)
179        };
180
181        self.pull_model(&model_with_tag).await
182    }
183
184    /// Get information about a specific model
185    pub async fn get_model_info(&self, model_name: &str) -> Result<LocalModel> {
186        if model_name.is_empty() {
187            return Err(LocalModelError::InvalidModelName(
188                "Model name cannot be empty".to_string(),
189            ));
190        }
191
192        debug!("Getting model info: {}", model_name);
193
194        let url = format!("{}/api/show", self.base_url);
195        let request_body = serde_json::json!({
196            "name": model_name
197        });
198
199        let response = self
200            .client
201            .post(&url)
202            .json(&request_body)
203            .send()
204            .await
205            .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
206
207        let status = response.status();
208        if !status.is_success() {
209            if status == 404 {
210                return Err(LocalModelError::ModelNotFound(model_name.to_string()));
211            }
212            let error_text = response.text().await.unwrap_or_default();
213            error!(
214                "Failed to get model info for {}: {}",
215                model_name, error_text
216            );
217            return Err(LocalModelError::Unknown(format!(
218                "HTTP {}: {}",
219                status, error_text
220            )));
221        }
222
223        let model_info: OllamaModelInfo = response.json().await.map_err(|e| {
224            error!("Failed to parse model info response: {}", e);
225            LocalModelError::NetworkError(e.to_string())
226        })?;
227
228        Ok(LocalModel {
229            name: model_info.name,
230            size: model_info.details.parameter_size.parse().unwrap_or(0),
231            digest: model_info.digest,
232            modified_at: model_info.modified_at,
233            metadata: ModelMetadata {
234                format: model_info.details.format,
235                family: model_info.details.family,
236                parameter_size: model_info.details.parameter_size,
237                quantization_level: model_info.details.quantization_level,
238            },
239        })
240    }
241
242    /// List all available models
243    pub async fn list_models(&self) -> Result<Vec<LocalModel>> {
244        debug!("Listing all models");
245
246        let url = format!("{}/api/tags", self.base_url);
247
248        let response = self
249            .client
250            .get(&url)
251            .send()
252            .await
253            .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
254
255        if !response.status().is_success() {
256            let status = response.status();
257            let error_text = response.text().await.unwrap_or_default();
258            error!("Failed to list models: {}", error_text);
259            return Err(LocalModelError::Unknown(format!(
260                "HTTP {}: {}",
261                status, error_text
262            )));
263        }
264
265        let tags_response: OllamaTagsResponse = response.json().await.map_err(|e| {
266            error!("Failed to parse tags response: {}", e);
267            LocalModelError::NetworkError(e.to_string())
268        })?;
269
270        let models: Vec<LocalModel> = tags_response
271            .models
272            .unwrap_or_default()
273            .into_iter()
274            .map(|m| LocalModel {
275                name: m.name,
276                size: m.size,
277                digest: m.digest,
278                modified_at: m.modified_at,
279                metadata: ModelMetadata {
280                    format: "gguf".to_string(), // Default format
281                    family: "unknown".to_string(),
282                    parameter_size: "unknown".to_string(),
283                    quantization_level: "unknown".to_string(),
284                },
285            })
286            .collect();
287
288        debug!("Listed {} models", models.len());
289        Ok(models)
290    }
291
292    /// Check if a model exists
293    pub async fn model_exists(&self, model_name: &str) -> Result<bool> {
294        match self.get_model_info(model_name).await {
295            Ok(_) => Ok(true),
296            Err(LocalModelError::ModelNotFound(_)) => Ok(false),
297            Err(e) => Err(e),
298        }
299    }
300}
301
302/// Ollama API response for model info
303#[derive(Debug, Deserialize)]
304struct OllamaModelInfo {
305    name: String,
306    digest: String,
307    modified_at: chrono::DateTime<chrono::Utc>,
308    #[allow(dead_code)]
309    size: u64,
310    details: OllamaModelDetails,
311}
312
313/// Ollama model details
314#[derive(Debug, Deserialize)]
315struct OllamaModelDetails {
316    format: String,
317    family: String,
318    parameter_size: String,
319    quantization_level: String,
320}
321
322/// Ollama API response for tags
323#[derive(Debug, Deserialize)]
324struct OllamaTagsResponse {
325    models: Option<Vec<OllamaModelTag>>,
326}
327
328/// Ollama model tag
329#[derive(Debug, Deserialize)]
330struct OllamaModelTag {
331    name: String,
332    digest: String,
333    modified_at: chrono::DateTime<chrono::Utc>,
334    size: u64,
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_local_model_manager_creation() {
343        let manager = LocalModelManager::new("http://localhost:11434".to_string());
344        assert!(manager.is_ok());
345    }
346
347    #[test]
348    fn test_local_model_manager_empty_url() {
349        let manager = LocalModelManager::new("".to_string());
350        assert!(manager.is_err());
351    }
352
353    #[test]
354    fn test_local_model_manager_default_endpoint() {
355        let manager = LocalModelManager::with_default_endpoint();
356        assert!(manager.is_ok());
357    }
358
359    #[test]
360    fn test_pull_model_empty_name() {
361        let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
362        let result = tokio::runtime::Runtime::new()
363            .unwrap()
364            .block_on(manager.pull_model(""));
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_remove_model_empty_name() {
370        let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
371        let result = tokio::runtime::Runtime::new()
372            .unwrap()
373            .block_on(manager.remove_model(""));
374        assert!(result.is_err());
375    }
376
377    #[test]
378    fn test_update_model_empty_name() {
379        let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
380        let result = tokio::runtime::Runtime::new()
381            .unwrap()
382            .block_on(manager.update_model(""));
383        assert!(result.is_err());
384    }
385
386    #[test]
387    fn test_get_model_info_empty_name() {
388        let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
389        let result = tokio::runtime::Runtime::new()
390            .unwrap()
391            .block_on(manager.get_model_info(""));
392        assert!(result.is_err());
393    }
394}