rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::completion::Usage;
43use crate::json_utils::merge_inplace;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46    Embed, OneOrMany,
47    completion::{self, CompletionError, CompletionRequest},
48    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49    impl_conversion_traits, json_utils, message,
50    message::{ImageDetail, Text},
51    streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::{convert::TryFrom, str::FromStr};
59// ---------- Main Client ----------
60
61const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
62
63pub struct ClientBuilder<'a> {
64    base_url: &'a str,
65    http_client: Option<reqwest::Client>,
66}
67
68impl<'a> ClientBuilder<'a> {
69    #[allow(clippy::new_without_default)]
70    pub fn new() -> Self {
71        Self {
72            base_url: OLLAMA_API_BASE_URL,
73            http_client: None,
74        }
75    }
76
77    pub fn base_url(mut self, base_url: &'a str) -> Self {
78        self.base_url = base_url;
79        self
80    }
81
82    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
83        self.http_client = Some(client);
84        self
85    }
86
87    pub fn build(self) -> Result<Client, ClientBuilderError> {
88        let http_client = if let Some(http_client) = self.http_client {
89            http_client
90        } else {
91            reqwest::Client::builder().build()?
92        };
93
94        Ok(Client {
95            base_url: self.base_url.to_string(),
96            http_client,
97        })
98    }
99}
100
101#[derive(Clone, Debug)]
102pub struct Client {
103    base_url: String,
104    http_client: reqwest::Client,
105}
106
107impl Default for Client {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl Client {
114    /// Create a new Ollama client builder.
115    ///
116    /// # Example
117    /// ```
118    /// use rig::providers::ollama::{ClientBuilder, self};
119    ///
120    /// // Initialize the Ollama client
121    /// let client = Client::builder()
122    ///    .build()
123    /// ```
124    pub fn builder() -> ClientBuilder<'static> {
125        ClientBuilder::new()
126    }
127
128    /// Create a new Ollama client. For more control, use the `builder` method.
129    ///
130    /// # Panics
131    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
132    pub fn new() -> Self {
133        Self::builder().build().expect("Ollama client should build")
134    }
135
136    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
137        let url = format!("{}/{}", self.base_url, path);
138        self.http_client.post(url)
139    }
140}
141
142impl ProviderClient for Client {
143    fn from_env() -> Self
144    where
145        Self: Sized,
146    {
147        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
148        Self::builder().base_url(&api_base).build().unwrap()
149    }
150
151    fn from_val(input: crate::client::ProviderValue) -> Self {
152        let crate::client::ProviderValue::Simple(_) = input else {
153            panic!("Incorrect provider value type")
154        };
155
156        Self::new()
157    }
158}
159
160impl CompletionClient for Client {
161    type CompletionModel = CompletionModel;
162
163    fn completion_model(&self, model: &str) -> CompletionModel {
164        CompletionModel::new(self.clone(), model)
165    }
166}
167
168impl EmbeddingsClient for Client {
169    type EmbeddingModel = EmbeddingModel;
170    fn embedding_model(&self, model: &str) -> EmbeddingModel {
171        EmbeddingModel::new(self.clone(), model, 0)
172    }
173    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
174        EmbeddingModel::new(self.clone(), model, ndims)
175    }
176    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
177        EmbeddingsBuilder::new(self.embedding_model(model))
178    }
179}
180
181impl_conversion_traits!(
182    AsTranscription,
183    AsImageGeneration,
184    AsAudioGeneration for Client
185);
186
187// ---------- API Error and Response Structures ----------
188
189#[derive(Debug, Deserialize)]
190struct ApiErrorResponse {
191    message: String,
192}
193
194#[derive(Debug, Deserialize)]
195#[serde(untagged)]
196enum ApiResponse<T> {
197    Ok(T),
198    Err(ApiErrorResponse),
199}
200
201// ---------- Embedding API ----------
202
203pub const ALL_MINILM: &str = "all-minilm";
204pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
205
206#[derive(Debug, Serialize, Deserialize)]
207pub struct EmbeddingResponse {
208    pub model: String,
209    pub embeddings: Vec<Vec<f64>>,
210    #[serde(default)]
211    pub total_duration: Option<u64>,
212    #[serde(default)]
213    pub load_duration: Option<u64>,
214    #[serde(default)]
215    pub prompt_eval_count: Option<u64>,
216}
217
218impl From<ApiErrorResponse> for EmbeddingError {
219    fn from(err: ApiErrorResponse) -> Self {
220        EmbeddingError::ProviderError(err.message)
221    }
222}
223
224impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
225    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
226        match value {
227            ApiResponse::Ok(response) => Ok(response),
228            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
229        }
230    }
231}
232
233// ---------- Embedding Model ----------
234
235#[derive(Clone)]
236pub struct EmbeddingModel {
237    client: Client,
238    pub model: String,
239    ndims: usize,
240}
241
242impl EmbeddingModel {
243    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
244        Self {
245            client,
246            model: model.to_owned(),
247            ndims,
248        }
249    }
250}
251
252impl embeddings::EmbeddingModel for EmbeddingModel {
253    const MAX_DOCUMENTS: usize = 1024;
254    fn ndims(&self) -> usize {
255        self.ndims
256    }
257    #[cfg_attr(feature = "worker", worker::send)]
258    async fn embed_texts(
259        &self,
260        documents: impl IntoIterator<Item = String>,
261    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
262        let docs: Vec<String> = documents.into_iter().collect();
263        let payload = json!({
264            "model": self.model,
265            "input": docs,
266        });
267        let response = self
268            .client
269            .post("api/embed")
270            .json(&payload)
271            .send()
272            .await
273            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
274        if response.status().is_success() {
275            let api_resp: EmbeddingResponse = response
276                .json()
277                .await
278                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
279            if api_resp.embeddings.len() != docs.len() {
280                return Err(EmbeddingError::ResponseError(
281                    "Number of returned embeddings does not match input".into(),
282                ));
283            }
284            Ok(api_resp
285                .embeddings
286                .into_iter()
287                .zip(docs.into_iter())
288                .map(|(vec, document)| embeddings::Embedding { document, vec })
289                .collect())
290        } else {
291            Err(EmbeddingError::ProviderError(response.text().await?))
292        }
293    }
294}
295
296// ---------- Completion API ----------
297
298pub const LLAMA3_2: &str = "llama3.2";
299pub const LLAVA: &str = "llava";
300pub const MISTRAL: &str = "mistral";
301
302#[derive(Debug, Serialize, Deserialize)]
303pub struct CompletionResponse {
304    pub model: String,
305    pub created_at: String,
306    pub message: Message,
307    pub done: bool,
308    #[serde(default)]
309    pub done_reason: Option<String>,
310    #[serde(default)]
311    pub total_duration: Option<u64>,
312    #[serde(default)]
313    pub load_duration: Option<u64>,
314    #[serde(default)]
315    pub prompt_eval_count: Option<u64>,
316    #[serde(default)]
317    pub prompt_eval_duration: Option<u64>,
318    #[serde(default)]
319    pub eval_count: Option<u64>,
320    #[serde(default)]
321    pub eval_duration: Option<u64>,
322}
323impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
324    type Error = CompletionError;
325    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
326        match resp.message {
327            // Process only if an assistant message is present.
328            Message::Assistant {
329                content,
330                thinking,
331                tool_calls,
332                ..
333            } => {
334                let mut assistant_contents = Vec::new();
335                // Add the assistant's text content if any.
336                if !content.is_empty() {
337                    assistant_contents.push(completion::AssistantContent::text(&content));
338                }
339                // Process tool_calls following Ollama's chat response definition.
340                // Each ToolCall has an id, a type, and a function field.
341                for tc in tool_calls.iter() {
342                    assistant_contents.push(completion::AssistantContent::tool_call(
343                        tc.function.name.clone(),
344                        tc.function.name.clone(),
345                        tc.function.arguments.clone(),
346                    ));
347                }
348                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
349                    CompletionError::ResponseError("No content provided".to_owned())
350                })?;
351                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
352                let completion_tokens = resp.eval_count.unwrap_or(0);
353
354                let raw_response = CompletionResponse {
355                    model: resp.model,
356                    created_at: resp.created_at,
357                    done: resp.done,
358                    done_reason: resp.done_reason,
359                    total_duration: resp.total_duration,
360                    load_duration: resp.load_duration,
361                    prompt_eval_count: resp.prompt_eval_count,
362                    prompt_eval_duration: resp.prompt_eval_duration,
363                    eval_count: resp.eval_count,
364                    eval_duration: resp.eval_duration,
365                    message: Message::Assistant {
366                        content,
367                        thinking,
368                        images: None,
369                        name: None,
370                        tool_calls,
371                    },
372                };
373
374                Ok(completion::CompletionResponse {
375                    choice,
376                    usage: Usage {
377                        input_tokens: prompt_tokens,
378                        output_tokens: completion_tokens,
379                        total_tokens: prompt_tokens + completion_tokens,
380                    },
381                    raw_response,
382                })
383            }
384            _ => Err(CompletionError::ResponseError(
385                "Chat response does not include an assistant message".into(),
386            )),
387        }
388    }
389}
390
391// ---------- Completion Model ----------
392
393#[derive(Clone)]
394pub struct CompletionModel {
395    client: Client,
396    pub model: String,
397}
398
399impl CompletionModel {
400    pub fn new(client: Client, model: &str) -> Self {
401        Self {
402            client,
403            model: model.to_owned(),
404        }
405    }
406
407    fn create_completion_request(
408        &self,
409        completion_request: CompletionRequest,
410    ) -> Result<Value, CompletionError> {
411        // Build up the order of messages (context, chat_history)
412        let mut partial_history = vec![];
413        if let Some(docs) = completion_request.normalized_documents() {
414            partial_history.push(docs);
415        }
416        partial_history.extend(completion_request.chat_history);
417
418        // Initialize full history with preamble (or empty if non-existent)
419        let mut full_history: Vec<Message> = completion_request
420            .preamble
421            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
422
423        // Convert and extend the rest of the history
424        full_history.extend(
425            partial_history
426                .into_iter()
427                .map(|msg| msg.try_into())
428                .collect::<Result<Vec<Vec<Message>>, _>>()?
429                .into_iter()
430                .flatten()
431                .collect::<Vec<Message>>(),
432        );
433
434        // Convert internal prompt into a provider Message
435        let options = if let Some(extra) = completion_request.additional_params {
436            json_utils::merge(
437                json!({ "temperature": completion_request.temperature }),
438                extra,
439            )
440        } else {
441            json!({ "temperature": completion_request.temperature })
442        };
443
444        let mut request_payload = json!({
445            "model": self.model,
446            "messages": full_history,
447            "options": options,
448            "stream": false,
449        });
450        if !completion_request.tools.is_empty() {
451            request_payload["tools"] = json!(
452                completion_request
453                    .tools
454                    .into_iter()
455                    .map(|tool| tool.into())
456                    .collect::<Vec<ToolDefinition>>()
457            );
458        }
459
460        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
461
462        Ok(request_payload)
463    }
464}
465
466// ---------- CompletionModel Implementation ----------
467
468#[derive(Clone, Serialize, Deserialize, Debug)]
469pub struct StreamingCompletionResponse {
470    pub done_reason: Option<String>,
471    pub total_duration: Option<u64>,
472    pub load_duration: Option<u64>,
473    pub prompt_eval_count: Option<u64>,
474    pub prompt_eval_duration: Option<u64>,
475    pub eval_count: Option<u64>,
476    pub eval_duration: Option<u64>,
477}
478impl completion::CompletionModel for CompletionModel {
479    type Response = CompletionResponse;
480    type StreamingResponse = StreamingCompletionResponse;
481
482    #[cfg_attr(feature = "worker", worker::send)]
483    async fn completion(
484        &self,
485        completion_request: CompletionRequest,
486    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
487        let request_payload = self.create_completion_request(completion_request)?;
488
489        let response = self
490            .client
491            .post("api/chat")
492            .json(&request_payload)
493            .send()
494            .await
495            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
496        if response.status().is_success() {
497            let text = response
498                .text()
499                .await
500                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
501            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
502            let chat_resp: CompletionResponse = serde_json::from_str(&text)
503                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
504            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
505            Ok(conv)
506        } else {
507            let err_text = response
508                .text()
509                .await
510                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
511            Err(CompletionError::ProviderError(err_text))
512        }
513    }
514
515    #[cfg_attr(feature = "worker", worker::send)]
516    async fn stream(
517        &self,
518        request: CompletionRequest,
519    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
520    {
521        let mut request_payload = self.create_completion_request(request)?;
522        merge_inplace(&mut request_payload, json!({"stream": true}));
523
524        let response = self
525            .client
526            .post("api/chat")
527            .json(&request_payload)
528            .send()
529            .await
530            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
531
532        if !response.status().is_success() {
533            let err_text = response
534                .text()
535                .await
536                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
537            return Err(CompletionError::ProviderError(err_text));
538        }
539
540        let stream = Box::pin(stream! {
541            let mut stream = response.bytes_stream();
542            while let Some(chunk_result) = stream.next().await {
543                let chunk = match chunk_result {
544                    Ok(c) => c,
545                    Err(e) => {
546                        yield Err(CompletionError::from(e));
547                        break;
548                    }
549                };
550
551                let text = match String::from_utf8(chunk.to_vec()) {
552                    Ok(t) => t,
553                    Err(e) => {
554                        yield Err(CompletionError::ResponseError(e.to_string()));
555                        break;
556                    }
557                };
558
559
560                for line in text.lines() {
561                    let line = line.to_string();
562
563                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
564                        continue;
565                    };
566
567                    match response.message {
568                        Message::Assistant{ content, tool_calls, .. } => {
569                            if !content.is_empty() {
570                                yield Ok(RawStreamingChoice::Message(content))
571                            }
572
573                            for tool_call in tool_calls.iter() {
574                                let function = tool_call.function.clone();
575
576                                yield Ok(RawStreamingChoice::ToolCall {
577                                    id: "".to_string(),
578                                    name: function.name,
579                                    arguments: function.arguments,
580                                    call_id: None
581                                });
582                            }
583                        }
584                        _ => {
585                            continue;
586                        }
587                    }
588
589                    if response.done {
590                        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
591                            total_duration: response.total_duration,
592                            load_duration: response.load_duration,
593                            prompt_eval_count: response.prompt_eval_count,
594                            prompt_eval_duration: response.prompt_eval_duration,
595                            eval_count: response.eval_count,
596                            eval_duration: response.eval_duration,
597                            done_reason: response.done_reason,
598                        }));
599                    }
600                }
601            }
602        });
603
604        Ok(streaming::StreamingCompletionResponse::stream(stream))
605    }
606}
607
608// ---------- Tool Definition Conversion ----------
609
610/// Ollama-required tool definition format.
611#[derive(Clone, Debug, Deserialize, Serialize)]
612pub struct ToolDefinition {
613    #[serde(rename = "type")]
614    pub type_field: String, // Fixed as "function"
615    pub function: completion::ToolDefinition,
616}
617
618/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
619impl From<crate::completion::ToolDefinition> for ToolDefinition {
620    fn from(tool: crate::completion::ToolDefinition) -> Self {
621        ToolDefinition {
622            type_field: "function".to_owned(),
623            function: completion::ToolDefinition {
624                name: tool.name,
625                description: tool.description,
626                parameters: tool.parameters,
627            },
628        }
629    }
630}
631
632#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
633pub struct ToolCall {
634    // pub id: String,
635    #[serde(default, rename = "type")]
636    pub r#type: ToolType,
637    pub function: Function,
638}
639#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
640#[serde(rename_all = "lowercase")]
641pub enum ToolType {
642    #[default]
643    Function,
644}
645#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
646pub struct Function {
647    pub name: String,
648    pub arguments: Value,
649}
650
651// ---------- Provider Message Definition ----------
652
653#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
654#[serde(tag = "role", rename_all = "lowercase")]
655pub enum Message {
656    User {
657        content: String,
658        #[serde(skip_serializing_if = "Option::is_none")]
659        images: Option<Vec<String>>,
660        #[serde(skip_serializing_if = "Option::is_none")]
661        name: Option<String>,
662    },
663    Assistant {
664        #[serde(default)]
665        content: String,
666        #[serde(skip_serializing_if = "Option::is_none")]
667        thinking: Option<String>,
668        #[serde(skip_serializing_if = "Option::is_none")]
669        images: Option<Vec<String>>,
670        #[serde(skip_serializing_if = "Option::is_none")]
671        name: Option<String>,
672        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
673        tool_calls: Vec<ToolCall>,
674    },
675    System {
676        content: String,
677        #[serde(skip_serializing_if = "Option::is_none")]
678        images: Option<Vec<String>>,
679        #[serde(skip_serializing_if = "Option::is_none")]
680        name: Option<String>,
681    },
682    #[serde(rename = "tool")]
683    ToolResult {
684        #[serde(rename = "tool_name")]
685        name: String,
686        content: String,
687    },
688}
689
690/// -----------------------------
691/// Provider Message Conversions
692/// -----------------------------
693/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
694/// (Only User and Assistant variants are supported.)
695impl TryFrom<crate::message::Message> for Vec<Message> {
696    type Error = crate::message::MessageError;
697    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
698        use crate::message::Message as InternalMessage;
699        match internal_msg {
700            InternalMessage::User { content, .. } => {
701                let (tool_results, other_content): (Vec<_>, Vec<_>) =
702                    content.into_iter().partition(|content| {
703                        matches!(content, crate::message::UserContent::ToolResult(_))
704                    });
705
706                if !tool_results.is_empty() {
707                    tool_results
708                        .into_iter()
709                        .map(|content| match content {
710                            crate::message::UserContent::ToolResult(
711                                crate::message::ToolResult { id, content, .. },
712                            ) => {
713                                // Ollama expects a single string for tool results, so we concatenate
714                                let content_string = content
715                                    .into_iter()
716                                    .map(|content| match content {
717                                        crate::message::ToolResultContent::Text(text) => text.text,
718                                        _ => "[Non-text content]".to_string(),
719                                    })
720                                    .collect::<Vec<_>>()
721                                    .join("\n");
722
723                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
724                                    name: id,
725                                    content: content_string,
726                                })
727                            }
728                            _ => unreachable!(),
729                        })
730                        .collect::<Result<Vec<_>, _>>()
731                } else {
732                    // Ollama requires separate text content and images array
733                    let (texts, images) = other_content.into_iter().fold(
734                        (Vec::new(), Vec::new()),
735                        |(mut texts, mut images), content| {
736                            match content {
737                                crate::message::UserContent::Text(crate::message::Text {
738                                    text,
739                                }) => texts.push(text),
740                                crate::message::UserContent::Image(crate::message::Image {
741                                    data,
742                                    ..
743                                }) => images.push(data),
744                                _ => {} // Audio/Document not supported by Ollama
745                            }
746                            (texts, images)
747                        },
748                    );
749
750                    Ok(vec![Message::User {
751                        content: texts.join(" "),
752                        images: if images.is_empty() {
753                            None
754                        } else {
755                            Some(images)
756                        },
757                        name: None,
758                    }])
759                }
760            }
761            InternalMessage::Assistant { content, .. } => {
762                let mut thinking: Option<String> = None;
763                let (text_content, tool_calls) = content.into_iter().fold(
764                    (Vec::new(), Vec::new()),
765                    |(mut texts, mut tools), content| {
766                        match content {
767                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
768                            crate::message::AssistantContent::ToolCall(tool_call) => {
769                                tools.push(tool_call)
770                            }
771                            crate::message::AssistantContent::Reasoning(
772                                crate::message::Reasoning { reasoning },
773                            ) => {
774                                thinking = Some(reasoning);
775                            }
776                        }
777                        (texts, tools)
778                    },
779                );
780
781                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
782                //  so either `content` or `tool_calls` will have some content.
783                Ok(vec![Message::Assistant {
784                    content: text_content.join(" "),
785                    thinking,
786                    images: None,
787                    name: None,
788                    tool_calls: tool_calls
789                        .into_iter()
790                        .map(|tool_call| tool_call.into())
791                        .collect::<Vec<_>>(),
792                }])
793            }
794        }
795    }
796}
797
798/// Conversion from provider Message to a completion message.
799/// This is needed so that responses can be converted back into chat history.
800impl From<Message> for crate::completion::Message {
801    fn from(msg: Message) -> Self {
802        match msg {
803            Message::User { content, .. } => crate::completion::Message::User {
804                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
805                    text: content,
806                })),
807            },
808            Message::Assistant {
809                content,
810                tool_calls,
811                ..
812            } => {
813                let mut assistant_contents =
814                    vec![crate::completion::message::AssistantContent::Text(Text {
815                        text: content,
816                    })];
817                for tc in tool_calls {
818                    assistant_contents.push(
819                        crate::completion::message::AssistantContent::tool_call(
820                            tc.function.name.clone(),
821                            tc.function.name,
822                            tc.function.arguments,
823                        ),
824                    );
825                }
826                crate::completion::Message::Assistant {
827                    id: None,
828                    content: OneOrMany::many(assistant_contents).unwrap(),
829                }
830            }
831            // System and ToolResult are converted to User message as needed.
832            Message::System { content, .. } => crate::completion::Message::User {
833                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
834                    text: content,
835                })),
836            },
837            Message::ToolResult { name, content } => crate::completion::Message::User {
838                content: OneOrMany::one(message::UserContent::tool_result(
839                    name,
840                    OneOrMany::one(message::ToolResultContent::text(content)),
841                )),
842            },
843        }
844    }
845}
846
847impl Message {
848    /// Constructs a system message.
849    pub fn system(content: &str) -> Self {
850        Message::System {
851            content: content.to_owned(),
852            images: None,
853            name: None,
854        }
855    }
856}
857
858// ---------- Additional Message Types ----------
859
860impl From<crate::message::ToolCall> for ToolCall {
861    fn from(tool_call: crate::message::ToolCall) -> Self {
862        Self {
863            r#type: ToolType::Function,
864            function: Function {
865                name: tool_call.function.name,
866                arguments: tool_call.function.arguments,
867            },
868        }
869    }
870}
871
872#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
873pub struct SystemContent {
874    #[serde(default)]
875    r#type: SystemContentType,
876    text: String,
877}
878
879#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
880#[serde(rename_all = "lowercase")]
881pub enum SystemContentType {
882    #[default]
883    Text,
884}
885
886impl From<String> for SystemContent {
887    fn from(s: String) -> Self {
888        SystemContent {
889            r#type: SystemContentType::default(),
890            text: s,
891        }
892    }
893}
894
895impl FromStr for SystemContent {
896    type Err = std::convert::Infallible;
897    fn from_str(s: &str) -> Result<Self, Self::Err> {
898        Ok(SystemContent {
899            r#type: SystemContentType::default(),
900            text: s.to_string(),
901        })
902    }
903}
904
905#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
906pub struct AssistantContent {
907    pub text: String,
908}
909
910impl FromStr for AssistantContent {
911    type Err = std::convert::Infallible;
912    fn from_str(s: &str) -> Result<Self, Self::Err> {
913        Ok(AssistantContent { text: s.to_owned() })
914    }
915}
916
917#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
918#[serde(tag = "type", rename_all = "lowercase")]
919pub enum UserContent {
920    Text { text: String },
921    Image { image_url: ImageUrl },
922    // Audio variant removed as Ollama API does not support audio input.
923}
924
925impl FromStr for UserContent {
926    type Err = std::convert::Infallible;
927    fn from_str(s: &str) -> Result<Self, Self::Err> {
928        Ok(UserContent::Text { text: s.to_owned() })
929    }
930}
931
932#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
933pub struct ImageUrl {
934    pub url: String,
935    #[serde(default)]
936    pub detail: ImageDetail,
937}
938
939// =================================================================
940// Tests
941// =================================================================
942
943#[cfg(test)]
944mod tests {
945    use super::*;
946    use serde_json::json;
947
948    // Test deserialization and conversion for the /api/chat endpoint.
949    #[tokio::test]
950    async fn test_chat_completion() {
951        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
952        let sample_chat_response = json!({
953            "model": "llama3.2",
954            "created_at": "2023-08-04T19:22:45.499127Z",
955            "message": {
956                "role": "assistant",
957                "content": "The sky is blue because of Rayleigh scattering.",
958                "images": null,
959                "tool_calls": [
960                    {
961                        "type": "function",
962                        "function": {
963                            "name": "get_current_weather",
964                            "arguments": {
965                                "location": "San Francisco, CA",
966                                "format": "celsius"
967                            }
968                        }
969                    }
970                ]
971            },
972            "done": true,
973            "total_duration": 8000000000u64,
974            "load_duration": 6000000u64,
975            "prompt_eval_count": 61u64,
976            "prompt_eval_duration": 400000000u64,
977            "eval_count": 468u64,
978            "eval_duration": 7700000000u64
979        });
980        let sample_text = sample_chat_response.to_string();
981
982        let chat_resp: CompletionResponse =
983            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
984        let conv: completion::CompletionResponse<CompletionResponse> =
985            chat_resp.try_into().unwrap();
986        assert!(
987            !conv.choice.is_empty(),
988            "Expected non-empty choice in chat response"
989        );
990    }
991
992    // Test conversion from provider Message to completion Message.
993    #[test]
994    fn test_message_conversion() {
995        // Construct a provider Message (User variant with String content).
996        let provider_msg = Message::User {
997            content: "Test message".to_owned(),
998            images: None,
999            name: None,
1000        };
1001        // Convert it into a completion::Message.
1002        let comp_msg: crate::completion::Message = provider_msg.into();
1003        match comp_msg {
1004            crate::completion::Message::User { content } => {
1005                // Assume OneOrMany<T> has a method first() to access the first element.
1006                let first_content = content.first();
1007                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1008                match first_content {
1009                    crate::completion::message::UserContent::Text(text_struct) => {
1010                        assert_eq!(text_struct.text, "Test message");
1011                    }
1012                    _ => panic!("Expected text content in conversion"),
1013                }
1014            }
1015            _ => panic!("Conversion from provider Message to completion Message failed"),
1016        }
1017    }
1018
1019    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1020    #[test]
1021    fn test_tool_definition_conversion() {
1022        // Internal tool definition from the completion module.
1023        let internal_tool = crate::completion::ToolDefinition {
1024            name: "get_current_weather".to_owned(),
1025            description: "Get the current weather for a location".to_owned(),
1026            parameters: json!({
1027                "type": "object",
1028                "properties": {
1029                    "location": {
1030                        "type": "string",
1031                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1032                    },
1033                    "format": {
1034                        "type": "string",
1035                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1036                        "enum": ["celsius", "fahrenheit"]
1037                    }
1038                },
1039                "required": ["location", "format"]
1040            }),
1041        };
1042        // Convert internal tool to Ollama's tool definition.
1043        let ollama_tool: ToolDefinition = internal_tool.into();
1044        assert_eq!(ollama_tool.type_field, "function");
1045        assert_eq!(ollama_tool.function.name, "get_current_weather");
1046        assert_eq!(
1047            ollama_tool.function.description,
1048            "Get the current weather for a location"
1049        );
1050        // Check JSON fields in parameters.
1051        let params = &ollama_tool.function.parameters;
1052        assert_eq!(params["properties"]["location"]["type"], "string");
1053    }
1054}