1use crate::core::error::{Error, Result};
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
18const DEFAULT_MODEL: &str = "codellama";
19const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
20
21#[derive(Debug, Serialize)]
27struct GenerateRequest {
28 model: String,
29 prompt: String,
30 stream: bool,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 system: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 options: Option<GenerateOptions>,
35}
36
37#[derive(Debug, Serialize)]
39struct GenerateOptions {
40 #[serde(skip_serializing_if = "Option::is_none")]
41 temperature: Option<f32>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 num_predict: Option<i32>,
44}
45
46#[derive(Debug, Deserialize)]
48struct GenerateResponse {
49 response: String,
50 #[allow(dead_code)]
51 done: bool,
52}
53
54#[derive(Debug, Serialize)]
56struct ChatRequest {
57 model: String,
58 messages: Vec<ChatMessage>,
59 stream: bool,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 options: Option<GenerateOptions>,
62}
63
64#[derive(Debug, Serialize)]
66struct ChatMessage {
67 role: String,
68 content: String,
69}
70
71#[derive(Debug, Deserialize)]
73struct ChatResponse {
74 message: ChatMessageResponse,
75 #[allow(dead_code)]
76 done: bool,
77}
78
79#[derive(Debug, Deserialize)]
81struct ChatMessageResponse {
82 content: String,
83}
84
85#[derive(Debug, Deserialize)]
87struct ModelsResponse {
88 models: Vec<ModelInfo>,
89}
90
91#[derive(Debug, Deserialize)]
93pub struct ModelInfo {
94 pub name: String,
95 #[allow(dead_code)]
96 pub size: Option<u64>,
97}
98
99pub struct OllamaClient {
105 client: Client,
106 base_url: String,
107 model: String,
108}
109
110impl OllamaClient {
111 pub fn new() -> Self {
113 Self::with_config(DEFAULT_OLLAMA_URL, DEFAULT_MODEL)
114 }
115
116 pub fn with_config(base_url: &str, model: &str) -> Self {
118 let client = Client::builder()
119 .timeout(REQUEST_TIMEOUT)
120 .build()
121 .unwrap_or_else(|_| Client::new());
122
123 Self {
124 client,
125 base_url: base_url.trim_end_matches('/').to_string(),
126 model: model.to_string(),
127 }
128 }
129
130 pub async fn is_available(&self) -> bool {
132 let url = format!("{}/api/tags", self.base_url);
133 self.client
134 .get(&url)
135 .timeout(Duration::from_secs(5))
136 .send()
137 .await
138 .map(|r| r.status().is_success())
139 .unwrap_or(false)
140 }
141
142 pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
144 let url = format!("{}/api/tags", self.base_url);
145
146 let res = self
147 .client
148 .get(&url)
149 .send()
150 .await
151 .map_err(|e| self.connection_error(e))?;
152
153 if !res.status().is_success() {
154 return Err(self.api_error("Failed to list models", res).await);
155 }
156
157 let models: ModelsResponse = res.json().await.map_err(|e| Error::DaemonError {
158 message: format!("Failed to parse models response: {}", e),
159 })?;
160
161 Ok(models.models)
162 }
163
164 pub async fn has_model(&self, model: &str) -> bool {
166 self.list_models()
167 .await
168 .map(|models| models.iter().any(|m| m.name.starts_with(model)))
169 .unwrap_or(false)
170 }
171
172 pub async fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
174 let url = format!("{}/api/generate", self.base_url);
175
176 let request = GenerateRequest {
177 model: self.model.clone(),
178 prompt: prompt.to_string(),
179 stream: false,
180 system: system.map(|s| s.to_string()),
181 options: Some(GenerateOptions {
182 temperature: Some(0.1), num_predict: Some(512), }),
185 };
186
187 let res = self
188 .client
189 .post(&url)
190 .json(&request)
191 .send()
192 .await
193 .map_err(|e| self.connection_error(e))?;
194
195 if !res.status().is_success() {
196 return Err(self.api_error("Generation failed", res).await);
197 }
198
199 let response: GenerateResponse = res.json().await.map_err(|e| Error::DaemonError {
200 message: format!("Failed to parse generate response: {}", e),
201 })?;
202
203 Ok(response.response)
204 }
205
206 pub async fn chat(&self, user_message: &str, system: Option<&str>) -> Result<String> {
208 let url = format!("{}/api/chat", self.base_url);
209
210 let mut messages = Vec::new();
211
212 if let Some(sys) = system {
213 messages.push(ChatMessage {
214 role: "system".to_string(),
215 content: sys.to_string(),
216 });
217 }
218
219 messages.push(ChatMessage {
220 role: "user".to_string(),
221 content: user_message.to_string(),
222 });
223
224 let request = ChatRequest {
225 model: self.model.clone(),
226 messages,
227 stream: false,
228 options: Some(GenerateOptions {
229 temperature: Some(0.1),
230 num_predict: Some(512),
231 }),
232 };
233
234 let res = self
235 .client
236 .post(&url)
237 .json(&request)
238 .send()
239 .await
240 .map_err(|e| self.connection_error(e))?;
241
242 if !res.status().is_success() {
243 return Err(self.api_error("Chat failed", res).await);
244 }
245
246 let response: ChatResponse = res.json().await.map_err(|e| Error::DaemonError {
247 message: format!("Failed to parse chat response: {}", e),
248 })?;
249
250 Ok(response.message.content)
251 }
252
253 pub async fn rerank(&self, query: &str, chunks: &[String]) -> Result<Vec<usize>> {
258 if !self.is_available().await {
260 return Ok((0..chunks.len()).collect());
262 }
263
264 let system_prompt =
265 "You are a code search reranker. Given a query and numbered code chunks, \
266 return ONLY a JSON array of chunk indices ordered by relevance to the query. \
267 Most relevant first. Example response: [2, 0, 5, 1, 3, 4]";
268
269 let mut user_prompt = format!("Query: {}\n\nCode chunks:\n", query);
270 for (i, chunk) in chunks.iter().enumerate() {
271 user_prompt.push_str(&format!("\n--- Chunk {} ---\n{}\n", i, chunk));
272 }
273 user_prompt.push_str("\nReturn ONLY the JSON array of indices, nothing else.");
274
275 let response = match self.chat(&user_prompt, Some(system_prompt)).await {
277 Ok(r) => r,
278 Err(_) => {
279 let full_prompt = format!("{}\n\n{}", system_prompt, user_prompt);
281 self.generate(&full_prompt, None).await?
282 }
283 };
284
285 self.parse_rerank_response(&response, chunks.len())
287 }
288
289 fn parse_rerank_response(&self, text: &str, chunk_count: usize) -> Result<Vec<usize>> {
295 let text = text.trim();
296
297 if let Ok(indices) = serde_json::from_str::<Vec<usize>>(text) {
299 return Ok(self.validate_indices(indices, chunk_count));
300 }
301
302 if let Some(start) = text.find('[') {
304 if let Some(end) = text.rfind(']') {
305 let json_str = &text[start..=end];
306 if let Ok(indices) = serde_json::from_str::<Vec<usize>>(json_str) {
307 return Ok(self.validate_indices(indices, chunk_count));
308 }
309 }
310 }
311
312 Ok((0..chunk_count).collect())
314 }
315
316 fn validate_indices(&self, indices: Vec<usize>, chunk_count: usize) -> Vec<usize> {
318 let mut seen = std::collections::HashSet::new();
319 let mut valid: Vec<usize> = indices
320 .into_iter()
321 .filter(|&i| i < chunk_count && seen.insert(i))
322 .collect();
323
324 for i in 0..chunk_count {
326 if !seen.contains(&i) {
327 valid.push(i);
328 }
329 }
330
331 valid
332 }
333
334 fn connection_error(&self, e: reqwest::Error) -> Error {
336 if e.is_connect() {
337 Error::DaemonError {
338 message: format!(
339 "Cannot connect to Ollama at {}. \
340 Make sure Ollama is running (ollama serve) or check your config.",
341 self.base_url
342 ),
343 }
344 } else if e.is_timeout() {
345 Error::DaemonError {
346 message: format!(
347 "Ollama request timed out. The model '{}' may be loading or too slow.",
348 self.model
349 ),
350 }
351 } else {
352 Error::DaemonError {
353 message: format!("Ollama request failed: {}", e),
354 }
355 }
356 }
357
358 async fn api_error(&self, context: &str, res: reqwest::Response) -> Error {
360 let status = res.status();
361 let text = res.text().await.unwrap_or_default();
362
363 if status.as_u16() == 404 && text.contains("model") {
364 Error::DaemonError {
365 message: format!(
366 "Model '{}' not found. Run 'ollama pull {}' to download it.",
367 self.model, self.model
368 ),
369 }
370 } else {
371 Error::DaemonError {
372 message: format!("{}: HTTP {} - {}", context, status, text),
373 }
374 }
375 }
376}
377
378impl Default for OllamaClient {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_validate_indices() {
394 let client = OllamaClient::new();
395
396 let result = client.validate_indices(vec![2, 0, 1], 3);
398 assert_eq!(result, vec![2, 0, 1]);
399
400 let result = client.validate_indices(vec![5, 0, 1], 3);
402 assert_eq!(result, vec![0, 1, 2]);
403
404 let result = client.validate_indices(vec![0, 0, 1], 3);
406 assert_eq!(result, vec![0, 1, 2]);
407
408 let result = client.validate_indices(vec![2], 3);
410 assert_eq!(result, vec![2, 0, 1]);
411 }
412
413 #[test]
414 fn test_parse_rerank_response() {
415 let client = OllamaClient::new();
416
417 let result = client.parse_rerank_response("[2, 0, 1]", 3).unwrap();
419 assert_eq!(result, vec![2, 0, 1]);
420
421 let result = client
423 .parse_rerank_response("Here's the ranking: [2, 0, 1] based on relevance", 3)
424 .unwrap();
425 assert_eq!(result, vec![2, 0, 1]);
426
427 let result = client
429 .parse_rerank_response("I cannot rank these", 3)
430 .unwrap();
431 assert_eq!(result, vec![0, 1, 2]);
432 }
433}