Skip to main content

rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust,ignore
5//! use rig::client::{Nothing, CompletionClient};
6//! use rig::completion::Prompt;
7//! use rig::providers::ollama;
8//!
9//! // Create a new Ollama client (defaults to http://localhost:11434)
10//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct
11//! let client: ollama::Client = ollama::Client::new(Nothing).unwrap();
12//!
13//! // Create an agent with a preamble
14//! let comedian_agent = client
15//!     .agent("qwen2.5:14b")
16//!     .preamble("You are a comedian here to entertain the user using humour and jokes.")
17//!     .build();
18//!
19//! // Prompt the agent and print the response
20//! let response = comedian_agent.prompt("Entertain me!").await?;
21//! println!("{response}");
22//!
23//! // Create an embedding model using the "all-minilm" model
24//! let emb_model = client.embedding_model("all-minilm", 384);
25//! let embeddings = emb_model.embed_texts(vec![
26//!     "Why is the sky blue?".to_owned(),
27//!     "Why is the grass green?".to_owned()
28//! ]).await?;
29//! println!("Embedding response: {:?}", embeddings);
30//!
31//! // Create an extractor if needed
32//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
33//! ```
34use crate::client::{
35    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42    OneOrMany,
43    completion::{self, CompletionError, CompletionRequest},
44    embeddings::{self, EmbeddingError},
45    json_utils, message,
46    message::{ImageDetail, Text},
47    streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57// ---------- Main Client ----------
58
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68    type Builder = OllamaBuilder;
69
70    const VERIFY_PATH: &'static str = "api/tags";
71
72    fn build<H>(
73        _: &crate::client::ClientBuilder<
74            Self::Builder,
75            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
76            H,
77        >,
78    ) -> http_client::Result<Self> {
79        Ok(Self)
80    }
81}
82
83impl<H> Capabilities<H> for OllamaExt {
84    type Completion = Capable<CompletionModel<H>>;
85    type Transcription = Nothing;
86    type Embeddings = Capable<EmbeddingModel<H>>;
87    #[cfg(feature = "image")]
88    type ImageGeneration = Nothing;
89
90    #[cfg(feature = "audio")]
91    type AudioGeneration = Nothing;
92}
93
94impl DebugExt for OllamaExt {}
95
96impl ProviderBuilder for OllamaBuilder {
97    type Output = OllamaExt;
98    type ApiKey = Nothing;
99
100    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
101}
102
103pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
104pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
105
106impl ProviderClient for Client {
107    type Input = Nothing;
108
109    fn from_env() -> Self {
110        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
111
112        Self::builder()
113            .api_key(Nothing)
114            .base_url(&api_base)
115            .build()
116            .unwrap()
117    }
118
119    fn from_val(_: Self::Input) -> Self {
120        Self::builder().api_key(Nothing).build().unwrap()
121    }
122}
123
124// ---------- API Error and Response Structures ----------
125
126#[derive(Debug, Deserialize)]
127struct ApiErrorResponse {
128    message: String,
129}
130
131#[derive(Debug, Deserialize)]
132#[serde(untagged)]
133enum ApiResponse<T> {
134    Ok(T),
135    Err(ApiErrorResponse),
136}
137
138// ---------- Embedding API ----------
139
140pub const ALL_MINILM: &str = "all-minilm";
141pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
142
143#[derive(Debug, Serialize, Deserialize)]
144pub struct EmbeddingResponse {
145    pub model: String,
146    pub embeddings: Vec<Vec<f64>>,
147    #[serde(default)]
148    pub total_duration: Option<u64>,
149    #[serde(default)]
150    pub load_duration: Option<u64>,
151    #[serde(default)]
152    pub prompt_eval_count: Option<u64>,
153}
154
155impl From<ApiErrorResponse> for EmbeddingError {
156    fn from(err: ApiErrorResponse) -> Self {
157        EmbeddingError::ProviderError(err.message)
158    }
159}
160
161impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
162    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
163        match value {
164            ApiResponse::Ok(response) => Ok(response),
165            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
166        }
167    }
168}
169
170// ---------- Embedding Model ----------
171
172#[derive(Clone)]
173pub struct EmbeddingModel<T = reqwest::Client> {
174    client: Client<T>,
175    pub model: String,
176    ndims: usize,
177}
178
179impl<T> EmbeddingModel<T> {
180    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
181        Self {
182            client,
183            model: model.into(),
184            ndims,
185        }
186    }
187
188    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
189        Self {
190            client,
191            model: model.into(),
192            ndims,
193        }
194    }
195}
196
197impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
198where
199    T: HttpClientExt + Clone + 'static,
200{
201    type Client = Client<T>;
202
203    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
204        Self::new(client.clone(), model, dims.unwrap())
205    }
206
207    const MAX_DOCUMENTS: usize = 1024;
208    fn ndims(&self) -> usize {
209        self.ndims
210    }
211
212    async fn embed_texts(
213        &self,
214        documents: impl IntoIterator<Item = String>,
215    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
216        let docs: Vec<String> = documents.into_iter().collect();
217
218        let body = serde_json::to_vec(&json!({
219            "model": self.model,
220            "input": docs
221        }))?;
222
223        let req = self
224            .client
225            .post("api/embed")?
226            .body(body)
227            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
228
229        let response = self.client.send(req).await?;
230
231        if !response.status().is_success() {
232            let text = http_client::text(response).await?;
233            return Err(EmbeddingError::ProviderError(text));
234        }
235
236        let bytes: Vec<u8> = response.into_body().await?;
237
238        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
239
240        if api_resp.embeddings.len() != docs.len() {
241            return Err(EmbeddingError::ResponseError(
242                "Number of returned embeddings does not match input".into(),
243            ));
244        }
245        Ok(api_resp
246            .embeddings
247            .into_iter()
248            .zip(docs.into_iter())
249            .map(|(vec, document)| embeddings::Embedding { document, vec })
250            .collect())
251    }
252}
253
254// ---------- Completion API ----------
255
256pub const LLAMA3_2: &str = "llama3.2";
257pub const LLAVA: &str = "llava";
258pub const MISTRAL: &str = "mistral";
259
260#[derive(Debug, Serialize, Deserialize)]
261pub struct CompletionResponse {
262    pub model: String,
263    pub created_at: String,
264    pub message: Message,
265    pub done: bool,
266    #[serde(default)]
267    pub done_reason: Option<String>,
268    #[serde(default)]
269    pub total_duration: Option<u64>,
270    #[serde(default)]
271    pub load_duration: Option<u64>,
272    #[serde(default)]
273    pub prompt_eval_count: Option<u64>,
274    #[serde(default)]
275    pub prompt_eval_duration: Option<u64>,
276    #[serde(default)]
277    pub eval_count: Option<u64>,
278    #[serde(default)]
279    pub eval_duration: Option<u64>,
280}
281impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
282    type Error = CompletionError;
283    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
284        match resp.message {
285            // Process only if an assistant message is present.
286            Message::Assistant {
287                content,
288                thinking,
289                tool_calls,
290                ..
291            } => {
292                let mut assistant_contents = Vec::new();
293                // Add the assistant's text content if any.
294                if !content.is_empty() {
295                    assistant_contents.push(completion::AssistantContent::text(&content));
296                }
297                // Process tool_calls following Ollama's chat response definition.
298                // Each ToolCall has an id, a type, and a function field.
299                for tc in tool_calls.iter() {
300                    assistant_contents.push(completion::AssistantContent::tool_call(
301                        tc.function.name.clone(),
302                        tc.function.name.clone(),
303                        tc.function.arguments.clone(),
304                    ));
305                }
306                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
307                    CompletionError::ResponseError("No content provided".to_owned())
308                })?;
309                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
310                let completion_tokens = resp.eval_count.unwrap_or(0);
311
312                let raw_response = CompletionResponse {
313                    model: resp.model,
314                    created_at: resp.created_at,
315                    done: resp.done,
316                    done_reason: resp.done_reason,
317                    total_duration: resp.total_duration,
318                    load_duration: resp.load_duration,
319                    prompt_eval_count: resp.prompt_eval_count,
320                    prompt_eval_duration: resp.prompt_eval_duration,
321                    eval_count: resp.eval_count,
322                    eval_duration: resp.eval_duration,
323                    message: Message::Assistant {
324                        content,
325                        thinking,
326                        images: None,
327                        name: None,
328                        tool_calls,
329                    },
330                };
331
332                Ok(completion::CompletionResponse {
333                    choice,
334                    usage: Usage {
335                        input_tokens: prompt_tokens,
336                        output_tokens: completion_tokens,
337                        total_tokens: prompt_tokens + completion_tokens,
338                        cached_input_tokens: 0,
339                    },
340                    raw_response,
341                })
342            }
343            _ => Err(CompletionError::ResponseError(
344                "Chat response does not include an assistant message".into(),
345            )),
346        }
347    }
348}
349
350#[derive(Debug, Serialize, Deserialize)]
351pub(super) struct OllamaCompletionRequest {
352    model: String,
353    pub messages: Vec<Message>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    temperature: Option<f64>,
356    #[serde(skip_serializing_if = "Vec::is_empty")]
357    tools: Vec<ToolDefinition>,
358    pub stream: bool,
359    think: bool,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    max_tokens: Option<u64>,
362    options: serde_json::Value,
363}
364
365impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
366    type Error = CompletionError;
367
368    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
369        if req.tool_choice.is_some() {
370            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
371        }
372        // Build up the order of messages (context, chat_history, prompt)
373        let mut partial_history = vec![];
374        if let Some(docs) = req.normalized_documents() {
375            partial_history.push(docs);
376        }
377        partial_history.extend(req.chat_history);
378
379        // Add preamble to chat history (if available)
380        let mut full_history: Vec<Message> = match &req.preamble {
381            Some(preamble) => vec![Message::system(preamble)],
382            None => vec![],
383        };
384
385        // Convert and extend the rest of the history
386        full_history.extend(
387            partial_history
388                .into_iter()
389                .map(message::Message::try_into)
390                .collect::<Result<Vec<Vec<Message>>, _>>()?
391                .into_iter()
392                .flatten()
393                .collect::<Vec<_>>(),
394        );
395
396        let mut think = false;
397
398        // TODO: Fix this up to include the full range of ollama options
399        let options = if let Some(mut extra) = req.additional_params {
400            if extra.get("think").is_some() {
401                think = extra["think"].take().as_bool().ok_or_else(|| {
402                    CompletionError::RequestError("`think` must be a bool".into())
403                })?;
404            }
405            json_utils::merge(json!({ "temperature": req.temperature }), extra)
406        } else {
407            json!({ "temperature": req.temperature })
408        };
409
410        Ok(Self {
411            model: model.to_string(),
412            messages: full_history,
413            temperature: req.temperature,
414            max_tokens: req.max_tokens,
415            stream: false,
416            think,
417            tools: req
418                .tools
419                .clone()
420                .into_iter()
421                .map(ToolDefinition::from)
422                .collect::<Vec<_>>(),
423            options,
424        })
425    }
426}
427
428#[derive(Clone)]
429pub struct CompletionModel<T = reqwest::Client> {
430    client: Client<T>,
431    pub model: String,
432}
433
434impl<T> CompletionModel<T> {
435    pub fn new(client: Client<T>, model: &str) -> Self {
436        Self {
437            client,
438            model: model.to_owned(),
439        }
440    }
441}
442
443// ---------- CompletionModel Implementation ----------
444
445#[derive(Clone, Serialize, Deserialize, Debug)]
446pub struct StreamingCompletionResponse {
447    pub done_reason: Option<String>,
448    pub total_duration: Option<u64>,
449    pub load_duration: Option<u64>,
450    pub prompt_eval_count: Option<u64>,
451    pub prompt_eval_duration: Option<u64>,
452    pub eval_count: Option<u64>,
453    pub eval_duration: Option<u64>,
454}
455
456impl GetTokenUsage for StreamingCompletionResponse {
457    fn token_usage(&self) -> Option<crate::completion::Usage> {
458        let mut usage = crate::completion::Usage::new();
459        let input_tokens = self.prompt_eval_count.unwrap_or_default();
460        let output_tokens = self.eval_count.unwrap_or_default();
461        usage.input_tokens = input_tokens;
462        usage.output_tokens = output_tokens;
463        usage.total_tokens = input_tokens + output_tokens;
464
465        Some(usage)
466    }
467}
468
469impl<T> completion::CompletionModel for CompletionModel<T>
470where
471    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
472{
473    type Response = CompletionResponse;
474    type StreamingResponse = StreamingCompletionResponse;
475
476    type Client = Client<T>;
477
478    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
479        Self::new(client.clone(), model.into().as_str())
480    }
481
482    async fn completion(
483        &self,
484        completion_request: CompletionRequest,
485    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
486        let span = if tracing::Span::current().is_disabled() {
487            info_span!(
488                target: "rig::completions",
489                "chat",
490                gen_ai.operation.name = "chat",
491                gen_ai.provider.name = "ollama",
492                gen_ai.request.model = self.model,
493                gen_ai.system_instructions = tracing::field::Empty,
494                gen_ai.response.id = tracing::field::Empty,
495                gen_ai.response.model = tracing::field::Empty,
496                gen_ai.usage.output_tokens = tracing::field::Empty,
497                gen_ai.usage.input_tokens = tracing::field::Empty,
498            )
499        } else {
500            tracing::Span::current()
501        };
502
503        span.record("gen_ai.system_instructions", &completion_request.preamble);
504        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
505
506        if tracing::enabled!(tracing::Level::TRACE) {
507            tracing::trace!(target: "rig::completions",
508                "Ollama completion request: {}",
509                serde_json::to_string_pretty(&request)?
510            );
511        }
512
513        let body = serde_json::to_vec(&request)?;
514
515        let req = self
516            .client
517            .post("api/chat")?
518            .body(body)
519            .map_err(http_client::Error::from)?;
520
521        let async_block = async move {
522            let response = self.client.send::<_, Bytes>(req).await?;
523            let status = response.status();
524            let response_body = response.into_body().into_future().await?.to_vec();
525
526            if !status.is_success() {
527                return Err(CompletionError::ProviderError(
528                    String::from_utf8_lossy(&response_body).to_string(),
529                ));
530            }
531
532            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
533            let span = tracing::Span::current();
534            span.record("gen_ai.response.model_name", &response.model);
535            span.record(
536                "gen_ai.usage.input_tokens",
537                response.prompt_eval_count.unwrap_or_default(),
538            );
539            span.record(
540                "gen_ai.usage.output_tokens",
541                response.eval_count.unwrap_or_default(),
542            );
543
544            if tracing::enabled!(tracing::Level::TRACE) {
545                tracing::trace!(target: "rig::completions",
546                    "Ollama completion response: {}",
547                    serde_json::to_string_pretty(&response)?
548                );
549            }
550
551            let response: completion::CompletionResponse<CompletionResponse> =
552                response.try_into()?;
553
554            Ok(response)
555        };
556
557        tracing::Instrument::instrument(async_block, span).await
558    }
559
560    async fn stream(
561        &self,
562        request: CompletionRequest,
563    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
564    {
565        let span = if tracing::Span::current().is_disabled() {
566            info_span!(
567                target: "rig::completions",
568                "chat_streaming",
569                gen_ai.operation.name = "chat_streaming",
570                gen_ai.provider.name = "ollama",
571                gen_ai.request.model = self.model,
572                gen_ai.system_instructions = tracing::field::Empty,
573                gen_ai.response.id = tracing::field::Empty,
574                gen_ai.response.model = self.model,
575                gen_ai.usage.output_tokens = tracing::field::Empty,
576                gen_ai.usage.input_tokens = tracing::field::Empty,
577            )
578        } else {
579            tracing::Span::current()
580        };
581
582        span.record("gen_ai.system_instructions", &request.preamble);
583
584        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
585        request.stream = true;
586
587        if tracing::enabled!(tracing::Level::TRACE) {
588            tracing::trace!(target: "rig::completions",
589                "Ollama streaming completion request: {}",
590                serde_json::to_string_pretty(&request)?
591            );
592        }
593
594        let body = serde_json::to_vec(&request)?;
595
596        let req = self
597            .client
598            .post("api/chat")?
599            .body(body)
600            .map_err(http_client::Error::from)?;
601
602        let response = self.client.send_streaming(req).await?;
603        let status = response.status();
604        let mut byte_stream = response.into_body();
605
606        if !status.is_success() {
607            return Err(CompletionError::ProviderError(format!(
608                "Got error status code trying to send a request to Ollama: {status}"
609            )));
610        }
611
612        let stream = try_stream! {
613            let span = tracing::Span::current();
614            let mut tool_calls_final = Vec::new();
615            let mut text_response = String::new();
616            let mut thinking_response = String::new();
617
618            while let Some(chunk) = byte_stream.next().await {
619                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
620
621                for line in bytes.split(|&b| b == b'\n') {
622                    if line.is_empty() {
623                        continue;
624                    }
625
626                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
627
628                    let response: CompletionResponse = serde_json::from_slice(line)?;
629
630                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
631                        if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
632                            thinking_response += &thinking_content;
633                            yield RawStreamingChoice::ReasoningDelta {
634                                id: None,
635                                reasoning: thinking_content,
636                            };
637                        }
638
639                        if !content.is_empty() {
640                            text_response += &content;
641                            yield RawStreamingChoice::Message(content);
642                        }
643
644                        for tool_call in tool_calls {
645                            tool_calls_final.push(tool_call.clone());
646                            yield RawStreamingChoice::ToolCall(
647                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
648                            );
649                        }
650                    }
651
652                    if response.done {
653                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
654                        span.record("gen_ai.usage.output_tokens", response.eval_count);
655                        let message = Message::Assistant {
656                            content: text_response.clone(),
657                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
658                            images: None,
659                            name: None,
660                            tool_calls: tool_calls_final.clone()
661                        };
662                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
663                        yield RawStreamingChoice::FinalResponse(
664                            StreamingCompletionResponse {
665                                total_duration: response.total_duration,
666                                load_duration: response.load_duration,
667                                prompt_eval_count: response.prompt_eval_count,
668                                prompt_eval_duration: response.prompt_eval_duration,
669                                eval_count: response.eval_count,
670                                eval_duration: response.eval_duration,
671                                done_reason: response.done_reason,
672                            }
673                        );
674                        break;
675                    }
676                }
677            }
678        }.instrument(span);
679
680        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
681            stream,
682        )))
683    }
684}
685
686// ---------- Tool Definition Conversion ----------
687
688/// Ollama-required tool definition format.
689#[derive(Clone, Debug, Deserialize, Serialize)]
690pub struct ToolDefinition {
691    #[serde(rename = "type")]
692    pub type_field: String, // Fixed as "function"
693    pub function: completion::ToolDefinition,
694}
695
696/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
697impl From<crate::completion::ToolDefinition> for ToolDefinition {
698    fn from(tool: crate::completion::ToolDefinition) -> Self {
699        ToolDefinition {
700            type_field: "function".to_owned(),
701            function: completion::ToolDefinition {
702                name: tool.name,
703                description: tool.description,
704                parameters: tool.parameters,
705            },
706        }
707    }
708}
709
710#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
711pub struct ToolCall {
712    #[serde(default, rename = "type")]
713    pub r#type: ToolType,
714    pub function: Function,
715}
716#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
717#[serde(rename_all = "lowercase")]
718pub enum ToolType {
719    #[default]
720    Function,
721}
722#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
723pub struct Function {
724    pub name: String,
725    pub arguments: Value,
726}
727
728// ---------- Provider Message Definition ----------
729
730#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
731#[serde(tag = "role", rename_all = "lowercase")]
732pub enum Message {
733    User {
734        content: String,
735        #[serde(skip_serializing_if = "Option::is_none")]
736        images: Option<Vec<String>>,
737        #[serde(skip_serializing_if = "Option::is_none")]
738        name: Option<String>,
739    },
740    Assistant {
741        #[serde(default)]
742        content: String,
743        #[serde(skip_serializing_if = "Option::is_none")]
744        thinking: Option<String>,
745        #[serde(skip_serializing_if = "Option::is_none")]
746        images: Option<Vec<String>>,
747        #[serde(skip_serializing_if = "Option::is_none")]
748        name: Option<String>,
749        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
750        tool_calls: Vec<ToolCall>,
751    },
752    System {
753        content: String,
754        #[serde(skip_serializing_if = "Option::is_none")]
755        images: Option<Vec<String>>,
756        #[serde(skip_serializing_if = "Option::is_none")]
757        name: Option<String>,
758    },
759    #[serde(rename = "tool")]
760    ToolResult {
761        #[serde(rename = "tool_name")]
762        name: String,
763        content: String,
764    },
765}
766
767/// -----------------------------
768/// Provider Message Conversions
769/// -----------------------------
770/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
771/// (Only User and Assistant variants are supported.)
772impl TryFrom<crate::message::Message> for Vec<Message> {
773    type Error = crate::message::MessageError;
774    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
775        use crate::message::Message as InternalMessage;
776        match internal_msg {
777            InternalMessage::User { content, .. } => {
778                let (tool_results, other_content): (Vec<_>, Vec<_>) =
779                    content.into_iter().partition(|content| {
780                        matches!(content, crate::message::UserContent::ToolResult(_))
781                    });
782
783                if !tool_results.is_empty() {
784                    tool_results
785                        .into_iter()
786                        .map(|content| match content {
787                            crate::message::UserContent::ToolResult(
788                                crate::message::ToolResult { id, content, .. },
789                            ) => {
790                                // Ollama expects a single string for tool results, so we concatenate
791                                let content_string = content
792                                    .into_iter()
793                                    .map(|content| match content {
794                                        crate::message::ToolResultContent::Text(text) => text.text,
795                                        _ => "[Non-text content]".to_string(),
796                                    })
797                                    .collect::<Vec<_>>()
798                                    .join("\n");
799
800                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
801                                    name: id,
802                                    content: content_string,
803                                })
804                            }
805                            _ => unreachable!(),
806                        })
807                        .collect::<Result<Vec<_>, _>>()
808                } else {
809                    // Ollama requires separate text content and images array
810                    let (texts, images) = other_content.into_iter().fold(
811                        (Vec::new(), Vec::new()),
812                        |(mut texts, mut images), content| {
813                            match content {
814                                crate::message::UserContent::Text(crate::message::Text {
815                                    text,
816                                }) => texts.push(text),
817                                crate::message::UserContent::Image(crate::message::Image {
818                                    data: DocumentSourceKind::Base64(data),
819                                    ..
820                                }) => images.push(data),
821                                crate::message::UserContent::Document(
822                                    crate::message::Document {
823                                        data:
824                                            DocumentSourceKind::Base64(data)
825                                            | DocumentSourceKind::String(data),
826                                        ..
827                                    },
828                                ) => texts.push(data),
829                                _ => {} // Audio not supported by Ollama
830                            }
831                            (texts, images)
832                        },
833                    );
834
835                    Ok(vec![Message::User {
836                        content: texts.join(" "),
837                        images: if images.is_empty() {
838                            None
839                        } else {
840                            Some(
841                                images
842                                    .into_iter()
843                                    .map(|x| x.to_string())
844                                    .collect::<Vec<String>>(),
845                            )
846                        },
847                        name: None,
848                    }])
849                }
850            }
851            InternalMessage::Assistant { content, .. } => {
852                let mut thinking: Option<String> = None;
853                let mut text_content = Vec::new();
854                let mut tool_calls = Vec::new();
855
856                for content in content.into_iter() {
857                    match content {
858                        crate::message::AssistantContent::Text(text) => {
859                            text_content.push(text.text)
860                        }
861                        crate::message::AssistantContent::ToolCall(tool_call) => {
862                            tool_calls.push(tool_call)
863                        }
864                        crate::message::AssistantContent::Reasoning(
865                            crate::message::Reasoning { reasoning, .. },
866                        ) => {
867                            thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
868                        }
869                        crate::message::AssistantContent::Image(_) => {
870                            return Err(crate::message::MessageError::ConversionError(
871                                "Ollama currently doesn't support images.".into(),
872                            ));
873                        }
874                    }
875                }
876
877                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
878                //  so either `content` or `tool_calls` will have some content.
879                Ok(vec![Message::Assistant {
880                    content: text_content.join(" "),
881                    thinking,
882                    images: None,
883                    name: None,
884                    tool_calls: tool_calls
885                        .into_iter()
886                        .map(|tool_call| tool_call.into())
887                        .collect::<Vec<_>>(),
888                }])
889            }
890        }
891    }
892}
893
894/// Conversion from provider Message to a completion message.
895/// This is needed so that responses can be converted back into chat history.
896impl From<Message> for crate::completion::Message {
897    fn from(msg: Message) -> Self {
898        match msg {
899            Message::User { content, .. } => crate::completion::Message::User {
900                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
901                    text: content,
902                })),
903            },
904            Message::Assistant {
905                content,
906                tool_calls,
907                ..
908            } => {
909                let mut assistant_contents =
910                    vec![crate::completion::message::AssistantContent::Text(Text {
911                        text: content,
912                    })];
913                for tc in tool_calls {
914                    assistant_contents.push(
915                        crate::completion::message::AssistantContent::tool_call(
916                            tc.function.name.clone(),
917                            tc.function.name,
918                            tc.function.arguments,
919                        ),
920                    );
921                }
922                crate::completion::Message::Assistant {
923                    id: None,
924                    content: OneOrMany::many(assistant_contents).unwrap(),
925                }
926            }
927            // System and ToolResult are converted to User message as needed.
928            Message::System { content, .. } => crate::completion::Message::User {
929                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
930                    text: content,
931                })),
932            },
933            Message::ToolResult { name, content } => crate::completion::Message::User {
934                content: OneOrMany::one(message::UserContent::tool_result(
935                    name,
936                    OneOrMany::one(message::ToolResultContent::text(content)),
937                )),
938            },
939        }
940    }
941}
942
943impl Message {
944    /// Constructs a system message.
945    pub fn system(content: &str) -> Self {
946        Message::System {
947            content: content.to_owned(),
948            images: None,
949            name: None,
950        }
951    }
952}
953
954// ---------- Additional Message Types ----------
955
956impl From<crate::message::ToolCall> for ToolCall {
957    fn from(tool_call: crate::message::ToolCall) -> Self {
958        Self {
959            r#type: ToolType::Function,
960            function: Function {
961                name: tool_call.function.name,
962                arguments: tool_call.function.arguments,
963            },
964        }
965    }
966}
967
968#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
969pub struct SystemContent {
970    #[serde(default)]
971    r#type: SystemContentType,
972    text: String,
973}
974
975#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
976#[serde(rename_all = "lowercase")]
977pub enum SystemContentType {
978    #[default]
979    Text,
980}
981
982impl From<String> for SystemContent {
983    fn from(s: String) -> Self {
984        SystemContent {
985            r#type: SystemContentType::default(),
986            text: s,
987        }
988    }
989}
990
991impl FromStr for SystemContent {
992    type Err = std::convert::Infallible;
993    fn from_str(s: &str) -> Result<Self, Self::Err> {
994        Ok(SystemContent {
995            r#type: SystemContentType::default(),
996            text: s.to_string(),
997        })
998    }
999}
1000
1001#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1002pub struct AssistantContent {
1003    pub text: String,
1004}
1005
1006impl FromStr for AssistantContent {
1007    type Err = std::convert::Infallible;
1008    fn from_str(s: &str) -> Result<Self, Self::Err> {
1009        Ok(AssistantContent { text: s.to_owned() })
1010    }
1011}
1012
1013#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1014#[serde(tag = "type", rename_all = "lowercase")]
1015pub enum UserContent {
1016    Text { text: String },
1017    Image { image_url: ImageUrl },
1018    // Audio variant removed as Ollama API does not support audio input.
1019}
1020
1021impl FromStr for UserContent {
1022    type Err = std::convert::Infallible;
1023    fn from_str(s: &str) -> Result<Self, Self::Err> {
1024        Ok(UserContent::Text { text: s.to_owned() })
1025    }
1026}
1027
1028#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1029pub struct ImageUrl {
1030    pub url: String,
1031    #[serde(default)]
1032    pub detail: ImageDetail,
1033}
1034
1035// =================================================================
1036// Tests
1037// =================================================================
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042    use serde_json::json;
1043
1044    // Test deserialization and conversion for the /api/chat endpoint.
1045    #[tokio::test]
1046    async fn test_chat_completion() {
1047        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1048        let sample_chat_response = json!({
1049            "model": "llama3.2",
1050            "created_at": "2023-08-04T19:22:45.499127Z",
1051            "message": {
1052                "role": "assistant",
1053                "content": "The sky is blue because of Rayleigh scattering.",
1054                "images": null,
1055                "tool_calls": [
1056                    {
1057                        "type": "function",
1058                        "function": {
1059                            "name": "get_current_weather",
1060                            "arguments": {
1061                                "location": "San Francisco, CA",
1062                                "format": "celsius"
1063                            }
1064                        }
1065                    }
1066                ]
1067            },
1068            "done": true,
1069            "total_duration": 8000000000u64,
1070            "load_duration": 6000000u64,
1071            "prompt_eval_count": 61u64,
1072            "prompt_eval_duration": 400000000u64,
1073            "eval_count": 468u64,
1074            "eval_duration": 7700000000u64
1075        });
1076        let sample_text = sample_chat_response.to_string();
1077
1078        let chat_resp: CompletionResponse =
1079            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1080        let conv: completion::CompletionResponse<CompletionResponse> =
1081            chat_resp.try_into().unwrap();
1082        assert!(
1083            !conv.choice.is_empty(),
1084            "Expected non-empty choice in chat response"
1085        );
1086    }
1087
1088    // Test conversion from provider Message to completion Message.
1089    #[test]
1090    fn test_message_conversion() {
1091        // Construct a provider Message (User variant with String content).
1092        let provider_msg = Message::User {
1093            content: "Test message".to_owned(),
1094            images: None,
1095            name: None,
1096        };
1097        // Convert it into a completion::Message.
1098        let comp_msg: crate::completion::Message = provider_msg.into();
1099        match comp_msg {
1100            crate::completion::Message::User { content } => {
1101                // Assume OneOrMany<T> has a method first() to access the first element.
1102                let first_content = content.first();
1103                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1104                match first_content {
1105                    crate::completion::message::UserContent::Text(text_struct) => {
1106                        assert_eq!(text_struct.text, "Test message");
1107                    }
1108                    _ => panic!("Expected text content in conversion"),
1109                }
1110            }
1111            _ => panic!("Conversion from provider Message to completion Message failed"),
1112        }
1113    }
1114
1115    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1116    #[test]
1117    fn test_tool_definition_conversion() {
1118        // Internal tool definition from the completion module.
1119        let internal_tool = crate::completion::ToolDefinition {
1120            name: "get_current_weather".to_owned(),
1121            description: "Get the current weather for a location".to_owned(),
1122            parameters: json!({
1123                "type": "object",
1124                "properties": {
1125                    "location": {
1126                        "type": "string",
1127                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1128                    },
1129                    "format": {
1130                        "type": "string",
1131                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1132                        "enum": ["celsius", "fahrenheit"]
1133                    }
1134                },
1135                "required": ["location", "format"]
1136            }),
1137        };
1138        // Convert internal tool to Ollama's tool definition.
1139        let ollama_tool: ToolDefinition = internal_tool.into();
1140        assert_eq!(ollama_tool.type_field, "function");
1141        assert_eq!(ollama_tool.function.name, "get_current_weather");
1142        assert_eq!(
1143            ollama_tool.function.description,
1144            "Get the current weather for a location"
1145        );
1146        // Check JSON fields in parameters.
1147        let params = &ollama_tool.function.parameters;
1148        assert_eq!(params["properties"]["location"]["type"], "string");
1149    }
1150
1151    // Test deserialization of chat response with thinking content
1152    #[tokio::test]
1153    async fn test_chat_completion_with_thinking() {
1154        let sample_response = json!({
1155            "model": "qwen-thinking",
1156            "created_at": "2023-08-04T19:22:45.499127Z",
1157            "message": {
1158                "role": "assistant",
1159                "content": "The answer is 42.",
1160                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1161                "images": null,
1162                "tool_calls": []
1163            },
1164            "done": true,
1165            "total_duration": 8000000000u64,
1166            "load_duration": 6000000u64,
1167            "prompt_eval_count": 61u64,
1168            "prompt_eval_duration": 400000000u64,
1169            "eval_count": 468u64,
1170            "eval_duration": 7700000000u64
1171        });
1172
1173        let chat_resp: CompletionResponse =
1174            serde_json::from_value(sample_response).expect("Failed to deserialize");
1175
1176        // Verify thinking field is present
1177        if let Message::Assistant {
1178            thinking, content, ..
1179        } = &chat_resp.message
1180        {
1181            assert_eq!(
1182                thinking.as_ref().unwrap(),
1183                "Let me think about this carefully. The question asks for the meaning of life..."
1184            );
1185            assert_eq!(content, "The answer is 42.");
1186        } else {
1187            panic!("Expected Assistant message");
1188        }
1189    }
1190
1191    // Test deserialization of chat response without thinking content
1192    #[tokio::test]
1193    async fn test_chat_completion_without_thinking() {
1194        let sample_response = json!({
1195            "model": "llama3.2",
1196            "created_at": "2023-08-04T19:22:45.499127Z",
1197            "message": {
1198                "role": "assistant",
1199                "content": "Hello!",
1200                "images": null,
1201                "tool_calls": []
1202            },
1203            "done": true,
1204            "total_duration": 8000000000u64,
1205            "load_duration": 6000000u64,
1206            "prompt_eval_count": 10u64,
1207            "prompt_eval_duration": 400000000u64,
1208            "eval_count": 5u64,
1209            "eval_duration": 7700000000u64
1210        });
1211
1212        let chat_resp: CompletionResponse =
1213            serde_json::from_value(sample_response).expect("Failed to deserialize");
1214
1215        // Verify thinking field is None when not provided
1216        if let Message::Assistant {
1217            thinking, content, ..
1218        } = &chat_resp.message
1219        {
1220            assert!(thinking.is_none());
1221            assert_eq!(content, "Hello!");
1222        } else {
1223            panic!("Expected Assistant message");
1224        }
1225    }
1226
1227    // Test deserialization of streaming response with thinking content
1228    #[test]
1229    fn test_streaming_response_with_thinking() {
1230        let sample_chunk = json!({
1231            "model": "qwen-thinking",
1232            "created_at": "2023-08-04T19:22:45.499127Z",
1233            "message": {
1234                "role": "assistant",
1235                "content": "",
1236                "thinking": "Analyzing the problem...",
1237                "images": null,
1238                "tool_calls": []
1239            },
1240            "done": false
1241        });
1242
1243        let chunk: CompletionResponse =
1244            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1245
1246        if let Message::Assistant {
1247            thinking, content, ..
1248        } = &chunk.message
1249        {
1250            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1251            assert_eq!(content, "");
1252        } else {
1253            panic!("Expected Assistant message");
1254        }
1255    }
1256
1257    // Test message conversion with thinking content
1258    #[test]
1259    fn test_message_conversion_with_thinking() {
1260        // Create an internal message with reasoning content
1261        let reasoning_content = crate::message::Reasoning {
1262            id: None,
1263            reasoning: vec!["Step 1: Consider the problem".to_string()],
1264            signature: None,
1265        };
1266
1267        let internal_msg = crate::message::Message::Assistant {
1268            id: None,
1269            content: crate::OneOrMany::many(vec![
1270                crate::message::AssistantContent::Reasoning(reasoning_content),
1271                crate::message::AssistantContent::Text(crate::message::Text {
1272                    text: "The answer is X".to_string(),
1273                }),
1274            ])
1275            .unwrap(),
1276        };
1277
1278        // Convert to provider Message
1279        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1280        assert_eq!(provider_msgs.len(), 1);
1281
1282        if let Message::Assistant {
1283            thinking, content, ..
1284        } = &provider_msgs[0]
1285        {
1286            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1287            assert_eq!(content, "The answer is X");
1288        } else {
1289            panic!("Expected Assistant message with thinking");
1290        }
1291    }
1292
1293    // Test empty thinking content is handled correctly
1294    #[test]
1295    fn test_empty_thinking_content() {
1296        let sample_response = json!({
1297            "model": "llama3.2",
1298            "created_at": "2023-08-04T19:22:45.499127Z",
1299            "message": {
1300                "role": "assistant",
1301                "content": "Response",
1302                "thinking": "",
1303                "images": null,
1304                "tool_calls": []
1305            },
1306            "done": true,
1307            "total_duration": 8000000000u64,
1308            "load_duration": 6000000u64,
1309            "prompt_eval_count": 10u64,
1310            "prompt_eval_duration": 400000000u64,
1311            "eval_count": 5u64,
1312            "eval_duration": 7700000000u64
1313        });
1314
1315        let chat_resp: CompletionResponse =
1316            serde_json::from_value(sample_response).expect("Failed to deserialize");
1317
1318        if let Message::Assistant {
1319            thinking, content, ..
1320        } = &chat_resp.message
1321        {
1322            // Empty string should still deserialize as Some("")
1323            assert_eq!(thinking.as_ref().unwrap(), "");
1324            assert_eq!(content, "Response");
1325        } else {
1326            panic!("Expected Assistant message");
1327        }
1328    }
1329
1330    // Test thinking with tool calls
1331    #[test]
1332    fn test_thinking_with_tool_calls() {
1333        let sample_response = json!({
1334            "model": "qwen-thinking",
1335            "created_at": "2023-08-04T19:22:45.499127Z",
1336            "message": {
1337                "role": "assistant",
1338                "content": "Let me check the weather.",
1339                "thinking": "User wants weather info, I should use the weather tool",
1340                "images": null,
1341                "tool_calls": [
1342                    {
1343                        "type": "function",
1344                        "function": {
1345                            "name": "get_weather",
1346                            "arguments": {
1347                                "location": "San Francisco"
1348                            }
1349                        }
1350                    }
1351                ]
1352            },
1353            "done": true,
1354            "total_duration": 8000000000u64,
1355            "load_duration": 6000000u64,
1356            "prompt_eval_count": 30u64,
1357            "prompt_eval_duration": 400000000u64,
1358            "eval_count": 50u64,
1359            "eval_duration": 7700000000u64
1360        });
1361
1362        let chat_resp: CompletionResponse =
1363            serde_json::from_value(sample_response).expect("Failed to deserialize");
1364
1365        if let Message::Assistant {
1366            thinking,
1367            content,
1368            tool_calls,
1369            ..
1370        } = &chat_resp.message
1371        {
1372            assert_eq!(
1373                thinking.as_ref().unwrap(),
1374                "User wants weather info, I should use the weather tool"
1375            );
1376            assert_eq!(content, "Let me check the weather.");
1377            assert_eq!(tool_calls.len(), 1);
1378            assert_eq!(tool_calls[0].function.name, "get_weather");
1379        } else {
1380            panic!("Expected Assistant message with thinking and tool calls");
1381        }
1382    }
1383}