rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49    OneOrMany,
50    completion::{self, CompletionError, CompletionRequest},
51    embeddings::{self, EmbeddingError},
52    json_utils, message,
53    message::{ImageDetail, Text},
54    streaming,
55};
56use async_stream::try_stream;
57use bytes::Bytes;
58use futures::StreamExt;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use tracing::info_span;
63use tracing_futures::Instrument;
64// ---------- Main Client ----------
65
66const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68#[derive(Debug, Default, Clone, Copy)]
69pub struct OllamaExt;
70
71#[derive(Debug, Default, Clone, Copy)]
72pub struct OllamaBuilder;
73
74impl Provider for OllamaExt {
75    type Builder = OllamaBuilder;
76
77    const VERIFY_PATH: &'static str = "api/tags";
78
79    fn build<H>(
80        _: &crate::client::ClientBuilder<
81            Self::Builder,
82            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
83            H,
84        >,
85    ) -> http_client::Result<Self> {
86        Ok(Self)
87    }
88}
89
90impl<H> Capabilities<H> for OllamaExt {
91    type Completion = Capable<CompletionModel<H>>;
92    type Transcription = Nothing;
93    type Embeddings = Capable<EmbeddingModel<H>>;
94    #[cfg(feature = "image")]
95    type ImageGeneration = Nothing;
96
97    #[cfg(feature = "audio")]
98    type AudioGeneration = Nothing;
99}
100
101impl DebugExt for OllamaExt {}
102
103impl ProviderBuilder for OllamaBuilder {
104    type Output = OllamaExt;
105    type ApiKey = Nothing;
106
107    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
108}
109
110pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
111pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
112
113impl ProviderClient for Client {
114    type Input = Nothing;
115
116    fn from_env() -> Self {
117        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
118
119        Self::builder()
120            .api_key(Nothing)
121            .base_url(&api_base)
122            .build()
123            .unwrap()
124    }
125
126    fn from_val(_: Self::Input) -> Self {
127        Self::builder().api_key(Nothing).build().unwrap()
128    }
129}
130
131// ---------- API Error and Response Structures ----------
132
133#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135    message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141    Ok(T),
142    Err(ApiErrorResponse),
143}
144
145// ---------- Embedding API ----------
146
147pub const ALL_MINILM: &str = "all-minilm";
148pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct EmbeddingResponse {
152    pub model: String,
153    pub embeddings: Vec<Vec<f64>>,
154    #[serde(default)]
155    pub total_duration: Option<u64>,
156    #[serde(default)]
157    pub load_duration: Option<u64>,
158    #[serde(default)]
159    pub prompt_eval_count: Option<u64>,
160}
161
162impl From<ApiErrorResponse> for EmbeddingError {
163    fn from(err: ApiErrorResponse) -> Self {
164        EmbeddingError::ProviderError(err.message)
165    }
166}
167
168impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
169    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
170        match value {
171            ApiResponse::Ok(response) => Ok(response),
172            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
173        }
174    }
175}
176
177// ---------- Embedding Model ----------
178
179#[derive(Clone)]
180pub struct EmbeddingModel<T> {
181    client: Client<T>,
182    pub model: String,
183    ndims: usize,
184}
185
186impl<T> EmbeddingModel<T> {
187    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
188        Self {
189            client,
190            model: model.into(),
191            ndims,
192        }
193    }
194
195    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
196        Self {
197            client,
198            model: model.into(),
199            ndims,
200        }
201    }
202}
203
204impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
205where
206    T: HttpClientExt + Clone + 'static,
207{
208    type Client = Client<T>;
209
210    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
211        Self::new(client.clone(), model, dims.unwrap())
212    }
213
214    const MAX_DOCUMENTS: usize = 1024;
215    fn ndims(&self) -> usize {
216        self.ndims
217    }
218    #[cfg_attr(feature = "worker", worker::send)]
219    async fn embed_texts(
220        &self,
221        documents: impl IntoIterator<Item = String>,
222    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223        let docs: Vec<String> = documents.into_iter().collect();
224
225        let body = serde_json::to_vec(&json!({
226            "model": self.model,
227            "input": docs
228        }))?;
229
230        let req = self
231            .client
232            .post("api/embed")?
233            .body(body)
234            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
235
236        let response = self.client.send(req).await?;
237
238        if !response.status().is_success() {
239            let text = http_client::text(response).await?;
240            return Err(EmbeddingError::ProviderError(text));
241        }
242
243        let bytes: Vec<u8> = response.into_body().await?;
244
245        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
246
247        if api_resp.embeddings.len() != docs.len() {
248            return Err(EmbeddingError::ResponseError(
249                "Number of returned embeddings does not match input".into(),
250            ));
251        }
252        Ok(api_resp
253            .embeddings
254            .into_iter()
255            .zip(docs.into_iter())
256            .map(|(vec, document)| embeddings::Embedding { document, vec })
257            .collect())
258    }
259}
260
261// ---------- Completion API ----------
262
263pub const LLAMA3_2: &str = "llama3.2";
264pub const LLAVA: &str = "llava";
265pub const MISTRAL: &str = "mistral";
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct CompletionResponse {
269    pub model: String,
270    pub created_at: String,
271    pub message: Message,
272    pub done: bool,
273    #[serde(default)]
274    pub done_reason: Option<String>,
275    #[serde(default)]
276    pub total_duration: Option<u64>,
277    #[serde(default)]
278    pub load_duration: Option<u64>,
279    #[serde(default)]
280    pub prompt_eval_count: Option<u64>,
281    #[serde(default)]
282    pub prompt_eval_duration: Option<u64>,
283    #[serde(default)]
284    pub eval_count: Option<u64>,
285    #[serde(default)]
286    pub eval_duration: Option<u64>,
287}
288impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
289    type Error = CompletionError;
290    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
291        match resp.message {
292            // Process only if an assistant message is present.
293            Message::Assistant {
294                content,
295                thinking,
296                tool_calls,
297                ..
298            } => {
299                let mut assistant_contents = Vec::new();
300                // Add the assistant's text content if any.
301                if !content.is_empty() {
302                    assistant_contents.push(completion::AssistantContent::text(&content));
303                }
304                // Process tool_calls following Ollama's chat response definition.
305                // Each ToolCall has an id, a type, and a function field.
306                for tc in tool_calls.iter() {
307                    assistant_contents.push(completion::AssistantContent::tool_call(
308                        tc.function.name.clone(),
309                        tc.function.name.clone(),
310                        tc.function.arguments.clone(),
311                    ));
312                }
313                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
314                    CompletionError::ResponseError("No content provided".to_owned())
315                })?;
316                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
317                let completion_tokens = resp.eval_count.unwrap_or(0);
318
319                let raw_response = CompletionResponse {
320                    model: resp.model,
321                    created_at: resp.created_at,
322                    done: resp.done,
323                    done_reason: resp.done_reason,
324                    total_duration: resp.total_duration,
325                    load_duration: resp.load_duration,
326                    prompt_eval_count: resp.prompt_eval_count,
327                    prompt_eval_duration: resp.prompt_eval_duration,
328                    eval_count: resp.eval_count,
329                    eval_duration: resp.eval_duration,
330                    message: Message::Assistant {
331                        content,
332                        thinking,
333                        images: None,
334                        name: None,
335                        tool_calls,
336                    },
337                };
338
339                Ok(completion::CompletionResponse {
340                    choice,
341                    usage: Usage {
342                        input_tokens: prompt_tokens,
343                        output_tokens: completion_tokens,
344                        total_tokens: prompt_tokens + completion_tokens,
345                    },
346                    raw_response,
347                })
348            }
349            _ => Err(CompletionError::ResponseError(
350                "Chat response does not include an assistant message".into(),
351            )),
352        }
353    }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357pub(super) struct OllamaCompletionRequest {
358    model: String,
359    pub messages: Vec<Message>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    temperature: Option<f64>,
362    #[serde(skip_serializing_if = "Vec::is_empty")]
363    tools: Vec<ToolDefinition>,
364    pub stream: bool,
365    think: bool,
366    #[serde(skip_serializing_if = "Option::is_none")]
367    max_tokens: Option<u64>,
368    options: serde_json::Value,
369}
370
371impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
372    type Error = CompletionError;
373
374    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
375        if req.tool_choice.is_some() {
376            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
377        }
378        // Build up the order of messages (context, chat_history, prompt)
379        let mut partial_history = vec![];
380        if let Some(docs) = req.normalized_documents() {
381            partial_history.push(docs);
382        }
383        partial_history.extend(req.chat_history);
384
385        // Add preamble to chat history (if available)
386        let mut full_history: Vec<Message> = match &req.preamble {
387            Some(preamble) => vec![Message::system(preamble)],
388            None => vec![],
389        };
390
391        // Convert and extend the rest of the history
392        full_history.extend(
393            partial_history
394                .into_iter()
395                .map(message::Message::try_into)
396                .collect::<Result<Vec<Vec<Message>>, _>>()?
397                .into_iter()
398                .flatten()
399                .collect::<Vec<_>>(),
400        );
401
402        let mut think = false;
403
404        // TODO: Fix this up to include the full range of ollama options
405        let options = if let Some(mut extra) = req.additional_params {
406            if extra.get("think").is_some() {
407                think = extra["think"].take().as_bool().ok_or_else(|| {
408                    CompletionError::RequestError("`think` must be a bool".into())
409                })?;
410            }
411            json_utils::merge(json!({ "temperature": req.temperature }), extra)
412        } else {
413            json!({ "temperature": req.temperature })
414        };
415
416        Ok(Self {
417            model: model.to_string(),
418            messages: full_history,
419            temperature: req.temperature,
420            max_tokens: req.max_tokens,
421            stream: false,
422            think,
423            tools: req
424                .tools
425                .clone()
426                .into_iter()
427                .map(ToolDefinition::from)
428                .collect::<Vec<_>>(),
429            options,
430        })
431    }
432}
433
434#[derive(Clone)]
435pub struct CompletionModel<T = reqwest::Client> {
436    client: Client<T>,
437    pub model: String,
438}
439
440impl<T> CompletionModel<T> {
441    pub fn new(client: Client<T>, model: &str) -> Self {
442        Self {
443            client,
444            model: model.to_owned(),
445        }
446    }
447}
448
449// ---------- CompletionModel Implementation ----------
450
451#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct StreamingCompletionResponse {
453    pub done_reason: Option<String>,
454    pub total_duration: Option<u64>,
455    pub load_duration: Option<u64>,
456    pub prompt_eval_count: Option<u64>,
457    pub prompt_eval_duration: Option<u64>,
458    pub eval_count: Option<u64>,
459    pub eval_duration: Option<u64>,
460}
461
462impl GetTokenUsage for StreamingCompletionResponse {
463    fn token_usage(&self) -> Option<crate::completion::Usage> {
464        let mut usage = crate::completion::Usage::new();
465        let input_tokens = self.prompt_eval_count.unwrap_or_default();
466        let output_tokens = self.eval_count.unwrap_or_default();
467        usage.input_tokens = input_tokens;
468        usage.output_tokens = output_tokens;
469        usage.total_tokens = input_tokens + output_tokens;
470
471        Some(usage)
472    }
473}
474
475impl<T> completion::CompletionModel for CompletionModel<T>
476where
477    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
478{
479    type Response = CompletionResponse;
480    type StreamingResponse = StreamingCompletionResponse;
481
482    type Client = Client<T>;
483
484    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
485        Self::new(client.clone(), model.into().as_str())
486    }
487
488    #[cfg_attr(feature = "worker", worker::send)]
489    async fn completion(
490        &self,
491        completion_request: CompletionRequest,
492    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
493        let span = if tracing::Span::current().is_disabled() {
494            info_span!(
495                target: "rig::completions",
496                "chat",
497                gen_ai.operation.name = "chat",
498                gen_ai.provider.name = "ollama",
499                gen_ai.request.model = self.model,
500                gen_ai.system_instructions = tracing::field::Empty,
501                gen_ai.response.id = tracing::field::Empty,
502                gen_ai.response.model = tracing::field::Empty,
503                gen_ai.usage.output_tokens = tracing::field::Empty,
504                gen_ai.usage.input_tokens = tracing::field::Empty,
505            )
506        } else {
507            tracing::Span::current()
508        };
509
510        span.record("gen_ai.system_instructions", &completion_request.preamble);
511        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
512
513        if tracing::enabled!(tracing::Level::TRACE) {
514            tracing::trace!(target: "rig::completions",
515                "Ollama completion request: {}",
516                serde_json::to_string_pretty(&request)?
517            );
518        }
519
520        let body = serde_json::to_vec(&request)?;
521
522        let req = self
523            .client
524            .post("api/chat")?
525            .body(body)
526            .map_err(http_client::Error::from)?;
527
528        let async_block = async move {
529            let response = self.client.send::<_, Bytes>(req).await?;
530            let status = response.status();
531            let response_body = response.into_body().into_future().await?.to_vec();
532
533            if !status.is_success() {
534                return Err(CompletionError::ProviderError(
535                    String::from_utf8_lossy(&response_body).to_string(),
536                ));
537            }
538
539            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
540            let span = tracing::Span::current();
541            span.record("gen_ai.response.model_name", &response.model);
542            span.record(
543                "gen_ai.usage.input_tokens",
544                response.prompt_eval_count.unwrap_or_default(),
545            );
546            span.record(
547                "gen_ai.usage.output_tokens",
548                response.eval_count.unwrap_or_default(),
549            );
550
551            if tracing::enabled!(tracing::Level::TRACE) {
552                tracing::trace!(target: "rig::completions",
553                    "Ollama completion response: {}",
554                    serde_json::to_string_pretty(&response)?
555                );
556            }
557
558            let response: completion::CompletionResponse<CompletionResponse> =
559                response.try_into()?;
560
561            Ok(response)
562        };
563
564        tracing::Instrument::instrument(async_block, span).await
565    }
566
567    #[cfg_attr(feature = "worker", worker::send)]
568    async fn stream(
569        &self,
570        request: CompletionRequest,
571    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
572    {
573        let span = if tracing::Span::current().is_disabled() {
574            info_span!(
575                target: "rig::completions",
576                "chat_streaming",
577                gen_ai.operation.name = "chat_streaming",
578                gen_ai.provider.name = "ollama",
579                gen_ai.request.model = self.model,
580                gen_ai.system_instructions = tracing::field::Empty,
581                gen_ai.response.id = tracing::field::Empty,
582                gen_ai.response.model = self.model,
583                gen_ai.usage.output_tokens = tracing::field::Empty,
584                gen_ai.usage.input_tokens = tracing::field::Empty,
585            )
586        } else {
587            tracing::Span::current()
588        };
589
590        span.record("gen_ai.system_instructions", &request.preamble);
591
592        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
593        request.stream = true;
594
595        if tracing::enabled!(tracing::Level::TRACE) {
596            tracing::trace!(target: "rig::completions",
597                "Ollama streaming completion request: {}",
598                serde_json::to_string_pretty(&request)?
599            );
600        }
601
602        let body = serde_json::to_vec(&request)?;
603
604        let req = self
605            .client
606            .post("api/chat")?
607            .body(body)
608            .map_err(http_client::Error::from)?;
609
610        let response = self.client.send_streaming(req).await?;
611        let status = response.status();
612        let mut byte_stream = response.into_body();
613
614        if !status.is_success() {
615            return Err(CompletionError::ProviderError(format!(
616                "Got error status code trying to send a request to Ollama: {status}"
617            )));
618        }
619
620        let stream = try_stream! {
621            let span = tracing::Span::current();
622            let mut tool_calls_final = Vec::new();
623            let mut text_response = String::new();
624            let mut thinking_response = String::new();
625
626            while let Some(chunk) = byte_stream.next().await {
627                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
628
629                for line in bytes.split(|&b| b == b'\n') {
630                    if line.is_empty() {
631                        continue;
632                    }
633
634                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
635
636                    let response: CompletionResponse = serde_json::from_slice(line)?;
637
638                    if response.done {
639                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
640                        span.record("gen_ai.usage.output_tokens", response.eval_count);
641                        let message = Message::Assistant {
642                            content: text_response.clone(),
643                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
644                            images: None,
645                            name: None,
646                            tool_calls: tool_calls_final.clone()
647                        };
648                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
649                        yield RawStreamingChoice::FinalResponse(
650                            StreamingCompletionResponse {
651                                total_duration: response.total_duration,
652                                load_duration: response.load_duration,
653                                prompt_eval_count: response.prompt_eval_count,
654                                prompt_eval_duration: response.prompt_eval_duration,
655                                eval_count: response.eval_count,
656                                eval_duration: response.eval_duration,
657                                done_reason: response.done_reason,
658                            }
659                        );
660                        break;
661                    }
662
663                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
664                        if let Some(thinking_content) = thinking
665                            && !thinking_content.is_empty() {
666                            thinking_response += &thinking_content;
667                            yield RawStreamingChoice::Reasoning {
668                                reasoning: thinking_content,
669                                id: None,
670                                signature: None,
671                            };
672                        }
673
674                        if !content.is_empty() {
675                            text_response += &content;
676                            yield RawStreamingChoice::Message(content);
677                        }
678
679                        for tool_call in tool_calls {
680                            tool_calls_final.push(tool_call.clone());
681                            yield RawStreamingChoice::ToolCall {
682                                id: String::new(),
683                                name: tool_call.function.name,
684                                arguments: tool_call.function.arguments,
685                                call_id: None,
686                            };
687                        }
688                    }
689                }
690            }
691        }.instrument(span);
692
693        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
694            stream,
695        )))
696    }
697}
698
699// ---------- Tool Definition Conversion ----------
700
701/// Ollama-required tool definition format.
702#[derive(Clone, Debug, Deserialize, Serialize)]
703pub struct ToolDefinition {
704    #[serde(rename = "type")]
705    pub type_field: String, // Fixed as "function"
706    pub function: completion::ToolDefinition,
707}
708
709/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
710impl From<crate::completion::ToolDefinition> for ToolDefinition {
711    fn from(tool: crate::completion::ToolDefinition) -> Self {
712        ToolDefinition {
713            type_field: "function".to_owned(),
714            function: completion::ToolDefinition {
715                name: tool.name,
716                description: tool.description,
717                parameters: tool.parameters,
718            },
719        }
720    }
721}
722
723#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
724pub struct ToolCall {
725    #[serde(default, rename = "type")]
726    pub r#type: ToolType,
727    pub function: Function,
728}
729#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
730#[serde(rename_all = "lowercase")]
731pub enum ToolType {
732    #[default]
733    Function,
734}
735#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
736pub struct Function {
737    pub name: String,
738    pub arguments: Value,
739}
740
741// ---------- Provider Message Definition ----------
742
743#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
744#[serde(tag = "role", rename_all = "lowercase")]
745pub enum Message {
746    User {
747        content: String,
748        #[serde(skip_serializing_if = "Option::is_none")]
749        images: Option<Vec<String>>,
750        #[serde(skip_serializing_if = "Option::is_none")]
751        name: Option<String>,
752    },
753    Assistant {
754        #[serde(default)]
755        content: String,
756        #[serde(skip_serializing_if = "Option::is_none")]
757        thinking: Option<String>,
758        #[serde(skip_serializing_if = "Option::is_none")]
759        images: Option<Vec<String>>,
760        #[serde(skip_serializing_if = "Option::is_none")]
761        name: Option<String>,
762        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
763        tool_calls: Vec<ToolCall>,
764    },
765    System {
766        content: String,
767        #[serde(skip_serializing_if = "Option::is_none")]
768        images: Option<Vec<String>>,
769        #[serde(skip_serializing_if = "Option::is_none")]
770        name: Option<String>,
771    },
772    #[serde(rename = "tool")]
773    ToolResult {
774        #[serde(rename = "tool_name")]
775        name: String,
776        content: String,
777    },
778}
779
780/// -----------------------------
781/// Provider Message Conversions
782/// -----------------------------
783/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
784/// (Only User and Assistant variants are supported.)
785impl TryFrom<crate::message::Message> for Vec<Message> {
786    type Error = crate::message::MessageError;
787    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
788        use crate::message::Message as InternalMessage;
789        match internal_msg {
790            InternalMessage::User { content, .. } => {
791                let (tool_results, other_content): (Vec<_>, Vec<_>) =
792                    content.into_iter().partition(|content| {
793                        matches!(content, crate::message::UserContent::ToolResult(_))
794                    });
795
796                if !tool_results.is_empty() {
797                    tool_results
798                        .into_iter()
799                        .map(|content| match content {
800                            crate::message::UserContent::ToolResult(
801                                crate::message::ToolResult { id, content, .. },
802                            ) => {
803                                // Ollama expects a single string for tool results, so we concatenate
804                                let content_string = content
805                                    .into_iter()
806                                    .map(|content| match content {
807                                        crate::message::ToolResultContent::Text(text) => text.text,
808                                        _ => "[Non-text content]".to_string(),
809                                    })
810                                    .collect::<Vec<_>>()
811                                    .join("\n");
812
813                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
814                                    name: id,
815                                    content: content_string,
816                                })
817                            }
818                            _ => unreachable!(),
819                        })
820                        .collect::<Result<Vec<_>, _>>()
821                } else {
822                    // Ollama requires separate text content and images array
823                    let (texts, images) = other_content.into_iter().fold(
824                        (Vec::new(), Vec::new()),
825                        |(mut texts, mut images), content| {
826                            match content {
827                                crate::message::UserContent::Text(crate::message::Text {
828                                    text,
829                                }) => texts.push(text),
830                                crate::message::UserContent::Image(crate::message::Image {
831                                    data: DocumentSourceKind::Base64(data),
832                                    ..
833                                }) => images.push(data),
834                                crate::message::UserContent::Document(
835                                    crate::message::Document {
836                                        data:
837                                            DocumentSourceKind::Base64(data)
838                                            | DocumentSourceKind::String(data),
839                                        ..
840                                    },
841                                ) => texts.push(data),
842                                _ => {} // Audio not supported by Ollama
843                            }
844                            (texts, images)
845                        },
846                    );
847
848                    Ok(vec![Message::User {
849                        content: texts.join(" "),
850                        images: if images.is_empty() {
851                            None
852                        } else {
853                            Some(
854                                images
855                                    .into_iter()
856                                    .map(|x| x.to_string())
857                                    .collect::<Vec<String>>(),
858                            )
859                        },
860                        name: None,
861                    }])
862                }
863            }
864            InternalMessage::Assistant { content, .. } => {
865                let mut thinking: Option<String> = None;
866                let mut text_content = Vec::new();
867                let mut tool_calls = Vec::new();
868
869                for content in content.into_iter() {
870                    match content {
871                        crate::message::AssistantContent::Text(text) => {
872                            text_content.push(text.text)
873                        }
874                        crate::message::AssistantContent::ToolCall(tool_call) => {
875                            tool_calls.push(tool_call)
876                        }
877                        crate::message::AssistantContent::Reasoning(
878                            crate::message::Reasoning { reasoning, .. },
879                        ) => {
880                            thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
881                        }
882                        crate::message::AssistantContent::Image(_) => {
883                            return Err(crate::message::MessageError::ConversionError(
884                                "Ollama currently doesn't support images.".into(),
885                            ));
886                        }
887                    }
888                }
889
890                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
891                //  so either `content` or `tool_calls` will have some content.
892                Ok(vec![Message::Assistant {
893                    content: text_content.join(" "),
894                    thinking,
895                    images: None,
896                    name: None,
897                    tool_calls: tool_calls
898                        .into_iter()
899                        .map(|tool_call| tool_call.into())
900                        .collect::<Vec<_>>(),
901                }])
902            }
903        }
904    }
905}
906
907/// Conversion from provider Message to a completion message.
908/// This is needed so that responses can be converted back into chat history.
909impl From<Message> for crate::completion::Message {
910    fn from(msg: Message) -> Self {
911        match msg {
912            Message::User { content, .. } => crate::completion::Message::User {
913                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
914                    text: content,
915                })),
916            },
917            Message::Assistant {
918                content,
919                tool_calls,
920                ..
921            } => {
922                let mut assistant_contents =
923                    vec![crate::completion::message::AssistantContent::Text(Text {
924                        text: content,
925                    })];
926                for tc in tool_calls {
927                    assistant_contents.push(
928                        crate::completion::message::AssistantContent::tool_call(
929                            tc.function.name.clone(),
930                            tc.function.name,
931                            tc.function.arguments,
932                        ),
933                    );
934                }
935                crate::completion::Message::Assistant {
936                    id: None,
937                    content: OneOrMany::many(assistant_contents).unwrap(),
938                }
939            }
940            // System and ToolResult are converted to User message as needed.
941            Message::System { content, .. } => crate::completion::Message::User {
942                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
943                    text: content,
944                })),
945            },
946            Message::ToolResult { name, content } => crate::completion::Message::User {
947                content: OneOrMany::one(message::UserContent::tool_result(
948                    name,
949                    OneOrMany::one(message::ToolResultContent::text(content)),
950                )),
951            },
952        }
953    }
954}
955
956impl Message {
957    /// Constructs a system message.
958    pub fn system(content: &str) -> Self {
959        Message::System {
960            content: content.to_owned(),
961            images: None,
962            name: None,
963        }
964    }
965}
966
967// ---------- Additional Message Types ----------
968
969impl From<crate::message::ToolCall> for ToolCall {
970    fn from(tool_call: crate::message::ToolCall) -> Self {
971        Self {
972            r#type: ToolType::Function,
973            function: Function {
974                name: tool_call.function.name,
975                arguments: tool_call.function.arguments,
976            },
977        }
978    }
979}
980
981#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
982pub struct SystemContent {
983    #[serde(default)]
984    r#type: SystemContentType,
985    text: String,
986}
987
988#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
989#[serde(rename_all = "lowercase")]
990pub enum SystemContentType {
991    #[default]
992    Text,
993}
994
995impl From<String> for SystemContent {
996    fn from(s: String) -> Self {
997        SystemContent {
998            r#type: SystemContentType::default(),
999            text: s,
1000        }
1001    }
1002}
1003
1004impl FromStr for SystemContent {
1005    type Err = std::convert::Infallible;
1006    fn from_str(s: &str) -> Result<Self, Self::Err> {
1007        Ok(SystemContent {
1008            r#type: SystemContentType::default(),
1009            text: s.to_string(),
1010        })
1011    }
1012}
1013
1014#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1015pub struct AssistantContent {
1016    pub text: String,
1017}
1018
1019impl FromStr for AssistantContent {
1020    type Err = std::convert::Infallible;
1021    fn from_str(s: &str) -> Result<Self, Self::Err> {
1022        Ok(AssistantContent { text: s.to_owned() })
1023    }
1024}
1025
1026#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1027#[serde(tag = "type", rename_all = "lowercase")]
1028pub enum UserContent {
1029    Text { text: String },
1030    Image { image_url: ImageUrl },
1031    // Audio variant removed as Ollama API does not support audio input.
1032}
1033
1034impl FromStr for UserContent {
1035    type Err = std::convert::Infallible;
1036    fn from_str(s: &str) -> Result<Self, Self::Err> {
1037        Ok(UserContent::Text { text: s.to_owned() })
1038    }
1039}
1040
1041#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1042pub struct ImageUrl {
1043    pub url: String,
1044    #[serde(default)]
1045    pub detail: ImageDetail,
1046}
1047
1048// =================================================================
1049// Tests
1050// =================================================================
1051
1052#[cfg(test)]
1053mod tests {
1054    use super::*;
1055    use serde_json::json;
1056
1057    // Test deserialization and conversion for the /api/chat endpoint.
1058    #[tokio::test]
1059    async fn test_chat_completion() {
1060        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1061        let sample_chat_response = json!({
1062            "model": "llama3.2",
1063            "created_at": "2023-08-04T19:22:45.499127Z",
1064            "message": {
1065                "role": "assistant",
1066                "content": "The sky is blue because of Rayleigh scattering.",
1067                "images": null,
1068                "tool_calls": [
1069                    {
1070                        "type": "function",
1071                        "function": {
1072                            "name": "get_current_weather",
1073                            "arguments": {
1074                                "location": "San Francisco, CA",
1075                                "format": "celsius"
1076                            }
1077                        }
1078                    }
1079                ]
1080            },
1081            "done": true,
1082            "total_duration": 8000000000u64,
1083            "load_duration": 6000000u64,
1084            "prompt_eval_count": 61u64,
1085            "prompt_eval_duration": 400000000u64,
1086            "eval_count": 468u64,
1087            "eval_duration": 7700000000u64
1088        });
1089        let sample_text = sample_chat_response.to_string();
1090
1091        let chat_resp: CompletionResponse =
1092            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1093        let conv: completion::CompletionResponse<CompletionResponse> =
1094            chat_resp.try_into().unwrap();
1095        assert!(
1096            !conv.choice.is_empty(),
1097            "Expected non-empty choice in chat response"
1098        );
1099    }
1100
1101    // Test conversion from provider Message to completion Message.
1102    #[test]
1103    fn test_message_conversion() {
1104        // Construct a provider Message (User variant with String content).
1105        let provider_msg = Message::User {
1106            content: "Test message".to_owned(),
1107            images: None,
1108            name: None,
1109        };
1110        // Convert it into a completion::Message.
1111        let comp_msg: crate::completion::Message = provider_msg.into();
1112        match comp_msg {
1113            crate::completion::Message::User { content } => {
1114                // Assume OneOrMany<T> has a method first() to access the first element.
1115                let first_content = content.first();
1116                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1117                match first_content {
1118                    crate::completion::message::UserContent::Text(text_struct) => {
1119                        assert_eq!(text_struct.text, "Test message");
1120                    }
1121                    _ => panic!("Expected text content in conversion"),
1122                }
1123            }
1124            _ => panic!("Conversion from provider Message to completion Message failed"),
1125        }
1126    }
1127
1128    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1129    #[test]
1130    fn test_tool_definition_conversion() {
1131        // Internal tool definition from the completion module.
1132        let internal_tool = crate::completion::ToolDefinition {
1133            name: "get_current_weather".to_owned(),
1134            description: "Get the current weather for a location".to_owned(),
1135            parameters: json!({
1136                "type": "object",
1137                "properties": {
1138                    "location": {
1139                        "type": "string",
1140                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1141                    },
1142                    "format": {
1143                        "type": "string",
1144                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1145                        "enum": ["celsius", "fahrenheit"]
1146                    }
1147                },
1148                "required": ["location", "format"]
1149            }),
1150        };
1151        // Convert internal tool to Ollama's tool definition.
1152        let ollama_tool: ToolDefinition = internal_tool.into();
1153        assert_eq!(ollama_tool.type_field, "function");
1154        assert_eq!(ollama_tool.function.name, "get_current_weather");
1155        assert_eq!(
1156            ollama_tool.function.description,
1157            "Get the current weather for a location"
1158        );
1159        // Check JSON fields in parameters.
1160        let params = &ollama_tool.function.parameters;
1161        assert_eq!(params["properties"]["location"]["type"], "string");
1162    }
1163
1164    // Test deserialization of chat response with thinking content
1165    #[tokio::test]
1166    async fn test_chat_completion_with_thinking() {
1167        let sample_response = json!({
1168            "model": "qwen-thinking",
1169            "created_at": "2023-08-04T19:22:45.499127Z",
1170            "message": {
1171                "role": "assistant",
1172                "content": "The answer is 42.",
1173                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1174                "images": null,
1175                "tool_calls": []
1176            },
1177            "done": true,
1178            "total_duration": 8000000000u64,
1179            "load_duration": 6000000u64,
1180            "prompt_eval_count": 61u64,
1181            "prompt_eval_duration": 400000000u64,
1182            "eval_count": 468u64,
1183            "eval_duration": 7700000000u64
1184        });
1185
1186        let chat_resp: CompletionResponse =
1187            serde_json::from_value(sample_response).expect("Failed to deserialize");
1188
1189        // Verify thinking field is present
1190        if let Message::Assistant {
1191            thinking, content, ..
1192        } = &chat_resp.message
1193        {
1194            assert_eq!(
1195                thinking.as_ref().unwrap(),
1196                "Let me think about this carefully. The question asks for the meaning of life..."
1197            );
1198            assert_eq!(content, "The answer is 42.");
1199        } else {
1200            panic!("Expected Assistant message");
1201        }
1202    }
1203
1204    // Test deserialization of chat response without thinking content
1205    #[tokio::test]
1206    async fn test_chat_completion_without_thinking() {
1207        let sample_response = json!({
1208            "model": "llama3.2",
1209            "created_at": "2023-08-04T19:22:45.499127Z",
1210            "message": {
1211                "role": "assistant",
1212                "content": "Hello!",
1213                "images": null,
1214                "tool_calls": []
1215            },
1216            "done": true,
1217            "total_duration": 8000000000u64,
1218            "load_duration": 6000000u64,
1219            "prompt_eval_count": 10u64,
1220            "prompt_eval_duration": 400000000u64,
1221            "eval_count": 5u64,
1222            "eval_duration": 7700000000u64
1223        });
1224
1225        let chat_resp: CompletionResponse =
1226            serde_json::from_value(sample_response).expect("Failed to deserialize");
1227
1228        // Verify thinking field is None when not provided
1229        if let Message::Assistant {
1230            thinking, content, ..
1231        } = &chat_resp.message
1232        {
1233            assert!(thinking.is_none());
1234            assert_eq!(content, "Hello!");
1235        } else {
1236            panic!("Expected Assistant message");
1237        }
1238    }
1239
1240    // Test deserialization of streaming response with thinking content
1241    #[test]
1242    fn test_streaming_response_with_thinking() {
1243        let sample_chunk = json!({
1244            "model": "qwen-thinking",
1245            "created_at": "2023-08-04T19:22:45.499127Z",
1246            "message": {
1247                "role": "assistant",
1248                "content": "",
1249                "thinking": "Analyzing the problem...",
1250                "images": null,
1251                "tool_calls": []
1252            },
1253            "done": false
1254        });
1255
1256        let chunk: CompletionResponse =
1257            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1258
1259        if let Message::Assistant {
1260            thinking, content, ..
1261        } = &chunk.message
1262        {
1263            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1264            assert_eq!(content, "");
1265        } else {
1266            panic!("Expected Assistant message");
1267        }
1268    }
1269
1270    // Test message conversion with thinking content
1271    #[test]
1272    fn test_message_conversion_with_thinking() {
1273        // Create an internal message with reasoning content
1274        let reasoning_content = crate::message::Reasoning {
1275            id: None,
1276            reasoning: vec!["Step 1: Consider the problem".to_string()],
1277            signature: None,
1278        };
1279
1280        let internal_msg = crate::message::Message::Assistant {
1281            id: None,
1282            content: crate::OneOrMany::many(vec![
1283                crate::message::AssistantContent::Reasoning(reasoning_content),
1284                crate::message::AssistantContent::Text(crate::message::Text {
1285                    text: "The answer is X".to_string(),
1286                }),
1287            ])
1288            .unwrap(),
1289        };
1290
1291        // Convert to provider Message
1292        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1293        assert_eq!(provider_msgs.len(), 1);
1294
1295        if let Message::Assistant {
1296            thinking, content, ..
1297        } = &provider_msgs[0]
1298        {
1299            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1300            assert_eq!(content, "The answer is X");
1301        } else {
1302            panic!("Expected Assistant message with thinking");
1303        }
1304    }
1305
1306    // Test empty thinking content is handled correctly
1307    #[test]
1308    fn test_empty_thinking_content() {
1309        let sample_response = json!({
1310            "model": "llama3.2",
1311            "created_at": "2023-08-04T19:22:45.499127Z",
1312            "message": {
1313                "role": "assistant",
1314                "content": "Response",
1315                "thinking": "",
1316                "images": null,
1317                "tool_calls": []
1318            },
1319            "done": true,
1320            "total_duration": 8000000000u64,
1321            "load_duration": 6000000u64,
1322            "prompt_eval_count": 10u64,
1323            "prompt_eval_duration": 400000000u64,
1324            "eval_count": 5u64,
1325            "eval_duration": 7700000000u64
1326        });
1327
1328        let chat_resp: CompletionResponse =
1329            serde_json::from_value(sample_response).expect("Failed to deserialize");
1330
1331        if let Message::Assistant {
1332            thinking, content, ..
1333        } = &chat_resp.message
1334        {
1335            // Empty string should still deserialize as Some("")
1336            assert_eq!(thinking.as_ref().unwrap(), "");
1337            assert_eq!(content, "Response");
1338        } else {
1339            panic!("Expected Assistant message");
1340        }
1341    }
1342
1343    // Test thinking with tool calls
1344    #[test]
1345    fn test_thinking_with_tool_calls() {
1346        let sample_response = json!({
1347            "model": "qwen-thinking",
1348            "created_at": "2023-08-04T19:22:45.499127Z",
1349            "message": {
1350                "role": "assistant",
1351                "content": "Let me check the weather.",
1352                "thinking": "User wants weather info, I should use the weather tool",
1353                "images": null,
1354                "tool_calls": [
1355                    {
1356                        "type": "function",
1357                        "function": {
1358                            "name": "get_weather",
1359                            "arguments": {
1360                                "location": "San Francisco"
1361                            }
1362                        }
1363                    }
1364                ]
1365            },
1366            "done": true,
1367            "total_duration": 8000000000u64,
1368            "load_duration": 6000000u64,
1369            "prompt_eval_count": 30u64,
1370            "prompt_eval_duration": 400000000u64,
1371            "eval_count": 50u64,
1372            "eval_duration": 7700000000u64
1373        });
1374
1375        let chat_resp: CompletionResponse =
1376            serde_json::from_value(sample_response).expect("Failed to deserialize");
1377
1378        if let Message::Assistant {
1379            thinking,
1380            content,
1381            tool_calls,
1382            ..
1383        } = &chat_resp.message
1384        {
1385            assert_eq!(
1386                thinking.as_ref().unwrap(),
1387                "User wants weather info, I should use the weather tool"
1388            );
1389            assert_eq!(content, "Let me check the weather.");
1390            assert_eq!(tool_calls.len(), 1);
1391            assert_eq!(tool_calls[0].function.name, "get_weather");
1392        } else {
1393            panic!("Expected Assistant message with thinking and tool calls");
1394        }
1395    }
1396}