1use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_stream::stream;
27use async_trait::async_trait;
28use futures::{Stream, StreamExt};
29use ollama_rs::{
30 generation::chat::{request::ChatMessageRequest, ChatMessage},
31 generation::tools::{ToolCall as OllamaToolCall, ToolFunctionInfo, ToolInfo, ToolType},
32 models::ModelOptions,
33 Ollama,
34};
35use schemars::Schema;
36
37pub struct OllamaClient {
41 client: Ollama,
42 model: String,
43 params: ModelParams,
44}
45
46impl OllamaClient {
47 pub async fn new(base_url: String, model: String) -> Result<Self> {
49 Self::with_params(base_url, model, ModelParams::default()).await
50 }
51
52 pub async fn with_params(base_url: String, model: String, params: ModelParams) -> Result<Self> {
54 let trimmed = base_url.trim();
66 if trimmed.is_empty() {
67 return Err(AppError::Configuration(
68 "OLLAMA_URL is empty/invalid; expected something like http://localhost:11434"
69 .to_string(),
70 ));
71 }
72
73 let without_scheme = trimmed
75 .strip_prefix("http://")
76 .or_else(|| trimmed.strip_prefix("https://"))
77 .unwrap_or(trimmed);
78
79 let host_port = without_scheme
81 .split(&['/', '?', '#'][..])
82 .next()
83 .unwrap_or("localhost:11434");
84
85 let (host, port) = if let Some(colon_idx) = host_port.rfind(':') {
87 let h = &host_port[..colon_idx];
88 let p_str = &host_port[colon_idx + 1..];
89 let p = p_str.parse::<u16>().map_err(|_| {
90 AppError::Configuration(format!(
91 "Invalid OLLAMA_URL port in '{}'; expected e.g. http://localhost:11434",
92 base_url
93 ))
94 })?;
95 (h.to_string(), p)
96 } else {
97 (host_port.to_string(), 11434)
98 };
99
100 let client = Ollama::new(format!("http://{}", host), port);
102
103 Ok(Self {
104 client,
105 model,
106 params,
107 })
108 }
109
110 fn build_model_options(&self) -> ModelOptions {
112 let mut options = ModelOptions::default();
113 if let Some(temp) = self.params.temperature {
114 options = options.temperature(temp);
115 }
116 if let Some(max_tokens) = self.params.max_tokens {
117 options = options.num_predict(max_tokens as i32);
118 }
119 if let Some(top_p) = self.params.top_p {
120 options = options.top_p(top_p);
121 }
122 if let Some(pres_penalty) = self.params.presence_penalty {
125 options = options.repeat_penalty(pres_penalty);
126 }
127 options
128 }
129
130 fn convert_tool_definition(tool: &ToolDefinition) -> ToolInfo {
132 let schema: Schema =
135 serde_json::from_value(tool.parameters.clone()).unwrap_or_else(|_| Schema::default());
136
137 ToolInfo {
138 tool_type: ToolType::Function,
139 function: ToolFunctionInfo {
140 name: tool.name.clone(),
141 description: tool.description.clone(),
142 parameters: schema,
143 },
144 }
145 }
146
147 fn convert_tool_call(call: &OllamaToolCall) -> ToolCall {
149 ToolCall {
150 id: uuid::Uuid::new_v4().to_string(),
151 name: call.function.name.clone(),
152 arguments: call.function.arguments.clone(),
153 }
154 }
155
156 fn convert_conversation_message(&self, msg: &ConversationMessage) -> ChatMessage {
158 match msg.role {
159 MessageRole::System => ChatMessage::system(msg.content.clone()),
160 MessageRole::User => ChatMessage::user(msg.content.clone()),
161 MessageRole::Assistant => {
162 ChatMessage::assistant(msg.content.clone())
164 }
165 MessageRole::Tool => {
166 ChatMessage::tool(msg.content.clone())
168 }
169 }
170 }
171}
172
173#[async_trait]
174impl LLMClient for OllamaClient {
175 async fn generate(&self, prompt: &str) -> Result<String> {
176 let messages = vec![ChatMessage::user(prompt.to_string())];
177
178 let request = ChatMessageRequest::new(self.model.clone(), messages)
179 .options(self.build_model_options());
180
181 let response = self
182 .client
183 .send_chat_messages(request)
184 .await
185 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
186
187 Ok(response.message.content)
189 }
190
191 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
192 let messages = vec![
193 ChatMessage::system(system.to_string()),
194 ChatMessage::user(prompt.to_string()),
195 ];
196
197 let request = ChatMessageRequest::new(self.model.clone(), messages)
198 .options(self.build_model_options());
199
200 let response = self
201 .client
202 .send_chat_messages(request)
203 .await
204 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
205
206 Ok(response.message.content)
207 }
208
209 async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<String> {
210 let chat_messages: Vec<ChatMessage> = messages
211 .iter()
212 .map(|(role, content)| match role.as_str() {
213 "system" => ChatMessage::system(content.clone()),
214 "user" => ChatMessage::user(content.clone()),
215 "assistant" => ChatMessage::assistant(content.clone()),
216 _ => ChatMessage::user(content.clone()),
217 })
218 .collect();
219
220 let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
221 .options(self.build_model_options());
222
223 let response = self
224 .client
225 .send_chat_messages(request)
226 .await
227 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
228
229 Ok(response.message.content)
230 }
231
232 async fn generate_with_tools(
233 &self,
234 prompt: &str,
235 tools: &[ToolDefinition],
236 ) -> Result<LLMResponse> {
237 let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
239
240 let messages = vec![ChatMessage::user(prompt.to_string())];
241
242 let request = ChatMessageRequest::new(self.model.clone(), messages)
244 .tools(ollama_tools)
245 .options(self.build_model_options());
246
247 let response = self
248 .client
249 .send_chat_messages(request)
250 .await
251 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
252
253 let content = response.message.content.clone();
255 let tool_calls: Vec<ToolCall> = response
256 .message
257 .tool_calls
258 .iter()
259 .map(Self::convert_tool_call)
260 .collect();
261
262 let finish_reason = if tool_calls.is_empty() {
264 "stop"
265 } else {
266 "tool_calls"
267 };
268
269 let usage = response
271 .final_data
272 .as_ref()
273 .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
274
275 Ok(LLMResponse {
276 content,
277 tool_calls,
278 finish_reason: finish_reason.to_string(),
279 usage,
280 })
281 }
282
283 async fn generate_with_tools_and_history(
284 &self,
285 messages: &[ConversationMessage],
286 tools: &[ToolDefinition],
287 ) -> Result<LLMResponse> {
288 let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
290
291 let chat_messages: Vec<ChatMessage> = messages
293 .iter()
294 .map(|msg| self.convert_conversation_message(msg))
295 .collect();
296
297 let mut request = ChatMessageRequest::new(self.model.clone(), chat_messages)
299 .options(self.build_model_options());
300
301 if !ollama_tools.is_empty() {
302 request = request.tools(ollama_tools);
303 }
304
305 let response = self
306 .client
307 .send_chat_messages(request)
308 .await
309 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
310
311 let content = response.message.content.clone();
313 let tool_calls: Vec<ToolCall> = response
314 .message
315 .tool_calls
316 .iter()
317 .map(Self::convert_tool_call)
318 .collect();
319
320 let finish_reason = if tool_calls.is_empty() {
322 "stop"
323 } else {
324 "tool_calls"
325 };
326
327 let usage = response
329 .final_data
330 .as_ref()
331 .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
332
333 Ok(LLMResponse {
334 content,
335 tool_calls,
336 finish_reason: finish_reason.to_string(),
337 usage,
338 })
339 }
340
341 async fn stream(
342 &self,
343 prompt: &str,
344 ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
345 let messages = vec![ChatMessage::user(prompt.to_string())];
346 let request = ChatMessageRequest::new(self.model.clone(), messages)
347 .options(self.build_model_options());
348
349 let mut stream_response = self
350 .client
351 .send_chat_messages_stream(request)
352 .await
353 .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
354
355 let output_stream = stream! {
357 while let Some(chunk_result) = stream_response.next().await {
358 match chunk_result {
359 Ok(chunk) => {
360 let content = chunk.message.content;
362 if !content.is_empty() {
363 yield Ok(content);
364 }
365 }
366 Err(_) => {
367 yield Err(AppError::LLM("Stream chunk error".to_string()));
368 break;
369 }
370 }
371 }
372 };
373
374 Ok(Box::new(Box::pin(output_stream)))
375 }
376
377 async fn stream_with_system(
378 &self,
379 system: &str,
380 prompt: &str,
381 ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
382 let messages = vec![
383 ChatMessage::system(system.to_string()),
384 ChatMessage::user(prompt.to_string()),
385 ];
386 let request = ChatMessageRequest::new(self.model.clone(), messages)
387 .options(self.build_model_options());
388
389 let mut stream_response = self
390 .client
391 .send_chat_messages_stream(request)
392 .await
393 .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
394
395 let output_stream = stream! {
396 while let Some(chunk_result) = stream_response.next().await {
397 match chunk_result {
398 Ok(chunk) => {
399 let content = chunk.message.content;
400 if !content.is_empty() {
401 yield Ok(content);
402 }
403 }
404 Err(_) => {
405 yield Err(AppError::LLM("Stream chunk error".to_string()));
406 break;
407 }
408 }
409 }
410 };
411
412 Ok(Box::new(Box::pin(output_stream)))
413 }
414
415 async fn stream_with_history(
416 &self,
417 messages: &[(String, String)],
418 ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
419 let chat_messages: Vec<ChatMessage> = messages
420 .iter()
421 .map(|(role, content)| match role.as_str() {
422 "system" => ChatMessage::system(content.clone()),
423 "user" => ChatMessage::user(content.clone()),
424 "assistant" => ChatMessage::assistant(content.clone()),
425 _ => ChatMessage::user(content.clone()),
426 })
427 .collect();
428
429 let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
430 .options(self.build_model_options());
431
432 let mut stream_response = self
433 .client
434 .send_chat_messages_stream(request)
435 .await
436 .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
437
438 let output_stream = stream! {
439 while let Some(chunk_result) = stream_response.next().await {
440 match chunk_result {
441 Ok(chunk) => {
442 let content = chunk.message.content;
443 if !content.is_empty() {
444 yield Ok(content);
445 }
446 }
447 Err(_) => {
448 yield Err(AppError::LLM("Stream chunk error".to_string()));
449 break;
450 }
451 }
452 }
453 };
454
455 Ok(Box::new(Box::pin(output_stream)))
456 }
457
458 fn model_name(&self) -> &str {
459 &self.model
460 }
461}
462
463impl OllamaClient {
465 pub async fn health_check(&self) -> Result<bool> {
467 match self.client.list_local_models().await {
469 Ok(_) => Ok(true),
470 Err(_) => Ok(false),
471 }
472 }
473
474 pub async fn list_models(&self) -> Result<Vec<String>> {
476 let models = self
477 .client
478 .list_local_models()
479 .await
480 .map_err(|e| AppError::LLM(format!("Failed to list models: {}", e)))?;
481
482 Ok(models.into_iter().map(|m| m.name).collect())
484 }
485
486 pub async fn pull_model(&self, model_name: &str) -> Result<()> {
488 self.client
489 .pull_model(model_name.to_string(), false)
490 .await
491 .map_err(|e| AppError::LLM(format!("Failed to pull model '{}': {}", model_name, e)))?;
492 Ok(())
493 }
494
495 pub async fn model_info(&self, model_name: &str) -> Result<serde_json::Value> {
497 let info = self
498 .client
499 .show_model_info(model_name.to_string())
500 .await
501 .map_err(|e| {
502 AppError::LLM(format!(
503 "Failed to get model info for '{}': {}",
504 model_name, e
505 ))
506 })?;
507
508 Ok(serde_json::json!({
510 "modelfile": info.modelfile,
511 "parameters": info.parameters,
512 "template": info.template,
513 }))
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_url_parsing_full() {
523 let base_url = "http://localhost:11434";
524 let url_parts: Vec<&str> = base_url.split("://").collect();
525 assert_eq!(url_parts.len(), 2);
526 assert_eq!(url_parts[0], "http");
527 assert_eq!(url_parts[1], "localhost:11434");
528
529 let host_port: Vec<&str> = url_parts[1].split(':').collect();
530 assert_eq!(host_port[0], "localhost");
531 assert_eq!(host_port[1], "11434");
532 }
533
534 #[test]
535 fn test_url_parsing_no_port() {
536 let base_url = "http://localhost";
537 let url_parts: Vec<&str> = base_url.split("://").collect();
538 let host_port: Vec<&str> = url_parts[1].split(':').collect();
539
540 let host = host_port[0].to_string();
541 let port = if host_port.len() == 2 {
542 host_port[1].parse().unwrap_or(11434)
543 } else {
544 11434
545 };
546
547 assert_eq!(host, "localhost");
548 assert_eq!(port, 11434);
549 }
550
551 #[test]
552 fn test_url_parsing_custom_port() {
553 let base_url = "http://192.168.1.100:8080";
554 let url_parts: Vec<&str> = base_url.split("://").collect();
555 let host_port: Vec<&str> = url_parts[1].split(':').collect();
556
557 let host = host_port[0].to_string();
558 let port: u16 = host_port[1].parse().unwrap_or(11434);
559
560 assert_eq!(host, "192.168.1.100");
561 assert_eq!(port, 8080);
562 }
563
564 #[test]
565 fn test_tool_definition_conversion() {
566 let tool = ToolDefinition {
567 name: "calculator".to_string(),
568 description: "Performs basic math".to_string(),
569 parameters: serde_json::json!({
570 "type": "object",
571 "properties": {
572 "operation": {"type": "string"},
573 "a": {"type": "number"},
574 "b": {"type": "number"}
575 },
576 "required": ["operation", "a", "b"]
577 }),
578 };
579
580 let ollama_tool = OllamaClient::convert_tool_definition(&tool);
581 assert_eq!(ollama_tool.function.name, "calculator");
582 assert_eq!(ollama_tool.function.description, "Performs basic math");
583 }
584
585 #[test]
586 fn test_tool_call_conversion() {
587 let ollama_call = OllamaToolCall {
588 function: ollama_rs::generation::tools::ToolCallFunction {
589 name: "test_tool".to_string(),
590 arguments: serde_json::json!({"arg1": "value1"}),
591 },
592 };
593
594 let tool_call = OllamaClient::convert_tool_call(&ollama_call);
595 assert_eq!(tool_call.name, "test_tool");
596 assert_eq!(tool_call.arguments["arg1"], "value1");
597 assert!(!tool_call.id.is_empty());
599 }
600}