Skip to main content

rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust,ignore
5//! use rig::client::{Nothing, CompletionClient};
6//! use rig::completion::Prompt;
7//! use rig::providers::ollama;
8//!
9//! // Create a new Ollama client (defaults to http://localhost:11434)
10//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct
11//! let client = ollama::Client::new(Nothing).unwrap();
12//!
13//! // Create an agent with a preamble
14//! let comedian_agent = client
15//!     .agent("qwen2.5:14b")
16//!     .preamble("You are a comedian here to entertain the user using humour and jokes.")
17//!     .build();
18//!
19//! // Prompt the agent and print the response
20//! let response = comedian_agent.prompt("Entertain me!").await?;
21//! println!("{response}");
22//!
23//! // Create an embedding model using the "all-minilm" model
24//! let emb_model = client.embedding_model("all-minilm", 384);
25//! let embeddings = emb_model.embed_texts(vec![
26//!     "Why is the sky blue?".to_owned(),
27//!     "Why is the grass green?".to_owned()
28//! ]).await?;
29//! println!("Embedding response: {:?}", embeddings);
30//!
31//! // Create an extractor if needed
32//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
33//! ```
34use crate::client::{
35    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42    OneOrMany,
43    completion::{self, CompletionError, CompletionRequest},
44    embeddings::{self, EmbeddingError},
45    json_utils, message,
46    message::{ImageDetail, Text},
47    streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57// ---------- Main Client ----------
58
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68    type Builder = OllamaBuilder;
69    const VERIFY_PATH: &'static str = "api/tags";
70}
71
72impl<H> Capabilities<H> for OllamaExt {
73    type Completion = Capable<CompletionModel<H>>;
74    type Transcription = Nothing;
75    type Embeddings = Capable<EmbeddingModel<H>>;
76    type ModelListing = Nothing;
77    #[cfg(feature = "image")]
78    type ImageGeneration = Nothing;
79
80    #[cfg(feature = "audio")]
81    type AudioGeneration = Nothing;
82}
83
84impl DebugExt for OllamaExt {}
85
86impl ProviderBuilder for OllamaBuilder {
87    type Extension<H>
88        = OllamaExt
89    where
90        H: HttpClientExt;
91    type ApiKey = Nothing;
92
93    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
94
95    fn build<H>(
96        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97    ) -> http_client::Result<Self::Extension<H>>
98    where
99        H: HttpClientExt,
100    {
101        Ok(OllamaExt)
102    }
103}
104
105pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
106pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
107
108impl ProviderClient for Client {
109    type Input = Nothing;
110
111    fn from_env() -> Self {
112        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
113
114        Self::builder()
115            .api_key(Nothing)
116            .base_url(&api_base)
117            .build()
118            .unwrap()
119    }
120
121    fn from_val(_: Self::Input) -> Self {
122        Self::builder().api_key(Nothing).build().unwrap()
123    }
124}
125
126// ---------- API Error and Response Structures ----------
127
128#[derive(Debug, Deserialize)]
129struct ApiErrorResponse {
130    message: String,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum ApiResponse<T> {
136    Ok(T),
137    Err(ApiErrorResponse),
138}
139
140// ---------- Embedding API ----------
141
142pub const ALL_MINILM: &str = "all-minilm";
143pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
144
145fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
146    match identifier {
147        ALL_MINILM => Some(384),
148        NOMIC_EMBED_TEXT => Some(768),
149        _ => None,
150    }
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct EmbeddingResponse {
155    pub model: String,
156    pub embeddings: Vec<Vec<f64>>,
157    #[serde(default)]
158    pub total_duration: Option<u64>,
159    #[serde(default)]
160    pub load_duration: Option<u64>,
161    #[serde(default)]
162    pub prompt_eval_count: Option<u64>,
163}
164
165impl From<ApiErrorResponse> for EmbeddingError {
166    fn from(err: ApiErrorResponse) -> Self {
167        EmbeddingError::ProviderError(err.message)
168    }
169}
170
171impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
172    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
173        match value {
174            ApiResponse::Ok(response) => Ok(response),
175            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
176        }
177    }
178}
179
180// ---------- Embedding Model ----------
181
182#[derive(Clone)]
183pub struct EmbeddingModel<T = reqwest::Client> {
184    client: Client<T>,
185    pub model: String,
186    ndims: usize,
187}
188
189impl<T> EmbeddingModel<T> {
190    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
191        Self {
192            client,
193            model: model.into(),
194            ndims,
195        }
196    }
197
198    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
199        Self {
200            client,
201            model: model.into(),
202            ndims,
203        }
204    }
205}
206
207impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
208where
209    T: HttpClientExt + Clone + 'static,
210{
211    type Client = Client<T>;
212
213    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
214        let model = model.into();
215        let dims = dims
216            .or(model_dimensions_from_identifier(&model))
217            .unwrap_or_default();
218        Self::new(client.clone(), model, dims)
219    }
220
221    const MAX_DOCUMENTS: usize = 1024;
222    fn ndims(&self) -> usize {
223        self.ndims
224    }
225
226    async fn embed_texts(
227        &self,
228        documents: impl IntoIterator<Item = String>,
229    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
230        let docs: Vec<String> = documents.into_iter().collect();
231
232        let body = serde_json::to_vec(&json!({
233            "model": self.model,
234            "input": docs
235        }))?;
236
237        let req = self
238            .client
239            .post("api/embed")?
240            .body(body)
241            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
242
243        let response = self.client.send(req).await?;
244
245        if !response.status().is_success() {
246            let text = http_client::text(response).await?;
247            return Err(EmbeddingError::ProviderError(text));
248        }
249
250        let bytes: Vec<u8> = response.into_body().await?;
251
252        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
253
254        if api_resp.embeddings.len() != docs.len() {
255            return Err(EmbeddingError::ResponseError(
256                "Number of returned embeddings does not match input".into(),
257            ));
258        }
259        Ok(api_resp
260            .embeddings
261            .into_iter()
262            .zip(docs.into_iter())
263            .map(|(vec, document)| embeddings::Embedding { document, vec })
264            .collect())
265    }
266}
267
268// ---------- Completion API ----------
269
270pub const LLAMA3_2: &str = "llama3.2";
271pub const LLAVA: &str = "llava";
272pub const MISTRAL: &str = "mistral";
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct CompletionResponse {
276    pub model: String,
277    pub created_at: String,
278    pub message: Message,
279    pub done: bool,
280    #[serde(default)]
281    pub done_reason: Option<String>,
282    #[serde(default)]
283    pub total_duration: Option<u64>,
284    #[serde(default)]
285    pub load_duration: Option<u64>,
286    #[serde(default)]
287    pub prompt_eval_count: Option<u64>,
288    #[serde(default)]
289    pub prompt_eval_duration: Option<u64>,
290    #[serde(default)]
291    pub eval_count: Option<u64>,
292    #[serde(default)]
293    pub eval_duration: Option<u64>,
294}
295impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
296    type Error = CompletionError;
297    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
298        match resp.message {
299            // Process only if an assistant message is present.
300            Message::Assistant {
301                content,
302                thinking,
303                tool_calls,
304                ..
305            } => {
306                let mut assistant_contents = Vec::new();
307                // Add the assistant's text content if any.
308                if !content.is_empty() {
309                    assistant_contents.push(completion::AssistantContent::text(&content));
310                }
311                // Process tool_calls following Ollama's chat response definition.
312                // Each ToolCall has an id, a type, and a function field.
313                for tc in tool_calls.iter() {
314                    assistant_contents.push(completion::AssistantContent::tool_call(
315                        tc.function.name.clone(),
316                        tc.function.name.clone(),
317                        tc.function.arguments.clone(),
318                    ));
319                }
320                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
321                    CompletionError::ResponseError("No content provided".to_owned())
322                })?;
323                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
324                let completion_tokens = resp.eval_count.unwrap_or(0);
325
326                let raw_response = CompletionResponse {
327                    model: resp.model,
328                    created_at: resp.created_at,
329                    done: resp.done,
330                    done_reason: resp.done_reason,
331                    total_duration: resp.total_duration,
332                    load_duration: resp.load_duration,
333                    prompt_eval_count: resp.prompt_eval_count,
334                    prompt_eval_duration: resp.prompt_eval_duration,
335                    eval_count: resp.eval_count,
336                    eval_duration: resp.eval_duration,
337                    message: Message::Assistant {
338                        content,
339                        thinking,
340                        images: None,
341                        name: None,
342                        tool_calls,
343                    },
344                };
345
346                Ok(completion::CompletionResponse {
347                    choice,
348                    usage: Usage {
349                        input_tokens: prompt_tokens,
350                        output_tokens: completion_tokens,
351                        total_tokens: prompt_tokens + completion_tokens,
352                        cached_input_tokens: 0,
353                        cache_creation_input_tokens: 0,
354                    },
355                    raw_response,
356                    message_id: None,
357                })
358            }
359            _ => Err(CompletionError::ResponseError(
360                "Chat response does not include an assistant message".into(),
361            )),
362        }
363    }
364}
365
366#[derive(Debug, Serialize, Deserialize)]
367pub(super) struct OllamaCompletionRequest {
368    model: String,
369    pub messages: Vec<Message>,
370    #[serde(skip_serializing_if = "Option::is_none")]
371    temperature: Option<f64>,
372    #[serde(skip_serializing_if = "Vec::is_empty")]
373    tools: Vec<ToolDefinition>,
374    pub stream: bool,
375    think: bool,
376    #[serde(skip_serializing_if = "Option::is_none")]
377    max_tokens: Option<u64>,
378    #[serde(skip_serializing_if = "Option::is_none")]
379    keep_alive: Option<String>,
380    #[serde(skip_serializing_if = "Option::is_none")]
381    format: Option<schemars::Schema>,
382    options: serde_json::Value,
383}
384
385impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
386    type Error = CompletionError;
387
388    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
389        let model = req.model.clone().unwrap_or_else(|| model.to_string());
390        if req.tool_choice.is_some() {
391            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
392        }
393        // Build up the order of messages (context, chat_history, prompt)
394        let mut partial_history = vec![];
395        if let Some(docs) = req.normalized_documents() {
396            partial_history.push(docs);
397        }
398        partial_history.extend(req.chat_history);
399
400        // Add preamble to chat history (if available)
401        let mut full_history: Vec<Message> = match &req.preamble {
402            Some(preamble) => vec![Message::system(preamble)],
403            None => vec![],
404        };
405
406        // Convert and extend the rest of the history
407        full_history.extend(
408            partial_history
409                .into_iter()
410                .map(message::Message::try_into)
411                .collect::<Result<Vec<Vec<Message>>, _>>()?
412                .into_iter()
413                .flatten()
414                .collect::<Vec<_>>(),
415        );
416
417        let mut think = false;
418        let mut keep_alive: Option<String> = None;
419
420        let options = if let Some(mut extra) = req.additional_params {
421            // Extract top-level parameters that should not be in `options`
422            if let Some(obj) = extra.as_object_mut() {
423                // Extract `think` parameter
424                if let Some(think_val) = obj.remove("think") {
425                    think = think_val.as_bool().ok_or_else(|| {
426                        CompletionError::RequestError("`think` must be a bool".into())
427                    })?;
428                }
429
430                // Extract `keep_alive` parameter
431                if let Some(keep_alive_val) = obj.remove("keep_alive") {
432                    keep_alive = Some(
433                        keep_alive_val
434                            .as_str()
435                            .ok_or_else(|| {
436                                CompletionError::RequestError(
437                                    "`keep_alive` must be a string".into(),
438                                )
439                            })?
440                            .to_string(),
441                    );
442                }
443            }
444
445            json_utils::merge(json!({ "temperature": req.temperature }), extra)
446        } else {
447            json!({ "temperature": req.temperature })
448        };
449
450        Ok(Self {
451            model: model.to_string(),
452            messages: full_history,
453            temperature: req.temperature,
454            max_tokens: req.max_tokens,
455            stream: false,
456            think,
457            keep_alive,
458            format: req.output_schema,
459            tools: req
460                .tools
461                .clone()
462                .into_iter()
463                .map(ToolDefinition::from)
464                .collect::<Vec<_>>(),
465            options,
466        })
467    }
468}
469
470#[derive(Clone)]
471pub struct CompletionModel<T = reqwest::Client> {
472    client: Client<T>,
473    pub model: String,
474}
475
476impl<T> CompletionModel<T> {
477    pub fn new(client: Client<T>, model: &str) -> Self {
478        Self {
479            client,
480            model: model.to_owned(),
481        }
482    }
483}
484
485// ---------- CompletionModel Implementation ----------
486
487#[derive(Clone, Serialize, Deserialize, Debug)]
488pub struct StreamingCompletionResponse {
489    pub done_reason: Option<String>,
490    pub total_duration: Option<u64>,
491    pub load_duration: Option<u64>,
492    pub prompt_eval_count: Option<u64>,
493    pub prompt_eval_duration: Option<u64>,
494    pub eval_count: Option<u64>,
495    pub eval_duration: Option<u64>,
496}
497
498impl GetTokenUsage for StreamingCompletionResponse {
499    fn token_usage(&self) -> Option<crate::completion::Usage> {
500        let mut usage = crate::completion::Usage::new();
501        let input_tokens = self.prompt_eval_count.unwrap_or_default();
502        let output_tokens = self.eval_count.unwrap_or_default();
503        usage.input_tokens = input_tokens;
504        usage.output_tokens = output_tokens;
505        usage.total_tokens = input_tokens + output_tokens;
506
507        Some(usage)
508    }
509}
510
511impl<T> completion::CompletionModel for CompletionModel<T>
512where
513    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
514{
515    type Response = CompletionResponse;
516    type StreamingResponse = StreamingCompletionResponse;
517
518    type Client = Client<T>;
519
520    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
521        Self::new(client.clone(), model.into().as_str())
522    }
523
524    async fn completion(
525        &self,
526        completion_request: CompletionRequest,
527    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
528        let span = if tracing::Span::current().is_disabled() {
529            info_span!(
530                target: "rig::completions",
531                "chat",
532                gen_ai.operation.name = "chat",
533                gen_ai.provider.name = "ollama",
534                gen_ai.request.model = self.model,
535                gen_ai.system_instructions = tracing::field::Empty,
536                gen_ai.response.id = tracing::field::Empty,
537                gen_ai.response.model = tracing::field::Empty,
538                gen_ai.usage.output_tokens = tracing::field::Empty,
539                gen_ai.usage.input_tokens = tracing::field::Empty,
540                gen_ai.usage.cached_tokens = tracing::field::Empty,
541            )
542        } else {
543            tracing::Span::current()
544        };
545
546        span.record("gen_ai.system_instructions", &completion_request.preamble);
547        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
548
549        if tracing::enabled!(tracing::Level::TRACE) {
550            tracing::trace!(target: "rig::completions",
551                "Ollama completion request: {}",
552                serde_json::to_string_pretty(&request)?
553            );
554        }
555
556        let body = serde_json::to_vec(&request)?;
557
558        let req = self
559            .client
560            .post("api/chat")?
561            .body(body)
562            .map_err(http_client::Error::from)?;
563
564        let async_block = async move {
565            let response = self.client.send::<_, Bytes>(req).await?;
566            let status = response.status();
567            let response_body = response.into_body().into_future().await?.to_vec();
568
569            if !status.is_success() {
570                return Err(CompletionError::ProviderError(
571                    String::from_utf8_lossy(&response_body).to_string(),
572                ));
573            }
574
575            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
576            let span = tracing::Span::current();
577            span.record("gen_ai.response.model_name", &response.model);
578            span.record(
579                "gen_ai.usage.input_tokens",
580                response.prompt_eval_count.unwrap_or_default(),
581            );
582            span.record(
583                "gen_ai.usage.output_tokens",
584                response.eval_count.unwrap_or_default(),
585            );
586
587            if tracing::enabled!(tracing::Level::TRACE) {
588                tracing::trace!(target: "rig::completions",
589                    "Ollama completion response: {}",
590                    serde_json::to_string_pretty(&response)?
591                );
592            }
593
594            let response: completion::CompletionResponse<CompletionResponse> =
595                response.try_into()?;
596
597            Ok(response)
598        };
599
600        tracing::Instrument::instrument(async_block, span).await
601    }
602
603    async fn stream(
604        &self,
605        request: CompletionRequest,
606    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
607    {
608        let span = if tracing::Span::current().is_disabled() {
609            info_span!(
610                target: "rig::completions",
611                "chat_streaming",
612                gen_ai.operation.name = "chat_streaming",
613                gen_ai.provider.name = "ollama",
614                gen_ai.request.model = self.model,
615                gen_ai.system_instructions = tracing::field::Empty,
616                gen_ai.response.id = tracing::field::Empty,
617                gen_ai.response.model = self.model,
618                gen_ai.usage.output_tokens = tracing::field::Empty,
619                gen_ai.usage.input_tokens = tracing::field::Empty,
620                gen_ai.usage.cached_tokens = tracing::field::Empty,
621            )
622        } else {
623            tracing::Span::current()
624        };
625
626        span.record("gen_ai.system_instructions", &request.preamble);
627
628        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
629        request.stream = true;
630
631        if tracing::enabled!(tracing::Level::TRACE) {
632            tracing::trace!(target: "rig::completions",
633                "Ollama streaming completion request: {}",
634                serde_json::to_string_pretty(&request)?
635            );
636        }
637
638        let body = serde_json::to_vec(&request)?;
639
640        let req = self
641            .client
642            .post("api/chat")?
643            .body(body)
644            .map_err(http_client::Error::from)?;
645
646        let response = self.client.send_streaming(req).await?;
647        let status = response.status();
648        let mut byte_stream = response.into_body();
649
650        if !status.is_success() {
651            return Err(CompletionError::ProviderError(format!(
652                "Got error status code trying to send a request to Ollama: {status}"
653            )));
654        }
655
656        let stream = try_stream! {
657            let span = tracing::Span::current();
658            let mut tool_calls_final = Vec::new();
659            let mut text_response = String::new();
660            let mut thinking_response = String::new();
661
662            while let Some(chunk) = byte_stream.next().await {
663                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
664
665                for line in bytes.split(|&b| b == b'\n') {
666                    if line.is_empty() {
667                        continue;
668                    }
669
670                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
671
672                    let response: CompletionResponse = serde_json::from_slice(line)?;
673
674                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
675                        if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
676                            thinking_response += &thinking_content;
677                            yield RawStreamingChoice::ReasoningDelta {
678                                id: None,
679                                reasoning: thinking_content,
680                            };
681                        }
682
683                        if !content.is_empty() {
684                            text_response += &content;
685                            yield RawStreamingChoice::Message(content);
686                        }
687
688                        for tool_call in tool_calls {
689                            tool_calls_final.push(tool_call.clone());
690                            yield RawStreamingChoice::ToolCall(
691                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
692                            );
693                        }
694                    }
695
696                    if response.done {
697                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
698                        span.record("gen_ai.usage.output_tokens", response.eval_count);
699                        let message = Message::Assistant {
700                            content: text_response.clone(),
701                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
702                            images: None,
703                            name: None,
704                            tool_calls: tool_calls_final.clone()
705                        };
706                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
707                        yield RawStreamingChoice::FinalResponse(
708                            StreamingCompletionResponse {
709                                total_duration: response.total_duration,
710                                load_duration: response.load_duration,
711                                prompt_eval_count: response.prompt_eval_count,
712                                prompt_eval_duration: response.prompt_eval_duration,
713                                eval_count: response.eval_count,
714                                eval_duration: response.eval_duration,
715                                done_reason: response.done_reason,
716                            }
717                        );
718                        break;
719                    }
720                }
721            }
722        }.instrument(span);
723
724        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
725            stream,
726        )))
727    }
728}
729
730// ---------- Tool Definition Conversion ----------
731
732/// Ollama-required tool definition format.
733#[derive(Clone, Debug, Deserialize, Serialize)]
734pub struct ToolDefinition {
735    #[serde(rename = "type")]
736    pub type_field: String, // Fixed as "function"
737    pub function: completion::ToolDefinition,
738}
739
740/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
741impl From<crate::completion::ToolDefinition> for ToolDefinition {
742    fn from(tool: crate::completion::ToolDefinition) -> Self {
743        ToolDefinition {
744            type_field: "function".to_owned(),
745            function: completion::ToolDefinition {
746                name: tool.name,
747                description: tool.description,
748                parameters: tool.parameters,
749            },
750        }
751    }
752}
753
754#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
755pub struct ToolCall {
756    #[serde(default, rename = "type")]
757    pub r#type: ToolType,
758    pub function: Function,
759}
760#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
761#[serde(rename_all = "lowercase")]
762pub enum ToolType {
763    #[default]
764    Function,
765}
766#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
767pub struct Function {
768    pub name: String,
769    pub arguments: Value,
770}
771
772// ---------- Provider Message Definition ----------
773
774#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
775#[serde(tag = "role", rename_all = "lowercase")]
776pub enum Message {
777    User {
778        content: String,
779        #[serde(skip_serializing_if = "Option::is_none")]
780        images: Option<Vec<String>>,
781        #[serde(skip_serializing_if = "Option::is_none")]
782        name: Option<String>,
783    },
784    Assistant {
785        #[serde(default)]
786        content: String,
787        #[serde(skip_serializing_if = "Option::is_none")]
788        thinking: Option<String>,
789        #[serde(skip_serializing_if = "Option::is_none")]
790        images: Option<Vec<String>>,
791        #[serde(skip_serializing_if = "Option::is_none")]
792        name: Option<String>,
793        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
794        tool_calls: Vec<ToolCall>,
795    },
796    System {
797        content: String,
798        #[serde(skip_serializing_if = "Option::is_none")]
799        images: Option<Vec<String>>,
800        #[serde(skip_serializing_if = "Option::is_none")]
801        name: Option<String>,
802    },
803    #[serde(rename = "tool")]
804    ToolResult {
805        #[serde(rename = "tool_name")]
806        name: String,
807        content: String,
808    },
809}
810
811/// -----------------------------
812/// Provider Message Conversions
813/// -----------------------------
814/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
815/// (Only User and Assistant variants are supported.)
816impl TryFrom<crate::message::Message> for Vec<Message> {
817    type Error = crate::message::MessageError;
818    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
819        use crate::message::Message as InternalMessage;
820        match internal_msg {
821            InternalMessage::System { content } => Ok(vec![Message::System {
822                content,
823                images: None,
824                name: None,
825            }]),
826            InternalMessage::User { content, .. } => {
827                let (tool_results, other_content): (Vec<_>, Vec<_>) =
828                    content.into_iter().partition(|content| {
829                        matches!(content, crate::message::UserContent::ToolResult(_))
830                    });
831
832                if !tool_results.is_empty() {
833                    tool_results
834                        .into_iter()
835                        .map(|content| match content {
836                            crate::message::UserContent::ToolResult(
837                                crate::message::ToolResult { id, content, .. },
838                            ) => {
839                                // Ollama expects a single string for tool results, so we concatenate
840                                let content_string = content
841                                    .into_iter()
842                                    .map(|content| match content {
843                                        crate::message::ToolResultContent::Text(text) => text.text,
844                                        _ => "[Non-text content]".to_string(),
845                                    })
846                                    .collect::<Vec<_>>()
847                                    .join("\n");
848
849                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
850                                    name: id,
851                                    content: content_string,
852                                })
853                            }
854                            _ => unreachable!(),
855                        })
856                        .collect::<Result<Vec<_>, _>>()
857                } else {
858                    // Ollama requires separate text content and images array
859                    let (texts, images) = other_content.into_iter().fold(
860                        (Vec::new(), Vec::new()),
861                        |(mut texts, mut images), content| {
862                            match content {
863                                crate::message::UserContent::Text(crate::message::Text {
864                                    text,
865                                }) => texts.push(text),
866                                crate::message::UserContent::Image(crate::message::Image {
867                                    data: DocumentSourceKind::Base64(data),
868                                    ..
869                                }) => images.push(data),
870                                crate::message::UserContent::Document(
871                                    crate::message::Document {
872                                        data:
873                                            DocumentSourceKind::Base64(data)
874                                            | DocumentSourceKind::String(data),
875                                        ..
876                                    },
877                                ) => texts.push(data),
878                                _ => {} // Audio not supported by Ollama
879                            }
880                            (texts, images)
881                        },
882                    );
883
884                    Ok(vec![Message::User {
885                        content: texts.join(" "),
886                        images: if images.is_empty() {
887                            None
888                        } else {
889                            Some(
890                                images
891                                    .into_iter()
892                                    .map(|x| x.to_string())
893                                    .collect::<Vec<String>>(),
894                            )
895                        },
896                        name: None,
897                    }])
898                }
899            }
900            InternalMessage::Assistant { content, .. } => {
901                let mut thinking: Option<String> = None;
902                let mut text_content = Vec::new();
903                let mut tool_calls = Vec::new();
904
905                for content in content.into_iter() {
906                    match content {
907                        crate::message::AssistantContent::Text(text) => {
908                            text_content.push(text.text)
909                        }
910                        crate::message::AssistantContent::ToolCall(tool_call) => {
911                            tool_calls.push(tool_call)
912                        }
913                        crate::message::AssistantContent::Reasoning(reasoning) => {
914                            let display = reasoning.display_text();
915                            if !display.is_empty() {
916                                thinking = Some(display);
917                            }
918                        }
919                        crate::message::AssistantContent::Image(_) => {
920                            return Err(crate::message::MessageError::ConversionError(
921                                "Ollama currently doesn't support images.".into(),
922                            ));
923                        }
924                    }
925                }
926
927                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
928                //  so either `content` or `tool_calls` will have some content.
929                Ok(vec![Message::Assistant {
930                    content: text_content.join(" "),
931                    thinking,
932                    images: None,
933                    name: None,
934                    tool_calls: tool_calls
935                        .into_iter()
936                        .map(|tool_call| tool_call.into())
937                        .collect::<Vec<_>>(),
938                }])
939            }
940        }
941    }
942}
943
944/// Conversion from provider Message to a completion message.
945/// This is needed so that responses can be converted back into chat history.
946impl From<Message> for crate::completion::Message {
947    fn from(msg: Message) -> Self {
948        match msg {
949            Message::User { content, .. } => crate::completion::Message::User {
950                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
951                    text: content,
952                })),
953            },
954            Message::Assistant {
955                content,
956                tool_calls,
957                ..
958            } => {
959                let mut assistant_contents =
960                    vec![crate::completion::message::AssistantContent::Text(Text {
961                        text: content,
962                    })];
963                for tc in tool_calls {
964                    assistant_contents.push(
965                        crate::completion::message::AssistantContent::tool_call(
966                            tc.function.name.clone(),
967                            tc.function.name,
968                            tc.function.arguments,
969                        ),
970                    );
971                }
972                crate::completion::Message::Assistant {
973                    id: None,
974                    content: OneOrMany::many(assistant_contents).unwrap(),
975                }
976            }
977            // System and ToolResult are converted to User message as needed.
978            Message::System { content, .. } => crate::completion::Message::User {
979                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
980                    text: content,
981                })),
982            },
983            Message::ToolResult { name, content } => crate::completion::Message::User {
984                content: OneOrMany::one(message::UserContent::tool_result(
985                    name,
986                    OneOrMany::one(message::ToolResultContent::text(content)),
987                )),
988            },
989        }
990    }
991}
992
993impl Message {
994    /// Constructs a system message.
995    pub fn system(content: &str) -> Self {
996        Message::System {
997            content: content.to_owned(),
998            images: None,
999            name: None,
1000        }
1001    }
1002}
1003
1004// ---------- Additional Message Types ----------
1005
1006impl From<crate::message::ToolCall> for ToolCall {
1007    fn from(tool_call: crate::message::ToolCall) -> Self {
1008        Self {
1009            r#type: ToolType::Function,
1010            function: Function {
1011                name: tool_call.function.name,
1012                arguments: tool_call.function.arguments,
1013            },
1014        }
1015    }
1016}
1017
1018#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1019pub struct SystemContent {
1020    #[serde(default)]
1021    r#type: SystemContentType,
1022    text: String,
1023}
1024
1025#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1026#[serde(rename_all = "lowercase")]
1027pub enum SystemContentType {
1028    #[default]
1029    Text,
1030}
1031
1032impl From<String> for SystemContent {
1033    fn from(s: String) -> Self {
1034        SystemContent {
1035            r#type: SystemContentType::default(),
1036            text: s,
1037        }
1038    }
1039}
1040
1041impl FromStr for SystemContent {
1042    type Err = std::convert::Infallible;
1043    fn from_str(s: &str) -> Result<Self, Self::Err> {
1044        Ok(SystemContent {
1045            r#type: SystemContentType::default(),
1046            text: s.to_string(),
1047        })
1048    }
1049}
1050
1051#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1052pub struct AssistantContent {
1053    pub text: String,
1054}
1055
1056impl FromStr for AssistantContent {
1057    type Err = std::convert::Infallible;
1058    fn from_str(s: &str) -> Result<Self, Self::Err> {
1059        Ok(AssistantContent { text: s.to_owned() })
1060    }
1061}
1062
1063#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1064#[serde(tag = "type", rename_all = "lowercase")]
1065pub enum UserContent {
1066    Text { text: String },
1067    Image { image_url: ImageUrl },
1068    // Audio variant removed as Ollama API does not support audio input.
1069}
1070
1071impl FromStr for UserContent {
1072    type Err = std::convert::Infallible;
1073    fn from_str(s: &str) -> Result<Self, Self::Err> {
1074        Ok(UserContent::Text { text: s.to_owned() })
1075    }
1076}
1077
1078#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1079pub struct ImageUrl {
1080    pub url: String,
1081    #[serde(default)]
1082    pub detail: ImageDetail,
1083}
1084
1085// =================================================================
1086// Tests
1087// =================================================================
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092    use serde_json::json;
1093
1094    // Test deserialization and conversion for the /api/chat endpoint.
1095    #[tokio::test]
1096    async fn test_chat_completion() {
1097        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1098        let sample_chat_response = json!({
1099            "model": "llama3.2",
1100            "created_at": "2023-08-04T19:22:45.499127Z",
1101            "message": {
1102                "role": "assistant",
1103                "content": "The sky is blue because of Rayleigh scattering.",
1104                "images": null,
1105                "tool_calls": [
1106                    {
1107                        "type": "function",
1108                        "function": {
1109                            "name": "get_current_weather",
1110                            "arguments": {
1111                                "location": "San Francisco, CA",
1112                                "format": "celsius"
1113                            }
1114                        }
1115                    }
1116                ]
1117            },
1118            "done": true,
1119            "total_duration": 8000000000u64,
1120            "load_duration": 6000000u64,
1121            "prompt_eval_count": 61u64,
1122            "prompt_eval_duration": 400000000u64,
1123            "eval_count": 468u64,
1124            "eval_duration": 7700000000u64
1125        });
1126        let sample_text = sample_chat_response.to_string();
1127
1128        let chat_resp: CompletionResponse =
1129            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1130        let conv: completion::CompletionResponse<CompletionResponse> =
1131            chat_resp.try_into().unwrap();
1132        assert!(
1133            !conv.choice.is_empty(),
1134            "Expected non-empty choice in chat response"
1135        );
1136    }
1137
1138    // Test conversion from provider Message to completion Message.
1139    #[test]
1140    fn test_message_conversion() {
1141        // Construct a provider Message (User variant with String content).
1142        let provider_msg = Message::User {
1143            content: "Test message".to_owned(),
1144            images: None,
1145            name: None,
1146        };
1147        // Convert it into a completion::Message.
1148        let comp_msg: crate::completion::Message = provider_msg.into();
1149        match comp_msg {
1150            crate::completion::Message::User { content } => {
1151                // Assume OneOrMany<T> has a method first() to access the first element.
1152                let first_content = content.first();
1153                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1154                match first_content {
1155                    crate::completion::message::UserContent::Text(text_struct) => {
1156                        assert_eq!(text_struct.text, "Test message");
1157                    }
1158                    _ => panic!("Expected text content in conversion"),
1159                }
1160            }
1161            _ => panic!("Conversion from provider Message to completion Message failed"),
1162        }
1163    }
1164
1165    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1166    #[test]
1167    fn test_tool_definition_conversion() {
1168        // Internal tool definition from the completion module.
1169        let internal_tool = crate::completion::ToolDefinition {
1170            name: "get_current_weather".to_owned(),
1171            description: "Get the current weather for a location".to_owned(),
1172            parameters: json!({
1173                "type": "object",
1174                "properties": {
1175                    "location": {
1176                        "type": "string",
1177                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1178                    },
1179                    "format": {
1180                        "type": "string",
1181                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1182                        "enum": ["celsius", "fahrenheit"]
1183                    }
1184                },
1185                "required": ["location", "format"]
1186            }),
1187        };
1188        // Convert internal tool to Ollama's tool definition.
1189        let ollama_tool: ToolDefinition = internal_tool.into();
1190        assert_eq!(ollama_tool.type_field, "function");
1191        assert_eq!(ollama_tool.function.name, "get_current_weather");
1192        assert_eq!(
1193            ollama_tool.function.description,
1194            "Get the current weather for a location"
1195        );
1196        // Check JSON fields in parameters.
1197        let params = &ollama_tool.function.parameters;
1198        assert_eq!(params["properties"]["location"]["type"], "string");
1199    }
1200
1201    // Test deserialization of chat response with thinking content
1202    #[tokio::test]
1203    async fn test_chat_completion_with_thinking() {
1204        let sample_response = json!({
1205            "model": "qwen-thinking",
1206            "created_at": "2023-08-04T19:22:45.499127Z",
1207            "message": {
1208                "role": "assistant",
1209                "content": "The answer is 42.",
1210                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1211                "images": null,
1212                "tool_calls": []
1213            },
1214            "done": true,
1215            "total_duration": 8000000000u64,
1216            "load_duration": 6000000u64,
1217            "prompt_eval_count": 61u64,
1218            "prompt_eval_duration": 400000000u64,
1219            "eval_count": 468u64,
1220            "eval_duration": 7700000000u64
1221        });
1222
1223        let chat_resp: CompletionResponse =
1224            serde_json::from_value(sample_response).expect("Failed to deserialize");
1225
1226        // Verify thinking field is present
1227        if let Message::Assistant {
1228            thinking, content, ..
1229        } = &chat_resp.message
1230        {
1231            assert_eq!(
1232                thinking.as_ref().unwrap(),
1233                "Let me think about this carefully. The question asks for the meaning of life..."
1234            );
1235            assert_eq!(content, "The answer is 42.");
1236        } else {
1237            panic!("Expected Assistant message");
1238        }
1239    }
1240
1241    // Test deserialization of chat response without thinking content
1242    #[tokio::test]
1243    async fn test_chat_completion_without_thinking() {
1244        let sample_response = json!({
1245            "model": "llama3.2",
1246            "created_at": "2023-08-04T19:22:45.499127Z",
1247            "message": {
1248                "role": "assistant",
1249                "content": "Hello!",
1250                "images": null,
1251                "tool_calls": []
1252            },
1253            "done": true,
1254            "total_duration": 8000000000u64,
1255            "load_duration": 6000000u64,
1256            "prompt_eval_count": 10u64,
1257            "prompt_eval_duration": 400000000u64,
1258            "eval_count": 5u64,
1259            "eval_duration": 7700000000u64
1260        });
1261
1262        let chat_resp: CompletionResponse =
1263            serde_json::from_value(sample_response).expect("Failed to deserialize");
1264
1265        // Verify thinking field is None when not provided
1266        if let Message::Assistant {
1267            thinking, content, ..
1268        } = &chat_resp.message
1269        {
1270            assert!(thinking.is_none());
1271            assert_eq!(content, "Hello!");
1272        } else {
1273            panic!("Expected Assistant message");
1274        }
1275    }
1276
1277    // Test deserialization of streaming response with thinking content
1278    #[test]
1279    fn test_streaming_response_with_thinking() {
1280        let sample_chunk = json!({
1281            "model": "qwen-thinking",
1282            "created_at": "2023-08-04T19:22:45.499127Z",
1283            "message": {
1284                "role": "assistant",
1285                "content": "",
1286                "thinking": "Analyzing the problem...",
1287                "images": null,
1288                "tool_calls": []
1289            },
1290            "done": false
1291        });
1292
1293        let chunk: CompletionResponse =
1294            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1295
1296        if let Message::Assistant {
1297            thinking, content, ..
1298        } = &chunk.message
1299        {
1300            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1301            assert_eq!(content, "");
1302        } else {
1303            panic!("Expected Assistant message");
1304        }
1305    }
1306
1307    // Test message conversion with thinking content
1308    #[test]
1309    fn test_message_conversion_with_thinking() {
1310        // Create an internal message with reasoning content
1311        let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1312
1313        let internal_msg = crate::message::Message::Assistant {
1314            id: None,
1315            content: crate::OneOrMany::many(vec![
1316                crate::message::AssistantContent::Reasoning(reasoning_content),
1317                crate::message::AssistantContent::Text(crate::message::Text {
1318                    text: "The answer is X".to_string(),
1319                }),
1320            ])
1321            .unwrap(),
1322        };
1323
1324        // Convert to provider Message
1325        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1326        assert_eq!(provider_msgs.len(), 1);
1327
1328        if let Message::Assistant {
1329            thinking, content, ..
1330        } = &provider_msgs[0]
1331        {
1332            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1333            assert_eq!(content, "The answer is X");
1334        } else {
1335            panic!("Expected Assistant message with thinking");
1336        }
1337    }
1338
1339    // Test empty thinking content is handled correctly
1340    #[test]
1341    fn test_empty_thinking_content() {
1342        let sample_response = json!({
1343            "model": "llama3.2",
1344            "created_at": "2023-08-04T19:22:45.499127Z",
1345            "message": {
1346                "role": "assistant",
1347                "content": "Response",
1348                "thinking": "",
1349                "images": null,
1350                "tool_calls": []
1351            },
1352            "done": true,
1353            "total_duration": 8000000000u64,
1354            "load_duration": 6000000u64,
1355            "prompt_eval_count": 10u64,
1356            "prompt_eval_duration": 400000000u64,
1357            "eval_count": 5u64,
1358            "eval_duration": 7700000000u64
1359        });
1360
1361        let chat_resp: CompletionResponse =
1362            serde_json::from_value(sample_response).expect("Failed to deserialize");
1363
1364        if let Message::Assistant {
1365            thinking, content, ..
1366        } = &chat_resp.message
1367        {
1368            // Empty string should still deserialize as Some("")
1369            assert_eq!(thinking.as_ref().unwrap(), "");
1370            assert_eq!(content, "Response");
1371        } else {
1372            panic!("Expected Assistant message");
1373        }
1374    }
1375
1376    // Test thinking with tool calls
1377    #[test]
1378    fn test_thinking_with_tool_calls() {
1379        let sample_response = json!({
1380            "model": "qwen-thinking",
1381            "created_at": "2023-08-04T19:22:45.499127Z",
1382            "message": {
1383                "role": "assistant",
1384                "content": "Let me check the weather.",
1385                "thinking": "User wants weather info, I should use the weather tool",
1386                "images": null,
1387                "tool_calls": [
1388                    {
1389                        "type": "function",
1390                        "function": {
1391                            "name": "get_weather",
1392                            "arguments": {
1393                                "location": "San Francisco"
1394                            }
1395                        }
1396                    }
1397                ]
1398            },
1399            "done": true,
1400            "total_duration": 8000000000u64,
1401            "load_duration": 6000000u64,
1402            "prompt_eval_count": 30u64,
1403            "prompt_eval_duration": 400000000u64,
1404            "eval_count": 50u64,
1405            "eval_duration": 7700000000u64
1406        });
1407
1408        let chat_resp: CompletionResponse =
1409            serde_json::from_value(sample_response).expect("Failed to deserialize");
1410
1411        if let Message::Assistant {
1412            thinking,
1413            content,
1414            tool_calls,
1415            ..
1416        } = &chat_resp.message
1417        {
1418            assert_eq!(
1419                thinking.as_ref().unwrap(),
1420                "User wants weather info, I should use the weather tool"
1421            );
1422            assert_eq!(content, "Let me check the weather.");
1423            assert_eq!(tool_calls.len(), 1);
1424            assert_eq!(tool_calls[0].function.name, "get_weather");
1425        } else {
1426            panic!("Expected Assistant message with thinking and tool calls");
1427        }
1428    }
1429
1430    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1431    #[test]
1432    fn test_completion_request_with_think_param() {
1433        use crate::OneOrMany;
1434        use crate::completion::Message as CompletionMessage;
1435        use crate::message::{Text, UserContent};
1436
1437        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1438        let completion_request = CompletionRequest {
1439            model: None,
1440            preamble: Some("You are a helpful assistant.".to_string()),
1441            chat_history: OneOrMany::one(CompletionMessage::User {
1442                content: OneOrMany::one(UserContent::Text(Text {
1443                    text: "What is 2 + 2?".to_string(),
1444                })),
1445            }),
1446            documents: vec![],
1447            tools: vec![],
1448            temperature: Some(0.7),
1449            max_tokens: Some(1024),
1450            tool_choice: None,
1451            additional_params: Some(json!({
1452                "think": true,
1453                "keep_alive": "-1m",
1454                "num_ctx": 4096
1455            })),
1456            output_schema: None,
1457        };
1458
1459        // Convert to OllamaCompletionRequest
1460        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1461            .expect("Failed to create Ollama request");
1462
1463        // Serialize to JSON
1464        let serialized =
1465            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1466
1467        // Assert equality with expected JSON
1468        // - "tools" is skipped when empty (skip_serializing_if)
1469        // - "think" should be a top-level boolean, NOT in options
1470        // - "keep_alive" should be a top-level string, NOT in options
1471        // - "num_ctx" should be in options (it's a model parameter)
1472        let expected = json!({
1473            "model": "qwen3:8b",
1474            "messages": [
1475                {
1476                    "role": "system",
1477                    "content": "You are a helpful assistant."
1478                },
1479                {
1480                    "role": "user",
1481                    "content": "What is 2 + 2?"
1482                }
1483            ],
1484            "temperature": 0.7,
1485            "stream": false,
1486            "think": true,
1487            "max_tokens": 1024,
1488            "keep_alive": "-1m",
1489            "options": {
1490                "temperature": 0.7,
1491                "num_ctx": 4096
1492            }
1493        });
1494
1495        assert_eq!(serialized, expected);
1496    }
1497
1498    // Test that `think` defaults to false when not specified
1499    #[test]
1500    fn test_completion_request_with_think_false_default() {
1501        use crate::OneOrMany;
1502        use crate::completion::Message as CompletionMessage;
1503        use crate::message::{Text, UserContent};
1504
1505        // Create a CompletionRequest WITHOUT "think" in additional_params
1506        let completion_request = CompletionRequest {
1507            model: None,
1508            preamble: Some("You are a helpful assistant.".to_string()),
1509            chat_history: OneOrMany::one(CompletionMessage::User {
1510                content: OneOrMany::one(UserContent::Text(Text {
1511                    text: "Hello!".to_string(),
1512                })),
1513            }),
1514            documents: vec![],
1515            tools: vec![],
1516            temperature: Some(0.5),
1517            max_tokens: None,
1518            tool_choice: None,
1519            additional_params: None,
1520            output_schema: None,
1521        };
1522
1523        // Convert to OllamaCompletionRequest
1524        let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1525            .expect("Failed to create Ollama request");
1526
1527        // Serialize to JSON
1528        let serialized =
1529            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1530
1531        // Assert that "think" defaults to false and "keep_alive" is not present
1532        let expected = json!({
1533            "model": "llama3.2",
1534            "messages": [
1535                {
1536                    "role": "system",
1537                    "content": "You are a helpful assistant."
1538                },
1539                {
1540                    "role": "user",
1541                    "content": "Hello!"
1542                }
1543            ],
1544            "temperature": 0.5,
1545            "stream": false,
1546            "think": false,
1547            "options": {
1548                "temperature": 0.5
1549            }
1550        });
1551
1552        assert_eq!(serialized, expected);
1553    }
1554
1555    #[test]
1556    fn test_completion_request_with_output_schema() {
1557        use crate::OneOrMany;
1558        use crate::completion::Message as CompletionMessage;
1559        use crate::message::{Text, UserContent};
1560
1561        let schema: schemars::Schema = serde_json::from_value(json!({
1562            "type": "object",
1563            "properties": {
1564                "age": { "type": "integer" },
1565                "available": { "type": "boolean" }
1566            },
1567            "required": ["age", "available"]
1568        }))
1569        .expect("Failed to parse schema");
1570
1571        let completion_request = CompletionRequest {
1572            model: Some("llama3.1".to_string()),
1573            preamble: None,
1574            chat_history: OneOrMany::one(CompletionMessage::User {
1575                content: OneOrMany::one(UserContent::Text(Text {
1576                    text: "How old is Ollama?".to_string(),
1577                })),
1578            }),
1579            documents: vec![],
1580            tools: vec![],
1581            temperature: None,
1582            max_tokens: None,
1583            tool_choice: None,
1584            additional_params: None,
1585            output_schema: Some(schema),
1586        };
1587
1588        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1589            .expect("Failed to create Ollama request");
1590
1591        let serialized =
1592            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1593
1594        let format = serialized
1595            .get("format")
1596            .expect("format field should be present");
1597        assert_eq!(
1598            *format,
1599            json!({
1600                "type": "object",
1601                "properties": {
1602                    "age": { "type": "integer" },
1603                    "available": { "type": "boolean" }
1604                },
1605                "required": ["age", "available"]
1606            })
1607        );
1608    }
1609
1610    #[test]
1611    fn test_completion_request_without_output_schema() {
1612        use crate::OneOrMany;
1613        use crate::completion::Message as CompletionMessage;
1614        use crate::message::{Text, UserContent};
1615
1616        let completion_request = CompletionRequest {
1617            model: Some("llama3.1".to_string()),
1618            preamble: None,
1619            chat_history: OneOrMany::one(CompletionMessage::User {
1620                content: OneOrMany::one(UserContent::Text(Text {
1621                    text: "Hello!".to_string(),
1622                })),
1623            }),
1624            documents: vec![],
1625            tools: vec![],
1626            temperature: None,
1627            max_tokens: None,
1628            tool_choice: None,
1629            additional_params: None,
1630            output_schema: None,
1631        };
1632
1633        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1634            .expect("Failed to create Ollama request");
1635
1636        let serialized =
1637            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1638
1639        assert!(
1640            serialized.get("format").is_none(),
1641            "format field should be absent when output_schema is None"
1642        );
1643    }
1644
1645    #[test]
1646    fn test_client_initialization() {
1647        let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1648        let _client_from_builder = crate::providers::ollama::Client::builder()
1649            .api_key(Nothing)
1650            .build()
1651            .expect("Client::builder() failed");
1652    }
1653}