greppy/ai/
ollama.rs

1//! Ollama Local LLM Client
2//!
3//! Provides local LLM inference via Ollama for search reranking and trace enhancement.
4//! Supports models like codellama, deepseek-coder, llama3, etc.
5//!
6//! @module ai/ollama
7
8use crate::core::error::{Error, Result};
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13// =============================================================================
14// CONSTANTS
15// =============================================================================
16
17const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
18const DEFAULT_MODEL: &str = "codellama";
19const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
20
21// =============================================================================
22// REQUEST/RESPONSE TYPES
23// =============================================================================
24
25/// Ollama generate request
26#[derive(Debug, Serialize)]
27struct GenerateRequest {
28    model: String,
29    prompt: String,
30    stream: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    system: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    options: Option<GenerateOptions>,
35}
36
37/// Ollama generation options
38#[derive(Debug, Serialize)]
39struct GenerateOptions {
40    #[serde(skip_serializing_if = "Option::is_none")]
41    temperature: Option<f32>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    num_predict: Option<i32>,
44}
45
46/// Ollama generate response
47#[derive(Debug, Deserialize)]
48struct GenerateResponse {
49    response: String,
50    #[allow(dead_code)]
51    done: bool,
52}
53
54/// Ollama chat request (alternative API)
55#[derive(Debug, Serialize)]
56struct ChatRequest {
57    model: String,
58    messages: Vec<ChatMessage>,
59    stream: bool,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    options: Option<GenerateOptions>,
62}
63
64/// Chat message
65#[derive(Debug, Serialize)]
66struct ChatMessage {
67    role: String,
68    content: String,
69}
70
71/// Ollama chat response
72#[derive(Debug, Deserialize)]
73struct ChatResponse {
74    message: ChatMessageResponse,
75    #[allow(dead_code)]
76    done: bool,
77}
78
79/// Chat message response
80#[derive(Debug, Deserialize)]
81struct ChatMessageResponse {
82    content: String,
83}
84
85/// Ollama model list response
86#[derive(Debug, Deserialize)]
87struct ModelsResponse {
88    models: Vec<ModelInfo>,
89}
90
91/// Model information
92#[derive(Debug, Deserialize)]
93pub struct ModelInfo {
94    pub name: String,
95    #[allow(dead_code)]
96    pub size: Option<u64>,
97}
98
99// =============================================================================
100// OLLAMA CLIENT
101// =============================================================================
102
103/// Ollama client for local LLM inference
104pub struct OllamaClient {
105    client: Client,
106    base_url: String,
107    model: String,
108}
109
110impl OllamaClient {
111    /// Create a new Ollama client with default settings
112    pub fn new() -> Self {
113        Self::with_config(DEFAULT_OLLAMA_URL, DEFAULT_MODEL)
114    }
115
116    /// Create a new Ollama client with custom URL and model
117    pub fn with_config(base_url: &str, model: &str) -> Self {
118        let client = Client::builder()
119            .timeout(REQUEST_TIMEOUT)
120            .build()
121            .unwrap_or_else(|_| Client::new());
122
123        Self {
124            client,
125            base_url: base_url.trim_end_matches('/').to_string(),
126            model: model.to_string(),
127        }
128    }
129
130    /// Check if Ollama is running and accessible
131    pub async fn is_available(&self) -> bool {
132        let url = format!("{}/api/tags", self.base_url);
133        self.client
134            .get(&url)
135            .timeout(Duration::from_secs(5))
136            .send()
137            .await
138            .map(|r| r.status().is_success())
139            .unwrap_or(false)
140    }
141
142    /// List available models
143    pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
144        let url = format!("{}/api/tags", self.base_url);
145
146        let res = self
147            .client
148            .get(&url)
149            .send()
150            .await
151            .map_err(|e| self.connection_error(e))?;
152
153        if !res.status().is_success() {
154            return Err(self.api_error("Failed to list models", res).await);
155        }
156
157        let models: ModelsResponse = res.json().await.map_err(|e| Error::DaemonError {
158            message: format!("Failed to parse models response: {}", e),
159        })?;
160
161        Ok(models.models)
162    }
163
164    /// Check if a specific model is available
165    pub async fn has_model(&self, model: &str) -> bool {
166        self.list_models()
167            .await
168            .map(|models| models.iter().any(|m| m.name.starts_with(model)))
169            .unwrap_or(false)
170    }
171
172    /// Generate completion using the generate API
173    pub async fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
174        let url = format!("{}/api/generate", self.base_url);
175
176        let request = GenerateRequest {
177            model: self.model.clone(),
178            prompt: prompt.to_string(),
179            stream: false,
180            system: system.map(|s| s.to_string()),
181            options: Some(GenerateOptions {
182                temperature: Some(0.1), // Low temperature for deterministic output
183                num_predict: Some(512), // Limit response length
184            }),
185        };
186
187        let res = self
188            .client
189            .post(&url)
190            .json(&request)
191            .send()
192            .await
193            .map_err(|e| self.connection_error(e))?;
194
195        if !res.status().is_success() {
196            return Err(self.api_error("Generation failed", res).await);
197        }
198
199        let response: GenerateResponse = res.json().await.map_err(|e| Error::DaemonError {
200            message: format!("Failed to parse generate response: {}", e),
201        })?;
202
203        Ok(response.response)
204    }
205
206    /// Generate completion using the chat API
207    pub async fn chat(&self, user_message: &str, system: Option<&str>) -> Result<String> {
208        let url = format!("{}/api/chat", self.base_url);
209
210        let mut messages = Vec::new();
211
212        if let Some(sys) = system {
213            messages.push(ChatMessage {
214                role: "system".to_string(),
215                content: sys.to_string(),
216            });
217        }
218
219        messages.push(ChatMessage {
220            role: "user".to_string(),
221            content: user_message.to_string(),
222        });
223
224        let request = ChatRequest {
225            model: self.model.clone(),
226            messages,
227            stream: false,
228            options: Some(GenerateOptions {
229                temperature: Some(0.1),
230                num_predict: Some(512),
231            }),
232        };
233
234        let res = self
235            .client
236            .post(&url)
237            .json(&request)
238            .send()
239            .await
240            .map_err(|e| self.connection_error(e))?;
241
242        if !res.status().is_success() {
243            return Err(self.api_error("Chat failed", res).await);
244        }
245
246        let response: ChatResponse = res.json().await.map_err(|e| Error::DaemonError {
247            message: format!("Failed to parse chat response: {}", e),
248        })?;
249
250        Ok(response.message.content)
251    }
252
253    /// Rerank search results by relevance to query
254    /// Returns indices in order of relevance: [2, 0, 5, 1, ...]
255    ///
256    /// This matches the interface of ClaudeClient and GeminiClient
257    pub async fn rerank(&self, query: &str, chunks: &[String]) -> Result<Vec<usize>> {
258        // First check if Ollama is available
259        if !self.is_available().await {
260            // Graceful fallback: return original order if Ollama not running
261            return Ok((0..chunks.len()).collect());
262        }
263
264        let system_prompt =
265            "You are a code search reranker. Given a query and numbered code chunks, \
266            return ONLY a JSON array of chunk indices ordered by relevance to the query. \
267            Most relevant first. Example response: [2, 0, 5, 1, 3, 4]";
268
269        let mut user_prompt = format!("Query: {}\n\nCode chunks:\n", query);
270        for (i, chunk) in chunks.iter().enumerate() {
271            user_prompt.push_str(&format!("\n--- Chunk {} ---\n{}\n", i, chunk));
272        }
273        user_prompt.push_str("\nReturn ONLY the JSON array of indices, nothing else.");
274
275        // Try chat API first (better for instruction following)
276        let response = match self.chat(&user_prompt, Some(system_prompt)).await {
277            Ok(r) => r,
278            Err(_) => {
279                // Fallback to generate API
280                let full_prompt = format!("{}\n\n{}", system_prompt, user_prompt);
281                self.generate(&full_prompt, None).await?
282            }
283        };
284
285        // Parse the JSON array from response
286        self.parse_rerank_response(&response, chunks.len())
287    }
288
289    // =========================================================================
290    // PRIVATE HELPERS
291    // =========================================================================
292
293    /// Parse rerank response, extracting JSON array
294    fn parse_rerank_response(&self, text: &str, chunk_count: usize) -> Result<Vec<usize>> {
295        let text = text.trim();
296
297        // Try direct parse
298        if let Ok(indices) = serde_json::from_str::<Vec<usize>>(text) {
299            return Ok(self.validate_indices(indices, chunk_count));
300        }
301
302        // Try to find JSON array in the text
303        if let Some(start) = text.find('[') {
304            if let Some(end) = text.rfind(']') {
305                let json_str = &text[start..=end];
306                if let Ok(indices) = serde_json::from_str::<Vec<usize>>(json_str) {
307                    return Ok(self.validate_indices(indices, chunk_count));
308                }
309            }
310        }
311
312        // Fallback: return original order
313        Ok((0..chunk_count).collect())
314    }
315
316    /// Validate and filter indices to ensure they're within bounds
317    fn validate_indices(&self, indices: Vec<usize>, chunk_count: usize) -> Vec<usize> {
318        let mut seen = std::collections::HashSet::new();
319        let mut valid: Vec<usize> = indices
320            .into_iter()
321            .filter(|&i| i < chunk_count && seen.insert(i))
322            .collect();
323
324        // Add any missing indices at the end
325        for i in 0..chunk_count {
326            if !seen.contains(&i) {
327                valid.push(i);
328            }
329        }
330
331        valid
332    }
333
334    /// Create connection error with helpful message
335    fn connection_error(&self, e: reqwest::Error) -> Error {
336        if e.is_connect() {
337            Error::DaemonError {
338                message: format!(
339                    "Cannot connect to Ollama at {}. \
340                    Make sure Ollama is running (ollama serve) or check your config.",
341                    self.base_url
342                ),
343            }
344        } else if e.is_timeout() {
345            Error::DaemonError {
346                message: format!(
347                    "Ollama request timed out. The model '{}' may be loading or too slow.",
348                    self.model
349                ),
350            }
351        } else {
352            Error::DaemonError {
353                message: format!("Ollama request failed: {}", e),
354            }
355        }
356    }
357
358    /// Create API error from response
359    async fn api_error(&self, context: &str, res: reqwest::Response) -> Error {
360        let status = res.status();
361        let text = res.text().await.unwrap_or_default();
362
363        if status.as_u16() == 404 && text.contains("model") {
364            Error::DaemonError {
365                message: format!(
366                    "Model '{}' not found. Run 'ollama pull {}' to download it.",
367                    self.model, self.model
368                ),
369            }
370        } else {
371            Error::DaemonError {
372                message: format!("{}: HTTP {} - {}", context, status, text),
373            }
374        }
375    }
376}
377
378impl Default for OllamaClient {
379    fn default() -> Self {
380        Self::new()
381    }
382}
383
384// =============================================================================
385// TESTS
386// =============================================================================
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_validate_indices() {
394        let client = OllamaClient::new();
395
396        // Valid indices
397        let result = client.validate_indices(vec![2, 0, 1], 3);
398        assert_eq!(result, vec![2, 0, 1]);
399
400        // Out of bounds filtered
401        let result = client.validate_indices(vec![5, 0, 1], 3);
402        assert_eq!(result, vec![0, 1, 2]);
403
404        // Duplicates removed
405        let result = client.validate_indices(vec![0, 0, 1], 3);
406        assert_eq!(result, vec![0, 1, 2]);
407
408        // Missing indices added
409        let result = client.validate_indices(vec![2], 3);
410        assert_eq!(result, vec![2, 0, 1]);
411    }
412
413    #[test]
414    fn test_parse_rerank_response() {
415        let client = OllamaClient::new();
416
417        // Clean JSON
418        let result = client.parse_rerank_response("[2, 0, 1]", 3).unwrap();
419        assert_eq!(result, vec![2, 0, 1]);
420
421        // JSON with surrounding text
422        let result = client
423            .parse_rerank_response("Here's the ranking: [2, 0, 1] based on relevance", 3)
424            .unwrap();
425        assert_eq!(result, vec![2, 0, 1]);
426
427        // Invalid response falls back to original order
428        let result = client
429            .parse_rerank_response("I cannot rank these", 3)
430            .unwrap();
431        assert_eq!(result, vec![0, 1, 2]);
432    }
433}