rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient,
43    VerifyError,
44};
45use crate::completion::{GetTokenUsage, Usage};
46use crate::json_utils::merge_inplace;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49    Embed, OneOrMany,
50    completion::{self, CompletionError, CompletionRequest},
51    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
52    impl_conversion_traits, json_utils, message,
53    message::{ImageDetail, Text},
54    streaming,
55};
56use async_stream::stream;
57use futures::StreamExt;
58use reqwest;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use url::Url;
63// ---------- Main Client ----------
64
65const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
66
67pub struct ClientBuilder<'a> {
68    base_url: &'a str,
69    http_client: Option<reqwest::Client>,
70}
71
72impl<'a> ClientBuilder<'a> {
73    #[allow(clippy::new_without_default)]
74    pub fn new() -> Self {
75        Self {
76            base_url: OLLAMA_API_BASE_URL,
77            http_client: None,
78        }
79    }
80
81    pub fn base_url(mut self, base_url: &'a str) -> Self {
82        self.base_url = base_url;
83        self
84    }
85
86    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
87        self.http_client = Some(client);
88        self
89    }
90
91    pub fn build(self) -> Result<Client, ClientBuilderError> {
92        let http_client = if let Some(http_client) = self.http_client {
93            http_client
94        } else {
95            reqwest::Client::builder().build()?
96        };
97
98        Ok(Client {
99            base_url: Url::parse(self.base_url)
100                .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
101            http_client,
102        })
103    }
104}
105
106#[derive(Clone, Debug)]
107pub struct Client {
108    base_url: Url,
109    http_client: reqwest::Client,
110}
111
112impl Default for Client {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl Client {
119    /// Create a new Ollama client builder.
120    ///
121    /// # Example
122    /// ```
123    /// use rig::providers::ollama::{ClientBuilder, self};
124    ///
125    /// // Initialize the Ollama client
126    /// let client = Client::builder()
127    ///    .build()
128    /// ```
129    pub fn builder() -> ClientBuilder<'static> {
130        ClientBuilder::new()
131    }
132
133    /// Create a new Ollama client. For more control, use the `builder` method.
134    ///
135    /// # Panics
136    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
137    pub fn new() -> Self {
138        Self::builder().build().expect("Ollama client should build")
139    }
140
141    pub(crate) fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
142        let url = self.base_url.join(path)?;
143        Ok(self.http_client.post(url))
144    }
145
146    pub(crate) fn get(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
147        let url = self.base_url.join(path)?;
148        Ok(self.http_client.get(url))
149    }
150}
151
152impl ProviderClient for Client {
153    fn from_env() -> Self
154    where
155        Self: Sized,
156    {
157        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
158        Self::builder().base_url(&api_base).build().unwrap()
159    }
160
161    fn from_val(input: crate::client::ProviderValue) -> Self {
162        let crate::client::ProviderValue::Simple(_) = input else {
163            panic!("Incorrect provider value type")
164        };
165
166        Self::new()
167    }
168}
169
170impl CompletionClient for Client {
171    type CompletionModel = CompletionModel;
172
173    fn completion_model(&self, model: &str) -> CompletionModel {
174        CompletionModel::new(self.clone(), model)
175    }
176}
177
178impl EmbeddingsClient for Client {
179    type EmbeddingModel = EmbeddingModel;
180    fn embedding_model(&self, model: &str) -> EmbeddingModel {
181        EmbeddingModel::new(self.clone(), model, 0)
182    }
183    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
184        EmbeddingModel::new(self.clone(), model, ndims)
185    }
186    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
187        EmbeddingsBuilder::new(self.embedding_model(model))
188    }
189}
190
191impl VerifyClient for Client {
192    #[cfg_attr(feature = "worker", worker::send)]
193    async fn verify(&self) -> Result<(), VerifyError> {
194        let response = self
195            .get("api/tags")
196            .expect("Failed to build request")
197            .send()
198            .await?;
199        match response.status() {
200            reqwest::StatusCode::OK => Ok(()),
201            _ => {
202                response.error_for_status()?;
203                Ok(())
204            }
205        }
206    }
207}
208
209impl_conversion_traits!(
210    AsTranscription,
211    AsImageGeneration,
212    AsAudioGeneration for Client
213);
214
215// ---------- API Error and Response Structures ----------
216
217#[derive(Debug, Deserialize)]
218struct ApiErrorResponse {
219    message: String,
220}
221
222#[derive(Debug, Deserialize)]
223#[serde(untagged)]
224enum ApiResponse<T> {
225    Ok(T),
226    Err(ApiErrorResponse),
227}
228
229// ---------- Embedding API ----------
230
231pub const ALL_MINILM: &str = "all-minilm";
232pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
233
234#[derive(Debug, Serialize, Deserialize)]
235pub struct EmbeddingResponse {
236    pub model: String,
237    pub embeddings: Vec<Vec<f64>>,
238    #[serde(default)]
239    pub total_duration: Option<u64>,
240    #[serde(default)]
241    pub load_duration: Option<u64>,
242    #[serde(default)]
243    pub prompt_eval_count: Option<u64>,
244}
245
246impl From<ApiErrorResponse> for EmbeddingError {
247    fn from(err: ApiErrorResponse) -> Self {
248        EmbeddingError::ProviderError(err.message)
249    }
250}
251
252impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
253    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
254        match value {
255            ApiResponse::Ok(response) => Ok(response),
256            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
257        }
258    }
259}
260
261// ---------- Embedding Model ----------
262
263#[derive(Clone)]
264pub struct EmbeddingModel {
265    client: Client,
266    pub model: String,
267    ndims: usize,
268}
269
270impl EmbeddingModel {
271    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
272        Self {
273            client,
274            model: model.to_owned(),
275            ndims,
276        }
277    }
278}
279
280impl embeddings::EmbeddingModel for EmbeddingModel {
281    const MAX_DOCUMENTS: usize = 1024;
282    fn ndims(&self) -> usize {
283        self.ndims
284    }
285    #[cfg_attr(feature = "worker", worker::send)]
286    async fn embed_texts(
287        &self,
288        documents: impl IntoIterator<Item = String>,
289    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
290        let docs: Vec<String> = documents.into_iter().collect();
291        let payload = json!({
292            "model": self.model,
293            "input": docs,
294        });
295        let response = self
296            .client
297            .post("api/embed")?
298            .json(&payload)
299            .send()
300            .await
301            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
302        if response.status().is_success() {
303            let api_resp: EmbeddingResponse = response
304                .json()
305                .await
306                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
307            if api_resp.embeddings.len() != docs.len() {
308                return Err(EmbeddingError::ResponseError(
309                    "Number of returned embeddings does not match input".into(),
310                ));
311            }
312            Ok(api_resp
313                .embeddings
314                .into_iter()
315                .zip(docs.into_iter())
316                .map(|(vec, document)| embeddings::Embedding { document, vec })
317                .collect())
318        } else {
319            Err(EmbeddingError::ProviderError(response.text().await?))
320        }
321    }
322}
323
324// ---------- Completion API ----------
325
326pub const LLAMA3_2: &str = "llama3.2";
327pub const LLAVA: &str = "llava";
328pub const MISTRAL: &str = "mistral";
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct CompletionResponse {
332    pub model: String,
333    pub created_at: String,
334    pub message: Message,
335    pub done: bool,
336    #[serde(default)]
337    pub done_reason: Option<String>,
338    #[serde(default)]
339    pub total_duration: Option<u64>,
340    #[serde(default)]
341    pub load_duration: Option<u64>,
342    #[serde(default)]
343    pub prompt_eval_count: Option<u64>,
344    #[serde(default)]
345    pub prompt_eval_duration: Option<u64>,
346    #[serde(default)]
347    pub eval_count: Option<u64>,
348    #[serde(default)]
349    pub eval_duration: Option<u64>,
350}
351impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
352    type Error = CompletionError;
353    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
354        match resp.message {
355            // Process only if an assistant message is present.
356            Message::Assistant {
357                content,
358                thinking,
359                tool_calls,
360                ..
361            } => {
362                let mut assistant_contents = Vec::new();
363                // Add the assistant's text content if any.
364                if !content.is_empty() {
365                    assistant_contents.push(completion::AssistantContent::text(&content));
366                }
367                // Process tool_calls following Ollama's chat response definition.
368                // Each ToolCall has an id, a type, and a function field.
369                for tc in tool_calls.iter() {
370                    assistant_contents.push(completion::AssistantContent::tool_call(
371                        tc.function.name.clone(),
372                        tc.function.name.clone(),
373                        tc.function.arguments.clone(),
374                    ));
375                }
376                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
377                    CompletionError::ResponseError("No content provided".to_owned())
378                })?;
379                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
380                let completion_tokens = resp.eval_count.unwrap_or(0);
381
382                let raw_response = CompletionResponse {
383                    model: resp.model,
384                    created_at: resp.created_at,
385                    done: resp.done,
386                    done_reason: resp.done_reason,
387                    total_duration: resp.total_duration,
388                    load_duration: resp.load_duration,
389                    prompt_eval_count: resp.prompt_eval_count,
390                    prompt_eval_duration: resp.prompt_eval_duration,
391                    eval_count: resp.eval_count,
392                    eval_duration: resp.eval_duration,
393                    message: Message::Assistant {
394                        content,
395                        thinking,
396                        images: None,
397                        name: None,
398                        tool_calls,
399                    },
400                };
401
402                Ok(completion::CompletionResponse {
403                    choice,
404                    usage: Usage {
405                        input_tokens: prompt_tokens,
406                        output_tokens: completion_tokens,
407                        total_tokens: prompt_tokens + completion_tokens,
408                    },
409                    raw_response,
410                })
411            }
412            _ => Err(CompletionError::ResponseError(
413                "Chat response does not include an assistant message".into(),
414            )),
415        }
416    }
417}
418
419// ---------- Completion Model ----------
420
421#[derive(Clone)]
422pub struct CompletionModel {
423    client: Client,
424    pub model: String,
425}
426
427impl CompletionModel {
428    pub fn new(client: Client, model: &str) -> Self {
429        Self {
430            client,
431            model: model.to_owned(),
432        }
433    }
434
435    fn create_completion_request(
436        &self,
437        completion_request: CompletionRequest,
438    ) -> Result<Value, CompletionError> {
439        // Build up the order of messages (context, chat_history)
440        let mut partial_history = vec![];
441        if let Some(docs) = completion_request.normalized_documents() {
442            partial_history.push(docs);
443        }
444        partial_history.extend(completion_request.chat_history);
445
446        // Initialize full history with preamble (or empty if non-existent)
447        let mut full_history: Vec<Message> = completion_request
448            .preamble
449            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
450
451        // Convert and extend the rest of the history
452        full_history.extend(
453            partial_history
454                .into_iter()
455                .map(|msg| msg.try_into())
456                .collect::<Result<Vec<Vec<Message>>, _>>()?
457                .into_iter()
458                .flatten()
459                .collect::<Vec<Message>>(),
460        );
461
462        // Convert internal prompt into a provider Message
463        let options = if let Some(extra) = completion_request.additional_params {
464            json_utils::merge(
465                json!({ "temperature": completion_request.temperature }),
466                extra,
467            )
468        } else {
469            json!({ "temperature": completion_request.temperature })
470        };
471
472        let mut request_payload = json!({
473            "model": self.model,
474            "messages": full_history,
475            "options": options,
476            "stream": false,
477        });
478        if !completion_request.tools.is_empty() {
479            request_payload["tools"] = json!(
480                completion_request
481                    .tools
482                    .into_iter()
483                    .map(|tool| tool.into())
484                    .collect::<Vec<ToolDefinition>>()
485            );
486        }
487
488        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
489
490        Ok(request_payload)
491    }
492}
493
494// ---------- CompletionModel Implementation ----------
495
496#[derive(Clone, Serialize, Deserialize, Debug)]
497pub struct StreamingCompletionResponse {
498    pub done_reason: Option<String>,
499    pub total_duration: Option<u64>,
500    pub load_duration: Option<u64>,
501    pub prompt_eval_count: Option<u64>,
502    pub prompt_eval_duration: Option<u64>,
503    pub eval_count: Option<u64>,
504    pub eval_duration: Option<u64>,
505}
506
507impl GetTokenUsage for StreamingCompletionResponse {
508    fn token_usage(&self) -> Option<crate::completion::Usage> {
509        let mut usage = crate::completion::Usage::new();
510        let input_tokens = self.prompt_eval_count.unwrap_or_default();
511        let output_tokens = self.eval_count.unwrap_or_default();
512        usage.input_tokens = input_tokens;
513        usage.output_tokens = output_tokens;
514        usage.total_tokens = input_tokens + output_tokens;
515
516        Some(usage)
517    }
518}
519
520impl completion::CompletionModel for CompletionModel {
521    type Response = CompletionResponse;
522    type StreamingResponse = StreamingCompletionResponse;
523
524    #[cfg_attr(feature = "worker", worker::send)]
525    async fn completion(
526        &self,
527        completion_request: CompletionRequest,
528    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
529        let request_payload = self.create_completion_request(completion_request)?;
530
531        let response = self
532            .client
533            .post("api/chat")?
534            .json(&request_payload)
535            .send()
536            .await
537            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
538        if response.status().is_success() {
539            let text = response
540                .text()
541                .await
542                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
543            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
544            let chat_resp: CompletionResponse = serde_json::from_str(&text)
545                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
546            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
547            Ok(conv)
548        } else {
549            let err_text = response
550                .text()
551                .await
552                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
553            Err(CompletionError::ProviderError(err_text))
554        }
555    }
556
557    #[cfg_attr(feature = "worker", worker::send)]
558    async fn stream(
559        &self,
560        request: CompletionRequest,
561    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
562    {
563        let mut request_payload = self.create_completion_request(request)?;
564        merge_inplace(&mut request_payload, json!({"stream": true}));
565
566        let response = self
567            .client
568            .post("api/chat")?
569            .json(&request_payload)
570            .send()
571            .await
572            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
573
574        if !response.status().is_success() {
575            let err_text = response
576                .text()
577                .await
578                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
579            return Err(CompletionError::ProviderError(err_text));
580        }
581
582        let stream = Box::pin(stream! {
583            let mut stream = response.bytes_stream();
584            while let Some(chunk_result) = stream.next().await {
585                let chunk = match chunk_result {
586                    Ok(c) => c,
587                    Err(e) => {
588                        yield Err(CompletionError::from(e));
589                        break;
590                    }
591                };
592
593                let text = match String::from_utf8(chunk.to_vec()) {
594                    Ok(t) => t,
595                    Err(e) => {
596                        yield Err(CompletionError::ResponseError(e.to_string()));
597                        break;
598                    }
599                };
600
601
602                for line in text.lines() {
603                    let line = line.to_string();
604
605                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
606                        continue;
607                    };
608
609                    match response.message {
610                        Message::Assistant{ content, tool_calls, .. } => {
611                            if !content.is_empty() {
612                                yield Ok(RawStreamingChoice::Message(content))
613                            }
614
615                            for tool_call in tool_calls.iter() {
616                                let function = tool_call.function.clone();
617
618                                yield Ok(RawStreamingChoice::ToolCall {
619                                    id: "".to_string(),
620                                    name: function.name,
621                                    arguments: function.arguments,
622                                    call_id: None
623                                });
624                            }
625                        }
626                        _ => {
627                            continue;
628                        }
629                    }
630
631                    if response.done {
632                        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
633                            total_duration: response.total_duration,
634                            load_duration: response.load_duration,
635                            prompt_eval_count: response.prompt_eval_count,
636                            prompt_eval_duration: response.prompt_eval_duration,
637                            eval_count: response.eval_count,
638                            eval_duration: response.eval_duration,
639                            done_reason: response.done_reason,
640                        }));
641                    }
642                }
643            }
644        });
645
646        Ok(streaming::StreamingCompletionResponse::stream(stream))
647    }
648}
649
650// ---------- Tool Definition Conversion ----------
651
652/// Ollama-required tool definition format.
653#[derive(Clone, Debug, Deserialize, Serialize)]
654pub struct ToolDefinition {
655    #[serde(rename = "type")]
656    pub type_field: String, // Fixed as "function"
657    pub function: completion::ToolDefinition,
658}
659
660/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
661impl From<crate::completion::ToolDefinition> for ToolDefinition {
662    fn from(tool: crate::completion::ToolDefinition) -> Self {
663        ToolDefinition {
664            type_field: "function".to_owned(),
665            function: completion::ToolDefinition {
666                name: tool.name,
667                description: tool.description,
668                parameters: tool.parameters,
669            },
670        }
671    }
672}
673
674#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
675pub struct ToolCall {
676    // pub id: String,
677    #[serde(default, rename = "type")]
678    pub r#type: ToolType,
679    pub function: Function,
680}
681#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
682#[serde(rename_all = "lowercase")]
683pub enum ToolType {
684    #[default]
685    Function,
686}
687#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
688pub struct Function {
689    pub name: String,
690    pub arguments: Value,
691}
692
693// ---------- Provider Message Definition ----------
694
695#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
696#[serde(tag = "role", rename_all = "lowercase")]
697pub enum Message {
698    User {
699        content: String,
700        #[serde(skip_serializing_if = "Option::is_none")]
701        images: Option<Vec<String>>,
702        #[serde(skip_serializing_if = "Option::is_none")]
703        name: Option<String>,
704    },
705    Assistant {
706        #[serde(default)]
707        content: String,
708        #[serde(skip_serializing_if = "Option::is_none")]
709        thinking: Option<String>,
710        #[serde(skip_serializing_if = "Option::is_none")]
711        images: Option<Vec<String>>,
712        #[serde(skip_serializing_if = "Option::is_none")]
713        name: Option<String>,
714        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
715        tool_calls: Vec<ToolCall>,
716    },
717    System {
718        content: String,
719        #[serde(skip_serializing_if = "Option::is_none")]
720        images: Option<Vec<String>>,
721        #[serde(skip_serializing_if = "Option::is_none")]
722        name: Option<String>,
723    },
724    #[serde(rename = "tool")]
725    ToolResult {
726        #[serde(rename = "tool_name")]
727        name: String,
728        content: String,
729    },
730}
731
732/// -----------------------------
733/// Provider Message Conversions
734/// -----------------------------
735/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
736/// (Only User and Assistant variants are supported.)
737impl TryFrom<crate::message::Message> for Vec<Message> {
738    type Error = crate::message::MessageError;
739    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
740        use crate::message::Message as InternalMessage;
741        match internal_msg {
742            InternalMessage::User { content, .. } => {
743                let (tool_results, other_content): (Vec<_>, Vec<_>) =
744                    content.into_iter().partition(|content| {
745                        matches!(content, crate::message::UserContent::ToolResult(_))
746                    });
747
748                if !tool_results.is_empty() {
749                    tool_results
750                        .into_iter()
751                        .map(|content| match content {
752                            crate::message::UserContent::ToolResult(
753                                crate::message::ToolResult { id, content, .. },
754                            ) => {
755                                // Ollama expects a single string for tool results, so we concatenate
756                                let content_string = content
757                                    .into_iter()
758                                    .map(|content| match content {
759                                        crate::message::ToolResultContent::Text(text) => text.text,
760                                        _ => "[Non-text content]".to_string(),
761                                    })
762                                    .collect::<Vec<_>>()
763                                    .join("\n");
764
765                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
766                                    name: id,
767                                    content: content_string,
768                                })
769                            }
770                            _ => unreachable!(),
771                        })
772                        .collect::<Result<Vec<_>, _>>()
773                } else {
774                    // Ollama requires separate text content and images array
775                    let (texts, images) = other_content.into_iter().fold(
776                        (Vec::new(), Vec::new()),
777                        |(mut texts, mut images), content| {
778                            match content {
779                                crate::message::UserContent::Text(crate::message::Text {
780                                    text,
781                                }) => texts.push(text),
782                                crate::message::UserContent::Image(crate::message::Image {
783                                    data,
784                                    ..
785                                }) => images.push(data),
786                                crate::message::UserContent::Document(
787                                    crate::message::Document { data, .. },
788                                ) => texts.push(data),
789                                _ => {} // Audio not supported by Ollama
790                            }
791                            (texts, images)
792                        },
793                    );
794
795                    Ok(vec![Message::User {
796                        content: texts.join(" "),
797                        images: if images.is_empty() {
798                            None
799                        } else {
800                            Some(images)
801                        },
802                        name: None,
803                    }])
804                }
805            }
806            InternalMessage::Assistant { content, .. } => {
807                let mut thinking: Option<String> = None;
808                let (text_content, tool_calls) = content.into_iter().fold(
809                    (Vec::new(), Vec::new()),
810                    |(mut texts, mut tools), content| {
811                        match content {
812                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
813                            crate::message::AssistantContent::ToolCall(tool_call) => {
814                                tools.push(tool_call)
815                            }
816                            crate::message::AssistantContent::Reasoning(
817                                crate::message::Reasoning { reasoning, .. },
818                            ) => {
819                                thinking =
820                                    Some(reasoning.first().cloned().unwrap_or(String::new()));
821                            }
822                        }
823                        (texts, tools)
824                    },
825                );
826
827                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
828                //  so either `content` or `tool_calls` will have some content.
829                Ok(vec![Message::Assistant {
830                    content: text_content.join(" "),
831                    thinking,
832                    images: None,
833                    name: None,
834                    tool_calls: tool_calls
835                        .into_iter()
836                        .map(|tool_call| tool_call.into())
837                        .collect::<Vec<_>>(),
838                }])
839            }
840        }
841    }
842}
843
844/// Conversion from provider Message to a completion message.
845/// This is needed so that responses can be converted back into chat history.
846impl From<Message> for crate::completion::Message {
847    fn from(msg: Message) -> Self {
848        match msg {
849            Message::User { content, .. } => crate::completion::Message::User {
850                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
851                    text: content,
852                })),
853            },
854            Message::Assistant {
855                content,
856                tool_calls,
857                ..
858            } => {
859                let mut assistant_contents =
860                    vec![crate::completion::message::AssistantContent::Text(Text {
861                        text: content,
862                    })];
863                for tc in tool_calls {
864                    assistant_contents.push(
865                        crate::completion::message::AssistantContent::tool_call(
866                            tc.function.name.clone(),
867                            tc.function.name,
868                            tc.function.arguments,
869                        ),
870                    );
871                }
872                crate::completion::Message::Assistant {
873                    id: None,
874                    content: OneOrMany::many(assistant_contents).unwrap(),
875                }
876            }
877            // System and ToolResult are converted to User message as needed.
878            Message::System { content, .. } => crate::completion::Message::User {
879                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
880                    text: content,
881                })),
882            },
883            Message::ToolResult { name, content } => crate::completion::Message::User {
884                content: OneOrMany::one(message::UserContent::tool_result(
885                    name,
886                    OneOrMany::one(message::ToolResultContent::text(content)),
887                )),
888            },
889        }
890    }
891}
892
893impl Message {
894    /// Constructs a system message.
895    pub fn system(content: &str) -> Self {
896        Message::System {
897            content: content.to_owned(),
898            images: None,
899            name: None,
900        }
901    }
902}
903
904// ---------- Additional Message Types ----------
905
906impl From<crate::message::ToolCall> for ToolCall {
907    fn from(tool_call: crate::message::ToolCall) -> Self {
908        Self {
909            r#type: ToolType::Function,
910            function: Function {
911                name: tool_call.function.name,
912                arguments: tool_call.function.arguments,
913            },
914        }
915    }
916}
917
918#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
919pub struct SystemContent {
920    #[serde(default)]
921    r#type: SystemContentType,
922    text: String,
923}
924
925#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
926#[serde(rename_all = "lowercase")]
927pub enum SystemContentType {
928    #[default]
929    Text,
930}
931
932impl From<String> for SystemContent {
933    fn from(s: String) -> Self {
934        SystemContent {
935            r#type: SystemContentType::default(),
936            text: s,
937        }
938    }
939}
940
941impl FromStr for SystemContent {
942    type Err = std::convert::Infallible;
943    fn from_str(s: &str) -> Result<Self, Self::Err> {
944        Ok(SystemContent {
945            r#type: SystemContentType::default(),
946            text: s.to_string(),
947        })
948    }
949}
950
951#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
952pub struct AssistantContent {
953    pub text: String,
954}
955
956impl FromStr for AssistantContent {
957    type Err = std::convert::Infallible;
958    fn from_str(s: &str) -> Result<Self, Self::Err> {
959        Ok(AssistantContent { text: s.to_owned() })
960    }
961}
962
963#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
964#[serde(tag = "type", rename_all = "lowercase")]
965pub enum UserContent {
966    Text { text: String },
967    Image { image_url: ImageUrl },
968    // Audio variant removed as Ollama API does not support audio input.
969}
970
971impl FromStr for UserContent {
972    type Err = std::convert::Infallible;
973    fn from_str(s: &str) -> Result<Self, Self::Err> {
974        Ok(UserContent::Text { text: s.to_owned() })
975    }
976}
977
978#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
979pub struct ImageUrl {
980    pub url: String,
981    #[serde(default)]
982    pub detail: ImageDetail,
983}
984
985// =================================================================
986// Tests
987// =================================================================
988
989#[cfg(test)]
990mod tests {
991    use super::*;
992    use serde_json::json;
993
994    // Test deserialization and conversion for the /api/chat endpoint.
995    #[tokio::test]
996    async fn test_chat_completion() {
997        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
998        let sample_chat_response = json!({
999            "model": "llama3.2",
1000            "created_at": "2023-08-04T19:22:45.499127Z",
1001            "message": {
1002                "role": "assistant",
1003                "content": "The sky is blue because of Rayleigh scattering.",
1004                "images": null,
1005                "tool_calls": [
1006                    {
1007                        "type": "function",
1008                        "function": {
1009                            "name": "get_current_weather",
1010                            "arguments": {
1011                                "location": "San Francisco, CA",
1012                                "format": "celsius"
1013                            }
1014                        }
1015                    }
1016                ]
1017            },
1018            "done": true,
1019            "total_duration": 8000000000u64,
1020            "load_duration": 6000000u64,
1021            "prompt_eval_count": 61u64,
1022            "prompt_eval_duration": 400000000u64,
1023            "eval_count": 468u64,
1024            "eval_duration": 7700000000u64
1025        });
1026        let sample_text = sample_chat_response.to_string();
1027
1028        let chat_resp: CompletionResponse =
1029            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1030        let conv: completion::CompletionResponse<CompletionResponse> =
1031            chat_resp.try_into().unwrap();
1032        assert!(
1033            !conv.choice.is_empty(),
1034            "Expected non-empty choice in chat response"
1035        );
1036    }
1037
1038    // Test conversion from provider Message to completion Message.
1039    #[test]
1040    fn test_message_conversion() {
1041        // Construct a provider Message (User variant with String content).
1042        let provider_msg = Message::User {
1043            content: "Test message".to_owned(),
1044            images: None,
1045            name: None,
1046        };
1047        // Convert it into a completion::Message.
1048        let comp_msg: crate::completion::Message = provider_msg.into();
1049        match comp_msg {
1050            crate::completion::Message::User { content } => {
1051                // Assume OneOrMany<T> has a method first() to access the first element.
1052                let first_content = content.first();
1053                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1054                match first_content {
1055                    crate::completion::message::UserContent::Text(text_struct) => {
1056                        assert_eq!(text_struct.text, "Test message");
1057                    }
1058                    _ => panic!("Expected text content in conversion"),
1059                }
1060            }
1061            _ => panic!("Conversion from provider Message to completion Message failed"),
1062        }
1063    }
1064
1065    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1066    #[test]
1067    fn test_tool_definition_conversion() {
1068        // Internal tool definition from the completion module.
1069        let internal_tool = crate::completion::ToolDefinition {
1070            name: "get_current_weather".to_owned(),
1071            description: "Get the current weather for a location".to_owned(),
1072            parameters: json!({
1073                "type": "object",
1074                "properties": {
1075                    "location": {
1076                        "type": "string",
1077                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1078                    },
1079                    "format": {
1080                        "type": "string",
1081                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1082                        "enum": ["celsius", "fahrenheit"]
1083                    }
1084                },
1085                "required": ["location", "format"]
1086            }),
1087        };
1088        // Convert internal tool to Ollama's tool definition.
1089        let ollama_tool: ToolDefinition = internal_tool.into();
1090        assert_eq!(ollama_tool.type_field, "function");
1091        assert_eq!(ollama_tool.function.name, "get_current_weather");
1092        assert_eq!(
1093            ollama_tool.function.description,
1094            "Get the current weather for a location"
1095        );
1096        // Check JSON fields in parameters.
1097        let params = &ollama_tool.function.parameters;
1098        assert_eq!(params["properties"]["location"]["type"], "string");
1099    }
1100}