1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Function, Message, ToolCall};
3
4use crate::Chat;
5use crate::model::{ModelInfo, Ollama, OllamaModelSize};
6use crate::provider::HTTPProvider;
7use crate::tool::{LlmToolInfo, ToolChoice};
8use async_trait::async_trait;
9use reqwest::{Client, Request, Url, header};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13use std::fmt::Debug;
14use thiserror::Error;
15use tracing::{debug, error, info, instrument};
16
17const DEFAULT_OLLAMA_API_BASE_URL: &str = "http://localhost:11434/api";
18
19#[derive(Error, Debug)]
20pub enum ProviderError {
21 #[error("API error: {message:?}")]
22 ApiError {
23 source: reqwest::Error,
24 message: Option<String>,
25 },
26
27 #[error("Deserialization error: {content}")]
28 DeserializationError {
29 content: String,
30 source: serde_json::Error,
31 },
32
33 #[error("Unexpected response ({status}): {content}")]
34 UnexpectedResponse { status: u16, content: String },
35
36 #[error("Error: {0}")]
37 Other(String),
38}
39
40impl From<ProviderError> for Error {
42 fn from(err: ProviderError) -> Self {
43 match err {
44 ProviderError::ApiError { source, message } => {
45 if let Some(msg) = message {
46 Error::ProviderUnavailable(format!("Ollama API error: {}: {}", source, msg))
47 } else {
48 Error::Request(source)
49 }
50 }
51 ProviderError::DeserializationError { content: _, source } => {
52 Error::Serialization(source)
53 }
54 ProviderError::UnexpectedResponse { status, content } => {
55 Error::ProviderUnavailable(format!(
56 "Unexpected response from Ollama API ({}): {}",
57 status, content
58 ))
59 }
60 ProviderError::Other(msg) => Error::Other(format!("Ollama provider error: {}", msg)),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct OllamaConfig {
67 pub base_url: Url,
68 }
72
73impl Default for OllamaConfig {
74 fn default() -> Self {
75 Self {
76 base_url: Url::parse(DEFAULT_OLLAMA_API_BASE_URL)
77 .expect("Failed to parse default Ollama base URL"),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct OllamaProvider {
84 config: OllamaConfig,
85 client: Client,
86}
87
88impl OllamaProvider {
89 pub fn new() -> Self {
91 Self::default()
92 }
93
94 pub fn with_config(config: OllamaConfig) -> Self {
96 Self {
97 config,
98 client: Client::new(),
99 }
100 }
101
102 fn id_for_model(&self, model: &Ollama) -> String {
104 model.ollama_model_id()
105 }
106
107 #[instrument(skip(self, messages, tools))]
108 #[allow(clippy::too_many_arguments)]
109 fn create_request_payload(
110 &self,
111 model: &Ollama,
112 messages: &[Message],
113 max_tokens: Option<u32>,
114 temperature: Option<f32>,
115 top_p: Option<f32>,
116 top_k: Option<u32>,
117 tools: Option<&[LlmToolInfo]>,
118 tool_choice: Option<&ToolChoice>,
119 system_prompt: Option<&str>,
120 ) -> Result<OllamaChatRequest> {
121 let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
122 let mut current_system_prompt = system_prompt.map(|s| s.to_string());
123
124 for message in messages {
125 match message {
127 Message::System { content, .. } => {
128 if !content.is_empty() {
130 if let Some(ref mut existing_prompt) = current_system_prompt {
131 existing_prompt.push('\n');
132 existing_prompt.push_str(content);
133 } else {
134 current_system_prompt = Some(content.clone());
135 }
136 }
137 }
138 Message::User { .. } | Message::Assistant { .. } | Message::Tool { .. } => {
139 ollama_messages.push(OllamaMessage::from(message));
141 }
142 }
143 }
144
145 let mut options = OllamaRequestOptions::default();
146 let mut options_set = false;
147
148 if let Some(temp) = temperature {
149 options.temperature = Some(temp);
150 options_set = true;
151 }
152 if let Some(tk) = top_k {
153 options.top_k = Some(tk);
154 options_set = true;
155 }
156 if let Some(tp) = top_p {
157 options.top_p = Some(tp);
158 options_set = true;
159 }
160 if let Some(mt) = max_tokens {
161 options.num_predict = Some(mt);
162 options_set = true;
163 }
164 let ollama_tools = tools.and_then(|tool_infos| {
167 if tool_infos.is_empty() {
168 None
169 } else {
170 Some(tool_infos.iter().map(OllamaTool::from).collect())
171 }
172 });
173
174 let mut format_option: Option<String> = None;
175 if let Some(tc) = tool_choice {
176 match tc {
177 ToolChoice::Auto => {
178 }
181 ToolChoice::Any => {
182 format_option = Some("json".to_string());
186 }
187 ToolChoice::None => {
188 }
191 ToolChoice::Specific(_name) => {
192 }
197 }
198 }
199
200 let final_tools = if tools.is_none_or(|t| t.is_empty()) {
205 None
206 } else {
207 ollama_tools
208 };
209
210 Ok(OllamaChatRequest {
211 model: self.id_for_model(model),
212 messages: ollama_messages,
213 system: current_system_prompt,
214 format: format_option,
215 options: if options_set { Some(options) } else { None },
216 stream: false, tools: final_tools,
218 keep_alive: Some("5m".to_string()), })
220 }
221}
222
223impl Default for OllamaProvider {
224 fn default() -> Self {
225 Self {
226 config: OllamaConfig::default(),
227 client: Client::new(),
228 }
229 }
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
235pub(crate) struct OllamaMessage {
236 pub role: String,
237 pub content: String,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub images: Option<Vec<String>>, #[serde(skip_serializing_if = "Option::is_none")]
241 pub tool_calls: Option<Vec<OllamaResponseToolCall>>, }
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub(crate) struct OllamaToolFunctionDefinition {
246 pub name: String,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub description: Option<String>,
249 pub parameters: Value, }
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub(crate) struct OllamaTool {
254 #[serde(rename = "type")]
255 pub type_field: String, pub function: OllamaToolFunctionDefinition,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize, Default)]
260pub(crate) struct OllamaRequestOptions {
261 #[serde(skip_serializing_if = "Option::is_none")]
262 pub temperature: Option<f32>,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub top_k: Option<u32>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 pub top_p: Option<f32>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 pub num_predict: Option<u32>, #[serde(skip_serializing_if = "Option::is_none")]
270 pub stop: Option<Vec<String>>, }
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub(crate) struct OllamaChatRequest {
276 pub model: String,
277 pub messages: Vec<OllamaMessage>,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 pub system: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
281 pub format: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
283 pub options: Option<OllamaRequestOptions>,
284 pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")]
286 pub tools: Option<Vec<OllamaTool>>,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 pub keep_alive: Option<String>, }
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
294pub(crate) struct OllamaResponseFunctionCall {
295 pub name: String,
296 pub arguments: Value, }
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub(crate) struct OllamaResponseToolCall {
301 #[serde(rename = "type")]
302 pub type_field: String, pub function: OllamaResponseFunctionCall,
304 }
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub(crate) struct OllamaResponseMessage {
310 pub role: String,
311 pub content: String, #[serde(skip_serializing_if = "Option::is_none")]
313 pub tool_calls: Option<Vec<OllamaResponseToolCall>>,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 pub images: Option<Vec<String>>, }
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub(crate) struct OllamaChatResponse {
320 pub model: String,
321 pub created_at: String, pub message: OllamaResponseMessage,
323 pub done: bool,
324 #[serde(skip_serializing_if = "Option::is_none")]
325 pub done_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
329 pub total_duration: Option<u64>,
330 #[serde(skip_serializing_if = "Option::is_none")]
331 pub load_duration: Option<u64>,
332 #[serde(skip_serializing_if = "Option::is_none")]
333 pub prompt_eval_count: Option<u32>,
334 #[serde(skip_serializing_if = "Option::is_none")]
335 pub prompt_eval_duration: Option<u64>,
336 #[serde(skip_serializing_if = "Option::is_none")]
337 pub eval_count: Option<u32>,
338 #[serde(skip_serializing_if = "Option::is_none")]
339 pub eval_duration: Option<u64>,
340}
341
342pub trait OllamaModelInfo {
346 fn ollama_model_id(&self) -> String;
348}
349
350impl OllamaModelInfo for Ollama {
351 fn ollama_model_id(&self) -> String {
352 match self {
353 Self::Llama3 { size } => match size {
354 OllamaModelSize::_8B => "llama3:8b",
355 OllamaModelSize::_7B => "llama3",
356 OllamaModelSize::_3B => "llama3:3b",
357 OllamaModelSize::_1B => "llama3:1b",
358 },
359 Self::Llava => "llava",
360 Self::Mistral { size } => match size {
361 OllamaModelSize::_8B => "mistral:8b",
362 OllamaModelSize::_7B => "mistral",
363 OllamaModelSize::_3B => "mistral:3b",
364 OllamaModelSize::_1B => "mistral:1b",
365 },
366 Self::Custom { name } => name,
367 }
368 .to_string()
369 }
370}
371
372#[async_trait]
373pub trait Provider<M: ModelInfo>: Send + Sync {
374 #[allow(clippy::too_many_arguments)]
376 async fn prompt(
377 &self,
378 model: &M,
379 messages: &[Message],
380 max_tokens: Option<u32>,
381 temperature: Option<f32>,
382 top_p: Option<f32>,
383 top_k: Option<u32>,
384 tools: Option<&[LlmToolInfo]>,
385 tool_choice: Option<&ToolChoice>,
386 system_prompt: Option<&str>,
387 ) -> Result<Message>;
388}
389
390#[async_trait]
391impl Provider<Ollama> for OllamaProvider {
392 #[instrument(skip(self), level = "debug")]
393 #[allow(clippy::too_many_arguments)]
394 async fn prompt(
395 &self,
396 model: &Ollama,
397 messages: &[Message],
398 max_tokens: Option<u32>,
399 temperature: Option<f32>,
400 top_p: Option<f32>,
401 top_k: Option<u32>,
402 tools: Option<&[LlmToolInfo]>,
403 tool_choice: Option<&ToolChoice>,
404 system_prompt: Option<&str>,
405 ) -> Result<Message> {
406 info!("Creating chat completion with Ollama model");
407 debug!("Model: {:?}", model);
408 debug!("Number of messages: {}", messages.len());
409 debug!("System prompt provided: {}", system_prompt.is_some());
410 debug!("Tools provided: {}", tools.is_some_and(|t| !t.is_empty()));
411 debug!("Tool choice provided: {}", tool_choice.is_some());
412
413 let request_url = self
414 .config
415 .base_url
416 .join("chat")
417 .map_err(Error::BaseUrlError)?;
418 debug!("Request URL: {}", request_url);
419
420 let request_payload = self.create_request_payload(
421 model,
422 messages,
423 max_tokens,
424 temperature,
425 top_p,
426 top_k,
427 tools,
428 tool_choice,
429 system_prompt,
430 )?;
431
432 let request = self
434 .client
435 .post(request_url)
436 .header(header::CONTENT_TYPE, "application/json")
437 .header(header::ACCEPT, "application/json")
438 .json(&request_payload)
439 .build()
440 .map_err(|e| ProviderError::ApiError {
441 source: e,
442 message: Some("Failed to build request".to_string()),
443 })?;
444
445 debug!("Sending request to Ollama API");
446 let response = self
447 .client
448 .execute(request)
449 .await
450 .map_err(|e| ProviderError::ApiError {
451 source: e,
452 message: Some("Failed to execute request".to_string()),
453 })?;
454
455 debug!("Response status: {}", response.status());
456
457 let response_text = response.text().await.map_err(|e| {
459 error!("Failed to get response text: {}", e);
460 ProviderError::ApiError {
461 source: e,
462 message: Some("Failed to get response text".to_string()),
463 }
464 })?;
465
466 let message = match self.parse(response_text) {
468 Ok(msg) => msg,
469 Err(e) => {
470 error!("Failed to parse response: {:?}", e);
471 return Err(e);
472 }
473 };
474
475 Ok(message)
477 }
478}
479
480impl From<&LlmToolInfo> for OllamaTool {
483 fn from(tool_info: &LlmToolInfo) -> Self {
484 OllamaTool {
485 type_field: "function".to_string(),
486 function: OllamaToolFunctionDefinition {
487 name: tool_info.name.clone(),
488 description: Some(tool_info.description.clone()),
489 parameters: tool_info.parameters.clone(),
490 },
491 }
492 }
493}
494
495impl From<&Message> for OllamaMessage {
496 fn from(message: &Message) -> Self {
497 let role = match message {
499 Message::User { .. } => "user".to_string(),
500 Message::Assistant { .. } => "assistant".to_string(),
501 Message::Tool { .. } => "tool".to_string(),
502 Message::System { .. } => {
503 tracing::warn!(
508 "System message encountered in From<&Message> for OllamaMessage conversion. This should be handled by the system prompt field."
509 );
510 "user".to_string()
511 }
512 };
513
514 let mut content_texts = Vec::new();
515 let mut image_data: Vec<String> = Vec::new();
516 let mut assistant_tool_calls: Vec<OllamaResponseToolCall> = Vec::new();
517
518 match message {
520 Message::User { content, .. } => {
521 match content {
522 Content::Text(text) => content_texts.push(text.clone()),
523 Content::Parts(parts) => {
524 for part in parts {
525 match part {
526 ContentPart::Text { text } => content_texts.push(text.clone()),
527 ContentPart::ImageUrl { image_url } => {
528 image_data.push(image_url.url.clone());
530 }
531 }
532 }
533 }
534 }
535 }
536 Message::Assistant {
537 content,
538 tool_calls,
539 ..
540 } => {
541 if let Some(content) = content {
543 match content {
544 Content::Text(text) => content_texts.push(text.clone()),
545 Content::Parts(parts) => {
546 for part in parts {
547 if let ContentPart::Text { text } = part {
548 content_texts.push(text.clone());
549 }
550 }
552 }
553 }
554 }
555
556 for tool_call in tool_calls {
558 assistant_tool_calls.push(OllamaResponseToolCall {
559 type_field: "function".to_string(),
560 function: OllamaResponseFunctionCall {
561 name: tool_call.function.name.clone(),
562 arguments: serde_json::from_str(&tool_call.function.arguments)
563 .unwrap_or(serde_json::Value::Null),
564 },
565 });
566 }
567 }
568 Message::Tool { content, .. } => {
569 content_texts.push(content.clone());
571 }
572 Message::System { content, .. } => {
573 content_texts.push(content.clone());
575 }
576 }
577
578 let final_content = content_texts.join("\n");
579
580 OllamaMessage {
581 role,
582 content: final_content,
583 images: if image_data.is_empty() {
584 None
585 } else {
586 Some(image_data)
587 },
588 tool_calls: if assistant_tool_calls.is_empty() {
589 None
590 } else {
591 Some(assistant_tool_calls)
592 },
593 }
594 }
595}
596
597#[async_trait]
598impl HTTPProvider<Ollama> for OllamaProvider {
599 #[instrument(skip(self, model, chat), level = "debug")]
600 fn accept(&self, model: Ollama, chat: &Chat) -> Result<Request> {
601 info!("Creating HTTP request for Ollama model: {:?}", model);
602 debug!("Number of messages in chat: {}", chat.history.len());
603
604 let url = self.config.base_url.join("chat").map_err(|e| {
605 error!("Failed to join chat URL path to base URL: {}", e);
606 crate::error::Error::Other(format!("Failed to join chat URL path to base URL: {}", e))
607 })?;
608 debug!("Request URL: {}", url);
609
610 let ollama_messages: Vec<_> = chat
612 .history
613 .iter()
614 .filter(|msg| !matches!(msg, Message::System { .. })) .map(OllamaMessage::from)
616 .collect();
617 debug!(
618 "Converted {} messages for Ollama request",
619 ollama_messages.len()
620 );
621
622 let system_prompt = if chat.system_prompt.is_empty() {
624 None
625 } else {
626 debug!(
627 "Using system prompt from chat: {} chars",
628 chat.system_prompt.len()
629 );
630 Some(chat.system_prompt.clone())
631 };
632
633 let tools = if let Some(ref tools) = chat.tools {
635 if tools.is_empty() {
636 debug!("No tools defined in chat");
637 None
638 } else {
639 debug!("Converting {} tools for Ollama request", tools.len());
640 Some(tools.iter().map(OllamaTool::from).collect::<Vec<_>>())
641 }
642 } else {
643 None
644 };
645
646 let format = match chat.tool_choice {
648 Some(ToolChoice::Any) => {
649 debug!(
650 "Using ToolChoice::Any - setting json format to encourage structured outputs"
651 );
652 Some("json".to_string())
653 }
654 Some(ToolChoice::Auto) => {
655 debug!("Using ToolChoice::Auto - letting the model decide");
656 None
657 }
658 Some(ToolChoice::None) => {
659 debug!("Using ToolChoice::None - tools will not be used");
660 None
661 }
662 Some(ToolChoice::Specific(_)) => {
663 debug!("Using specific tool choice - filter applied to tools");
664 None
666 }
667 None => None,
668 };
669
670 let options = Some(OllamaRequestOptions {
672 temperature: None, top_k: None, top_p: None, num_predict: Some(chat.max_output_tokens as u32),
676 stop: None, });
678
679 let payload = OllamaChatRequest {
681 model: model.ollama_model_id(),
682 messages: ollama_messages,
683 system: system_prompt,
684 format,
685 options,
686 stream: false, tools,
688 keep_alive: Some("5m".to_string()),
689 };
690
691 debug!("Created Ollama request payload");
692
693 let request = self
695 .client
696 .post(url)
697 .header(header::CONTENT_TYPE, "application/json")
698 .header(header::ACCEPT, "application/json")
699 .json(&payload)
700 .build()
701 .map_err(|e| {
702 error!("Failed to build request: {}", e);
703 crate::error::Error::Request(e)
704 })?;
705
706 debug!("Built Ollama HTTP request successfully");
707 Ok(request)
708 }
709
710 #[instrument(skip(self, raw_response_text), level = "debug")]
711 fn parse(&self, raw_response_text: String) -> Result<Message> {
712 info!("Parsing response from Ollama API");
713 debug!("Response text length: {}", raw_response_text.len());
714
715 if raw_response_text.contains("\"error\"") {
717 let error_response: serde_json::Value = serde_json::from_str(&raw_response_text)
718 .map_err(|e| {
719 error!("Failed to parse error response: {}", e);
720 Error::Serialization(e)
721 })?;
722
723 if let Some(error) = error_response.get("error") {
724 let error_msg = error.as_str().unwrap_or("Unknown Ollama error");
725 error!("Ollama API returned an error: {}", error_msg);
726 return Err(Error::ProviderUnavailable(error_msg.to_string()));
727 }
728 }
729
730 let ollama_response: OllamaChatResponse = serde_json::from_str(&raw_response_text)
732 .map_err(|e| {
733 error!("Failed to deserialize Ollama response: {}", e);
734 Error::Serialization(e)
735 })?;
736
737 debug!("Response deserialized successfully");
738 debug!("Model: {}", ollama_response.model);
739 debug!("Done reason: {:?}", ollama_response.done_reason);
740
741 let response_role = ollama_response.message.role.as_str();
743 let response_content = ollama_response.message.content.clone();
744
745 let message = match response_role {
746 "assistant" => {
747 let mut tool_calls = Vec::new();
751 if let Some(tool_calls_data) = ollama_response.message.tool_calls {
752 for tool_call in tool_calls_data {
753 let tool_call_id = format!(
756 "tc-{}",
757 std::time::SystemTime::now()
758 .duration_since(std::time::UNIX_EPOCH)
759 .unwrap_or_default()
760 .as_micros()
761 );
762
763 tool_calls.push(ToolCall {
764 id: tool_call_id,
765 tool_type: "function".to_string(),
766 function: Function {
767 name: tool_call.function.name,
768 arguments: serde_json::to_string(&tool_call.function.arguments)
769 .unwrap_or_default(),
770 },
771 });
772 }
773 }
774
775 let content = if response_content.is_empty() && !tool_calls.is_empty() {
777 None
778 } else {
779 Some(Content::Text(response_content))
781 };
782
783 Message::Assistant {
784 content,
785 tool_calls,
786 metadata: HashMap::new(),
787 }
788 }
789 "user" => Message::User {
790 content: Content::Text(response_content),
791 name: None,
792 metadata: HashMap::new(),
793 },
794 "system" => Message::System {
795 content: response_content,
796 metadata: HashMap::new(),
797 },
798 "tool" => Message::Tool {
799 tool_call_id: "response-tool-call".to_string(), content: response_content,
801 metadata: HashMap::new(),
802 },
803 _ => {
804 error!(
806 "Unknown message role in Ollama response: {}",
807 ollama_response.message.role
808 );
809 Message::Assistant {
810 content: Some(Content::Text(response_content)),
811 tool_calls: Vec::new(),
812 metadata: HashMap::new(),
813 }
814 }
815 };
816
817 let message_with_meta = if let Some(tokens) = ollama_response.prompt_eval_count {
819 message.with_metadata("input_tokens", serde_json::json!(tokens))
820 } else {
821 message
822 };
823
824 let message_with_meta = if let Some(tokens) = ollama_response.eval_count {
825 message_with_meta.with_metadata("output_tokens", serde_json::json!(tokens))
826 } else {
827 message_with_meta
828 };
829
830 info!("Successfully parsed Ollama response");
831 Ok(message_with_meta)
832 }
833}
834
835#[cfg(test)]
838mod tests {
839 use super::*;
840 use serde_json::json;
841
842 #[test]
843 fn test_message_to_ollama_conversion() {
844 let user_message_text = Message::user("Hello, Ollama!");
846 let ollama_user_message_text = OllamaMessage::from(&user_message_text);
847 assert_eq!(ollama_user_message_text.role, "user");
848 assert_eq!(ollama_user_message_text.content, "Hello, Ollama!");
849 assert!(ollama_user_message_text.images.is_none());
850 assert!(ollama_user_message_text.tool_calls.is_none());
851
852 let assistant_message_text = Message::assistant("Hi there!");
854 let ollama_assistant_message_text = OllamaMessage::from(&assistant_message_text);
855 assert_eq!(ollama_assistant_message_text.role, "assistant");
856 assert_eq!(ollama_assistant_message_text.content, "Hi there!");
857 assert!(ollama_assistant_message_text.images.is_none());
858 assert!(ollama_assistant_message_text.tool_calls.is_none());
859
860 let parts = vec![
862 crate::message::ContentPart::text("What is this?"),
863 crate::message::ContentPart::image_url("https://example.com/image.jpg"),
864 ];
865 let user_message_image = Message::user_with_parts(parts);
866 let ollama_user_message_image = OllamaMessage::from(&user_message_image);
867 assert_eq!(ollama_user_message_image.role, "user");
868 assert_eq!(ollama_user_message_image.content, "What is this?");
869 assert_eq!(
870 ollama_user_message_image.images.unwrap(),
871 vec!["https://example.com/image.jpg"]
872 );
873 assert!(ollama_user_message_image.tool_calls.is_none());
874
875 let tool_call = ToolCall {
877 id: "tool_call_123".to_string(),
878 tool_type: "function".to_string(),
879 function: Function {
880 name: "get_weather".to_string(),
881 arguments: "{\"location\":\"Boston\"}".to_string(),
882 },
883 };
884 let assistant_message = Message::assistant_with_tool_calls(vec![tool_call]);
885
886 let ollama_assistant_message = OllamaMessage::from(&assistant_message);
887 assert_eq!(ollama_assistant_message.role, "assistant");
888 assert_eq!(ollama_assistant_message.content, ""); assert!(ollama_assistant_message.images.is_none());
890
891 let tool_calls = ollama_assistant_message.tool_calls.unwrap();
892 assert_eq!(tool_calls.len(), 1);
893 assert_eq!(tool_calls[0].type_field, "function");
894 assert_eq!(tool_calls[0].function.name, "get_weather");
895 assert!(tool_calls[0].function.arguments.get("location").is_some());
897 assert_eq!(
898 tool_calls[0].function.arguments.get("location").unwrap(),
899 "Boston"
900 );
901
902 let tool_message = Message::tool("tool_call_123", "{\"temperature\": \"72F\"}");
904 let ollama_tool_message = OllamaMessage::from(&tool_message);
905 assert_eq!(ollama_tool_message.role, "tool");
906 assert_eq!(ollama_tool_message.content, "{\"temperature\": \"72F\"}");
907 assert!(ollama_tool_message.images.is_none());
908 assert!(ollama_tool_message.tool_calls.is_none());
909
910 let system_message = Message::system("You are a helpful assistant.");
913 let ollama_system_message = OllamaMessage::from(&system_message);
914 assert_eq!(ollama_system_message.role, "user");
916 assert_eq!(
917 ollama_system_message.content,
918 "You are a helpful assistant."
919 );
920 assert!(ollama_system_message.images.is_none());
921 assert!(ollama_system_message.tool_calls.is_none());
922 }
923
924 #[test]
925 fn test_ollama_response_to_message_conversion() {
926 let ollama_msg_text_only = OllamaResponseMessage {
928 role: "assistant".to_string(),
929 content: "This is a text response.".to_string(),
930 tool_calls: None,
931 images: None,
932 };
933
934 let message = match ollama_msg_text_only.role.as_str() {
936 "assistant" => Message::Assistant {
937 content: Some(Content::Text(ollama_msg_text_only.content)),
938 tool_calls: Vec::new(),
939 metadata: HashMap::new(),
940 },
941 _ => panic!("Unexpected role in test"),
942 };
943
944 match &message {
946 Message::Assistant {
947 content,
948 tool_calls,
949 ..
950 } => {
951 assert!(content.is_some());
952 if let Some(Content::Text(text)) = content {
953 assert_eq!(text, "This is a text response.");
954 } else {
955 panic!("Expected text content");
956 }
957 assert!(tool_calls.is_empty());
958 }
959 _ => panic!("Expected Assistant message"),
960 }
961
962 let _ollama_msg_tool_call = OllamaResponseMessage {
964 role: "assistant".to_string(),
965 content: "".to_string(), tool_calls: Some(vec![OllamaResponseToolCall {
967 type_field: "function".to_string(),
968 function: OllamaResponseFunctionCall {
969 name: "get_weather".to_string(),
970 arguments: json!({ "location": "Paris" }),
971 },
972 }]),
973 images: None,
974 };
975
976 let tool_call_id = "generated-id-for-test";
978 let message_tool_call = Message::Assistant {
979 content: None, tool_calls: vec![ToolCall {
981 id: tool_call_id.to_string(),
982 tool_type: "function".to_string(),
983 function: Function {
984 name: "get_weather".to_string(),
985 arguments: r#"{"location":"Paris"}"#.to_string(),
986 },
987 }],
988 metadata: HashMap::new(),
989 };
990
991 match &message_tool_call {
993 Message::Assistant {
994 content,
995 tool_calls,
996 ..
997 } => {
998 assert!(content.is_none()); assert_eq!(tool_calls.len(), 1);
1000 assert_eq!(tool_calls[0].id, tool_call_id);
1001 assert_eq!(tool_calls[0].function.name, "get_weather");
1002 assert!(tool_calls[0].function.arguments.contains("Paris"));
1004 }
1005 _ => panic!("Expected Assistant message"),
1006 }
1007
1008 let _ollama_msg_text_and_tool = OllamaResponseMessage {
1010 role: "assistant".to_string(),
1011 content: "Sure, I can get the weather for you.".to_string(),
1012 tool_calls: Some(vec![OllamaResponseToolCall {
1013 type_field: "function".to_string(),
1014 function: OllamaResponseFunctionCall {
1015 name: "get_current_weather".to_string(),
1016 arguments: json!({ "city": "London" }),
1017 },
1018 }]),
1019 images: None,
1020 };
1021
1022 let tool_call_id = "generated-id-for-test";
1024 let message_text_and_tool = Message::Assistant {
1025 content: Some(Content::Text(
1026 "Sure, I can get the weather for you.".to_string(),
1027 )),
1028 tool_calls: vec![ToolCall {
1029 id: tool_call_id.to_string(),
1030 tool_type: "function".to_string(),
1031 function: Function {
1032 name: "get_current_weather".to_string(),
1033 arguments: r#"{"city":"London"}"#.to_string(),
1034 },
1035 }],
1036 metadata: HashMap::new(),
1037 };
1038
1039 match &message_text_and_tool {
1041 Message::Assistant {
1042 content,
1043 tool_calls,
1044 ..
1045 } => {
1046 assert!(content.is_some());
1048 if let Some(Content::Text(text)) = content {
1049 assert_eq!(text, "Sure, I can get the weather for you.");
1050 } else {
1051 panic!("Expected text content");
1052 }
1053
1054 assert_eq!(tool_calls.len(), 1);
1056 assert_eq!(tool_calls[0].id, tool_call_id);
1057 assert_eq!(tool_calls[0].function.name, "get_current_weather");
1058 assert!(tool_calls[0].function.arguments.contains("London"));
1059 }
1060 _ => panic!("Expected Assistant message"),
1061 }
1062 }
1063
1064 #[test]
1065 fn test_create_request_payload() {
1066 let provider = OllamaProvider::new();
1067 let model = Ollama::Custom { name: "test-model" };
1068
1069 let messages_simple = vec![Message::user("Hello")];
1071 let system_prompt_simple = "You are a helpful bot.";
1072 let payload_simple = provider
1073 .create_request_payload(
1074 &model,
1075 &messages_simple,
1076 Some(100),
1077 Some(0.7),
1078 None,
1079 None,
1080 None,
1081 None,
1082 Some(system_prompt_simple),
1083 )
1084 .unwrap();
1085
1086 assert_eq!(payload_simple.model, "test-model");
1087 assert_eq!(payload_simple.messages.len(), 1);
1088 assert_eq!(payload_simple.messages[0].role, "user");
1089 assert_eq!(payload_simple.messages[0].content, "Hello");
1090 assert_eq!(
1091 payload_simple.system,
1092 Some(system_prompt_simple.to_string())
1093 );
1094 assert!(payload_simple.tools.is_none());
1095 assert!(payload_simple.format.is_none());
1096 assert_eq!(
1097 payload_simple.options.as_ref().unwrap().num_predict,
1098 Some(100)
1099 );
1100 assert_eq!(
1101 payload_simple.options.as_ref().unwrap().temperature,
1102 Some(0.7)
1103 );
1104
1105 let messages_multi = vec![
1107 Message::system("System directive."),
1108 Message::user("First question"),
1109 Message::assistant("First answer"),
1110 ];
1111 let payload_multi = provider
1112 .create_request_payload(
1113 &model,
1114 &messages_multi,
1115 None,
1116 None,
1117 None,
1118 None,
1119 None,
1120 None,
1121 Some("Initial system prompt."),
1122 )
1123 .unwrap();
1124
1125 assert_eq!(
1126 payload_multi.system,
1127 Some("Initial system prompt.\nSystem directive.".to_string())
1128 );
1129 assert_eq!(payload_multi.messages.len(), 2);
1130 assert_eq!(payload_multi.messages[0].role, "user");
1131 assert_eq!(payload_multi.messages[0].content, "First question");
1132 assert_eq!(payload_multi.messages[1].role, "assistant");
1133 assert_eq!(payload_multi.messages[1].content, "First answer");
1134
1135 let tools_info = vec![LlmToolInfo {
1137 name: "get_weather".to_string(),
1138 description: "Get current weather".to_string(),
1139 parameters: json!({"type": "object", "properties": {"location": {"type": "string"}}}),
1140 }];
1141 let messages_with_tools = vec![Message::user("What's the weather in London?")];
1142 let payload_with_tools = provider
1143 .create_request_payload(
1144 &model,
1145 &messages_with_tools,
1146 None,
1147 None,
1148 None,
1149 None,
1150 Some(&tools_info),
1151 None, None,
1153 )
1154 .unwrap();
1155
1156 assert!(payload_with_tools.system.is_none());
1157 assert_eq!(payload_with_tools.messages.len(), 1);
1158 assert_eq!(payload_with_tools.messages[0].role, "user");
1159 let request_tools = payload_with_tools.tools.unwrap();
1160 assert_eq!(request_tools.len(), 1);
1161 assert_eq!(request_tools[0].type_field, "function");
1162 assert_eq!(request_tools[0].function.name, "get_weather");
1163 assert_eq!(
1164 request_tools[0].function.description,
1165 Some("Get current weather".to_string())
1166 );
1167 assert_eq!(
1168 request_tools[0].function.parameters,
1169 json!({"type": "object", "properties": {"location": {"type": "string"}}})
1170 );
1171
1172 let messages_for_json = vec![Message::user("Give me a JSON object.")];
1174 let payload_json_mode = provider
1175 .create_request_payload(
1176 &model,
1177 &messages_for_json,
1178 None,
1179 None,
1180 None,
1181 None,
1182 None, Some(&ToolChoice::Any),
1184 Some("Respond in JSON format."),
1185 )
1186 .unwrap();
1187
1188 assert_eq!(
1189 payload_json_mode.system,
1190 Some("Respond in JSON format.".to_string())
1191 );
1192 assert_eq!(payload_json_mode.messages.len(), 1);
1193 assert_eq!(payload_json_mode.format, Some("json".to_string()));
1194 assert!(payload_json_mode.tools.is_none()); }
1196}