Skip to main content

ares/llm/
ollama.rs

1//! Ollama LLM client implementation
2//!
3//! This module provides integration with Ollama for local LLM inference.
4//! Supports chat, generation, streaming, and tool calling.
5//!
6//! # Features
7//!
8//! Enable with the `ollama` feature flag.
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use ares::llm::{LLMClient, Provider};
14//!
15//! let provider = Provider::Ollama {
16//!     base_url: "http://localhost:11434".to_string(),
17//!     model: "ministral-3:3b".to_string(),
18//! };
19//! let client = provider.create_client().await?;
20//! let response = client.generate("Hello!").await?;
21//! ```
22
23use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_stream::stream;
27use async_trait::async_trait;
28use futures::{Stream, StreamExt};
29use ollama_rs::{
30    generation::chat::{request::ChatMessageRequest, ChatMessage},
31    generation::tools::{ToolCall as OllamaToolCall, ToolFunctionInfo, ToolInfo, ToolType},
32    models::ModelOptions,
33    Ollama,
34};
35use schemars::Schema;
36
37/// Ollama LLM client implementation.
38///
39/// Connects to a local or remote Ollama server for inference.
40pub struct OllamaClient {
41    client: Ollama,
42    model: String,
43    params: ModelParams,
44}
45
46impl OllamaClient {
47    /// Creates a new OllamaClient with default parameters.
48    pub async fn new(base_url: String, model: String) -> Result<Self> {
49        Self::with_params(base_url, model, ModelParams::default()).await
50    }
51
52    /// Creates a new OllamaClient with model parameters.
53    pub async fn with_params(base_url: String, model: String, params: ModelParams) -> Result<Self> {
54        // ollama-rs' `Ollama::new(host, port)` parses `host` using reqwest's IntoUrl.
55        // If `host` is something like "localhost" (no scheme), it panics with
56        // `RelativeUrlWithoutBase`. To avoid server crashes, normalize user input
57        // so we *always* pass an absolute URL like "http://localhost".
58        //
59        // Accept incoming configs like:
60        // - http://localhost:11434
61        // - https://example.com:11434
62        // - localhost:11434
63        // - localhost
64        // - localhost:11434/api (path ignored)
65        let trimmed = base_url.trim();
66        if trimmed.is_empty() {
67            return Err(AppError::Configuration(
68                "OLLAMA_URL is empty/invalid; expected something like http://localhost:11434"
69                    .to_string(),
70            ));
71        }
72
73        // Strip scheme if present to get host[:port][/path...]
74        let without_scheme = trimmed
75            .strip_prefix("http://")
76            .or_else(|| trimmed.strip_prefix("https://"))
77            .unwrap_or(trimmed);
78
79        // Drop any path/query fragments after the first '/'. E.g. "localhost:11434/api" → "localhost:11434"
80        let host_port = without_scheme
81            .split(&['/', '?', '#'][..])
82            .next()
83            .unwrap_or("localhost:11434");
84
85        // Split host and port
86        let (host, port) = if let Some(colon_idx) = host_port.rfind(':') {
87            let h = &host_port[..colon_idx];
88            let p_str = &host_port[colon_idx + 1..];
89            let p = p_str.parse::<u16>().map_err(|_| {
90                AppError::Configuration(format!(
91                    "Invalid OLLAMA_URL port in '{}'; expected e.g. http://localhost:11434",
92                    base_url
93                ))
94            })?;
95            (h.to_string(), p)
96        } else {
97            (host_port.to_string(), 11434)
98        };
99
100        // ollama-rs Ollama::new expects an absolute URL; pass scheme+host
101        let client = Ollama::new(format!("http://{}", host), port);
102
103        Ok(Self {
104            client,
105            model,
106            params,
107        })
108    }
109
110    /// Build ModelOptions from the stored params
111    fn build_model_options(&self) -> ModelOptions {
112        let mut options = ModelOptions::default();
113        if let Some(temp) = self.params.temperature {
114            options = options.temperature(temp);
115        }
116        if let Some(max_tokens) = self.params.max_tokens {
117            options = options.num_predict(max_tokens as i32);
118        }
119        if let Some(top_p) = self.params.top_p {
120            options = options.top_p(top_p);
121        }
122        // Note: ollama-rs uses repeat_penalty instead of separate frequency/presence penalties
123        // We use presence_penalty as a fallback for repeat_penalty if set
124        if let Some(pres_penalty) = self.params.presence_penalty {
125            options = options.repeat_penalty(pres_penalty);
126        }
127        options
128    }
129
130    /// Convert our ToolDefinition to ollama-rs ToolInfo
131    fn convert_tool_definition(tool: &ToolDefinition) -> ToolInfo {
132        // Convert serde_json::Value to schemars Schema
133        // ollama-rs expects a schemars Schema for parameters
134        let schema: Schema =
135            serde_json::from_value(tool.parameters.clone()).unwrap_or_else(|_| Schema::default());
136
137        ToolInfo {
138            tool_type: ToolType::Function,
139            function: ToolFunctionInfo {
140                name: tool.name.clone(),
141                description: tool.description.clone(),
142                parameters: schema,
143            },
144        }
145    }
146
147    /// Convert ollama-rs ToolCall to our ToolCall type
148    fn convert_tool_call(call: &OllamaToolCall) -> ToolCall {
149        ToolCall {
150            id: uuid::Uuid::new_v4().to_string(),
151            name: call.function.name.clone(),
152            arguments: call.function.arguments.clone(),
153        }
154    }
155
156    /// Convert a ConversationMessage to Ollama's ChatMessage
157    fn convert_conversation_message(&self, msg: &ConversationMessage) -> ChatMessage {
158        match msg.role {
159            MessageRole::System => ChatMessage::system(msg.content.clone()),
160            MessageRole::User => ChatMessage::user(msg.content.clone()),
161            MessageRole::Assistant => {
162                // Assistant messages - content only (tool calls are handled by context)
163                ChatMessage::assistant(msg.content.clone())
164            }
165            MessageRole::Tool => {
166                // For tool result messages, use Ollama's native tool message type
167                ChatMessage::tool(msg.content.clone())
168            }
169        }
170    }
171}
172
173#[async_trait]
174impl LLMClient for OllamaClient {
175    async fn generate(&self, prompt: &str) -> Result<String> {
176        let messages = vec![ChatMessage::user(prompt.to_string())];
177
178        let request = ChatMessageRequest::new(self.model.clone(), messages)
179            .options(self.build_model_options());
180
181        let response = self
182            .client
183            .send_chat_messages(request)
184            .await
185            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
186
187        // response.message is a ChatMessage, not Option<ChatMessage>
188        Ok(response.message.content)
189    }
190
191    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
192        let messages = vec![
193            ChatMessage::system(system.to_string()),
194            ChatMessage::user(prompt.to_string()),
195        ];
196
197        let request = ChatMessageRequest::new(self.model.clone(), messages)
198            .options(self.build_model_options());
199
200        let response = self
201            .client
202            .send_chat_messages(request)
203            .await
204            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
205
206        Ok(response.message.content)
207    }
208
209    async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<String> {
210        let chat_messages: Vec<ChatMessage> = messages
211            .iter()
212            .map(|(role, content)| match role.as_str() {
213                "system" => ChatMessage::system(content.clone()),
214                "user" => ChatMessage::user(content.clone()),
215                "assistant" => ChatMessage::assistant(content.clone()),
216                _ => ChatMessage::user(content.clone()),
217            })
218            .collect();
219
220        let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
221            .options(self.build_model_options());
222
223        let response = self
224            .client
225            .send_chat_messages(request)
226            .await
227            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
228
229        Ok(response.message.content)
230    }
231
232    async fn generate_with_tools(
233        &self,
234        prompt: &str,
235        tools: &[ToolDefinition],
236    ) -> Result<LLMResponse> {
237        // Convert our tool definitions to ollama-rs format
238        let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
239
240        let messages = vec![ChatMessage::user(prompt.to_string())];
241
242        // Create request with tools and model options
243        let request = ChatMessageRequest::new(self.model.clone(), messages)
244            .tools(ollama_tools)
245            .options(self.build_model_options());
246
247        let response = self
248            .client
249            .send_chat_messages(request)
250            .await
251            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
252
253        // Extract content and tool calls from the message
254        let content = response.message.content.clone();
255        let tool_calls: Vec<ToolCall> = response
256            .message
257            .tool_calls
258            .iter()
259            .map(Self::convert_tool_call)
260            .collect();
261
262        // Determine finish reason based on whether tools were called
263        let finish_reason = if tool_calls.is_empty() {
264            "stop"
265        } else {
266            "tool_calls"
267        };
268
269        // Extract token usage from final_data if available
270        let usage = response
271            .final_data
272            .as_ref()
273            .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
274
275        Ok(LLMResponse {
276            content,
277            tool_calls,
278            finish_reason: finish_reason.to_string(),
279            usage,
280        })
281    }
282
283    async fn generate_with_tools_and_history(
284        &self,
285        messages: &[ConversationMessage],
286        tools: &[ToolDefinition],
287    ) -> Result<LLMResponse> {
288        // Convert our tool definitions to ollama-rs format
289        let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
290
291        // Convert ConversationMessage to Ollama ChatMessage
292        let chat_messages: Vec<ChatMessage> = messages
293            .iter()
294            .map(|msg| self.convert_conversation_message(msg))
295            .collect();
296
297        // Create request with tools and model options
298        let mut request = ChatMessageRequest::new(self.model.clone(), chat_messages)
299            .options(self.build_model_options());
300
301        if !ollama_tools.is_empty() {
302            request = request.tools(ollama_tools);
303        }
304
305        let response = self
306            .client
307            .send_chat_messages(request)
308            .await
309            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
310
311        // Extract content and tool calls from the message
312        let content = response.message.content.clone();
313        let tool_calls: Vec<ToolCall> = response
314            .message
315            .tool_calls
316            .iter()
317            .map(Self::convert_tool_call)
318            .collect();
319
320        // Determine finish reason based on whether tools were called
321        let finish_reason = if tool_calls.is_empty() {
322            "stop"
323        } else {
324            "tool_calls"
325        };
326
327        // Extract token usage from final_data if available
328        let usage = response
329            .final_data
330            .as_ref()
331            .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
332
333        Ok(LLMResponse {
334            content,
335            tool_calls,
336            finish_reason: finish_reason.to_string(),
337            usage,
338        })
339    }
340
341    async fn stream(
342        &self,
343        prompt: &str,
344    ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
345        let messages = vec![ChatMessage::user(prompt.to_string())];
346        let request = ChatMessageRequest::new(self.model.clone(), messages)
347            .options(self.build_model_options());
348
349        let mut stream_response = self
350            .client
351            .send_chat_messages_stream(request)
352            .await
353            .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
354
355        // Create an async stream that yields content chunks
356        let output_stream = stream! {
357            while let Some(chunk_result) = stream_response.next().await {
358                match chunk_result {
359                    Ok(chunk) => {
360                        // Each chunk has a message with content
361                        let content = chunk.message.content;
362                        if !content.is_empty() {
363                            yield Ok(content);
364                        }
365                    }
366                    Err(_) => {
367                        yield Err(AppError::LLM("Stream chunk error".to_string()));
368                        break;
369                    }
370                }
371            }
372        };
373
374        Ok(Box::new(Box::pin(output_stream)))
375    }
376
377    async fn stream_with_system(
378        &self,
379        system: &str,
380        prompt: &str,
381    ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
382        let messages = vec![
383            ChatMessage::system(system.to_string()),
384            ChatMessage::user(prompt.to_string()),
385        ];
386        let request = ChatMessageRequest::new(self.model.clone(), messages)
387            .options(self.build_model_options());
388
389        let mut stream_response = self
390            .client
391            .send_chat_messages_stream(request)
392            .await
393            .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
394
395        let output_stream = stream! {
396            while let Some(chunk_result) = stream_response.next().await {
397                match chunk_result {
398                    Ok(chunk) => {
399                        let content = chunk.message.content;
400                        if !content.is_empty() {
401                            yield Ok(content);
402                        }
403                    }
404                    Err(_) => {
405                        yield Err(AppError::LLM("Stream chunk error".to_string()));
406                        break;
407                    }
408                }
409            }
410        };
411
412        Ok(Box::new(Box::pin(output_stream)))
413    }
414
415    async fn stream_with_history(
416        &self,
417        messages: &[(String, String)],
418    ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
419        let chat_messages: Vec<ChatMessage> = messages
420            .iter()
421            .map(|(role, content)| match role.as_str() {
422                "system" => ChatMessage::system(content.clone()),
423                "user" => ChatMessage::user(content.clone()),
424                "assistant" => ChatMessage::assistant(content.clone()),
425                _ => ChatMessage::user(content.clone()),
426            })
427            .collect();
428
429        let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
430            .options(self.build_model_options());
431
432        let mut stream_response = self
433            .client
434            .send_chat_messages_stream(request)
435            .await
436            .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
437
438        let output_stream = stream! {
439            while let Some(chunk_result) = stream_response.next().await {
440                match chunk_result {
441                    Ok(chunk) => {
442                        let content = chunk.message.content;
443                        if !content.is_empty() {
444                            yield Ok(content);
445                        }
446                    }
447                    Err(_) => {
448                        yield Err(AppError::LLM("Stream chunk error".to_string()));
449                        break;
450                    }
451                }
452            }
453        };
454
455        Ok(Box::new(Box::pin(output_stream)))
456    }
457
458    fn model_name(&self) -> &str {
459        &self.model
460    }
461}
462
463/// Extended Ollama client methods for convenience
464impl OllamaClient {
465    /// Check if the Ollama server is available
466    pub async fn health_check(&self) -> Result<bool> {
467        // Try to list models - if this works, the server is up
468        match self.client.list_local_models().await {
469            Ok(_) => Ok(true),
470            Err(_) => Ok(false),
471        }
472    }
473
474    /// List available models on the Ollama server
475    pub async fn list_models(&self) -> Result<Vec<String>> {
476        let models = self
477            .client
478            .list_local_models()
479            .await
480            .map_err(|e| AppError::LLM(format!("Failed to list models: {}", e)))?;
481
482        // list_local_models returns Vec<LocalModel> directly
483        Ok(models.into_iter().map(|m| m.name).collect())
484    }
485
486    /// Pull a model from the Ollama registry
487    pub async fn pull_model(&self, model_name: &str) -> Result<()> {
488        self.client
489            .pull_model(model_name.to_string(), false)
490            .await
491            .map_err(|e| AppError::LLM(format!("Failed to pull model '{}': {}", model_name, e)))?;
492        Ok(())
493    }
494
495    /// Get information about a specific model
496    pub async fn model_info(&self, model_name: &str) -> Result<serde_json::Value> {
497        let info = self
498            .client
499            .show_model_info(model_name.to_string())
500            .await
501            .map_err(|e| {
502                AppError::LLM(format!(
503                    "Failed to get model info for '{}': {}",
504                    model_name, e
505                ))
506            })?;
507
508        // Convert to JSON value
509        Ok(serde_json::json!({
510            "modelfile": info.modelfile,
511            "parameters": info.parameters,
512            "template": info.template,
513        }))
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_url_parsing_full() {
523        let base_url = "http://localhost:11434";
524        let url_parts: Vec<&str> = base_url.split("://").collect();
525        assert_eq!(url_parts.len(), 2);
526        assert_eq!(url_parts[0], "http");
527        assert_eq!(url_parts[1], "localhost:11434");
528
529        let host_port: Vec<&str> = url_parts[1].split(':').collect();
530        assert_eq!(host_port[0], "localhost");
531        assert_eq!(host_port[1], "11434");
532    }
533
534    #[test]
535    fn test_url_parsing_no_port() {
536        let base_url = "http://localhost";
537        let url_parts: Vec<&str> = base_url.split("://").collect();
538        let host_port: Vec<&str> = url_parts[1].split(':').collect();
539
540        let host = host_port[0].to_string();
541        let port = if host_port.len() == 2 {
542            host_port[1].parse().unwrap_or(11434)
543        } else {
544            11434
545        };
546
547        assert_eq!(host, "localhost");
548        assert_eq!(port, 11434);
549    }
550
551    #[test]
552    fn test_url_parsing_custom_port() {
553        let base_url = "http://192.168.1.100:8080";
554        let url_parts: Vec<&str> = base_url.split("://").collect();
555        let host_port: Vec<&str> = url_parts[1].split(':').collect();
556
557        let host = host_port[0].to_string();
558        let port: u16 = host_port[1].parse().unwrap_or(11434);
559
560        assert_eq!(host, "192.168.1.100");
561        assert_eq!(port, 8080);
562    }
563
564    #[test]
565    fn test_tool_definition_conversion() {
566        let tool = ToolDefinition {
567            name: "calculator".to_string(),
568            description: "Performs basic math".to_string(),
569            parameters: serde_json::json!({
570                "type": "object",
571                "properties": {
572                    "operation": {"type": "string"},
573                    "a": {"type": "number"},
574                    "b": {"type": "number"}
575                },
576                "required": ["operation", "a", "b"]
577            }),
578        };
579
580        let ollama_tool = OllamaClient::convert_tool_definition(&tool);
581        assert_eq!(ollama_tool.function.name, "calculator");
582        assert_eq!(ollama_tool.function.description, "Performs basic math");
583    }
584
585    #[test]
586    fn test_tool_call_conversion() {
587        let ollama_call = OllamaToolCall {
588            function: ollama_rs::generation::tools::ToolCallFunction {
589                name: "test_tool".to_string(),
590                arguments: serde_json::json!({"arg1": "value1"}),
591            },
592        };
593
594        let tool_call = OllamaClient::convert_tool_call(&ollama_call);
595        assert_eq!(tool_call.name, "test_tool");
596        assert_eq!(tool_call.arguments["arg1"], "value1");
597        // ID should be a valid UUID
598        assert!(!tool_call.id.is_empty());
599    }
600}