1use crate::error::LocalModelError;
4use crate::models::{LocalModel, ModelMetadata, PullProgress};
5use crate::Result;
6use reqwest::Client;
7use serde::Deserialize;
8use std::sync::Arc;
9use tracing::{debug, error, info, warn};
10
11#[derive(Debug, Deserialize)]
13struct OllamaPullResponse {
14 status: String,
15 digest: String,
16 total: Option<u64>,
17 completed: Option<u64>,
18}
19
20#[allow(dead_code)]
23#[derive(Debug, Deserialize)]
24struct OllamaDeleteResponse {
25 status: String,
26}
27
28pub struct LocalModelManager {
30 client: Arc<Client>,
31 base_url: String,
32}
33
34impl LocalModelManager {
35 pub fn new(base_url: String) -> Result<Self> {
37 if base_url.is_empty() {
38 return Err(LocalModelError::ConfigError(
39 "Ollama base URL is required".to_string(),
40 ));
41 }
42
43 Ok(Self {
44 client: Arc::new(Client::new()),
45 base_url,
46 })
47 }
48
49 pub fn base_url(&self) -> &str {
51 &self.base_url
52 }
53
54 pub fn with_default_endpoint() -> Result<Self> {
56 Self::new("http://localhost:11434".to_string())
57 }
58
59 pub async fn pull_model(&self, model_name: &str) -> Result<Vec<PullProgress>> {
62 if model_name.is_empty() {
63 return Err(LocalModelError::InvalidModelName(
64 "Model name cannot be empty".to_string(),
65 ));
66 }
67
68 debug!("Pulling model: {}", model_name);
69
70 let url = format!("{}/api/pull", self.base_url);
71 let request_body = serde_json::json!({
72 "name": model_name,
73 "stream": true
74 });
75
76 let response = self
77 .client
78 .post(&url)
79 .json(&request_body)
80 .send()
81 .await
82 .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
83
84 if !response.status().is_success() {
85 let status = response.status();
86 let error_text = response.text().await.unwrap_or_default();
87 error!("Failed to pull model {}: {}", model_name, error_text);
88 return Err(LocalModelError::PullFailed(format!(
89 "HTTP {}: {}",
90 status, error_text
91 )));
92 }
93
94 let body = response.text().await.map_err(|e| {
95 error!("Failed to read pull response: {}", e);
96 LocalModelError::NetworkError(e.to_string())
97 })?;
98
99 let mut progress_updates = Vec::new();
101 for line in body.lines() {
102 if line.is_empty() {
103 continue;
104 }
105
106 match serde_json::from_str::<OllamaPullResponse>(line) {
107 Ok(resp) => {
108 let progress = PullProgress {
109 model: model_name.to_string(),
110 status: resp.status,
111 digest: resp.digest,
112 total: resp.total.unwrap_or(0),
113 completed: resp.completed.unwrap_or(0),
114 };
115 progress_updates.push(progress);
116 }
117 Err(e) => {
118 warn!("Failed to parse pull response line: {}", e);
119 }
120 }
121 }
122
123 info!("Successfully pulled model: {}", model_name);
124 Ok(progress_updates)
125 }
126
127 pub async fn remove_model(&self, model_name: &str) -> Result<()> {
129 if model_name.is_empty() {
130 return Err(LocalModelError::InvalidModelName(
131 "Model name cannot be empty".to_string(),
132 ));
133 }
134
135 debug!("Removing model: {}", model_name);
136
137 let url = format!("{}/api/delete", self.base_url);
138 let request_body = serde_json::json!({
139 "name": model_name
140 });
141
142 let response = self
143 .client
144 .delete(&url)
145 .json(&request_body)
146 .send()
147 .await
148 .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
149
150 if !response.status().is_success() {
151 let status = response.status();
152 let error_text = response.text().await.unwrap_or_default();
153 error!("Failed to remove model {}: {}", model_name, error_text);
154 return Err(LocalModelError::RemovalFailed(format!(
155 "HTTP {}: {}",
156 status, error_text
157 )));
158 }
159
160 info!("Successfully removed model: {}", model_name);
161 Ok(())
162 }
163
164 pub async fn update_model(&self, model_name: &str) -> Result<Vec<PullProgress>> {
166 if model_name.is_empty() {
167 return Err(LocalModelError::InvalidModelName(
168 "Model name cannot be empty".to_string(),
169 ));
170 }
171
172 debug!("Updating model: {}", model_name);
173
174 let model_with_tag = if model_name.contains(':') {
176 model_name.to_string()
177 } else {
178 format!("{}:latest", model_name)
179 };
180
181 self.pull_model(&model_with_tag).await
182 }
183
184 pub async fn get_model_info(&self, model_name: &str) -> Result<LocalModel> {
186 if model_name.is_empty() {
187 return Err(LocalModelError::InvalidModelName(
188 "Model name cannot be empty".to_string(),
189 ));
190 }
191
192 debug!("Getting model info: {}", model_name);
193
194 let url = format!("{}/api/show", self.base_url);
195 let request_body = serde_json::json!({
196 "name": model_name
197 });
198
199 let response = self
200 .client
201 .post(&url)
202 .json(&request_body)
203 .send()
204 .await
205 .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
206
207 let status = response.status();
208 if !status.is_success() {
209 if status == 404 {
210 return Err(LocalModelError::ModelNotFound(model_name.to_string()));
211 }
212 let error_text = response.text().await.unwrap_or_default();
213 error!(
214 "Failed to get model info for {}: {}",
215 model_name, error_text
216 );
217 return Err(LocalModelError::Unknown(format!(
218 "HTTP {}: {}",
219 status, error_text
220 )));
221 }
222
223 let model_info: OllamaModelInfo = response.json().await.map_err(|e| {
224 error!("Failed to parse model info response: {}", e);
225 LocalModelError::NetworkError(e.to_string())
226 })?;
227
228 Ok(LocalModel {
229 name: model_info.name,
230 size: model_info.details.parameter_size.parse().unwrap_or(0),
231 digest: model_info.digest,
232 modified_at: model_info.modified_at,
233 metadata: ModelMetadata {
234 format: model_info.details.format,
235 family: model_info.details.family,
236 parameter_size: model_info.details.parameter_size,
237 quantization_level: model_info.details.quantization_level,
238 },
239 })
240 }
241
242 pub async fn list_models(&self) -> Result<Vec<LocalModel>> {
244 debug!("Listing all models");
245
246 let url = format!("{}/api/tags", self.base_url);
247
248 let response = self
249 .client
250 .get(&url)
251 .send()
252 .await
253 .map_err(|e| LocalModelError::NetworkError(e.to_string()))?;
254
255 if !response.status().is_success() {
256 let status = response.status();
257 let error_text = response.text().await.unwrap_or_default();
258 error!("Failed to list models: {}", error_text);
259 return Err(LocalModelError::Unknown(format!(
260 "HTTP {}: {}",
261 status, error_text
262 )));
263 }
264
265 let tags_response: OllamaTagsResponse = response.json().await.map_err(|e| {
266 error!("Failed to parse tags response: {}", e);
267 LocalModelError::NetworkError(e.to_string())
268 })?;
269
270 let models: Vec<LocalModel> = tags_response
271 .models
272 .unwrap_or_default()
273 .into_iter()
274 .map(|m| LocalModel {
275 name: m.name,
276 size: m.size,
277 digest: m.digest,
278 modified_at: m.modified_at,
279 metadata: ModelMetadata {
280 format: "gguf".to_string(), family: "unknown".to_string(),
282 parameter_size: "unknown".to_string(),
283 quantization_level: "unknown".to_string(),
284 },
285 })
286 .collect();
287
288 debug!("Listed {} models", models.len());
289 Ok(models)
290 }
291
292 pub async fn model_exists(&self, model_name: &str) -> Result<bool> {
294 match self.get_model_info(model_name).await {
295 Ok(_) => Ok(true),
296 Err(LocalModelError::ModelNotFound(_)) => Ok(false),
297 Err(e) => Err(e),
298 }
299 }
300}
301
302#[derive(Debug, Deserialize)]
304struct OllamaModelInfo {
305 name: String,
306 digest: String,
307 modified_at: chrono::DateTime<chrono::Utc>,
308 #[allow(dead_code)]
309 size: u64,
310 details: OllamaModelDetails,
311}
312
313#[derive(Debug, Deserialize)]
315struct OllamaModelDetails {
316 format: String,
317 family: String,
318 parameter_size: String,
319 quantization_level: String,
320}
321
322#[derive(Debug, Deserialize)]
324struct OllamaTagsResponse {
325 models: Option<Vec<OllamaModelTag>>,
326}
327
328#[derive(Debug, Deserialize)]
330struct OllamaModelTag {
331 name: String,
332 digest: String,
333 modified_at: chrono::DateTime<chrono::Utc>,
334 size: u64,
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_local_model_manager_creation() {
343 let manager = LocalModelManager::new("http://localhost:11434".to_string());
344 assert!(manager.is_ok());
345 }
346
347 #[test]
348 fn test_local_model_manager_empty_url() {
349 let manager = LocalModelManager::new("".to_string());
350 assert!(manager.is_err());
351 }
352
353 #[test]
354 fn test_local_model_manager_default_endpoint() {
355 let manager = LocalModelManager::with_default_endpoint();
356 assert!(manager.is_ok());
357 }
358
359 #[test]
360 fn test_pull_model_empty_name() {
361 let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
362 let result = tokio::runtime::Runtime::new()
363 .unwrap()
364 .block_on(manager.pull_model(""));
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_remove_model_empty_name() {
370 let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
371 let result = tokio::runtime::Runtime::new()
372 .unwrap()
373 .block_on(manager.remove_model(""));
374 assert!(result.is_err());
375 }
376
377 #[test]
378 fn test_update_model_empty_name() {
379 let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
380 let result = tokio::runtime::Runtime::new()
381 .unwrap()
382 .block_on(manager.update_model(""));
383 assert!(result.is_err());
384 }
385
386 #[test]
387 fn test_get_model_info_empty_name() {
388 let manager = LocalModelManager::new("http://localhost:11434".to_string()).unwrap();
389 let result = tokio::runtime::Runtime::new()
390 .unwrap()
391 .block_on(manager.get_model_info(""));
392 assert!(result.is_err());
393 }
394}