1use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_stream::stream;
27use async_trait::async_trait;
28use futures::{Stream, StreamExt};
29use ollama_rs::{
30 generation::chat::{request::ChatMessageRequest, ChatMessage},
31 generation::tools::{ToolCall as OllamaToolCall, ToolFunctionInfo, ToolInfo, ToolType},
32 models::ModelOptions,
33 Ollama,
34};
35use schemars::Schema;
36
37pub struct OllamaClient {
41 client: Ollama,
42 model: String,
43 params: ModelParams,
44}
45
46impl OllamaClient {
47 pub async fn new(base_url: String, model: String) -> Result<Self> {
49 Self::with_params(base_url, model, ModelParams::default()).await
50 }
51
52 pub async fn with_params(base_url: String, model: String, params: ModelParams) -> Result<Self> {
54 let trimmed = base_url.trim();
66 if trimmed.is_empty() {
67 return Err(AppError::Configuration(
68 "OLLAMA_URL is empty/invalid; expected something like http://localhost:11434"
69 .to_string(),
70 ));
71 }
72
73 let without_scheme = trimmed
75 .strip_prefix("http://")
76 .or_else(|| trimmed.strip_prefix("https://"))
77 .unwrap_or(trimmed);
78
79 let host_port = without_scheme
81 .split(&['/', '?', '#'][..])
82 .next()
83 .unwrap_or("localhost:11434");
84
85 let (host, port) = if let Some(colon_idx) = host_port.rfind(':') {
87 let h = &host_port[..colon_idx];
88 let p_str = &host_port[colon_idx + 1..];
89 let p = p_str.parse::<u16>().map_err(|_| {
90 AppError::Configuration(format!(
91 "Invalid OLLAMA_URL port in '{}'; expected e.g. http://localhost:11434",
92 base_url
93 ))
94 })?;
95 (h.to_string(), p)
96 } else {
97 (host_port.to_string(), 11434)
98 };
99
100 let client = Ollama::new(format!("http://{}", host), port);
102
103 Ok(Self {
104 client,
105 model,
106 params,
107 })
108 }
109
110 fn build_model_options(&self) -> ModelOptions {
112 let mut options = ModelOptions::default();
113 if let Some(temp) = self.params.temperature {
114 options = options.temperature(temp);
115 }
116 if let Some(max_tokens) = self.params.max_tokens {
117 options = options.num_predict(max_tokens as i32);
118 }
119 if let Some(top_p) = self.params.top_p {
120 options = options.top_p(top_p);
121 }
122 if let Some(pres_penalty) = self.params.presence_penalty {
125 options = options.repeat_penalty(pres_penalty);
126 }
127 options
128 }
129
130 fn convert_tool_definition(tool: &ToolDefinition) -> ToolInfo {
132 let schema: Schema =
135 serde_json::from_value(tool.parameters.clone()).unwrap_or_else(|_| Schema::default());
136
137 ToolInfo {
138 tool_type: ToolType::Function,
139 function: ToolFunctionInfo {
140 name: tool.name.clone(),
141 description: tool.description.clone(),
142 parameters: schema,
143 },
144 }
145 }
146
147 fn convert_tool_call(call: &OllamaToolCall) -> ToolCall {
149 ToolCall {
150 id: uuid::Uuid::new_v4().to_string(),
151 name: call.function.name.clone(),
152 arguments: call.function.arguments.clone(),
153 }
154 }
155
156 fn convert_conversation_message(&self, msg: &ConversationMessage) -> ChatMessage {
158 match msg.role {
159 MessageRole::System => ChatMessage::system(msg.content.clone()),
160 MessageRole::User => ChatMessage::user(msg.content.clone()),
161 MessageRole::Assistant => {
162 ChatMessage::assistant(msg.content.clone())
164 }
165 MessageRole::Tool => {
166 ChatMessage::tool(msg.content.clone())
168 }
169 }
170 }
171}
172
173#[async_trait]
174impl LLMClient for OllamaClient {
175 async fn generate(&self, prompt: &str) -> Result<String> {
176 let messages = vec![ChatMessage::user(prompt.to_string())];
177
178 let request = ChatMessageRequest::new(self.model.clone(), messages)
179 .options(self.build_model_options());
180
181 let response = self
182 .client
183 .send_chat_messages(request)
184 .await
185 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
186
187 Ok(response.message.content)
189 }
190
191 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
192 let messages = vec![
193 ChatMessage::system(system.to_string()),
194 ChatMessage::user(prompt.to_string()),
195 ];
196
197 let request = ChatMessageRequest::new(self.model.clone(), messages)
198 .options(self.build_model_options());
199
200 let response = self
201 .client
202 .send_chat_messages(request)
203 .await
204 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
205
206 Ok(response.message.content)
207 }
208
209 async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<LLMResponse> {
210 let chat_messages: Vec<ChatMessage> = messages
211 .iter()
212 .map(|(role, content)| match role.as_str() {
213 "system" => ChatMessage::system(content.clone()),
214 "user" => ChatMessage::user(content.clone()),
215 "assistant" => ChatMessage::assistant(content.clone()),
216 _ => ChatMessage::user(content.clone()),
217 })
218 .collect();
219
220 let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
221 .options(self.build_model_options());
222
223 let response = self
224 .client
225 .send_chat_messages(request)
226 .await
227 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
228
229 let usage = response
231 .final_data
232 .as_ref()
233 .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
234
235 Ok(LLMResponse {
236 content: response.message.content,
237 tool_calls: vec![],
238 finish_reason: "stop".to_string(),
239 usage,
240 })
241 }
242
243 async fn generate_with_tools(
244 &self,
245 prompt: &str,
246 tools: &[ToolDefinition],
247 ) -> Result<LLMResponse> {
248 let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
250
251 let messages = vec![ChatMessage::user(prompt.to_string())];
252
253 let request = ChatMessageRequest::new(self.model.clone(), messages)
255 .tools(ollama_tools)
256 .options(self.build_model_options());
257
258 let response = self
259 .client
260 .send_chat_messages(request)
261 .await
262 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
263
264 let content = response.message.content.clone();
266 let tool_calls: Vec<ToolCall> = response
267 .message
268 .tool_calls
269 .iter()
270 .map(Self::convert_tool_call)
271 .collect();
272
273 let finish_reason = if tool_calls.is_empty() {
275 "stop"
276 } else {
277 "tool_calls"
278 };
279
280 let usage = response
282 .final_data
283 .as_ref()
284 .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
285
286 Ok(LLMResponse {
287 content,
288 tool_calls,
289 finish_reason: finish_reason.to_string(),
290 usage,
291 })
292 }
293
294 async fn generate_with_tools_and_history(
295 &self,
296 messages: &[ConversationMessage],
297 tools: &[ToolDefinition],
298 ) -> Result<LLMResponse> {
299 let ollama_tools: Vec<ToolInfo> = tools.iter().map(Self::convert_tool_definition).collect();
301
302 let chat_messages: Vec<ChatMessage> = messages
304 .iter()
305 .map(|msg| self.convert_conversation_message(msg))
306 .collect();
307
308 let mut request = ChatMessageRequest::new(self.model.clone(), chat_messages)
310 .options(self.build_model_options());
311
312 if !ollama_tools.is_empty() {
313 request = request.tools(ollama_tools);
314 }
315
316 let response = self
317 .client
318 .send_chat_messages(request)
319 .await
320 .map_err(|e| AppError::LLM(format!("Ollama error: {}", e)))?;
321
322 let content = response.message.content.clone();
324 let tool_calls: Vec<ToolCall> = response
325 .message
326 .tool_calls
327 .iter()
328 .map(Self::convert_tool_call)
329 .collect();
330
331 let finish_reason = if tool_calls.is_empty() {
333 "stop"
334 } else {
335 "tool_calls"
336 };
337
338 let usage = response
340 .final_data
341 .as_ref()
342 .map(|data| TokenUsage::new(data.prompt_eval_count as u32, data.eval_count as u32));
343
344 Ok(LLMResponse {
345 content,
346 tool_calls,
347 finish_reason: finish_reason.to_string(),
348 usage,
349 })
350 }
351
352 async fn stream(
353 &self,
354 prompt: &str,
355 ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
356 let messages = vec![ChatMessage::user(prompt.to_string())];
357 let request = ChatMessageRequest::new(self.model.clone(), messages)
358 .options(self.build_model_options());
359
360 let mut stream_response = self
361 .client
362 .send_chat_messages_stream(request)
363 .await
364 .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
365
366 let output_stream = stream! {
368 while let Some(chunk_result) = stream_response.next().await {
369 match chunk_result {
370 Ok(chunk) => {
371 let content = chunk.message.content;
373 if !content.is_empty() {
374 yield Ok(content);
375 }
376 }
377 Err(_) => {
378 yield Err(AppError::LLM("Stream chunk error".to_string()));
379 break;
380 }
381 }
382 }
383 };
384
385 Ok(Box::new(Box::pin(output_stream)))
386 }
387
388 async fn stream_with_system(
389 &self,
390 system: &str,
391 prompt: &str,
392 ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
393 let messages = vec![
394 ChatMessage::system(system.to_string()),
395 ChatMessage::user(prompt.to_string()),
396 ];
397 let request = ChatMessageRequest::new(self.model.clone(), messages)
398 .options(self.build_model_options());
399
400 let mut stream_response = self
401 .client
402 .send_chat_messages_stream(request)
403 .await
404 .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
405
406 let output_stream = stream! {
407 while let Some(chunk_result) = stream_response.next().await {
408 match chunk_result {
409 Ok(chunk) => {
410 let content = chunk.message.content;
411 if !content.is_empty() {
412 yield Ok(content);
413 }
414 }
415 Err(_) => {
416 yield Err(AppError::LLM("Stream chunk error".to_string()));
417 break;
418 }
419 }
420 }
421 };
422
423 Ok(Box::new(Box::pin(output_stream)))
424 }
425
426 async fn stream_with_history(
427 &self,
428 messages: &[(String, String)],
429 ) -> Result<Box<dyn Stream<Item = Result<String>> + Send + Unpin>> {
430 let chat_messages: Vec<ChatMessage> = messages
431 .iter()
432 .map(|(role, content)| match role.as_str() {
433 "system" => ChatMessage::system(content.clone()),
434 "user" => ChatMessage::user(content.clone()),
435 "assistant" => ChatMessage::assistant(content.clone()),
436 _ => ChatMessage::user(content.clone()),
437 })
438 .collect();
439
440 let request = ChatMessageRequest::new(self.model.clone(), chat_messages)
441 .options(self.build_model_options());
442
443 let mut stream_response = self
444 .client
445 .send_chat_messages_stream(request)
446 .await
447 .map_err(|e| AppError::LLM(format!("Ollama stream error: {}", e)))?;
448
449 let output_stream = stream! {
450 while let Some(chunk_result) = stream_response.next().await {
451 match chunk_result {
452 Ok(chunk) => {
453 let content = chunk.message.content;
454 if !content.is_empty() {
455 yield Ok(content);
456 }
457 }
458 Err(_) => {
459 yield Err(AppError::LLM("Stream chunk error".to_string()));
460 break;
461 }
462 }
463 }
464 };
465
466 Ok(Box::new(Box::pin(output_stream)))
467 }
468
469 fn model_name(&self) -> &str {
470 &self.model
471 }
472}
473
474impl OllamaClient {
476 pub async fn health_check(&self) -> Result<bool> {
478 match self.client.list_local_models().await {
480 Ok(_) => Ok(true),
481 Err(_) => Ok(false),
482 }
483 }
484
485 pub async fn list_models(&self) -> Result<Vec<String>> {
487 let models = self
488 .client
489 .list_local_models()
490 .await
491 .map_err(|e| AppError::LLM(format!("Failed to list models: {}", e)))?;
492
493 Ok(models.into_iter().map(|m| m.name).collect())
495 }
496
497 pub async fn pull_model(&self, model_name: &str) -> Result<()> {
499 self.client
500 .pull_model(model_name.to_string(), false)
501 .await
502 .map_err(|e| AppError::LLM(format!("Failed to pull model '{}': {}", model_name, e)))?;
503 Ok(())
504 }
505
506 pub async fn model_info(&self, model_name: &str) -> Result<serde_json::Value> {
508 let info = self
509 .client
510 .show_model_info(model_name.to_string())
511 .await
512 .map_err(|e| {
513 AppError::LLM(format!(
514 "Failed to get model info for '{}': {}",
515 model_name, e
516 ))
517 })?;
518
519 Ok(serde_json::json!({
521 "modelfile": info.modelfile,
522 "parameters": info.parameters,
523 "template": info.template,
524 }))
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_url_parsing_full() {
534 let base_url = "http://localhost:11434";
535 let url_parts: Vec<&str> = base_url.split("://").collect();
536 assert_eq!(url_parts.len(), 2);
537 assert_eq!(url_parts[0], "http");
538 assert_eq!(url_parts[1], "localhost:11434");
539
540 let host_port: Vec<&str> = url_parts[1].split(':').collect();
541 assert_eq!(host_port[0], "localhost");
542 assert_eq!(host_port[1], "11434");
543 }
544
545 #[test]
546 fn test_url_parsing_no_port() {
547 let base_url = "http://localhost";
548 let url_parts: Vec<&str> = base_url.split("://").collect();
549 let host_port: Vec<&str> = url_parts[1].split(':').collect();
550
551 let host = host_port[0].to_string();
552 let port = if host_port.len() == 2 {
553 host_port[1].parse().unwrap_or(11434)
554 } else {
555 11434
556 };
557
558 assert_eq!(host, "localhost");
559 assert_eq!(port, 11434);
560 }
561
562 #[test]
563 fn test_url_parsing_custom_port() {
564 let base_url = "http://192.168.1.100:8080";
565 let url_parts: Vec<&str> = base_url.split("://").collect();
566 let host_port: Vec<&str> = url_parts[1].split(':').collect();
567
568 let host = host_port[0].to_string();
569 let port: u16 = host_port[1].parse().unwrap_or(11434);
570
571 assert_eq!(host, "192.168.1.100");
572 assert_eq!(port, 8080);
573 }
574
575 #[test]
576 fn test_tool_definition_conversion() {
577 let tool = ToolDefinition {
578 name: "calculator".to_string(),
579 description: "Performs basic math".to_string(),
580 parameters: serde_json::json!({
581 "type": "object",
582 "properties": {
583 "operation": {"type": "string"},
584 "a": {"type": "number"},
585 "b": {"type": "number"}
586 },
587 "required": ["operation", "a", "b"]
588 }),
589 };
590
591 let ollama_tool = OllamaClient::convert_tool_definition(&tool);
592 assert_eq!(ollama_tool.function.name, "calculator");
593 assert_eq!(ollama_tool.function.description, "Performs basic math");
594 }
595
596 #[test]
597 fn test_tool_call_conversion() {
598 let ollama_call = OllamaToolCall {
599 function: ollama_rs::generation::tools::ToolCallFunction {
600 name: "test_tool".to_string(),
601 arguments: serde_json::json!({"arg1": "value1"}),
602 },
603 };
604
605 let tool_call = OllamaClient::convert_tool_call(&ollama_call);
606 assert_eq!(tool_call.name, "test_tool");
607 assert_eq!(tool_call.arguments["arg1"], "value1");
608 assert!(!tool_call.id.is_empty());
610 }
611}