rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::completion::{GetTokenUsage, Usage};
43use crate::json_utils::merge_inplace;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46    Embed, OneOrMany,
47    completion::{self, CompletionError, CompletionRequest},
48    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49    impl_conversion_traits, json_utils, message,
50    message::{ImageDetail, Text},
51    streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::{convert::TryFrom, str::FromStr};
59use url::Url;
60// ---------- Main Client ----------
61
62const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64pub struct ClientBuilder<'a> {
65    base_url: &'a str,
66    http_client: Option<reqwest::Client>,
67}
68
69impl<'a> ClientBuilder<'a> {
70    #[allow(clippy::new_without_default)]
71    pub fn new() -> Self {
72        Self {
73            base_url: OLLAMA_API_BASE_URL,
74            http_client: None,
75        }
76    }
77
78    pub fn base_url(mut self, base_url: &'a str) -> Self {
79        self.base_url = base_url;
80        self
81    }
82
83    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
84        self.http_client = Some(client);
85        self
86    }
87
88    pub fn build(self) -> Result<Client, ClientBuilderError> {
89        let http_client = if let Some(http_client) = self.http_client {
90            http_client
91        } else {
92            reqwest::Client::builder().build()?
93        };
94
95        Ok(Client {
96            base_url: Url::parse(self.base_url)
97                .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
98            http_client,
99        })
100    }
101}
102
103#[derive(Clone, Debug)]
104pub struct Client {
105    base_url: Url,
106    http_client: reqwest::Client,
107}
108
109impl Default for Client {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl Client {
116    /// Create a new Ollama client builder.
117    ///
118    /// # Example
119    /// ```
120    /// use rig::providers::ollama::{ClientBuilder, self};
121    ///
122    /// // Initialize the Ollama client
123    /// let client = Client::builder()
124    ///    .build()
125    /// ```
126    pub fn builder() -> ClientBuilder<'static> {
127        ClientBuilder::new()
128    }
129
130    /// Create a new Ollama client. For more control, use the `builder` method.
131    ///
132    /// # Panics
133    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
134    pub fn new() -> Self {
135        Self::builder().build().expect("Ollama client should build")
136    }
137
138    pub fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
139        let url = self.base_url.join(path)?;
140        Ok(self.http_client.post(url))
141    }
142}
143
144impl ProviderClient for Client {
145    fn from_env() -> Self
146    where
147        Self: Sized,
148    {
149        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
150        Self::builder().base_url(&api_base).build().unwrap()
151    }
152
153    fn from_val(input: crate::client::ProviderValue) -> Self {
154        let crate::client::ProviderValue::Simple(_) = input else {
155            panic!("Incorrect provider value type")
156        };
157
158        Self::new()
159    }
160}
161
162impl CompletionClient for Client {
163    type CompletionModel = CompletionModel;
164
165    fn completion_model(&self, model: &str) -> CompletionModel {
166        CompletionModel::new(self.clone(), model)
167    }
168}
169
170impl EmbeddingsClient for Client {
171    type EmbeddingModel = EmbeddingModel;
172    fn embedding_model(&self, model: &str) -> EmbeddingModel {
173        EmbeddingModel::new(self.clone(), model, 0)
174    }
175    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
176        EmbeddingModel::new(self.clone(), model, ndims)
177    }
178    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
179        EmbeddingsBuilder::new(self.embedding_model(model))
180    }
181}
182
183impl_conversion_traits!(
184    AsTranscription,
185    AsImageGeneration,
186    AsAudioGeneration for Client
187);
188
189// ---------- API Error and Response Structures ----------
190
191#[derive(Debug, Deserialize)]
192struct ApiErrorResponse {
193    message: String,
194}
195
196#[derive(Debug, Deserialize)]
197#[serde(untagged)]
198enum ApiResponse<T> {
199    Ok(T),
200    Err(ApiErrorResponse),
201}
202
203// ---------- Embedding API ----------
204
205pub const ALL_MINILM: &str = "all-minilm";
206pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
207
208#[derive(Debug, Serialize, Deserialize)]
209pub struct EmbeddingResponse {
210    pub model: String,
211    pub embeddings: Vec<Vec<f64>>,
212    #[serde(default)]
213    pub total_duration: Option<u64>,
214    #[serde(default)]
215    pub load_duration: Option<u64>,
216    #[serde(default)]
217    pub prompt_eval_count: Option<u64>,
218}
219
220impl From<ApiErrorResponse> for EmbeddingError {
221    fn from(err: ApiErrorResponse) -> Self {
222        EmbeddingError::ProviderError(err.message)
223    }
224}
225
226impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
227    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
228        match value {
229            ApiResponse::Ok(response) => Ok(response),
230            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
231        }
232    }
233}
234
235// ---------- Embedding Model ----------
236
237#[derive(Clone)]
238pub struct EmbeddingModel {
239    client: Client,
240    pub model: String,
241    ndims: usize,
242}
243
244impl EmbeddingModel {
245    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
246        Self {
247            client,
248            model: model.to_owned(),
249            ndims,
250        }
251    }
252}
253
254impl embeddings::EmbeddingModel for EmbeddingModel {
255    const MAX_DOCUMENTS: usize = 1024;
256    fn ndims(&self) -> usize {
257        self.ndims
258    }
259    #[cfg_attr(feature = "worker", worker::send)]
260    async fn embed_texts(
261        &self,
262        documents: impl IntoIterator<Item = String>,
263    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
264        let docs: Vec<String> = documents.into_iter().collect();
265        let payload = json!({
266            "model": self.model,
267            "input": docs,
268        });
269        let response = self
270            .client
271            .post("api/embed")?
272            .json(&payload)
273            .send()
274            .await
275            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
276        if response.status().is_success() {
277            let api_resp: EmbeddingResponse = response
278                .json()
279                .await
280                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
281            if api_resp.embeddings.len() != docs.len() {
282                return Err(EmbeddingError::ResponseError(
283                    "Number of returned embeddings does not match input".into(),
284                ));
285            }
286            Ok(api_resp
287                .embeddings
288                .into_iter()
289                .zip(docs.into_iter())
290                .map(|(vec, document)| embeddings::Embedding { document, vec })
291                .collect())
292        } else {
293            Err(EmbeddingError::ProviderError(response.text().await?))
294        }
295    }
296}
297
298// ---------- Completion API ----------
299
300pub const LLAMA3_2: &str = "llama3.2";
301pub const LLAVA: &str = "llava";
302pub const MISTRAL: &str = "mistral";
303
304#[derive(Debug, Serialize, Deserialize)]
305pub struct CompletionResponse {
306    pub model: String,
307    pub created_at: String,
308    pub message: Message,
309    pub done: bool,
310    #[serde(default)]
311    pub done_reason: Option<String>,
312    #[serde(default)]
313    pub total_duration: Option<u64>,
314    #[serde(default)]
315    pub load_duration: Option<u64>,
316    #[serde(default)]
317    pub prompt_eval_count: Option<u64>,
318    #[serde(default)]
319    pub prompt_eval_duration: Option<u64>,
320    #[serde(default)]
321    pub eval_count: Option<u64>,
322    #[serde(default)]
323    pub eval_duration: Option<u64>,
324}
325impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
326    type Error = CompletionError;
327    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
328        match resp.message {
329            // Process only if an assistant message is present.
330            Message::Assistant {
331                content,
332                thinking,
333                tool_calls,
334                ..
335            } => {
336                let mut assistant_contents = Vec::new();
337                // Add the assistant's text content if any.
338                if !content.is_empty() {
339                    assistant_contents.push(completion::AssistantContent::text(&content));
340                }
341                // Process tool_calls following Ollama's chat response definition.
342                // Each ToolCall has an id, a type, and a function field.
343                for tc in tool_calls.iter() {
344                    assistant_contents.push(completion::AssistantContent::tool_call(
345                        tc.function.name.clone(),
346                        tc.function.name.clone(),
347                        tc.function.arguments.clone(),
348                    ));
349                }
350                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
351                    CompletionError::ResponseError("No content provided".to_owned())
352                })?;
353                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
354                let completion_tokens = resp.eval_count.unwrap_or(0);
355
356                let raw_response = CompletionResponse {
357                    model: resp.model,
358                    created_at: resp.created_at,
359                    done: resp.done,
360                    done_reason: resp.done_reason,
361                    total_duration: resp.total_duration,
362                    load_duration: resp.load_duration,
363                    prompt_eval_count: resp.prompt_eval_count,
364                    prompt_eval_duration: resp.prompt_eval_duration,
365                    eval_count: resp.eval_count,
366                    eval_duration: resp.eval_duration,
367                    message: Message::Assistant {
368                        content,
369                        thinking,
370                        images: None,
371                        name: None,
372                        tool_calls,
373                    },
374                };
375
376                Ok(completion::CompletionResponse {
377                    choice,
378                    usage: Usage {
379                        input_tokens: prompt_tokens,
380                        output_tokens: completion_tokens,
381                        total_tokens: prompt_tokens + completion_tokens,
382                    },
383                    raw_response,
384                })
385            }
386            _ => Err(CompletionError::ResponseError(
387                "Chat response does not include an assistant message".into(),
388            )),
389        }
390    }
391}
392
393// ---------- Completion Model ----------
394
395#[derive(Clone)]
396pub struct CompletionModel {
397    client: Client,
398    pub model: String,
399}
400
401impl CompletionModel {
402    pub fn new(client: Client, model: &str) -> Self {
403        Self {
404            client,
405            model: model.to_owned(),
406        }
407    }
408
409    fn create_completion_request(
410        &self,
411        completion_request: CompletionRequest,
412    ) -> Result<Value, CompletionError> {
413        // Build up the order of messages (context, chat_history)
414        let mut partial_history = vec![];
415        if let Some(docs) = completion_request.normalized_documents() {
416            partial_history.push(docs);
417        }
418        partial_history.extend(completion_request.chat_history);
419
420        // Initialize full history with preamble (or empty if non-existent)
421        let mut full_history: Vec<Message> = completion_request
422            .preamble
423            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
424
425        // Convert and extend the rest of the history
426        full_history.extend(
427            partial_history
428                .into_iter()
429                .map(|msg| msg.try_into())
430                .collect::<Result<Vec<Vec<Message>>, _>>()?
431                .into_iter()
432                .flatten()
433                .collect::<Vec<Message>>(),
434        );
435
436        // Convert internal prompt into a provider Message
437        let options = if let Some(extra) = completion_request.additional_params {
438            json_utils::merge(
439                json!({ "temperature": completion_request.temperature }),
440                extra,
441            )
442        } else {
443            json!({ "temperature": completion_request.temperature })
444        };
445
446        let mut request_payload = json!({
447            "model": self.model,
448            "messages": full_history,
449            "options": options,
450            "stream": false,
451        });
452        if !completion_request.tools.is_empty() {
453            request_payload["tools"] = json!(
454                completion_request
455                    .tools
456                    .into_iter()
457                    .map(|tool| tool.into())
458                    .collect::<Vec<ToolDefinition>>()
459            );
460        }
461
462        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
463
464        Ok(request_payload)
465    }
466}
467
468// ---------- CompletionModel Implementation ----------
469
470#[derive(Clone, Serialize, Deserialize, Debug)]
471pub struct StreamingCompletionResponse {
472    pub done_reason: Option<String>,
473    pub total_duration: Option<u64>,
474    pub load_duration: Option<u64>,
475    pub prompt_eval_count: Option<u64>,
476    pub prompt_eval_duration: Option<u64>,
477    pub eval_count: Option<u64>,
478    pub eval_duration: Option<u64>,
479}
480
481impl GetTokenUsage for StreamingCompletionResponse {
482    fn token_usage(&self) -> Option<crate::completion::Usage> {
483        let mut usage = crate::completion::Usage::new();
484        let input_tokens = self.prompt_eval_count.unwrap_or_default();
485        let output_tokens = self.eval_count.unwrap_or_default();
486        usage.input_tokens = input_tokens;
487        usage.output_tokens = output_tokens;
488        usage.total_tokens = input_tokens + output_tokens;
489
490        Some(usage)
491    }
492}
493
494impl completion::CompletionModel for CompletionModel {
495    type Response = CompletionResponse;
496    type StreamingResponse = StreamingCompletionResponse;
497
498    #[cfg_attr(feature = "worker", worker::send)]
499    async fn completion(
500        &self,
501        completion_request: CompletionRequest,
502    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
503        let request_payload = self.create_completion_request(completion_request)?;
504
505        let response = self
506            .client
507            .post("api/chat")?
508            .json(&request_payload)
509            .send()
510            .await
511            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
512        if response.status().is_success() {
513            let text = response
514                .text()
515                .await
516                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
517            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
518            let chat_resp: CompletionResponse = serde_json::from_str(&text)
519                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
520            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
521            Ok(conv)
522        } else {
523            let err_text = response
524                .text()
525                .await
526                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
527            Err(CompletionError::ProviderError(err_text))
528        }
529    }
530
531    #[cfg_attr(feature = "worker", worker::send)]
532    async fn stream(
533        &self,
534        request: CompletionRequest,
535    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
536    {
537        let mut request_payload = self.create_completion_request(request)?;
538        merge_inplace(&mut request_payload, json!({"stream": true}));
539
540        let response = self
541            .client
542            .post("api/chat")?
543            .json(&request_payload)
544            .send()
545            .await
546            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
547
548        if !response.status().is_success() {
549            let err_text = response
550                .text()
551                .await
552                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
553            return Err(CompletionError::ProviderError(err_text));
554        }
555
556        let stream = Box::pin(stream! {
557            let mut stream = response.bytes_stream();
558            while let Some(chunk_result) = stream.next().await {
559                let chunk = match chunk_result {
560                    Ok(c) => c,
561                    Err(e) => {
562                        yield Err(CompletionError::from(e));
563                        break;
564                    }
565                };
566
567                let text = match String::from_utf8(chunk.to_vec()) {
568                    Ok(t) => t,
569                    Err(e) => {
570                        yield Err(CompletionError::ResponseError(e.to_string()));
571                        break;
572                    }
573                };
574
575
576                for line in text.lines() {
577                    let line = line.to_string();
578
579                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
580                        continue;
581                    };
582
583                    match response.message {
584                        Message::Assistant{ content, tool_calls, .. } => {
585                            if !content.is_empty() {
586                                yield Ok(RawStreamingChoice::Message(content))
587                            }
588
589                            for tool_call in tool_calls.iter() {
590                                let function = tool_call.function.clone();
591
592                                yield Ok(RawStreamingChoice::ToolCall {
593                                    id: "".to_string(),
594                                    name: function.name,
595                                    arguments: function.arguments,
596                                    call_id: None
597                                });
598                            }
599                        }
600                        _ => {
601                            continue;
602                        }
603                    }
604
605                    if response.done {
606                        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
607                            total_duration: response.total_duration,
608                            load_duration: response.load_duration,
609                            prompt_eval_count: response.prompt_eval_count,
610                            prompt_eval_duration: response.prompt_eval_duration,
611                            eval_count: response.eval_count,
612                            eval_duration: response.eval_duration,
613                            done_reason: response.done_reason,
614                        }));
615                    }
616                }
617            }
618        });
619
620        Ok(streaming::StreamingCompletionResponse::stream(stream))
621    }
622}
623
624// ---------- Tool Definition Conversion ----------
625
626/// Ollama-required tool definition format.
627#[derive(Clone, Debug, Deserialize, Serialize)]
628pub struct ToolDefinition {
629    #[serde(rename = "type")]
630    pub type_field: String, // Fixed as "function"
631    pub function: completion::ToolDefinition,
632}
633
634/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
635impl From<crate::completion::ToolDefinition> for ToolDefinition {
636    fn from(tool: crate::completion::ToolDefinition) -> Self {
637        ToolDefinition {
638            type_field: "function".to_owned(),
639            function: completion::ToolDefinition {
640                name: tool.name,
641                description: tool.description,
642                parameters: tool.parameters,
643            },
644        }
645    }
646}
647
648#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
649pub struct ToolCall {
650    // pub id: String,
651    #[serde(default, rename = "type")]
652    pub r#type: ToolType,
653    pub function: Function,
654}
655#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
656#[serde(rename_all = "lowercase")]
657pub enum ToolType {
658    #[default]
659    Function,
660}
661#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
662pub struct Function {
663    pub name: String,
664    pub arguments: Value,
665}
666
667// ---------- Provider Message Definition ----------
668
669#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
670#[serde(tag = "role", rename_all = "lowercase")]
671pub enum Message {
672    User {
673        content: String,
674        #[serde(skip_serializing_if = "Option::is_none")]
675        images: Option<Vec<String>>,
676        #[serde(skip_serializing_if = "Option::is_none")]
677        name: Option<String>,
678    },
679    Assistant {
680        #[serde(default)]
681        content: String,
682        #[serde(skip_serializing_if = "Option::is_none")]
683        thinking: Option<String>,
684        #[serde(skip_serializing_if = "Option::is_none")]
685        images: Option<Vec<String>>,
686        #[serde(skip_serializing_if = "Option::is_none")]
687        name: Option<String>,
688        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
689        tool_calls: Vec<ToolCall>,
690    },
691    System {
692        content: String,
693        #[serde(skip_serializing_if = "Option::is_none")]
694        images: Option<Vec<String>>,
695        #[serde(skip_serializing_if = "Option::is_none")]
696        name: Option<String>,
697    },
698    #[serde(rename = "tool")]
699    ToolResult {
700        #[serde(rename = "tool_name")]
701        name: String,
702        content: String,
703    },
704}
705
706/// -----------------------------
707/// Provider Message Conversions
708/// -----------------------------
709/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
710/// (Only User and Assistant variants are supported.)
711impl TryFrom<crate::message::Message> for Vec<Message> {
712    type Error = crate::message::MessageError;
713    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
714        use crate::message::Message as InternalMessage;
715        match internal_msg {
716            InternalMessage::User { content, .. } => {
717                let (tool_results, other_content): (Vec<_>, Vec<_>) =
718                    content.into_iter().partition(|content| {
719                        matches!(content, crate::message::UserContent::ToolResult(_))
720                    });
721
722                if !tool_results.is_empty() {
723                    tool_results
724                        .into_iter()
725                        .map(|content| match content {
726                            crate::message::UserContent::ToolResult(
727                                crate::message::ToolResult { id, content, .. },
728                            ) => {
729                                // Ollama expects a single string for tool results, so we concatenate
730                                let content_string = content
731                                    .into_iter()
732                                    .map(|content| match content {
733                                        crate::message::ToolResultContent::Text(text) => text.text,
734                                        _ => "[Non-text content]".to_string(),
735                                    })
736                                    .collect::<Vec<_>>()
737                                    .join("\n");
738
739                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
740                                    name: id,
741                                    content: content_string,
742                                })
743                            }
744                            _ => unreachable!(),
745                        })
746                        .collect::<Result<Vec<_>, _>>()
747                } else {
748                    // Ollama requires separate text content and images array
749                    let (texts, images) = other_content.into_iter().fold(
750                        (Vec::new(), Vec::new()),
751                        |(mut texts, mut images), content| {
752                            match content {
753                                crate::message::UserContent::Text(crate::message::Text {
754                                    text,
755                                }) => texts.push(text),
756                                crate::message::UserContent::Image(crate::message::Image {
757                                    data,
758                                    ..
759                                }) => images.push(data),
760                                crate::message::UserContent::Document(
761                                    crate::message::Document { data, .. },
762                                ) => texts.push(data),
763                                _ => {} // Audio not supported by Ollama
764                            }
765                            (texts, images)
766                        },
767                    );
768
769                    Ok(vec![Message::User {
770                        content: texts.join(" "),
771                        images: if images.is_empty() {
772                            None
773                        } else {
774                            Some(images)
775                        },
776                        name: None,
777                    }])
778                }
779            }
780            InternalMessage::Assistant { content, .. } => {
781                let mut thinking: Option<String> = None;
782                let (text_content, tool_calls) = content.into_iter().fold(
783                    (Vec::new(), Vec::new()),
784                    |(mut texts, mut tools), content| {
785                        match content {
786                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
787                            crate::message::AssistantContent::ToolCall(tool_call) => {
788                                tools.push(tool_call)
789                            }
790                            crate::message::AssistantContent::Reasoning(
791                                crate::message::Reasoning { reasoning, .. },
792                            ) => {
793                                thinking =
794                                    Some(reasoning.first().cloned().unwrap_or(String::new()));
795                            }
796                        }
797                        (texts, tools)
798                    },
799                );
800
801                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
802                //  so either `content` or `tool_calls` will have some content.
803                Ok(vec![Message::Assistant {
804                    content: text_content.join(" "),
805                    thinking,
806                    images: None,
807                    name: None,
808                    tool_calls: tool_calls
809                        .into_iter()
810                        .map(|tool_call| tool_call.into())
811                        .collect::<Vec<_>>(),
812                }])
813            }
814        }
815    }
816}
817
818/// Conversion from provider Message to a completion message.
819/// This is needed so that responses can be converted back into chat history.
820impl From<Message> for crate::completion::Message {
821    fn from(msg: Message) -> Self {
822        match msg {
823            Message::User { content, .. } => crate::completion::Message::User {
824                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
825                    text: content,
826                })),
827            },
828            Message::Assistant {
829                content,
830                tool_calls,
831                ..
832            } => {
833                let mut assistant_contents =
834                    vec![crate::completion::message::AssistantContent::Text(Text {
835                        text: content,
836                    })];
837                for tc in tool_calls {
838                    assistant_contents.push(
839                        crate::completion::message::AssistantContent::tool_call(
840                            tc.function.name.clone(),
841                            tc.function.name,
842                            tc.function.arguments,
843                        ),
844                    );
845                }
846                crate::completion::Message::Assistant {
847                    id: None,
848                    content: OneOrMany::many(assistant_contents).unwrap(),
849                }
850            }
851            // System and ToolResult are converted to User message as needed.
852            Message::System { content, .. } => crate::completion::Message::User {
853                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
854                    text: content,
855                })),
856            },
857            Message::ToolResult { name, content } => crate::completion::Message::User {
858                content: OneOrMany::one(message::UserContent::tool_result(
859                    name,
860                    OneOrMany::one(message::ToolResultContent::text(content)),
861                )),
862            },
863        }
864    }
865}
866
867impl Message {
868    /// Constructs a system message.
869    pub fn system(content: &str) -> Self {
870        Message::System {
871            content: content.to_owned(),
872            images: None,
873            name: None,
874        }
875    }
876}
877
878// ---------- Additional Message Types ----------
879
880impl From<crate::message::ToolCall> for ToolCall {
881    fn from(tool_call: crate::message::ToolCall) -> Self {
882        Self {
883            r#type: ToolType::Function,
884            function: Function {
885                name: tool_call.function.name,
886                arguments: tool_call.function.arguments,
887            },
888        }
889    }
890}
891
892#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
893pub struct SystemContent {
894    #[serde(default)]
895    r#type: SystemContentType,
896    text: String,
897}
898
899#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
900#[serde(rename_all = "lowercase")]
901pub enum SystemContentType {
902    #[default]
903    Text,
904}
905
906impl From<String> for SystemContent {
907    fn from(s: String) -> Self {
908        SystemContent {
909            r#type: SystemContentType::default(),
910            text: s,
911        }
912    }
913}
914
915impl FromStr for SystemContent {
916    type Err = std::convert::Infallible;
917    fn from_str(s: &str) -> Result<Self, Self::Err> {
918        Ok(SystemContent {
919            r#type: SystemContentType::default(),
920            text: s.to_string(),
921        })
922    }
923}
924
925#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
926pub struct AssistantContent {
927    pub text: String,
928}
929
930impl FromStr for AssistantContent {
931    type Err = std::convert::Infallible;
932    fn from_str(s: &str) -> Result<Self, Self::Err> {
933        Ok(AssistantContent { text: s.to_owned() })
934    }
935}
936
937#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
938#[serde(tag = "type", rename_all = "lowercase")]
939pub enum UserContent {
940    Text { text: String },
941    Image { image_url: ImageUrl },
942    // Audio variant removed as Ollama API does not support audio input.
943}
944
945impl FromStr for UserContent {
946    type Err = std::convert::Infallible;
947    fn from_str(s: &str) -> Result<Self, Self::Err> {
948        Ok(UserContent::Text { text: s.to_owned() })
949    }
950}
951
952#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
953pub struct ImageUrl {
954    pub url: String,
955    #[serde(default)]
956    pub detail: ImageDetail,
957}
958
959// =================================================================
960// Tests
961// =================================================================
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use serde_json::json;
967
968    // Test deserialization and conversion for the /api/chat endpoint.
969    #[tokio::test]
970    async fn test_chat_completion() {
971        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
972        let sample_chat_response = json!({
973            "model": "llama3.2",
974            "created_at": "2023-08-04T19:22:45.499127Z",
975            "message": {
976                "role": "assistant",
977                "content": "The sky is blue because of Rayleigh scattering.",
978                "images": null,
979                "tool_calls": [
980                    {
981                        "type": "function",
982                        "function": {
983                            "name": "get_current_weather",
984                            "arguments": {
985                                "location": "San Francisco, CA",
986                                "format": "celsius"
987                            }
988                        }
989                    }
990                ]
991            },
992            "done": true,
993            "total_duration": 8000000000u64,
994            "load_duration": 6000000u64,
995            "prompt_eval_count": 61u64,
996            "prompt_eval_duration": 400000000u64,
997            "eval_count": 468u64,
998            "eval_duration": 7700000000u64
999        });
1000        let sample_text = sample_chat_response.to_string();
1001
1002        let chat_resp: CompletionResponse =
1003            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1004        let conv: completion::CompletionResponse<CompletionResponse> =
1005            chat_resp.try_into().unwrap();
1006        assert!(
1007            !conv.choice.is_empty(),
1008            "Expected non-empty choice in chat response"
1009        );
1010    }
1011
1012    // Test conversion from provider Message to completion Message.
1013    #[test]
1014    fn test_message_conversion() {
1015        // Construct a provider Message (User variant with String content).
1016        let provider_msg = Message::User {
1017            content: "Test message".to_owned(),
1018            images: None,
1019            name: None,
1020        };
1021        // Convert it into a completion::Message.
1022        let comp_msg: crate::completion::Message = provider_msg.into();
1023        match comp_msg {
1024            crate::completion::Message::User { content } => {
1025                // Assume OneOrMany<T> has a method first() to access the first element.
1026                let first_content = content.first();
1027                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1028                match first_content {
1029                    crate::completion::message::UserContent::Text(text_struct) => {
1030                        assert_eq!(text_struct.text, "Test message");
1031                    }
1032                    _ => panic!("Expected text content in conversion"),
1033                }
1034            }
1035            _ => panic!("Conversion from provider Message to completion Message failed"),
1036        }
1037    }
1038
1039    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1040    #[test]
1041    fn test_tool_definition_conversion() {
1042        // Internal tool definition from the completion module.
1043        let internal_tool = crate::completion::ToolDefinition {
1044            name: "get_current_weather".to_owned(),
1045            description: "Get the current weather for a location".to_owned(),
1046            parameters: json!({
1047                "type": "object",
1048                "properties": {
1049                    "location": {
1050                        "type": "string",
1051                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1052                    },
1053                    "format": {
1054                        "type": "string",
1055                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1056                        "enum": ["celsius", "fahrenheit"]
1057                    }
1058                },
1059                "required": ["location", "format"]
1060            }),
1061        };
1062        // Convert internal tool to Ollama's tool definition.
1063        let ollama_tool: ToolDefinition = internal_tool.into();
1064        assert_eq!(ollama_tool.type_field, "function");
1065        assert_eq!(ollama_tool.function.name, "get_current_weather");
1066        assert_eq!(
1067            ollama_tool.function.description,
1068            "Get the current weather for a location"
1069        );
1070        // Check JSON fields in parameters.
1071        let params = &ollama_tool.function.parameters;
1072        assert_eq!(params["properties"]["location"]["type"], "string");
1073    }
1074}