rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient,
43    VerifyError,
44};
45use crate::completion::{GetTokenUsage, Usage};
46use crate::json_utils::merge_inplace;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49    Embed, OneOrMany,
50    completion::{self, CompletionError, CompletionRequest},
51    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
52    impl_conversion_traits, json_utils, message,
53    message::{ImageDetail, Text},
54    streaming,
55};
56use async_stream::stream;
57use futures::StreamExt;
58use reqwest;
59use reqwest_eventsource::{Event, RequestBuilderExt};
60use serde::{Deserialize, Serialize};
61use serde_json::{Value, json};
62use std::{convert::TryFrom, str::FromStr};
63use url::Url;
64// ---------- Main Client ----------
65
66const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68pub struct ClientBuilder<'a> {
69    base_url: &'a str,
70    http_client: Option<reqwest::Client>,
71}
72
73impl<'a> ClientBuilder<'a> {
74    #[allow(clippy::new_without_default)]
75    pub fn new() -> Self {
76        Self {
77            base_url: OLLAMA_API_BASE_URL,
78            http_client: None,
79        }
80    }
81
82    pub fn base_url(mut self, base_url: &'a str) -> Self {
83        self.base_url = base_url;
84        self
85    }
86
87    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
88        self.http_client = Some(client);
89        self
90    }
91
92    pub fn build(self) -> Result<Client, ClientBuilderError> {
93        let http_client = if let Some(http_client) = self.http_client {
94            http_client
95        } else {
96            reqwest::Client::builder().build()?
97        };
98
99        Ok(Client {
100            base_url: Url::parse(self.base_url)
101                .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
102            http_client,
103        })
104    }
105}
106
107#[derive(Clone, Debug)]
108pub struct Client {
109    base_url: Url,
110    http_client: reqwest::Client,
111}
112
113impl Default for Client {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl Client {
120    /// Create a new Ollama client builder.
121    ///
122    /// # Example
123    /// ```
124    /// use rig::providers::ollama::{ClientBuilder, self};
125    ///
126    /// // Initialize the Ollama client
127    /// let client = Client::builder()
128    ///    .build()
129    /// ```
130    pub fn builder() -> ClientBuilder<'static> {
131        ClientBuilder::new()
132    }
133
134    /// Create a new Ollama client. For more control, use the `builder` method.
135    ///
136    /// # Panics
137    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
138    pub fn new() -> Self {
139        Self::builder().build().expect("Ollama client should build")
140    }
141
142    pub(crate) fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
143        let url = self.base_url.join(path)?;
144        Ok(self.http_client.post(url))
145    }
146
147    pub(crate) fn get(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
148        let url = self.base_url.join(path)?;
149        Ok(self.http_client.get(url))
150    }
151}
152
153impl ProviderClient for Client {
154    fn from_env() -> Self
155    where
156        Self: Sized,
157    {
158        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
159        Self::builder().base_url(&api_base).build().unwrap()
160    }
161
162    fn from_val(input: crate::client::ProviderValue) -> Self {
163        let crate::client::ProviderValue::Simple(_) = input else {
164            panic!("Incorrect provider value type")
165        };
166
167        Self::new()
168    }
169}
170
171impl CompletionClient for Client {
172    type CompletionModel = CompletionModel;
173
174    fn completion_model(&self, model: &str) -> CompletionModel {
175        CompletionModel::new(self.clone(), model)
176    }
177}
178
179impl EmbeddingsClient for Client {
180    type EmbeddingModel = EmbeddingModel;
181    fn embedding_model(&self, model: &str) -> EmbeddingModel {
182        EmbeddingModel::new(self.clone(), model, 0)
183    }
184    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
185        EmbeddingModel::new(self.clone(), model, ndims)
186    }
187    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
188        EmbeddingsBuilder::new(self.embedding_model(model))
189    }
190}
191
192impl VerifyClient for Client {
193    #[cfg_attr(feature = "worker", worker::send)]
194    async fn verify(&self) -> Result<(), VerifyError> {
195        let response = self
196            .get("api/tags")
197            .expect("Failed to build request")
198            .send()
199            .await?;
200        match response.status() {
201            reqwest::StatusCode::OK => Ok(()),
202            _ => {
203                response.error_for_status()?;
204                Ok(())
205            }
206        }
207    }
208}
209
210impl_conversion_traits!(
211    AsTranscription,
212    AsImageGeneration,
213    AsAudioGeneration for Client
214);
215
216// ---------- API Error and Response Structures ----------
217
218#[derive(Debug, Deserialize)]
219struct ApiErrorResponse {
220    message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225enum ApiResponse<T> {
226    Ok(T),
227    Err(ApiErrorResponse),
228}
229
230// ---------- Embedding API ----------
231
232pub const ALL_MINILM: &str = "all-minilm";
233pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
234
235#[derive(Debug, Serialize, Deserialize)]
236pub struct EmbeddingResponse {
237    pub model: String,
238    pub embeddings: Vec<Vec<f64>>,
239    #[serde(default)]
240    pub total_duration: Option<u64>,
241    #[serde(default)]
242    pub load_duration: Option<u64>,
243    #[serde(default)]
244    pub prompt_eval_count: Option<u64>,
245}
246
247impl From<ApiErrorResponse> for EmbeddingError {
248    fn from(err: ApiErrorResponse) -> Self {
249        EmbeddingError::ProviderError(err.message)
250    }
251}
252
253impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
254    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
255        match value {
256            ApiResponse::Ok(response) => Ok(response),
257            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
258        }
259    }
260}
261
262// ---------- Embedding Model ----------
263
264#[derive(Clone)]
265pub struct EmbeddingModel {
266    client: Client,
267    pub model: String,
268    ndims: usize,
269}
270
271impl EmbeddingModel {
272    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
273        Self {
274            client,
275            model: model.to_owned(),
276            ndims,
277        }
278    }
279}
280
281impl embeddings::EmbeddingModel for EmbeddingModel {
282    const MAX_DOCUMENTS: usize = 1024;
283    fn ndims(&self) -> usize {
284        self.ndims
285    }
286    #[cfg_attr(feature = "worker", worker::send)]
287    async fn embed_texts(
288        &self,
289        documents: impl IntoIterator<Item = String>,
290    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
291        let docs: Vec<String> = documents.into_iter().collect();
292        let payload = json!({
293            "model": self.model,
294            "input": docs,
295        });
296        let response = self
297            .client
298            .post("api/embed")?
299            .json(&payload)
300            .send()
301            .await
302            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
303        if response.status().is_success() {
304            let api_resp: EmbeddingResponse = response
305                .json()
306                .await
307                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
308            if api_resp.embeddings.len() != docs.len() {
309                return Err(EmbeddingError::ResponseError(
310                    "Number of returned embeddings does not match input".into(),
311                ));
312            }
313            Ok(api_resp
314                .embeddings
315                .into_iter()
316                .zip(docs.into_iter())
317                .map(|(vec, document)| embeddings::Embedding { document, vec })
318                .collect())
319        } else {
320            Err(EmbeddingError::ProviderError(response.text().await?))
321        }
322    }
323}
324
325// ---------- Completion API ----------
326
327pub const LLAMA3_2: &str = "llama3.2";
328pub const LLAVA: &str = "llava";
329pub const MISTRAL: &str = "mistral";
330
331#[derive(Debug, Serialize, Deserialize)]
332pub struct CompletionResponse {
333    pub model: String,
334    pub created_at: String,
335    pub message: Message,
336    pub done: bool,
337    #[serde(default)]
338    pub done_reason: Option<String>,
339    #[serde(default)]
340    pub total_duration: Option<u64>,
341    #[serde(default)]
342    pub load_duration: Option<u64>,
343    #[serde(default)]
344    pub prompt_eval_count: Option<u64>,
345    #[serde(default)]
346    pub prompt_eval_duration: Option<u64>,
347    #[serde(default)]
348    pub eval_count: Option<u64>,
349    #[serde(default)]
350    pub eval_duration: Option<u64>,
351}
352impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
353    type Error = CompletionError;
354    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
355        match resp.message {
356            // Process only if an assistant message is present.
357            Message::Assistant {
358                content,
359                thinking,
360                tool_calls,
361                ..
362            } => {
363                let mut assistant_contents = Vec::new();
364                // Add the assistant's text content if any.
365                if !content.is_empty() {
366                    assistant_contents.push(completion::AssistantContent::text(&content));
367                }
368                // Process tool_calls following Ollama's chat response definition.
369                // Each ToolCall has an id, a type, and a function field.
370                for tc in tool_calls.iter() {
371                    assistant_contents.push(completion::AssistantContent::tool_call(
372                        tc.function.name.clone(),
373                        tc.function.name.clone(),
374                        tc.function.arguments.clone(),
375                    ));
376                }
377                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
378                    CompletionError::ResponseError("No content provided".to_owned())
379                })?;
380                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
381                let completion_tokens = resp.eval_count.unwrap_or(0);
382
383                let raw_response = CompletionResponse {
384                    model: resp.model,
385                    created_at: resp.created_at,
386                    done: resp.done,
387                    done_reason: resp.done_reason,
388                    total_duration: resp.total_duration,
389                    load_duration: resp.load_duration,
390                    prompt_eval_count: resp.prompt_eval_count,
391                    prompt_eval_duration: resp.prompt_eval_duration,
392                    eval_count: resp.eval_count,
393                    eval_duration: resp.eval_duration,
394                    message: Message::Assistant {
395                        content,
396                        thinking,
397                        images: None,
398                        name: None,
399                        tool_calls,
400                    },
401                };
402
403                Ok(completion::CompletionResponse {
404                    choice,
405                    usage: Usage {
406                        input_tokens: prompt_tokens,
407                        output_tokens: completion_tokens,
408                        total_tokens: prompt_tokens + completion_tokens,
409                    },
410                    raw_response,
411                })
412            }
413            _ => Err(CompletionError::ResponseError(
414                "Chat response does not include an assistant message".into(),
415            )),
416        }
417    }
418}
419
420// ---------- Completion Model ----------
421
422#[derive(Clone)]
423pub struct CompletionModel {
424    client: Client,
425    pub model: String,
426}
427
428impl CompletionModel {
429    pub fn new(client: Client, model: &str) -> Self {
430        Self {
431            client,
432            model: model.to_owned(),
433        }
434    }
435
436    fn create_completion_request(
437        &self,
438        completion_request: CompletionRequest,
439    ) -> Result<Value, CompletionError> {
440        // Build up the order of messages (context, chat_history)
441        let mut partial_history = vec![];
442        if let Some(docs) = completion_request.normalized_documents() {
443            partial_history.push(docs);
444        }
445        partial_history.extend(completion_request.chat_history);
446
447        // Initialize full history with preamble (or empty if non-existent)
448        let mut full_history: Vec<Message> = completion_request
449            .preamble
450            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
451
452        // Convert and extend the rest of the history
453        full_history.extend(
454            partial_history
455                .into_iter()
456                .map(|msg| msg.try_into())
457                .collect::<Result<Vec<Vec<Message>>, _>>()?
458                .into_iter()
459                .flatten()
460                .collect::<Vec<Message>>(),
461        );
462
463        // Convert internal prompt into a provider Message
464        let options = if let Some(extra) = completion_request.additional_params {
465            json_utils::merge(
466                json!({ "temperature": completion_request.temperature }),
467                extra,
468            )
469        } else {
470            json!({ "temperature": completion_request.temperature })
471        };
472
473        let mut request_payload = json!({
474            "model": self.model,
475            "messages": full_history,
476            "options": options,
477            "stream": false,
478        });
479        if !completion_request.tools.is_empty() {
480            request_payload["tools"] = json!(
481                completion_request
482                    .tools
483                    .into_iter()
484                    .map(|tool| tool.into())
485                    .collect::<Vec<ToolDefinition>>()
486            );
487        }
488
489        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
490
491        Ok(request_payload)
492    }
493}
494
495// ---------- CompletionModel Implementation ----------
496
497#[derive(Clone, Serialize, Deserialize, Debug)]
498pub struct StreamingCompletionResponse {
499    pub done_reason: Option<String>,
500    pub total_duration: Option<u64>,
501    pub load_duration: Option<u64>,
502    pub prompt_eval_count: Option<u64>,
503    pub prompt_eval_duration: Option<u64>,
504    pub eval_count: Option<u64>,
505    pub eval_duration: Option<u64>,
506}
507
508impl GetTokenUsage for StreamingCompletionResponse {
509    fn token_usage(&self) -> Option<crate::completion::Usage> {
510        let mut usage = crate::completion::Usage::new();
511        let input_tokens = self.prompt_eval_count.unwrap_or_default();
512        let output_tokens = self.eval_count.unwrap_or_default();
513        usage.input_tokens = input_tokens;
514        usage.output_tokens = output_tokens;
515        usage.total_tokens = input_tokens + output_tokens;
516
517        Some(usage)
518    }
519}
520
521impl completion::CompletionModel for CompletionModel {
522    type Response = CompletionResponse;
523    type StreamingResponse = StreamingCompletionResponse;
524
525    #[cfg_attr(feature = "worker", worker::send)]
526    async fn completion(
527        &self,
528        completion_request: CompletionRequest,
529    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
530        let request_payload = self.create_completion_request(completion_request)?;
531
532        let response = self
533            .client
534            .post("api/chat")?
535            .json(&request_payload)
536            .send()
537            .await
538            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
539        if response.status().is_success() {
540            let text = response
541                .text()
542                .await
543                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
544            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
545            let chat_resp: CompletionResponse = serde_json::from_str(&text)
546                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
547            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
548            Ok(conv)
549        } else {
550            let err_text = response
551                .text()
552                .await
553                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
554            Err(CompletionError::ProviderError(err_text))
555        }
556    }
557
558    #[cfg_attr(feature = "worker", worker::send)]
559    async fn stream(
560        &self,
561        request: CompletionRequest,
562    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
563    {
564        let mut request_payload = self.create_completion_request(request)?;
565        merge_inplace(&mut request_payload, json!({"stream": true}));
566
567        let mut event_source = self
568            .client
569            .post("api/chat")?
570            .json(&request_payload)
571            .eventsource()
572            .expect("Cloning request must succeed");
573
574        let stream = Box::pin(stream! {
575        while let Some(event_result) = event_source.next().await {
576            match event_result {
577                Ok(Event::Open) => {
578                    tracing::trace!("SSE connection opened");
579                    continue;
580                }
581
582                Ok(Event::Message(message)) => {
583                    let data_str = message.data.trim();
584
585                    let parsed = serde_json::from_str::<CompletionResponse>(data_str);
586                    let Ok(response) = parsed else {
587                        tracing::debug!("Couldn't parse SSE payload as CompletionResponse");
588                        continue;
589                    };
590
591                    match response.message {
592                        Message::Assistant { content, tool_calls, .. } => {
593                            if !content.is_empty() {
594                                yield Ok(RawStreamingChoice::Message(content));
595                            }
596
597                            for tool_call in tool_calls {
598                                let function = tool_call.function.clone();
599                                yield Ok(RawStreamingChoice::ToolCall {
600                                    id: "".to_string(),
601                                    name: function.name,
602                                    arguments: function.arguments,
603                                    call_id: None,
604                                });
605                            }
606                        }
607                        _ => continue,
608                    }
609
610                    if response.done {
611                        yield Ok(RawStreamingChoice::FinalResponse(
612                            StreamingCompletionResponse {
613                                total_duration: response.total_duration,
614                                load_duration: response.load_duration,
615                                prompt_eval_count: response.prompt_eval_count,
616                                prompt_eval_duration: response.prompt_eval_duration,
617                                eval_count: response.eval_count,
618                                eval_duration: response.eval_duration,
619                                done_reason: response.done_reason,
620                            }
621                        ));
622                    }
623                }
624
625                Err(reqwest_eventsource::Error::StreamEnded) => break,
626
627                Err(err) => {
628                    tracing::error!(?err, "SSE error");
629                    yield Err(CompletionError::ResponseError(err.to_string()));
630                    break;
631                }
632            };
633        }});
634
635        Ok(streaming::StreamingCompletionResponse::stream(stream))
636    }
637}
638
639// ---------- Tool Definition Conversion ----------
640
641/// Ollama-required tool definition format.
642#[derive(Clone, Debug, Deserialize, Serialize)]
643pub struct ToolDefinition {
644    #[serde(rename = "type")]
645    pub type_field: String, // Fixed as "function"
646    pub function: completion::ToolDefinition,
647}
648
649/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
650impl From<crate::completion::ToolDefinition> for ToolDefinition {
651    fn from(tool: crate::completion::ToolDefinition) -> Self {
652        ToolDefinition {
653            type_field: "function".to_owned(),
654            function: completion::ToolDefinition {
655                name: tool.name,
656                description: tool.description,
657                parameters: tool.parameters,
658            },
659        }
660    }
661}
662
663#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
664pub struct ToolCall {
665    // pub id: String,
666    #[serde(default, rename = "type")]
667    pub r#type: ToolType,
668    pub function: Function,
669}
670#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
671#[serde(rename_all = "lowercase")]
672pub enum ToolType {
673    #[default]
674    Function,
675}
676#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
677pub struct Function {
678    pub name: String,
679    pub arguments: Value,
680}
681
682// ---------- Provider Message Definition ----------
683
684#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
685#[serde(tag = "role", rename_all = "lowercase")]
686pub enum Message {
687    User {
688        content: String,
689        #[serde(skip_serializing_if = "Option::is_none")]
690        images: Option<Vec<String>>,
691        #[serde(skip_serializing_if = "Option::is_none")]
692        name: Option<String>,
693    },
694    Assistant {
695        #[serde(default)]
696        content: String,
697        #[serde(skip_serializing_if = "Option::is_none")]
698        thinking: Option<String>,
699        #[serde(skip_serializing_if = "Option::is_none")]
700        images: Option<Vec<String>>,
701        #[serde(skip_serializing_if = "Option::is_none")]
702        name: Option<String>,
703        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
704        tool_calls: Vec<ToolCall>,
705    },
706    System {
707        content: String,
708        #[serde(skip_serializing_if = "Option::is_none")]
709        images: Option<Vec<String>>,
710        #[serde(skip_serializing_if = "Option::is_none")]
711        name: Option<String>,
712    },
713    #[serde(rename = "tool")]
714    ToolResult {
715        #[serde(rename = "tool_name")]
716        name: String,
717        content: String,
718    },
719}
720
721/// -----------------------------
722/// Provider Message Conversions
723/// -----------------------------
724/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
725/// (Only User and Assistant variants are supported.)
726impl TryFrom<crate::message::Message> for Vec<Message> {
727    type Error = crate::message::MessageError;
728    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
729        use crate::message::Message as InternalMessage;
730        match internal_msg {
731            InternalMessage::User { content, .. } => {
732                let (tool_results, other_content): (Vec<_>, Vec<_>) =
733                    content.into_iter().partition(|content| {
734                        matches!(content, crate::message::UserContent::ToolResult(_))
735                    });
736
737                if !tool_results.is_empty() {
738                    tool_results
739                        .into_iter()
740                        .map(|content| match content {
741                            crate::message::UserContent::ToolResult(
742                                crate::message::ToolResult { id, content, .. },
743                            ) => {
744                                // Ollama expects a single string for tool results, so we concatenate
745                                let content_string = content
746                                    .into_iter()
747                                    .map(|content| match content {
748                                        crate::message::ToolResultContent::Text(text) => text.text,
749                                        _ => "[Non-text content]".to_string(),
750                                    })
751                                    .collect::<Vec<_>>()
752                                    .join("\n");
753
754                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
755                                    name: id,
756                                    content: content_string,
757                                })
758                            }
759                            _ => unreachable!(),
760                        })
761                        .collect::<Result<Vec<_>, _>>()
762                } else {
763                    // Ollama requires separate text content and images array
764                    let (texts, images) = other_content.into_iter().fold(
765                        (Vec::new(), Vec::new()),
766                        |(mut texts, mut images), content| {
767                            match content {
768                                crate::message::UserContent::Text(crate::message::Text {
769                                    text,
770                                }) => texts.push(text),
771                                crate::message::UserContent::Image(crate::message::Image {
772                                    data,
773                                    ..
774                                }) => images.push(data),
775                                crate::message::UserContent::Document(
776                                    crate::message::Document { data, .. },
777                                ) => texts.push(data),
778                                _ => {} // Audio not supported by Ollama
779                            }
780                            (texts, images)
781                        },
782                    );
783
784                    Ok(vec![Message::User {
785                        content: texts.join(" "),
786                        images: if images.is_empty() {
787                            None
788                        } else {
789                            Some(
790                                images
791                                    .into_iter()
792                                    .map(|x| x.to_string())
793                                    .collect::<Vec<String>>(),
794                            )
795                        },
796                        name: None,
797                    }])
798                }
799            }
800            InternalMessage::Assistant { content, .. } => {
801                let mut thinking: Option<String> = None;
802                let (text_content, tool_calls) = content.into_iter().fold(
803                    (Vec::new(), Vec::new()),
804                    |(mut texts, mut tools), content| {
805                        match content {
806                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
807                            crate::message::AssistantContent::ToolCall(tool_call) => {
808                                tools.push(tool_call)
809                            }
810                            crate::message::AssistantContent::Reasoning(
811                                crate::message::Reasoning { reasoning, .. },
812                            ) => {
813                                thinking =
814                                    Some(reasoning.first().cloned().unwrap_or(String::new()));
815                            }
816                        }
817                        (texts, tools)
818                    },
819                );
820
821                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
822                //  so either `content` or `tool_calls` will have some content.
823                Ok(vec![Message::Assistant {
824                    content: text_content.join(" "),
825                    thinking,
826                    images: None,
827                    name: None,
828                    tool_calls: tool_calls
829                        .into_iter()
830                        .map(|tool_call| tool_call.into())
831                        .collect::<Vec<_>>(),
832                }])
833            }
834        }
835    }
836}
837
838/// Conversion from provider Message to a completion message.
839/// This is needed so that responses can be converted back into chat history.
840impl From<Message> for crate::completion::Message {
841    fn from(msg: Message) -> Self {
842        match msg {
843            Message::User { content, .. } => crate::completion::Message::User {
844                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
845                    text: content,
846                })),
847            },
848            Message::Assistant {
849                content,
850                tool_calls,
851                ..
852            } => {
853                let mut assistant_contents =
854                    vec![crate::completion::message::AssistantContent::Text(Text {
855                        text: content,
856                    })];
857                for tc in tool_calls {
858                    assistant_contents.push(
859                        crate::completion::message::AssistantContent::tool_call(
860                            tc.function.name.clone(),
861                            tc.function.name,
862                            tc.function.arguments,
863                        ),
864                    );
865                }
866                crate::completion::Message::Assistant {
867                    id: None,
868                    content: OneOrMany::many(assistant_contents).unwrap(),
869                }
870            }
871            // System and ToolResult are converted to User message as needed.
872            Message::System { content, .. } => crate::completion::Message::User {
873                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
874                    text: content,
875                })),
876            },
877            Message::ToolResult { name, content } => crate::completion::Message::User {
878                content: OneOrMany::one(message::UserContent::tool_result(
879                    name,
880                    OneOrMany::one(message::ToolResultContent::text(content)),
881                )),
882            },
883        }
884    }
885}
886
887impl Message {
888    /// Constructs a system message.
889    pub fn system(content: &str) -> Self {
890        Message::System {
891            content: content.to_owned(),
892            images: None,
893            name: None,
894        }
895    }
896}
897
898// ---------- Additional Message Types ----------
899
900impl From<crate::message::ToolCall> for ToolCall {
901    fn from(tool_call: crate::message::ToolCall) -> Self {
902        Self {
903            r#type: ToolType::Function,
904            function: Function {
905                name: tool_call.function.name,
906                arguments: tool_call.function.arguments,
907            },
908        }
909    }
910}
911
912#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
913pub struct SystemContent {
914    #[serde(default)]
915    r#type: SystemContentType,
916    text: String,
917}
918
919#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
920#[serde(rename_all = "lowercase")]
921pub enum SystemContentType {
922    #[default]
923    Text,
924}
925
926impl From<String> for SystemContent {
927    fn from(s: String) -> Self {
928        SystemContent {
929            r#type: SystemContentType::default(),
930            text: s,
931        }
932    }
933}
934
935impl FromStr for SystemContent {
936    type Err = std::convert::Infallible;
937    fn from_str(s: &str) -> Result<Self, Self::Err> {
938        Ok(SystemContent {
939            r#type: SystemContentType::default(),
940            text: s.to_string(),
941        })
942    }
943}
944
945#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
946pub struct AssistantContent {
947    pub text: String,
948}
949
950impl FromStr for AssistantContent {
951    type Err = std::convert::Infallible;
952    fn from_str(s: &str) -> Result<Self, Self::Err> {
953        Ok(AssistantContent { text: s.to_owned() })
954    }
955}
956
957#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
958#[serde(tag = "type", rename_all = "lowercase")]
959pub enum UserContent {
960    Text { text: String },
961    Image { image_url: ImageUrl },
962    // Audio variant removed as Ollama API does not support audio input.
963}
964
965impl FromStr for UserContent {
966    type Err = std::convert::Infallible;
967    fn from_str(s: &str) -> Result<Self, Self::Err> {
968        Ok(UserContent::Text { text: s.to_owned() })
969    }
970}
971
972#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
973pub struct ImageUrl {
974    pub url: String,
975    #[serde(default)]
976    pub detail: ImageDetail,
977}
978
979// =================================================================
980// Tests
981// =================================================================
982
983#[cfg(test)]
984mod tests {
985    use super::*;
986    use serde_json::json;
987
988    // Test deserialization and conversion for the /api/chat endpoint.
989    #[tokio::test]
990    async fn test_chat_completion() {
991        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
992        let sample_chat_response = json!({
993            "model": "llama3.2",
994            "created_at": "2023-08-04T19:22:45.499127Z",
995            "message": {
996                "role": "assistant",
997                "content": "The sky is blue because of Rayleigh scattering.",
998                "images": null,
999                "tool_calls": [
1000                    {
1001                        "type": "function",
1002                        "function": {
1003                            "name": "get_current_weather",
1004                            "arguments": {
1005                                "location": "San Francisco, CA",
1006                                "format": "celsius"
1007                            }
1008                        }
1009                    }
1010                ]
1011            },
1012            "done": true,
1013            "total_duration": 8000000000u64,
1014            "load_duration": 6000000u64,
1015            "prompt_eval_count": 61u64,
1016            "prompt_eval_duration": 400000000u64,
1017            "eval_count": 468u64,
1018            "eval_duration": 7700000000u64
1019        });
1020        let sample_text = sample_chat_response.to_string();
1021
1022        let chat_resp: CompletionResponse =
1023            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1024        let conv: completion::CompletionResponse<CompletionResponse> =
1025            chat_resp.try_into().unwrap();
1026        assert!(
1027            !conv.choice.is_empty(),
1028            "Expected non-empty choice in chat response"
1029        );
1030    }
1031
1032    // Test conversion from provider Message to completion Message.
1033    #[test]
1034    fn test_message_conversion() {
1035        // Construct a provider Message (User variant with String content).
1036        let provider_msg = Message::User {
1037            content: "Test message".to_owned(),
1038            images: None,
1039            name: None,
1040        };
1041        // Convert it into a completion::Message.
1042        let comp_msg: crate::completion::Message = provider_msg.into();
1043        match comp_msg {
1044            crate::completion::Message::User { content } => {
1045                // Assume OneOrMany<T> has a method first() to access the first element.
1046                let first_content = content.first();
1047                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1048                match first_content {
1049                    crate::completion::message::UserContent::Text(text_struct) => {
1050                        assert_eq!(text_struct.text, "Test message");
1051                    }
1052                    _ => panic!("Expected text content in conversion"),
1053                }
1054            }
1055            _ => panic!("Conversion from provider Message to completion Message failed"),
1056        }
1057    }
1058
1059    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1060    #[test]
1061    fn test_tool_definition_conversion() {
1062        // Internal tool definition from the completion module.
1063        let internal_tool = crate::completion::ToolDefinition {
1064            name: "get_current_weather".to_owned(),
1065            description: "Get the current weather for a location".to_owned(),
1066            parameters: json!({
1067                "type": "object",
1068                "properties": {
1069                    "location": {
1070                        "type": "string",
1071                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1072                    },
1073                    "format": {
1074                        "type": "string",
1075                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1076                        "enum": ["celsius", "fahrenheit"]
1077                    }
1078                },
1079                "required": ["location", "format"]
1080            }),
1081        };
1082        // Convert internal tool to Ollama's tool definition.
1083        let ollama_tool: ToolDefinition = internal_tool.into();
1084        assert_eq!(ollama_tool.type_field, "function");
1085        assert_eq!(ollama_tool.function.name, "get_current_weather");
1086        assert_eq!(
1087            ollama_tool.function.description,
1088            "Get the current weather for a location"
1089        );
1090        // Check JSON fields in parameters.
1091        let params = &ollama_tool.function.parameters;
1092        assert_eq!(params["properties"]["location"]["type"], "string");
1093    }
1094}