Skip to main content

ares/llm/
ollama.rs

1//! Ollama LLM client implementation
2//!
3//! This module provides integration with Ollama for local LLM inference.
4//! Supports chat, generation, streaming, and tool calling.
5//!
6//! # Features
7//!
8//! Enable with the `ollama` feature flag.
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use ares::llm::{LLMClient, Provider};
14//!
15//! let provider = Provider::Ollama {
16//!     base_url: "http://localhost:11434".to_string(),
17//!     model: "ministral-3:3b".to_string(),
18//! };
19//! let client = provider.create_client().await?;
20//! let response = client.generate("Hello!").await?;
21//! ```
22
23use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_stream::stream;
27use async_trait::async_trait;
28use futures::{Stream, StreamExt};
29use ollama_rs::{
30    generation::chat::{request::ChatMessageRequest, ChatMessage},
31    generation::tools::{ToolCall as OllamaToolCall, ToolFunctionInfo, ToolInfo, ToolType},
32    models::ModelOptions,
33    Ollama,
34};
35use schemars::Schema;
36
37/// Ollama LLM client implementation.
38///
39/// Connects to a local or remote Ollama server for inference.
40pub struct OllamaClient {
41    client: Ollama,
42    model: String,
43    params: ModelParams,
44}
45
46impl OllamaClient {
47    /// Creates a new OllamaClient with default parameters.
48    pub async fn new(base_url: String, model: String) -> Result<Self> {
49        Self::with_params(base_url, model, ModelParams::default()).await
50    }
51
52    /// Creates a new OllamaClient with model parameters.
53    pub async fn with_params(base_url: String, model: String, params: ModelParams) -> Result<Self> {
54        // ollama-rs' `Ollama::new(host, port)` parses `host` using reqwest's IntoUrl.
55        // If `host` is something like "localhost" (no scheme), it panics with
56        // `RelativeUrlWithoutBase`. To avoid server crashes, normalize user input
57        // so we *always* pass an absolute URL like "http://localhost".
58        //
59        // Accept incoming configs like:
60        // - http://localhost:11434
61        // - https://example.com:11434
62        // - localhost:11434
63        // - localhost
64        // - localhost:11434/api (path ignored)
65        let trimmed = base_url.trim();
66        if trimmed.is_empty() {
67            return Err(AppError::Configuration(
68                "OLLAMA_URL is empty/invalid; expected something like http://localhost:11434"
69                    .to_string(),
70            ));
71        }
72
73        // Strip scheme if present to get host[:port][/path...]
74        let without_scheme = trimmed
75            .strip_prefix("http://")
76            .or_else(|| trimmed.strip_prefix("https://"))
77            .unwrap_or(trimmed);
78
79        // Drop any path/query fragments after the first '/'. E.g. "localhost:11434/api" → "localhost:11434"
80        let host_port = without_scheme
81            .split(&['/', '?', '#'][..])
82            .next()
83            .unwrap_or("localhost:11434");
84
85        // Split host and port
86        let (host, port) = if let Some(colon_idx) = host_port.rfind(':') {
87            let h = &host_port[..colon_idx];
88            let p_str = &host_port[colon_idx + 1..];
89            let p = p_str.parse::<u16>().map_err(|_| {
90                AppError::Configuration(format!(
91                    "Invalid OLLAMA_URL port in '{}'; expected e.g. http://localhost:11434",
92                    base_url
93                ))
94            })?;
95            (h.to_string(), p)
96        } else {
97            (host_port.to_string(), 11434)
98        };
99
100        // ollama-rs Ollama::new expects an absolute URL; pass scheme+host
101        let client = Ollama::new(format!("http://{}", host), port);
102
103        Ok(Self {
104            client,
105            model,
106            params,
107        })
108    }
109
110    /// Build ModelOptions from the stored params
111    fn build_model_options(&self) -> ModelOptions {
112        let mut options = ModelOptions::default();
113        if let Some(temp) = self.params.temperature {
114            options = options.temperature(temp);
115        }
116        if let Some(max_tokens) = self.params.max_tokens {
117            options = options.num_predict(max_tokens as i32);
118        }
119        if let Some(top_p) = self.params.top_p {
120            options = options.top_p(top_p);
121        }
122        // Note: ollama-rs uses repeat_penalty instead of separate frequency/presence penalties
123        // We use presence_penalty as a fallback for repeat_penalty if set
124        if let Some(pres_penalty) = self.params.presence_penalty {
125            options = options.repeat_penalty(pres_penalty);
126        }
127        options
128    }
129
130    /// Convert our ToolDefinition to ollama-rs ToolInfo
131    fn convert_tool_definition(tool: &ToolDefinition) -> ToolInfo {
132        // Convert serde_json::Value to schemars Schema
133        // ollama-rs expects a schemars Schema for parameters
134        let schema: Schema =
135            serde_json::from_value(tool.parameters.clone()).unwrap_or_else(|_| Schema::default());
136
137        ToolInfo {
138            tool_type: ToolType::Function,
139            function: ToolFunctionInfo {
140                name: tool.name.clone(),
141                description: tool.description.clone(),
142                parameters: schema,
143            },
144        }
145    }
146
147    /// Convert ollama-rs ToolCall to our ToolCall type
148    fn convert_tool_call(call: &OllamaToolCall) -> ToolCall {
149        ToolCall {
150            id: uuid::Uuid::new_v4().to_string(),
151            name: call.function.name.clone(),
152            arguments: call.function.arguments.clone(),
153        }
154    }
155
156    /// Convert a ConversationMessage to Ollama's ChatMessage
157    fn convert_conversation_message(&self, msg: &ConversationMessage) -> ChatMessage {
158        match msg.role {
159            MessageRole::System => ChatMessage::system(msg.content.clone()),
160            MessageRole::User => ChatMessage::user(msg.content.clone()),
161            MessageRole::Assistant => {
162                // Assistant messages - content only (tool calls are handled by context)
163                ChatMessage::assistant(msg.content.clone())
164            }
165            MessageRole::Tool => {
166                // For tool result messages, use Ollama's native tool message type
167                ChatMessage::tool(msg.content.clone())
168            }
169        }
170    }
171}
172
173#[async_trait]
174impl LLMClient for OllamaClient {
175    async fn generate(&self, prompt: &str) -> Result<String> {
176        let messages = vec![ChatMessage::user(prompt.to_string())];
177
178        let request = ChatMessageRequest::new(self.model.clone(), messages)
179            .options(self.build_model_options());
180
181        let response = self
182            .client
183            .send_chat_messages(request)
184            .await
185            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
186
187        // response.message is a ChatMessage, not Option<ChatMessage>
188        Ok(response.message.content)
189    }
190
191    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
192        let messages = vec![
193            ChatMessage::system(system.to_string()),
194            ChatMessage::user(prompt.to_string()),
195        ];
196
197        let request = ChatMessageRequest::new(self.model.clone(), messages)
198            .options(self.build_model_options());
199
200        let response = self
201            .client
202            .send_chat_messages(request)
203            .await
204            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
205
206        Ok(response.message.content)
207    }
208
209    async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<LLMResponse> {
210        let chat_messages: Vec<ChatMessage> = messages
211            .iter()
212            .map(|(role, content)| match role.as_str() {
213                "system" => ChatMessage::system(content.clone()),
214                "user" => ChatMessage::user(content.clone()),
215                "assistant" => ChatMessage::assistant(content.clone()),
216                _ => ChatMessage::user(content.clone()),
217            })
218            .collect();
219
220        let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
221            .options(self.build_model_options());
222
223        let response = self
224            .client
225            .send_chat_messages(request)
226            .await
227            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
228
229        // Extract token usage from final_data if available
230        let usage = response
231            .final_data
232            .as_ref()
233            .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
234
235        Ok(LLMResponse {
236            content: response.message.content,
237            tool_calls: vec![],
238            finish_reason: "stop".to_string(),
239            usage,
240        })
241    }
242
243    async fn generate_with_tools(
244        &self,
245        prompt: &str,
246        tools: &[ToolDefinition],
247    ) -> Result<LLMResponse> {
248        // Convert our tool definitions to ollama-rs format
249        let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
250
251        let messages = vec![ChatMessage::user(prompt.to_string())];
252
253        // Create request with tools and model options
254        let request = ChatMessageRequest::new(self.model.clone(), messages)
255            .tools(ollama_tools)
256            .options(self.build_model_options());
257
258        let response = self
259            .client
260            .send_chat_messages(request)
261            .await
262            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
263
264        // Extract content and tool calls from the message
265        let content = response.message.content.clone();
266        let tool_calls: Vec<ToolCall> = response
267            .message
268            .tool_calls
269            .iter()
270            .map(Self::convert_tool_call)
271            .collect();
272
273        // Determine finish reason based on whether tools were called
274        let finish_reason = if tool_calls.is_empty() {
275            "stop"
276        } else {
277            "tool_calls"
278        };
279
280        // Extract token usage from final_data if available
281        let usage = response
282            .final_data
283            .as_ref()
284            .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
285
286        Ok(LLMResponse {
287            content,
288            tool_calls,
289            finish_reason: finish_reason.to_string(),
290            usage,
291        })
292    }
293
294    async fn generate_with_tools_and_history(
295        &self,
296        messages: &[ConversationMessage],
297        tools: &[ToolDefinition],
298    ) -> Result<LLMResponse> {
299        // Convert our tool definitions to ollama-rs format
300        let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
301
302        // Convert ConversationMessage to Ollama ChatMessage
303        let chat_messages: Vec<ChatMessage> = messages
304            .iter()
305            .map(|msg| self.convert_conversation_message(msg))
306            .collect();
307
308        // Create request with tools and model options
309        let mut request = ChatMessageRequest::new(self.model.clone(), chat_messages)
310            .options(self.build_model_options());
311
312        if !ollama_tools.is_empty() {
313            request = request.tools(ollama_tools);
314        }
315
316        let response = self
317            .client
318            .send_chat_messages(request)
319            .await
320            .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
321
322        // Extract content and tool calls from the message
323        let content = response.message.content.clone();
324        let tool_calls: Vec<ToolCall> = response
325            .message
326            .tool_calls
327            .iter()
328            .map(Self::convert_tool_call)
329            .collect();
330
331        // Determine finish reason based on whether tools were called
332        let finish_reason = if tool_calls.is_empty() {
333            "stop"
334        } else {
335            "tool_calls"
336        };
337
338        // Extract token usage from final_data if available
339        let usage = response
340            .final_data
341            .as_ref()
342            .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
343
344        Ok(LLMResponse {
345            content,
346            tool_calls,
347            finish_reason: finish_reason.to_string(),
348            usage,
349        })
350    }
351
352    async fn stream(
353        &self,
354        prompt: &str,
355    ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
356        let messages = vec![ChatMessage::user(prompt.to_string())];
357        let request = ChatMessageRequest::new(self.model.clone(), messages)
358            .options(self.build_model_options());
359
360        let mut stream_response = self
361            .client
362            .send_chat_messages_stream(request)
363            .await
364            .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
365
366        // Create an async stream that yields content chunks
367        let output_stream = stream! {
368            while let Some(chunk_result) = stream_response.next().await {
369                match chunk_result {
370                    Ok(chunk) => {
371                        // Each chunk has a message with content
372                        let content = chunk.message.content;
373                        if !content.is_empty() {
374                            yield Ok(content);
375                        }
376                    }
377                    Err(_) => {
378                        yield Err(AppError::LLM("Stream chunk error".to_string()));
379                        break;
380                    }
381                }
382            }
383        };
384
385        Ok(Box::new(Box::pin(output_stream)))
386    }
387
388    async fn stream_with_system(
389        &self,
390        system: &str,
391        prompt: &str,
392    ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
393        let messages = vec![
394            ChatMessage::system(system.to_string()),
395            ChatMessage::user(prompt.to_string()),
396        ];
397        let request = ChatMessageRequest::new(self.model.clone(), messages)
398            .options(self.build_model_options());
399
400        let mut stream_response = self
401            .client
402            .send_chat_messages_stream(request)
403            .await
404            .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
405
406        let output_stream = stream! {
407            while let Some(chunk_result) = stream_response.next().await {
408                match chunk_result {
409                    Ok(chunk) => {
410                        let content = chunk.message.content;
411                        if !content.is_empty() {
412                            yield Ok(content);
413                        }
414                    }
415                    Err(_) => {
416                        yield Err(AppError::LLM("Stream chunk error".to_string()));
417                        break;
418                    }
419                }
420            }
421        };
422
423        Ok(Box::new(Box::pin(output_stream)))
424    }
425
426    async fn stream_with_history(
427        &self,
428        messages: &[(String, String)],
429    ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
430        let chat_messages: Vec<ChatMessage> = messages
431            .iter()
432            .map(|(role, content)| match role.as_str() {
433                "system" => ChatMessage::system(content.clone()),
434                "user" => ChatMessage::user(content.clone()),
435                "assistant" => ChatMessage::assistant(content.clone()),
436                _ => ChatMessage::user(content.clone()),
437            })
438            .collect();
439
440        let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
441            .options(self.build_model_options());
442
443        let mut stream_response = self
444            .client
445            .send_chat_messages_stream(request)
446            .await
447            .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
448
449        let output_stream = stream! {
450            while let Some(chunk_result) = stream_response.next().await {
451                match chunk_result {
452                    Ok(chunk) => {
453                        let content = chunk.message.content;
454                        if !content.is_empty() {
455                            yield Ok(content);
456                        }
457                    }
458                    Err(_) => {
459                        yield Err(AppError::LLM("Stream chunk error".to_string()));
460                        break;
461                    }
462                }
463            }
464        };
465
466        Ok(Box::new(Box::pin(output_stream)))
467    }
468
469    fn model_name(&self) -> &str {
470        &self.model
471    }
472}
473
474/// Extended Ollama client methods for convenience
475impl OllamaClient {
476    /// Check if the Ollama server is available
477    pub async fn health_check(&self) -> Result<bool> {
478        // Try to list models - if this works, the server is up
479        match self.client.list_local_models().await {
480            Ok(_) => Ok(true),
481            Err(_) => Ok(false),
482        }
483    }
484
485    /// List available models on the Ollama server
486    pub async fn list_models(&self) -> Result<Vec<String>> {
487        let models = self
488            .client
489            .list_local_models()
490            .await
491            .map_err(|e| AppError::LLM(format!("Failed to list models: {}", e)))?;
492
493        // list_local_models returns Vec<LocalModel> directly
494        Ok(models.into_iter().map(|m| m.name).collect())
495    }
496
497    /// Pull a model from the Ollama registry
498    pub async fn pull_model(&self, model_name: &str) -> Result<()> {
499        self.client
500            .pull_model(model_name.to_string(), false)
501            .await
502            .map_err(|e| AppError::LLM(format!("Failed to pull model '{}': {}", model_name, e)))?;
503        Ok(())
504    }
505
506    /// Get information about a specific model
507    pub async fn model_info(&self, model_name: &str) -> Result<serde_json::Value> {
508        let info = self
509            .client
510            .show_model_info(model_name.to_string())
511            .await
512            .map_err(|e| {
513                AppError::LLM(format!(
514                    "Failed to get model info for '{}': {}",
515                    model_name, e
516                ))
517            })?;
518
519        // Convert to JSON value
520        Ok(serde_json::json!({
521            "modelfile": info.modelfile,
522            "parameters": info.parameters,
523            "template": info.template,
524        }))
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_url_parsing_full() {
534        let base_url = "http://localhost:11434";
535        let url_parts: Vec<&str> = base_url.split("://").collect();
536        assert_eq!(url_parts.len(), 2);
537        assert_eq!(url_parts[0], "http");
538        assert_eq!(url_parts[1], "localhost:11434");
539
540        let host_port: Vec<&str> = url_parts[1].split(':').collect();
541        assert_eq!(host_port[0], "localhost");
542        assert_eq!(host_port[1], "11434");
543    }
544
545    #[test]
546    fn test_url_parsing_no_port() {
547        let base_url = "http://localhost";
548        let url_parts: Vec<&str> = base_url.split("://").collect();
549        let host_port: Vec<&str> = url_parts[1].split(':').collect();
550
551        let host = host_port[0].to_string();
552        let port = if host_port.len() == 2 {
553            host_port[1].parse().unwrap_or(11434)
554        } else {
555            11434
556        };
557
558        assert_eq!(host, "localhost");
559        assert_eq!(port, 11434);
560    }
561
562    #[test]
563    fn test_url_parsing_custom_port() {
564        let base_url = "http://192.168.1.100:8080";
565        let url_parts: Vec<&str> = base_url.split("://").collect();
566        let host_port: Vec<&str> = url_parts[1].split(':').collect();
567
568        let host = host_port[0].to_string();
569        let port: u16 = host_port[1].parse().unwrap_or(11434);
570
571        assert_eq!(host, "192.168.1.100");
572        assert_eq!(port, 8080);
573    }
574
575    #[test]
576    fn test_tool_definition_conversion() {
577        let tool = ToolDefinition {
578            name: "calculator".to_string(),
579            description: "Performs basic math".to_string(),
580            parameters: serde_json::json!({
581                "type": "object",
582                "properties": {
583                    "operation": {"type": "string"},
584                    "a": {"type": "number"},
585                    "b": {"type": "number"}
586                },
587                "required": ["operation", "a", "b"]
588            }),
589        };
590
591        let ollama_tool = OllamaClient::convert_tool_definition(&tool);
592        assert_eq!(ollama_tool.function.name, "calculator");
593        assert_eq!(ollama_tool.function.description, "Performs basic math");
594    }
595
596    #[test]
597    fn test_tool_call_conversion() {
598        let ollama_call = OllamaToolCall {
599            function: ollama_rs::generation::tools::ToolCallFunction {
600                name: "test_tool".to_string(),
601                arguments: serde_json::json!({"arg1": "value1"}),
602            },
603        };
604
605        let tool_call = OllamaClient::convert_tool_call(&ollama_call);
606        assert_eq!(tool_call.name, "test_tool");
607        assert_eq!(tool_call.arguments["arg1"], "value1");
608        // ID should be a valid UUID
609        assert!(!tool_call.id.is_empty());
610    }
611}