language_barrier_core/provider/
ollama.rs

1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Function, Message, ToolCall};
3
4use crate::Chat;
5use crate::model::{ModelInfo, Ollama, OllamaModelSize};
6use crate::provider::HTTPProvider;
7use crate::tool::{LlmToolInfo, ToolChoice};
8use async_trait::async_trait;
9use reqwest::{Client, Request, Url, header};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13use std::fmt::Debug;
14use thiserror::Error;
15use tracing::{debug, error, info, instrument};
16
17const DEFAULT_OLLAMA_API_BASE_URL: &str = "http://localhost:11434/api";
18
19#[derive(Error, Debug)]
20pub enum ProviderError {
21    #[error("API error: {message:?}")]
22    ApiError {
23        source: reqwest::Error,
24        message: Option<String>,
25    },
26
27    #[error("Deserialization error: {content}")]
28    DeserializationError {
29        content: String,
30        source: serde_json::Error,
31    },
32
33    #[error("Unexpected response ({status}): {content}")]
34    UnexpectedResponse { status: u16, content: String },
35
36    #[error("Error: {0}")]
37    Other(String),
38}
39
40// Implement From<ProviderError> for Error to allow conversion with the ? operator
41impl From<ProviderError> for Error {
42    fn from(err: ProviderError) -> Self {
43        match err {
44            ProviderError::ApiError { source, message } => {
45                if let Some(msg) = message {
46                    Error::ProviderUnavailable(format!("Ollama API error: {}: {}", source, msg))
47                } else {
48                    Error::Request(source)
49                }
50            }
51            ProviderError::DeserializationError { content: _, source } => {
52                Error::Serialization(source)
53            }
54            ProviderError::UnexpectedResponse { status, content } => {
55                Error::ProviderUnavailable(format!(
56                    "Unexpected response from Ollama API ({}): {}",
57                    status, content
58                ))
59            }
60            ProviderError::Other(msg) => Error::Other(format!("Ollama provider error: {}", msg)),
61        }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct OllamaConfig {
67    pub base_url: Url,
68    // Ollama doesn't typically use API keys directly in headers for local instances.
69    // Authentication for remote Ollama instances might be handled differently,
70    // potentially via a reverse proxy or custom headers, not covered by default.
71}
72
73impl Default for OllamaConfig {
74    fn default() -> Self {
75        Self {
76            base_url: Url::parse(DEFAULT_OLLAMA_API_BASE_URL)
77                .expect("Failed to parse default Ollama base URL"),
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct OllamaProvider {
84    config: OllamaConfig,
85    client: Client,
86}
87
88impl OllamaProvider {
89    /// Creates a new OllamaProvider with the default configuration.
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    /// Creates a new OllamaProvider with the given configuration.
95    pub fn with_config(config: OllamaConfig) -> Self {
96        Self {
97            config,
98            client: Client::new(),
99        }
100    }
101
102    /// Returns the Ollama model ID for the given model.
103    fn id_for_model(&self, model: &Ollama) -> String {
104        model.ollama_model_id()
105    }
106
107    #[instrument(skip(self, messages, tools))]
108    #[allow(clippy::too_many_arguments)]
109    fn create_request_payload(
110        &self,
111        model: &Ollama,
112        messages: &[Message],
113        max_tokens: Option<u32>,
114        temperature: Option<f32>,
115        top_p: Option<f32>,
116        top_k: Option<u32>,
117        tools: Option<&[LlmToolInfo]>,
118        tool_choice: Option<&ToolChoice>,
119        system_prompt: Option<&str>,
120    ) -> Result<OllamaChatRequest> {
121        let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
122        let mut current_system_prompt = system_prompt.map(|s| s.to_string());
123
124        for message in messages {
125            // Use pattern matching on Message enum instead of a non-existent MessageRole enum
126            match message {
127                Message::System { content, .. } => {
128                    // Extract text from system message and append to current_system_prompt
129                    if !content.is_empty() {
130                        if let Some(ref mut existing_prompt) = current_system_prompt {
131                            existing_prompt.push('\n');
132                            existing_prompt.push_str(content);
133                        } else {
134                            current_system_prompt = Some(content.clone());
135                        }
136                    }
137                }
138                Message::User { .. } | Message::Assistant { .. } | Message::Tool { .. } => {
139                    // Convert other message types using the From trait
140                    ollama_messages.push(OllamaMessage::from(message));
141                }
142            }
143        }
144
145        let mut options = OllamaRequestOptions::default();
146        let mut options_set = false;
147
148        if let Some(temp) = temperature {
149            options.temperature = Some(temp);
150            options_set = true;
151        }
152        if let Some(tk) = top_k {
153            options.top_k = Some(tk);
154            options_set = true;
155        }
156        if let Some(tp) = top_p {
157            options.top_p = Some(tp);
158            options_set = true;
159        }
160        if let Some(mt) = max_tokens {
161            options.num_predict = Some(mt);
162            options_set = true;
163        }
164        // `stop` sequences could be added here if available
165
166        let ollama_tools = tools.and_then(|tool_infos| {
167            if tool_infos.is_empty() {
168                None
169            } else {
170                Some(tool_infos.iter().map(OllamaTool::from).collect())
171            }
172        });
173
174        let mut format_option: Option<String> = None;
175        if let Some(tc) = tool_choice {
176            match tc {
177                ToolChoice::Auto => {
178                    // Auto is default behavior - the model decides whether to use tools
179                    // No specific action needed as Ollama's default behavior matches
180                }
181                ToolChoice::Any => {
182                    // Require model to use tools - closest equivalent is JSON mode for some models
183                    // Ollama doesn't have a direct equivalent for "required" tool choice
184                    // Setting format to "json" may encourage structured outputs
185                    format_option = Some("json".to_string());
186                }
187                ToolChoice::None => {
188                    // Force model not to use tools
189                    // Implemented by not sending the tools array (handled in final_tools logic below)
190                }
191                ToolChoice::Specific(_name) => {
192                    // Tell model to use a specific tool
193                    // Ollama doesn't support specific tool choice, but we can filter the tools
194                    // to only include the specified one - this is handled later when creating the request
195                    // No action needed here beyond normal filtering
196                }
197            }
198        }
199
200        // If `tools` is None or empty, and tool_choice was effectively "None", ensure ollama_tools is None.
201        // This is mostly handled by `ollama_tools` construction logic and `ToolChoice::None` not setting `tools`.
202        // However, if LlmProvider `prompt` is called with `tools = None` but `tool_choice = ToolChoice::Any` (which is weird),
203        // we should respect `tools = None`.
204        let final_tools = if tools.is_none_or(|t| t.is_empty()) {
205            None
206        } else {
207            ollama_tools
208        };
209
210        Ok(OllamaChatRequest {
211            model: self.id_for_model(model),
212            messages: ollama_messages,
213            system: current_system_prompt,
214            format: format_option,
215            options: if options_set { Some(options) } else { None },
216            stream: false, // For non-streaming
217            tools: final_tools,
218            keep_alive: Some("5m".to_string()), // Default keep_alive
219        })
220    }
221}
222
223impl Default for OllamaProvider {
224    fn default() -> Self {
225        Self {
226            config: OllamaConfig::default(),
227            client: Client::new(),
228        }
229    }
230}
231
232// --- Ollama API Request Structs ---
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub(crate) struct OllamaMessage {
236    pub role: String,
237    pub content: String,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub images: Option<Vec<String>>, // List of base64 encoded images
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub tool_calls: Option<Vec<OllamaResponseToolCall>>, // For assistant messages that previously made tool calls
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub(crate) struct OllamaToolFunctionDefinition {
246    pub name: String,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub description: Option<String>,
249    pub parameters: Value, // JSON Schema
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub(crate) struct OllamaTool {
254    #[serde(rename = "type")]
255    pub type_field: String, // "function"
256    pub function: OllamaToolFunctionDefinition,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize, Default)]
260pub(crate) struct OllamaRequestOptions {
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub temperature: Option<f32>,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub top_k: Option<u32>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub top_p: Option<f32>,
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub num_predict: Option<u32>, // Max tokens to generate
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub stop: Option<Vec<String>>, // Stop sequences
271                                   // Add other options like mirostat, seed, etc. as needed
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub(crate) struct OllamaChatRequest {
276    pub model: String,
277    pub messages: Vec<OllamaMessage>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub system: Option<String>, // System prompt
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub format: Option<String>, // e.g., "json"
282    #[serde(skip_serializing_if = "Option::is_none")]
283    pub options: Option<OllamaRequestOptions>,
284    pub stream: bool, // For this implementation, typically false
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub tools: Option<Vec<OllamaTool>>,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub keep_alive: Option<String>, // e.g., "5m"
289}
290
291// --- Ollama API Response Structs ---
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub(crate) struct OllamaResponseFunctionCall {
295    pub name: String,
296    pub arguments: Value, // JSON object
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub(crate) struct OllamaResponseToolCall {
301    #[serde(rename = "type")]
302    pub type_field: String, // "function"
303    pub function: OllamaResponseFunctionCall,
304    // Ollama API does not seem to provide an 'id' for the tool call in the response.
305    // We will need to generate one if our internal `ToolCall` struct requires it.
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub(crate) struct OllamaResponseMessage {
310    pub role: String,
311    pub content: String, // May be empty if tool_calls are present
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub tool_calls: Option<Vec<OllamaResponseToolCall>>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub images: Option<Vec<String>>, // Though not typical for assistant responses, include for completeness if API supports
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub(crate) struct OllamaChatResponse {
320    pub model: String,
321    pub created_at: String, // ISO 8601 timestamp
322    pub message: OllamaResponseMessage,
323    pub done: bool,
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub done_reason: Option<String>, // e.g., "stop", "length", "tool_calls"
326
327    // Optional performance and usage statistics
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub total_duration: Option<u64>,
330    #[serde(skip_serializing_if = "Option::is_none")]
331    pub load_duration: Option<u64>,
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub prompt_eval_count: Option<u32>,
334    #[serde(skip_serializing_if = "Option::is_none")]
335    pub prompt_eval_duration: Option<u64>,
336    #[serde(skip_serializing_if = "Option::is_none")]
337    pub eval_count: Option<u32>,
338    #[serde(skip_serializing_if = "Option::is_none")]
339    pub eval_duration: Option<u64>,
340}
341
342// HTTPProvider, and From implementations will be added subsequently.
343
344/// Trait for providing Ollama-specific model IDs
345pub trait OllamaModelInfo {
346    /// Returns the Ollama model ID for this model
347    fn ollama_model_id(&self) -> String;
348}
349
350impl OllamaModelInfo for Ollama {
351    fn ollama_model_id(&self) -> String {
352        match self {
353            Self::Llama3 { size } => match size {
354                OllamaModelSize::_8B => "llama3:8b",
355                OllamaModelSize::_7B => "llama3",
356                OllamaModelSize::_3B => "llama3:3b",
357                OllamaModelSize::_1B => "llama3:1b",
358            },
359            Self::Llava => "llava",
360            Self::Mistral { size } => match size {
361                OllamaModelSize::_8B => "mistral:8b",
362                OllamaModelSize::_7B => "mistral",
363                OllamaModelSize::_3B => "mistral:3b",
364                OllamaModelSize::_1B => "mistral:1b",
365            },
366            Self::Custom { name } => name,
367        }
368        .to_string()
369    }
370}
371
372#[async_trait]
373pub trait Provider<M: ModelInfo>: Send + Sync {
374    /// Generate a response from the LLM provider
375    #[allow(clippy::too_many_arguments)]
376    async fn prompt(
377        &self,
378        model: &M,
379        messages: &[Message],
380        max_tokens: Option<u32>,
381        temperature: Option<f32>,
382        top_p: Option<f32>,
383        top_k: Option<u32>,
384        tools: Option<&[LlmToolInfo]>,
385        tool_choice: Option<&ToolChoice>,
386        system_prompt: Option<&str>,
387    ) -> Result<Message>;
388}
389
390#[async_trait]
391impl Provider<Ollama> for OllamaProvider {
392    #[instrument(skip(self), level = "debug")]
393    #[allow(clippy::too_many_arguments)]
394    async fn prompt(
395        &self,
396        model: &Ollama,
397        messages: &[Message],
398        max_tokens: Option<u32>,
399        temperature: Option<f32>,
400        top_p: Option<f32>,
401        top_k: Option<u32>,
402        tools: Option<&[LlmToolInfo]>,
403        tool_choice: Option<&ToolChoice>,
404        system_prompt: Option<&str>,
405    ) -> Result<Message> {
406        info!("Creating chat completion with Ollama model");
407        debug!("Model: {:?}", model);
408        debug!("Number of messages: {}", messages.len());
409        debug!("System prompt provided: {}", system_prompt.is_some());
410        debug!("Tools provided: {}", tools.is_some_and(|t| !t.is_empty()));
411        debug!("Tool choice provided: {}", tool_choice.is_some());
412
413        let request_url = self
414            .config
415            .base_url
416            .join("chat")
417            .map_err(Error::BaseUrlError)?;
418        debug!("Request URL: {}", request_url);
419
420        let request_payload = self.create_request_payload(
421            model,
422            messages,
423            max_tokens,
424            temperature,
425            top_p,
426            top_k,
427            tools,
428            tool_choice,
429            system_prompt,
430        )?;
431
432        // Use the headers directly since we can't use accept() here
433        let request = self
434            .client
435            .post(request_url)
436            .header(header::CONTENT_TYPE, "application/json")
437            .header(header::ACCEPT, "application/json")
438            .json(&request_payload)
439            .build()
440            .map_err(|e| ProviderError::ApiError {
441                source: e,
442                message: Some("Failed to build request".to_string()),
443            })?;
444
445        debug!("Sending request to Ollama API");
446        let response = self
447            .client
448            .execute(request)
449            .await
450            .map_err(|e| ProviderError::ApiError {
451                source: e,
452                message: Some("Failed to execute request".to_string()),
453            })?;
454
455        debug!("Response status: {}", response.status());
456
457        // Get the response text
458        let response_text = response.text().await.map_err(|e| {
459            error!("Failed to get response text: {}", e);
460            ProviderError::ApiError {
461                source: e,
462                message: Some("Failed to get response text".to_string()),
463            }
464        })?;
465
466        // Extract the message from the response
467        let message = match self.parse(response_text) {
468            Ok(msg) => msg,
469            Err(e) => {
470                error!("Failed to parse response: {:?}", e);
471                return Err(e);
472            }
473        };
474
475        // Return the message with metadata
476        Ok(message)
477    }
478}
479
480// --- From Trait Implementations ---
481
482impl From<&LlmToolInfo> for OllamaTool {
483    fn from(tool_info: &LlmToolInfo) -> Self {
484        OllamaTool {
485            type_field: "function".to_string(),
486            function: OllamaToolFunctionDefinition {
487                name: tool_info.name.clone(),
488                description: Some(tool_info.description.clone()),
489                parameters: tool_info.parameters.clone(),
490            },
491        }
492    }
493}
494
495impl From<&Message> for OllamaMessage {
496    fn from(message: &Message) -> Self {
497        // Determine the role based on the Message variant
498        let role = match message {
499            Message::User { .. } => "user".to_string(),
500            Message::Assistant { .. } => "assistant".to_string(),
501            Message::Tool { .. } => "tool".to_string(),
502            Message::System { .. } => {
503                // System messages should ideally be handled separately by create_request_payload
504                // and placed in the `system` field of OllamaChatRequest.
505                // If a System message is passed here, it's a slight misuse,
506                // but we can convert it to a user message for robustness, though it's not standard for Ollama.
507                tracing::warn!(
508                    "System message encountered in From<&Message> for OllamaMessage conversion. This should be handled by the system prompt field."
509                );
510                "user".to_string()
511            }
512        };
513
514        let mut content_texts = Vec::new();
515        let mut image_data: Vec<String> = Vec::new();
516        let mut assistant_tool_calls: Vec<OllamaResponseToolCall> = Vec::new();
517
518        // Extract content based on message type
519        match message {
520            Message::User { content, .. } => {
521                match content {
522                    Content::Text(text) => content_texts.push(text.clone()),
523                    Content::Parts(parts) => {
524                        for part in parts {
525                            match part {
526                                ContentPart::Text { text } => content_texts.push(text.clone()),
527                                ContentPart::ImageUrl { image_url } => {
528                                    // Using the URL directly
529                                    image_data.push(image_url.url.clone());
530                                }
531                            }
532                        }
533                    }
534                }
535            }
536            Message::Assistant {
537                content,
538                tool_calls,
539                ..
540            } => {
541                // Handle content if present
542                if let Some(content) = content {
543                    match content {
544                        Content::Text(text) => content_texts.push(text.clone()),
545                        Content::Parts(parts) => {
546                            for part in parts {
547                                if let ContentPart::Text { text } = part {
548                                    content_texts.push(text.clone());
549                                }
550                                // Images in assistant messages are ignored as Ollama doesn't support them in responses
551                            }
552                        }
553                    }
554                }
555
556                // Handle tool calls
557                for tool_call in tool_calls {
558                    assistant_tool_calls.push(OllamaResponseToolCall {
559                        type_field: "function".to_string(),
560                        function: OllamaResponseFunctionCall {
561                            name: tool_call.function.name.clone(),
562                            arguments: serde_json::from_str(&tool_call.function.arguments)
563                                .unwrap_or(serde_json::Value::Null),
564                        },
565                    });
566                }
567            }
568            Message::Tool { content, .. } => {
569                // For tool messages, just use the content directly
570                content_texts.push(content.clone());
571            }
572            Message::System { content, .. } => {
573                // System messages should be handled by create_request_payload and not here
574                content_texts.push(content.clone());
575            }
576        }
577
578        let final_content = content_texts.join("\n");
579
580        OllamaMessage {
581            role,
582            content: final_content,
583            images: if image_data.is_empty() {
584                None
585            } else {
586                Some(image_data)
587            },
588            tool_calls: if assistant_tool_calls.is_empty() {
589                None
590            } else {
591                Some(assistant_tool_calls)
592            },
593        }
594    }
595}
596
597#[async_trait]
598impl HTTPProvider<Ollama> for OllamaProvider {
599    #[instrument(skip(self, model, chat), level = "debug")]
600    fn accept(&self, model: Ollama, chat: &Chat) -> Result<Request> {
601        info!("Creating HTTP request for Ollama model: {:?}", model);
602        debug!("Number of messages in chat: {}", chat.history.len());
603
604        let url = self.config.base_url.join("chat").map_err(|e| {
605            error!("Failed to join chat URL path to base URL: {}", e);
606            crate::error::Error::Other(format!("Failed to join chat URL path to base URL: {}", e))
607        })?;
608        debug!("Request URL: {}", url);
609
610        // Prepare the messages for the request payload
611        let ollama_messages: Vec<_> = chat
612            .history
613            .iter()
614            .filter(|msg| !matches!(msg, Message::System { .. })) // System messages are handled separately
615            .map(OllamaMessage::from)
616            .collect();
617        debug!(
618            "Converted {} messages for Ollama request",
619            ollama_messages.len()
620        );
621
622        // Extract system prompt
623        let system_prompt = if chat.system_prompt.is_empty() {
624            None
625        } else {
626            debug!(
627                "Using system prompt from chat: {} chars",
628                chat.system_prompt.len()
629            );
630            Some(chat.system_prompt.clone())
631        };
632
633        // Handle tool configuration
634        let tools = if let Some(ref tools) = chat.tools {
635            if tools.is_empty() {
636                debug!("No tools defined in chat");
637                None
638            } else {
639                debug!("Converting {} tools for Ollama request", tools.len());
640                Some(tools.iter().map(OllamaTool::from).collect::<Vec<_>>())
641            }
642        } else {
643            None
644        };
645
646        // Handle format option based on tool_choice
647        let format = match chat.tool_choice {
648            Some(ToolChoice::Any) => {
649                debug!(
650                    "Using ToolChoice::Any - setting json format to encourage structured outputs"
651                );
652                Some("json".to_string())
653            }
654            Some(ToolChoice::Auto) => {
655                debug!("Using ToolChoice::Auto - letting the model decide");
656                None
657            }
658            Some(ToolChoice::None) => {
659                debug!("Using ToolChoice::None - tools will not be used");
660                None
661            }
662            Some(ToolChoice::Specific(_)) => {
663                debug!("Using specific tool choice - filter applied to tools");
664                // Ollama doesn't have direct support for choosing specific tools
665                None
666            }
667            None => None,
668        };
669
670        // Create options
671        let options = Some(OllamaRequestOptions {
672            temperature: None, // TODO: Get from chat config when added
673            top_k: None,       // TODO: Get from chat config when added
674            top_p: None,       // TODO: Get from chat config when added
675            num_predict: Some(chat.max_output_tokens as u32),
676            stop: None, // TODO: Get from chat config when added
677        });
678
679        // Create the request payload
680        let payload = OllamaChatRequest {
681            model: model.ollama_model_id(),
682            messages: ollama_messages,
683            system: system_prompt,
684            format,
685            options,
686            stream: false, // We don't use streaming in this implementation
687            tools,
688            keep_alive: Some("5m".to_string()),
689        };
690
691        debug!("Created Ollama request payload");
692
693        // Build the HTTP request with JSON payload
694        let request = self
695            .client
696            .post(url)
697            .header(header::CONTENT_TYPE, "application/json")
698            .header(header::ACCEPT, "application/json")
699            .json(&payload)
700            .build()
701            .map_err(|e| {
702                error!("Failed to build request: {}", e);
703                crate::error::Error::Request(e)
704            })?;
705
706        debug!("Built Ollama HTTP request successfully");
707        Ok(request)
708    }
709
710    #[instrument(skip(self, raw_response_text), level = "debug")]
711    fn parse(&self, raw_response_text: String) -> Result<Message> {
712        info!("Parsing response from Ollama API");
713        debug!("Response text length: {}", raw_response_text.len());
714
715        // First check if it's an error response
716        if raw_response_text.contains("\"error\"") {
717            let error_response: serde_json::Value = serde_json::from_str(&raw_response_text)
718                .map_err(|e| {
719                    error!("Failed to parse error response: {}", e);
720                    Error::Serialization(e)
721                })?;
722
723            if let Some(error) = error_response.get("error") {
724                let error_msg = error.as_str().unwrap_or("Unknown Ollama error");
725                error!("Ollama API returned an error: {}", error_msg);
726                return Err(Error::ProviderUnavailable(error_msg.to_string()));
727            }
728        }
729
730        // Parse the response to our internal format
731        let ollama_response: OllamaChatResponse = serde_json::from_str(&raw_response_text)
732            .map_err(|e| {
733                error!("Failed to deserialize Ollama response: {}", e);
734                Error::Serialization(e)
735            })?;
736
737        debug!("Response deserialized successfully");
738        debug!("Model: {}", ollama_response.model);
739        debug!("Done reason: {:?}", ollama_response.done_reason);
740
741        // Convert response to Message format based on role
742        let response_role = ollama_response.message.role.as_str();
743        let response_content = ollama_response.message.content.clone();
744
745        let message = match response_role {
746            "assistant" => {
747                // For assistant messages, handle text content and tool calls
748
749                // First, prepare tool calls if present
750                let mut tool_calls = Vec::new();
751                if let Some(tool_calls_data) = ollama_response.message.tool_calls {
752                    for tool_call in tool_calls_data {
753                        // In a real UUID, would use uuid::Uuid::new_v4().to_string()
754                        // Using a timestamp-based string for now
755                        let tool_call_id = format!(
756                            "tc-{}",
757                            std::time::SystemTime::now()
758                                .duration_since(std::time::UNIX_EPOCH)
759                                .unwrap_or_default()
760                                .as_micros()
761                        );
762
763                        tool_calls.push(ToolCall {
764                            id: tool_call_id,
765                            tool_type: "function".to_string(),
766                            function: Function {
767                                name: tool_call.function.name,
768                                arguments: serde_json::to_string(&tool_call.function.arguments)
769                                    .unwrap_or_default(),
770                            },
771                        });
772                    }
773                }
774
775                // Create content from response text if present
776                let content = if response_content.is_empty() && !tool_calls.is_empty() {
777                    None
778                } else {
779                    // Use Text content type for simplicity, or create Parts with a single element
780                    Some(Content::Text(response_content))
781                };
782
783                Message::Assistant {
784                    content,
785                    tool_calls,
786                    metadata: HashMap::new(),
787                }
788            }
789            "user" => Message::User {
790                content: Content::Text(response_content),
791                name: None,
792                metadata: HashMap::new(),
793            },
794            "system" => Message::System {
795                content: response_content,
796                metadata: HashMap::new(),
797            },
798            "tool" => Message::Tool {
799                tool_call_id: "response-tool-call".to_string(), // This shouldn't happen in a response
800                content: response_content,
801                metadata: HashMap::new(),
802            },
803            _ => {
804                // Default to assistant if role is unknown
805                error!(
806                    "Unknown message role in Ollama response: {}",
807                    ollama_response.message.role
808                );
809                Message::Assistant {
810                    content: Some(Content::Text(response_content)),
811                    tool_calls: Vec::new(),
812                    metadata: HashMap::new(),
813                }
814            }
815        };
816
817        // Add usage metadata if available
818        let message_with_meta = if let Some(tokens) = ollama_response.prompt_eval_count {
819            message.with_metadata("input_tokens", serde_json::json!(tokens))
820        } else {
821            message
822        };
823
824        let message_with_meta = if let Some(tokens) = ollama_response.eval_count {
825            message_with_meta.with_metadata("output_tokens", serde_json::json!(tokens))
826        } else {
827            message_with_meta
828        };
829
830        info!("Successfully parsed Ollama response");
831        Ok(message_with_meta)
832    }
833}
834
835// From implementations for request/response structs are defined above.
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use serde_json::json;
841
842    #[test]
843    fn test_message_to_ollama_conversion() {
844        // 1. User message with simple text
845        let user_message_text = Message::user("Hello, Ollama!");
846        let ollama_user_message_text = OllamaMessage::from(&user_message_text);
847        assert_eq!(ollama_user_message_text.role, "user");
848        assert_eq!(ollama_user_message_text.content, "Hello, Ollama!");
849        assert!(ollama_user_message_text.images.is_none());
850        assert!(ollama_user_message_text.tool_calls.is_none());
851
852        // 2. Assistant message with simple text
853        let assistant_message_text = Message::assistant("Hi there!");
854        let ollama_assistant_message_text = OllamaMessage::from(&assistant_message_text);
855        assert_eq!(ollama_assistant_message_text.role, "assistant");
856        assert_eq!(ollama_assistant_message_text.content, "Hi there!");
857        assert!(ollama_assistant_message_text.images.is_none());
858        assert!(ollama_assistant_message_text.tool_calls.is_none());
859
860        // 3. User message with multimodal content
861        let parts = vec![
862            crate::message::ContentPart::text("What is this?"),
863            crate::message::ContentPart::image_url("https://example.com/image.jpg"),
864        ];
865        let user_message_image = Message::user_with_parts(parts);
866        let ollama_user_message_image = OllamaMessage::from(&user_message_image);
867        assert_eq!(ollama_user_message_image.role, "user");
868        assert_eq!(ollama_user_message_image.content, "What is this?");
869        assert_eq!(
870            ollama_user_message_image.images.unwrap(),
871            vec!["https://example.com/image.jpg"]
872        );
873        assert!(ollama_user_message_image.tool_calls.is_none());
874
875        // 4. Assistant message with tool calls
876        let tool_call = ToolCall {
877            id: "tool_call_123".to_string(),
878            tool_type: "function".to_string(),
879            function: Function {
880                name: "get_weather".to_string(),
881                arguments: "{\"location\":\"Boston\"}".to_string(),
882            },
883        };
884        let assistant_message = Message::assistant_with_tool_calls(vec![tool_call]);
885
886        let ollama_assistant_message = OllamaMessage::from(&assistant_message);
887        assert_eq!(ollama_assistant_message.role, "assistant");
888        assert_eq!(ollama_assistant_message.content, ""); // Content is empty since we only have tool calls
889        assert!(ollama_assistant_message.images.is_none());
890
891        let tool_calls = ollama_assistant_message.tool_calls.unwrap();
892        assert_eq!(tool_calls.len(), 1);
893        assert_eq!(tool_calls[0].type_field, "function");
894        assert_eq!(tool_calls[0].function.name, "get_weather");
895        // The JSON value might not match exactly due to whitespace differences, so check keys individually
896        assert!(tool_calls[0].function.arguments.get("location").is_some());
897        assert_eq!(
898            tool_calls[0].function.arguments.get("location").unwrap(),
899            "Boston"
900        );
901
902        // 5. Tool message (response from a tool)
903        let tool_message = Message::tool("tool_call_123", "{\"temperature\": \"72F\"}");
904        let ollama_tool_message = OllamaMessage::from(&tool_message);
905        assert_eq!(ollama_tool_message.role, "tool");
906        assert_eq!(ollama_tool_message.content, "{\"temperature\": \"72F\"}");
907        assert!(ollama_tool_message.images.is_none());
908        assert!(ollama_tool_message.tool_calls.is_none());
909
910        // 6. System message (special handling in From trait - logs warning, becomes user)
911        // This tests the direct From conversion. `create_request_payload` handles system messages differently.
912        let system_message = Message::system("You are a helpful assistant.");
913        let ollama_system_message = OllamaMessage::from(&system_message);
914        // The From<&Message> for OllamaMessage trait converts System to User with a warning.
915        assert_eq!(ollama_system_message.role, "user");
916        assert_eq!(
917            ollama_system_message.content,
918            "You are a helpful assistant."
919        );
920        assert!(ollama_system_message.images.is_none());
921        assert!(ollama_system_message.tool_calls.is_none());
922    }
923
924    #[test]
925    fn test_ollama_response_to_message_conversion() {
926        // 1. Assistant response with text only
927        let ollama_msg_text_only = OllamaResponseMessage {
928            role: "assistant".to_string(),
929            content: "This is a text response.".to_string(),
930            tool_calls: None,
931            images: None,
932        };
933
934        // Convert using our parse implementation approach
935        let message = match ollama_msg_text_only.role.as_str() {
936            "assistant" => Message::Assistant {
937                content: Some(Content::Text(ollama_msg_text_only.content)),
938                tool_calls: Vec::new(),
939                metadata: HashMap::new(),
940            },
941            _ => panic!("Unexpected role in test"),
942        };
943
944        // Verify
945        match &message {
946            Message::Assistant {
947                content,
948                tool_calls,
949                ..
950            } => {
951                assert!(content.is_some());
952                if let Some(Content::Text(text)) = content {
953                    assert_eq!(text, "This is a text response.");
954                } else {
955                    panic!("Expected text content");
956                }
957                assert!(tool_calls.is_empty());
958            }
959            _ => panic!("Expected Assistant message"),
960        }
961
962        // 2. Assistant response with a single tool call
963        let _ollama_msg_tool_call = OllamaResponseMessage {
964            role: "assistant".to_string(),
965            content: "".to_string(), // Content can be empty if there are tool calls
966            tool_calls: Some(vec![OllamaResponseToolCall {
967                type_field: "function".to_string(),
968                function: OllamaResponseFunctionCall {
969                    name: "get_weather".to_string(),
970                    arguments: json!({ "location": "Paris" }),
971                },
972            }]),
973            images: None,
974        };
975
976        // Convert using our approach
977        let tool_call_id = "generated-id-for-test";
978        let message_tool_call = Message::Assistant {
979            content: None, // Empty content for tool-only response
980            tool_calls: vec![ToolCall {
981                id: tool_call_id.to_string(),
982                tool_type: "function".to_string(),
983                function: Function {
984                    name: "get_weather".to_string(),
985                    arguments: r#"{"location":"Paris"}"#.to_string(),
986                },
987            }],
988            metadata: HashMap::new(),
989        };
990
991        // Verify
992        match &message_tool_call {
993            Message::Assistant {
994                content,
995                tool_calls,
996                ..
997            } => {
998                assert!(content.is_none()); // No content for tool-only response
999                assert_eq!(tool_calls.len(), 1);
1000                assert_eq!(tool_calls[0].id, tool_call_id);
1001                assert_eq!(tool_calls[0].function.name, "get_weather");
1002                // We can't easily compare the JSON string directly due to whitespace differences
1003                assert!(tool_calls[0].function.arguments.contains("Paris"));
1004            }
1005            _ => panic!("Expected Assistant message"),
1006        }
1007
1008        // 3. Assistant response with text AND a tool call
1009        let _ollama_msg_text_and_tool = OllamaResponseMessage {
1010            role: "assistant".to_string(),
1011            content: "Sure, I can get the weather for you.".to_string(),
1012            tool_calls: Some(vec![OllamaResponseToolCall {
1013                type_field: "function".to_string(),
1014                function: OllamaResponseFunctionCall {
1015                    name: "get_current_weather".to_string(),
1016                    arguments: json!({ "city": "London" }),
1017                },
1018            }]),
1019            images: None,
1020        };
1021
1022        // Convert using our approach - text + tool call
1023        let tool_call_id = "generated-id-for-test";
1024        let message_text_and_tool = Message::Assistant {
1025            content: Some(Content::Text(
1026                "Sure, I can get the weather for you.".to_string(),
1027            )),
1028            tool_calls: vec![ToolCall {
1029                id: tool_call_id.to_string(),
1030                tool_type: "function".to_string(),
1031                function: Function {
1032                    name: "get_current_weather".to_string(),
1033                    arguments: r#"{"city":"London"}"#.to_string(),
1034                },
1035            }],
1036            metadata: HashMap::new(),
1037        };
1038
1039        // Verify
1040        match &message_text_and_tool {
1041            Message::Assistant {
1042                content,
1043                tool_calls,
1044                ..
1045            } => {
1046                // Verify content
1047                assert!(content.is_some());
1048                if let Some(Content::Text(text)) = content {
1049                    assert_eq!(text, "Sure, I can get the weather for you.");
1050                } else {
1051                    panic!("Expected text content");
1052                }
1053
1054                // Verify tool calls
1055                assert_eq!(tool_calls.len(), 1);
1056                assert_eq!(tool_calls[0].id, tool_call_id);
1057                assert_eq!(tool_calls[0].function.name, "get_current_weather");
1058                assert!(tool_calls[0].function.arguments.contains("London"));
1059            }
1060            _ => panic!("Expected Assistant message"),
1061        }
1062    }
1063
1064    #[test]
1065    fn test_create_request_payload() {
1066        let provider = OllamaProvider::new();
1067        let model = Ollama::Custom { name: "test-model" };
1068
1069        // Scenario 1: Simple request with user message and system prompt
1070        let messages_simple = vec![Message::user("Hello")];
1071        let system_prompt_simple = "You are a helpful bot.";
1072        let payload_simple = provider
1073            .create_request_payload(
1074                &model,
1075                &messages_simple,
1076                Some(100),
1077                Some(0.7),
1078                None,
1079                None,
1080                None,
1081                None,
1082                Some(system_prompt_simple),
1083            )
1084            .unwrap();
1085
1086        assert_eq!(payload_simple.model, "test-model");
1087        assert_eq!(payload_simple.messages.len(), 1);
1088        assert_eq!(payload_simple.messages[0].role, "user");
1089        assert_eq!(payload_simple.messages[0].content, "Hello");
1090        assert_eq!(
1091            payload_simple.system,
1092            Some(system_prompt_simple.to_string())
1093        );
1094        assert!(payload_simple.tools.is_none());
1095        assert!(payload_simple.format.is_none());
1096        assert_eq!(
1097            payload_simple.options.as_ref().unwrap().num_predict,
1098            Some(100)
1099        );
1100        assert_eq!(
1101            payload_simple.options.as_ref().unwrap().temperature,
1102            Some(0.7)
1103        );
1104
1105        // Scenario 2: Multi-message request (user, assistant) and a system message part
1106        let messages_multi = vec![
1107            Message::system("System directive."),
1108            Message::user("First question"),
1109            Message::assistant("First answer"),
1110        ];
1111        let payload_multi = provider
1112            .create_request_payload(
1113                &model,
1114                &messages_multi,
1115                None,
1116                None,
1117                None,
1118                None,
1119                None,
1120                None,
1121                Some("Initial system prompt."),
1122            )
1123            .unwrap();
1124
1125        assert_eq!(
1126            payload_multi.system,
1127            Some("Initial system prompt.\nSystem directive.".to_string())
1128        );
1129        assert_eq!(payload_multi.messages.len(), 2);
1130        assert_eq!(payload_multi.messages[0].role, "user");
1131        assert_eq!(payload_multi.messages[0].content, "First question");
1132        assert_eq!(payload_multi.messages[1].role, "assistant");
1133        assert_eq!(payload_multi.messages[1].content, "First answer");
1134
1135        // Scenario 3: Request with tools
1136        let tools_info = vec![LlmToolInfo {
1137            name: "get_weather".to_string(),
1138            description: "Get current weather".to_string(),
1139            parameters: json!({"type": "object", "properties": {"location": {"type": "string"}}}),
1140        }];
1141        let messages_with_tools = vec![Message::user("What's the weather in London?")];
1142        let payload_with_tools = provider
1143            .create_request_payload(
1144                &model,
1145                &messages_with_tools,
1146                None,
1147                None,
1148                None,
1149                None,
1150                Some(&tools_info),
1151                None, // ToolChoice::Auto is default when providing tools
1152                None,
1153            )
1154            .unwrap();
1155
1156        assert!(payload_with_tools.system.is_none());
1157        assert_eq!(payload_with_tools.messages.len(), 1);
1158        assert_eq!(payload_with_tools.messages[0].role, "user");
1159        let request_tools = payload_with_tools.tools.unwrap();
1160        assert_eq!(request_tools.len(), 1);
1161        assert_eq!(request_tools[0].type_field, "function");
1162        assert_eq!(request_tools[0].function.name, "get_weather");
1163        assert_eq!(
1164            request_tools[0].function.description,
1165            Some("Get current weather".to_string())
1166        );
1167        assert_eq!(
1168            request_tools[0].function.parameters,
1169            json!({"type": "object", "properties": {"location": {"type": "string"}}})
1170        );
1171
1172        // Scenario 4: Request with Any mode (implies format: "json" for Ollama)
1173        let messages_for_json = vec![Message::user("Give me a JSON object.")];
1174        let payload_json_mode = provider
1175            .create_request_payload(
1176                &model,
1177                &messages_for_json,
1178                None,
1179                None,
1180                None,
1181                None,
1182                None, // No specific tools, but setting Any mode
1183                Some(&ToolChoice::Any),
1184                Some("Respond in JSON format."),
1185            )
1186            .unwrap();
1187
1188        assert_eq!(
1189            payload_json_mode.system,
1190            Some("Respond in JSON format.".to_string())
1191        );
1192        assert_eq!(payload_json_mode.messages.len(), 1);
1193        assert_eq!(payload_json_mode.format, Some("json".to_string()));
1194        assert!(payload_json_mode.tools.is_none()); // Any choice without tools just sets format for Ollama
1195    }
1196}