Skip to main content

rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust,ignore
5//! use rig::client::{Nothing, CompletionClient};
6//! use rig::completion::Prompt;
7//! use rig::providers::ollama;
8//!
9//! // Create a new Ollama client (defaults to http://localhost:11434)
10//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct
11//! let client = ollama::Client::new(Nothing).unwrap();
12//!
13//! // Create an agent with a preamble
14//! let comedian_agent = client
15//!     .agent("qwen2.5:14b")
16//!     .preamble("You are a comedian here to entertain the user using humour and jokes.")
17//!     .build();
18//!
19//! // Prompt the agent and print the response
20//! let response = comedian_agent.prompt("Entertain me!").await?;
21//! println!("{response}");
22//!
23//! // Create an embedding model using the "all-minilm" model
24//! let emb_model = client.embedding_model("all-minilm", 384);
25//! let embeddings = emb_model.embed_texts(vec![
26//!     "Why is the sky blue?".to_owned(),
27//!     "Why is the grass green?".to_owned()
28//! ]).await?;
29//! println!("Embedding response: {:?}", embeddings);
30//!
31//! // Create an extractor if needed
32//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
33//! ```
34use crate::client::{
35    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42    OneOrMany,
43    completion::{self, CompletionError, CompletionRequest},
44    embeddings::{self, EmbeddingError},
45    json_utils, message,
46    message::{ImageDetail, Text},
47    streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57// ---------- Main Client ----------
58
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68    type Builder = OllamaBuilder;
69    const VERIFY_PATH: &'static str = "api/tags";
70}
71
72impl<H> Capabilities<H> for OllamaExt {
73    type Completion = Capable<CompletionModel<H>>;
74    type Transcription = Nothing;
75    type Embeddings = Capable<EmbeddingModel<H>>;
76    type ModelListing = Nothing;
77    #[cfg(feature = "image")]
78    type ImageGeneration = Nothing;
79
80    #[cfg(feature = "audio")]
81    type AudioGeneration = Nothing;
82}
83
84impl DebugExt for OllamaExt {}
85
86impl ProviderBuilder for OllamaBuilder {
87    type Extension<H>
88        = OllamaExt
89    where
90        H: HttpClientExt;
91    type ApiKey = Nothing;
92
93    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
94
95    fn build<H>(
96        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97    ) -> http_client::Result<Self::Extension<H>>
98    where
99        H: HttpClientExt,
100    {
101        Ok(OllamaExt)
102    }
103}
104
105pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
106pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
107
108impl ProviderClient for Client {
109    type Input = Nothing;
110
111    fn from_env() -> Self {
112        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
113
114        Self::builder()
115            .api_key(Nothing)
116            .base_url(&api_base)
117            .build()
118            .unwrap()
119    }
120
121    fn from_val(_: Self::Input) -> Self {
122        Self::builder().api_key(Nothing).build().unwrap()
123    }
124}
125
126// ---------- API Error and Response Structures ----------
127
128#[derive(Debug, Deserialize)]
129struct ApiErrorResponse {
130    message: String,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum ApiResponse<T> {
136    Ok(T),
137    Err(ApiErrorResponse),
138}
139
140// ---------- Embedding API ----------
141
142pub const ALL_MINILM: &str = "all-minilm";
143pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
144
145fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
146    match identifier {
147        ALL_MINILM => Some(384),
148        NOMIC_EMBED_TEXT => Some(768),
149        _ => None,
150    }
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct EmbeddingResponse {
155    pub model: String,
156    pub embeddings: Vec<Vec<f64>>,
157    #[serde(default)]
158    pub total_duration: Option<u64>,
159    #[serde(default)]
160    pub load_duration: Option<u64>,
161    #[serde(default)]
162    pub prompt_eval_count: Option<u64>,
163}
164
165impl From<ApiErrorResponse> for EmbeddingError {
166    fn from(err: ApiErrorResponse) -> Self {
167        EmbeddingError::ProviderError(err.message)
168    }
169}
170
171impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
172    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
173        match value {
174            ApiResponse::Ok(response) => Ok(response),
175            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
176        }
177    }
178}
179
180// ---------- Embedding Model ----------
181
182#[derive(Clone)]
183pub struct EmbeddingModel<T = reqwest::Client> {
184    client: Client<T>,
185    pub model: String,
186    ndims: usize,
187}
188
189impl<T> EmbeddingModel<T> {
190    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
191        Self {
192            client,
193            model: model.into(),
194            ndims,
195        }
196    }
197
198    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
199        Self {
200            client,
201            model: model.into(),
202            ndims,
203        }
204    }
205}
206
207impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
208where
209    T: HttpClientExt + Clone + 'static,
210{
211    type Client = Client<T>;
212
213    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
214        let model = model.into();
215        let dims = dims
216            .or(model_dimensions_from_identifier(&model))
217            .unwrap_or_default();
218        Self::new(client.clone(), model, dims)
219    }
220
221    const MAX_DOCUMENTS: usize = 1024;
222    fn ndims(&self) -> usize {
223        self.ndims
224    }
225
226    async fn embed_texts(
227        &self,
228        documents: impl IntoIterator<Item = String>,
229    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
230        let docs: Vec<String> = documents.into_iter().collect();
231
232        let body = serde_json::to_vec(&json!({
233            "model": self.model,
234            "input": docs
235        }))?;
236
237        let req = self
238            .client
239            .post("api/embed")?
240            .body(body)
241            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
242
243        let response = self.client.send(req).await?;
244
245        if !response.status().is_success() {
246            let text = http_client::text(response).await?;
247            return Err(EmbeddingError::ProviderError(text));
248        }
249
250        let bytes: Vec<u8> = response.into_body().await?;
251
252        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
253
254        if api_resp.embeddings.len() != docs.len() {
255            return Err(EmbeddingError::ResponseError(
256                "Number of returned embeddings does not match input".into(),
257            ));
258        }
259        Ok(api_resp
260            .embeddings
261            .into_iter()
262            .zip(docs.into_iter())
263            .map(|(vec, document)| embeddings::Embedding { document, vec })
264            .collect())
265    }
266}
267
268// ---------- Completion API ----------
269
270pub const LLAMA3_2: &str = "llama3.2";
271pub const LLAVA: &str = "llava";
272pub const MISTRAL: &str = "mistral";
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct CompletionResponse {
276    pub model: String,
277    pub created_at: String,
278    pub message: Message,
279    pub done: bool,
280    #[serde(default)]
281    pub done_reason: Option<String>,
282    #[serde(default)]
283    pub total_duration: Option<u64>,
284    #[serde(default)]
285    pub load_duration: Option<u64>,
286    #[serde(default)]
287    pub prompt_eval_count: Option<u64>,
288    #[serde(default)]
289    pub prompt_eval_duration: Option<u64>,
290    #[serde(default)]
291    pub eval_count: Option<u64>,
292    #[serde(default)]
293    pub eval_duration: Option<u64>,
294}
295impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
296    type Error = CompletionError;
297    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
298        match resp.message {
299            // Process only if an assistant message is present.
300            Message::Assistant {
301                content,
302                thinking,
303                tool_calls,
304                ..
305            } => {
306                let mut assistant_contents = Vec::new();
307                // Add the assistant's text content if any.
308                if !content.is_empty() {
309                    assistant_contents.push(completion::AssistantContent::text(&content));
310                }
311                // Process tool_calls following Ollama's chat response definition.
312                // Each ToolCall has an id, a type, and a function field.
313                for tc in tool_calls.iter() {
314                    assistant_contents.push(completion::AssistantContent::tool_call(
315                        tc.function.name.clone(),
316                        tc.function.name.clone(),
317                        tc.function.arguments.clone(),
318                    ));
319                }
320                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
321                    CompletionError::ResponseError("No content provided".to_owned())
322                })?;
323                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
324                let completion_tokens = resp.eval_count.unwrap_or(0);
325
326                let raw_response = CompletionResponse {
327                    model: resp.model,
328                    created_at: resp.created_at,
329                    done: resp.done,
330                    done_reason: resp.done_reason,
331                    total_duration: resp.total_duration,
332                    load_duration: resp.load_duration,
333                    prompt_eval_count: resp.prompt_eval_count,
334                    prompt_eval_duration: resp.prompt_eval_duration,
335                    eval_count: resp.eval_count,
336                    eval_duration: resp.eval_duration,
337                    message: Message::Assistant {
338                        content,
339                        thinking,
340                        images: None,
341                        name: None,
342                        tool_calls,
343                    },
344                };
345
346                Ok(completion::CompletionResponse {
347                    choice,
348                    usage: Usage {
349                        input_tokens: prompt_tokens,
350                        output_tokens: completion_tokens,
351                        total_tokens: prompt_tokens + completion_tokens,
352                        cached_input_tokens: 0,
353                    },
354                    raw_response,
355                    message_id: None,
356                })
357            }
358            _ => Err(CompletionError::ResponseError(
359                "Chat response does not include an assistant message".into(),
360            )),
361        }
362    }
363}
364
365#[derive(Debug, Serialize, Deserialize)]
366pub(super) struct OllamaCompletionRequest {
367    model: String,
368    pub messages: Vec<Message>,
369    #[serde(skip_serializing_if = "Option::is_none")]
370    temperature: Option<f64>,
371    #[serde(skip_serializing_if = "Vec::is_empty")]
372    tools: Vec<ToolDefinition>,
373    pub stream: bool,
374    think: bool,
375    #[serde(skip_serializing_if = "Option::is_none")]
376    max_tokens: Option<u64>,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    keep_alive: Option<String>,
379    #[serde(skip_serializing_if = "Option::is_none")]
380    format: Option<schemars::Schema>,
381    options: serde_json::Value,
382}
383
384impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
385    type Error = CompletionError;
386
387    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
388        let model = req.model.clone().unwrap_or_else(|| model.to_string());
389        if req.tool_choice.is_some() {
390            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
391        }
392        // Build up the order of messages (context, chat_history, prompt)
393        let mut partial_history = vec![];
394        if let Some(docs) = req.normalized_documents() {
395            partial_history.push(docs);
396        }
397        partial_history.extend(req.chat_history);
398
399        // Add preamble to chat history (if available)
400        let mut full_history: Vec<Message> = match &req.preamble {
401            Some(preamble) => vec![Message::system(preamble)],
402            None => vec![],
403        };
404
405        // Convert and extend the rest of the history
406        full_history.extend(
407            partial_history
408                .into_iter()
409                .map(message::Message::try_into)
410                .collect::<Result<Vec<Vec<Message>>, _>>()?
411                .into_iter()
412                .flatten()
413                .collect::<Vec<_>>(),
414        );
415
416        let mut think = false;
417        let mut keep_alive: Option<String> = None;
418
419        let options = if let Some(mut extra) = req.additional_params {
420            // Extract top-level parameters that should not be in `options`
421            if let Some(obj) = extra.as_object_mut() {
422                // Extract `think` parameter
423                if let Some(think_val) = obj.remove("think") {
424                    think = think_val.as_bool().ok_or_else(|| {
425                        CompletionError::RequestError("`think` must be a bool".into())
426                    })?;
427                }
428
429                // Extract `keep_alive` parameter
430                if let Some(keep_alive_val) = obj.remove("keep_alive") {
431                    keep_alive = Some(
432                        keep_alive_val
433                            .as_str()
434                            .ok_or_else(|| {
435                                CompletionError::RequestError(
436                                    "`keep_alive` must be a string".into(),
437                                )
438                            })?
439                            .to_string(),
440                    );
441                }
442            }
443
444            json_utils::merge(json!({ "temperature": req.temperature }), extra)
445        } else {
446            json!({ "temperature": req.temperature })
447        };
448
449        Ok(Self {
450            model: model.to_string(),
451            messages: full_history,
452            temperature: req.temperature,
453            max_tokens: req.max_tokens,
454            stream: false,
455            think,
456            keep_alive,
457            format: req.output_schema,
458            tools: req
459                .tools
460                .clone()
461                .into_iter()
462                .map(ToolDefinition::from)
463                .collect::<Vec<_>>(),
464            options,
465        })
466    }
467}
468
469#[derive(Clone)]
470pub struct CompletionModel<T = reqwest::Client> {
471    client: Client<T>,
472    pub model: String,
473}
474
475impl<T> CompletionModel<T> {
476    pub fn new(client: Client<T>, model: &str) -> Self {
477        Self {
478            client,
479            model: model.to_owned(),
480        }
481    }
482}
483
484// ---------- CompletionModel Implementation ----------
485
486#[derive(Clone, Serialize, Deserialize, Debug)]
487pub struct StreamingCompletionResponse {
488    pub done_reason: Option<String>,
489    pub total_duration: Option<u64>,
490    pub load_duration: Option<u64>,
491    pub prompt_eval_count: Option<u64>,
492    pub prompt_eval_duration: Option<u64>,
493    pub eval_count: Option<u64>,
494    pub eval_duration: Option<u64>,
495}
496
497impl GetTokenUsage for StreamingCompletionResponse {
498    fn token_usage(&self) -> Option<crate::completion::Usage> {
499        let mut usage = crate::completion::Usage::new();
500        let input_tokens = self.prompt_eval_count.unwrap_or_default();
501        let output_tokens = self.eval_count.unwrap_or_default();
502        usage.input_tokens = input_tokens;
503        usage.output_tokens = output_tokens;
504        usage.total_tokens = input_tokens + output_tokens;
505
506        Some(usage)
507    }
508}
509
510impl<T> completion::CompletionModel for CompletionModel<T>
511where
512    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
513{
514    type Response = CompletionResponse;
515    type StreamingResponse = StreamingCompletionResponse;
516
517    type Client = Client<T>;
518
519    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
520        Self::new(client.clone(), model.into().as_str())
521    }
522
523    async fn completion(
524        &self,
525        completion_request: CompletionRequest,
526    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
527        let span = if tracing::Span::current().is_disabled() {
528            info_span!(
529                target: "rig::completions",
530                "chat",
531                gen_ai.operation.name = "chat",
532                gen_ai.provider.name = "ollama",
533                gen_ai.request.model = self.model,
534                gen_ai.system_instructions = tracing::field::Empty,
535                gen_ai.response.id = tracing::field::Empty,
536                gen_ai.response.model = tracing::field::Empty,
537                gen_ai.usage.output_tokens = tracing::field::Empty,
538                gen_ai.usage.input_tokens = tracing::field::Empty,
539                gen_ai.usage.cached_tokens = tracing::field::Empty,
540            )
541        } else {
542            tracing::Span::current()
543        };
544
545        span.record("gen_ai.system_instructions", &completion_request.preamble);
546        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
547
548        if tracing::enabled!(tracing::Level::TRACE) {
549            tracing::trace!(target: "rig::completions",
550                "Ollama completion request: {}",
551                serde_json::to_string_pretty(&request)?
552            );
553        }
554
555        let body = serde_json::to_vec(&request)?;
556
557        let req = self
558            .client
559            .post("api/chat")?
560            .body(body)
561            .map_err(http_client::Error::from)?;
562
563        let async_block = async move {
564            let response = self.client.send::<_, Bytes>(req).await?;
565            let status = response.status();
566            let response_body = response.into_body().into_future().await?.to_vec();
567
568            if !status.is_success() {
569                return Err(CompletionError::ProviderError(
570                    String::from_utf8_lossy(&response_body).to_string(),
571                ));
572            }
573
574            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
575            let span = tracing::Span::current();
576            span.record("gen_ai.response.model_name", &response.model);
577            span.record(
578                "gen_ai.usage.input_tokens",
579                response.prompt_eval_count.unwrap_or_default(),
580            );
581            span.record(
582                "gen_ai.usage.output_tokens",
583                response.eval_count.unwrap_or_default(),
584            );
585
586            if tracing::enabled!(tracing::Level::TRACE) {
587                tracing::trace!(target: "rig::completions",
588                    "Ollama completion response: {}",
589                    serde_json::to_string_pretty(&response)?
590                );
591            }
592
593            let response: completion::CompletionResponse<CompletionResponse> =
594                response.try_into()?;
595
596            Ok(response)
597        };
598
599        tracing::Instrument::instrument(async_block, span).await
600    }
601
602    async fn stream(
603        &self,
604        request: CompletionRequest,
605    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
606    {
607        let span = if tracing::Span::current().is_disabled() {
608            info_span!(
609                target: "rig::completions",
610                "chat_streaming",
611                gen_ai.operation.name = "chat_streaming",
612                gen_ai.provider.name = "ollama",
613                gen_ai.request.model = self.model,
614                gen_ai.system_instructions = tracing::field::Empty,
615                gen_ai.response.id = tracing::field::Empty,
616                gen_ai.response.model = self.model,
617                gen_ai.usage.output_tokens = tracing::field::Empty,
618                gen_ai.usage.input_tokens = tracing::field::Empty,
619                gen_ai.usage.cached_tokens = tracing::field::Empty,
620            )
621        } else {
622            tracing::Span::current()
623        };
624
625        span.record("gen_ai.system_instructions", &request.preamble);
626
627        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
628        request.stream = true;
629
630        if tracing::enabled!(tracing::Level::TRACE) {
631            tracing::trace!(target: "rig::completions",
632                "Ollama streaming completion request: {}",
633                serde_json::to_string_pretty(&request)?
634            );
635        }
636
637        let body = serde_json::to_vec(&request)?;
638
639        let req = self
640            .client
641            .post("api/chat")?
642            .body(body)
643            .map_err(http_client::Error::from)?;
644
645        let response = self.client.send_streaming(req).await?;
646        let status = response.status();
647        let mut byte_stream = response.into_body();
648
649        if !status.is_success() {
650            return Err(CompletionError::ProviderError(format!(
651                "Got error status code trying to send a request to Ollama: {status}"
652            )));
653        }
654
655        let stream = try_stream! {
656            let span = tracing::Span::current();
657            let mut tool_calls_final = Vec::new();
658            let mut text_response = String::new();
659            let mut thinking_response = String::new();
660
661            while let Some(chunk) = byte_stream.next().await {
662                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
663
664                for line in bytes.split(|&b| b == b'\n') {
665                    if line.is_empty() {
666                        continue;
667                    }
668
669                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
670
671                    let response: CompletionResponse = serde_json::from_slice(line)?;
672
673                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
674                        if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
675                            thinking_response += &thinking_content;
676                            yield RawStreamingChoice::ReasoningDelta {
677                                id: None,
678                                reasoning: thinking_content,
679                            };
680                        }
681
682                        if !content.is_empty() {
683                            text_response += &content;
684                            yield RawStreamingChoice::Message(content);
685                        }
686
687                        for tool_call in tool_calls {
688                            tool_calls_final.push(tool_call.clone());
689                            yield RawStreamingChoice::ToolCall(
690                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
691                            );
692                        }
693                    }
694
695                    if response.done {
696                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
697                        span.record("gen_ai.usage.output_tokens", response.eval_count);
698                        let message = Message::Assistant {
699                            content: text_response.clone(),
700                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
701                            images: None,
702                            name: None,
703                            tool_calls: tool_calls_final.clone()
704                        };
705                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
706                        yield RawStreamingChoice::FinalResponse(
707                            StreamingCompletionResponse {
708                                total_duration: response.total_duration,
709                                load_duration: response.load_duration,
710                                prompt_eval_count: response.prompt_eval_count,
711                                prompt_eval_duration: response.prompt_eval_duration,
712                                eval_count: response.eval_count,
713                                eval_duration: response.eval_duration,
714                                done_reason: response.done_reason,
715                            }
716                        );
717                        break;
718                    }
719                }
720            }
721        }.instrument(span);
722
723        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
724            stream,
725        )))
726    }
727}
728
729// ---------- Tool Definition Conversion ----------
730
731/// Ollama-required tool definition format.
732#[derive(Clone, Debug, Deserialize, Serialize)]
733pub struct ToolDefinition {
734    #[serde(rename = "type")]
735    pub type_field: String, // Fixed as "function"
736    pub function: completion::ToolDefinition,
737}
738
739/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
740impl From<crate::completion::ToolDefinition> for ToolDefinition {
741    fn from(tool: crate::completion::ToolDefinition) -> Self {
742        ToolDefinition {
743            type_field: "function".to_owned(),
744            function: completion::ToolDefinition {
745                name: tool.name,
746                description: tool.description,
747                parameters: tool.parameters,
748            },
749        }
750    }
751}
752
753#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
754pub struct ToolCall {
755    #[serde(default, rename = "type")]
756    pub r#type: ToolType,
757    pub function: Function,
758}
759#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
760#[serde(rename_all = "lowercase")]
761pub enum ToolType {
762    #[default]
763    Function,
764}
765#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
766pub struct Function {
767    pub name: String,
768    pub arguments: Value,
769}
770
771// ---------- Provider Message Definition ----------
772
773#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
774#[serde(tag = "role", rename_all = "lowercase")]
775pub enum Message {
776    User {
777        content: String,
778        #[serde(skip_serializing_if = "Option::is_none")]
779        images: Option<Vec<String>>,
780        #[serde(skip_serializing_if = "Option::is_none")]
781        name: Option<String>,
782    },
783    Assistant {
784        #[serde(default)]
785        content: String,
786        #[serde(skip_serializing_if = "Option::is_none")]
787        thinking: Option<String>,
788        #[serde(skip_serializing_if = "Option::is_none")]
789        images: Option<Vec<String>>,
790        #[serde(skip_serializing_if = "Option::is_none")]
791        name: Option<String>,
792        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
793        tool_calls: Vec<ToolCall>,
794    },
795    System {
796        content: String,
797        #[serde(skip_serializing_if = "Option::is_none")]
798        images: Option<Vec<String>>,
799        #[serde(skip_serializing_if = "Option::is_none")]
800        name: Option<String>,
801    },
802    #[serde(rename = "tool")]
803    ToolResult {
804        #[serde(rename = "tool_name")]
805        name: String,
806        content: String,
807    },
808}
809
810/// -----------------------------
811/// Provider Message Conversions
812/// -----------------------------
813/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
814/// (Only User and Assistant variants are supported.)
815impl TryFrom<crate::message::Message> for Vec<Message> {
816    type Error = crate::message::MessageError;
817    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
818        use crate::message::Message as InternalMessage;
819        match internal_msg {
820            InternalMessage::System { content } => Ok(vec![Message::System {
821                content,
822                images: None,
823                name: None,
824            }]),
825            InternalMessage::User { content, .. } => {
826                let (tool_results, other_content): (Vec<_>, Vec<_>) =
827                    content.into_iter().partition(|content| {
828                        matches!(content, crate::message::UserContent::ToolResult(_))
829                    });
830
831                if !tool_results.is_empty() {
832                    tool_results
833                        .into_iter()
834                        .map(|content| match content {
835                            crate::message::UserContent::ToolResult(
836                                crate::message::ToolResult { id, content, .. },
837                            ) => {
838                                // Ollama expects a single string for tool results, so we concatenate
839                                let content_string = content
840                                    .into_iter()
841                                    .map(|content| match content {
842                                        crate::message::ToolResultContent::Text(text) => text.text,
843                                        _ => "[Non-text content]".to_string(),
844                                    })
845                                    .collect::<Vec<_>>()
846                                    .join("\n");
847
848                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
849                                    name: id,
850                                    content: content_string,
851                                })
852                            }
853                            _ => unreachable!(),
854                        })
855                        .collect::<Result<Vec<_>, _>>()
856                } else {
857                    // Ollama requires separate text content and images array
858                    let (texts, images) = other_content.into_iter().fold(
859                        (Vec::new(), Vec::new()),
860                        |(mut texts, mut images), content| {
861                            match content {
862                                crate::message::UserContent::Text(crate::message::Text {
863                                    text,
864                                }) => texts.push(text),
865                                crate::message::UserContent::Image(crate::message::Image {
866                                    data: DocumentSourceKind::Base64(data),
867                                    ..
868                                }) => images.push(data),
869                                crate::message::UserContent::Document(
870                                    crate::message::Document {
871                                        data:
872                                            DocumentSourceKind::Base64(data)
873                                            | DocumentSourceKind::String(data),
874                                        ..
875                                    },
876                                ) => texts.push(data),
877                                _ => {} // Audio not supported by Ollama
878                            }
879                            (texts, images)
880                        },
881                    );
882
883                    Ok(vec![Message::User {
884                        content: texts.join(" "),
885                        images: if images.is_empty() {
886                            None
887                        } else {
888                            Some(
889                                images
890                                    .into_iter()
891                                    .map(|x| x.to_string())
892                                    .collect::<Vec<String>>(),
893                            )
894                        },
895                        name: None,
896                    }])
897                }
898            }
899            InternalMessage::Assistant { content, .. } => {
900                let mut thinking: Option<String> = None;
901                let mut text_content = Vec::new();
902                let mut tool_calls = Vec::new();
903
904                for content in content.into_iter() {
905                    match content {
906                        crate::message::AssistantContent::Text(text) => {
907                            text_content.push(text.text)
908                        }
909                        crate::message::AssistantContent::ToolCall(tool_call) => {
910                            tool_calls.push(tool_call)
911                        }
912                        crate::message::AssistantContent::Reasoning(reasoning) => {
913                            let display = reasoning.display_text();
914                            if !display.is_empty() {
915                                thinking = Some(display);
916                            }
917                        }
918                        crate::message::AssistantContent::Image(_) => {
919                            return Err(crate::message::MessageError::ConversionError(
920                                "Ollama currently doesn't support images.".into(),
921                            ));
922                        }
923                    }
924                }
925
926                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
927                //  so either `content` or `tool_calls` will have some content.
928                Ok(vec![Message::Assistant {
929                    content: text_content.join(" "),
930                    thinking,
931                    images: None,
932                    name: None,
933                    tool_calls: tool_calls
934                        .into_iter()
935                        .map(|tool_call| tool_call.into())
936                        .collect::<Vec<_>>(),
937                }])
938            }
939        }
940    }
941}
942
943/// Conversion from provider Message to a completion message.
944/// This is needed so that responses can be converted back into chat history.
945impl From<Message> for crate::completion::Message {
946    fn from(msg: Message) -> Self {
947        match msg {
948            Message::User { content, .. } => crate::completion::Message::User {
949                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
950                    text: content,
951                })),
952            },
953            Message::Assistant {
954                content,
955                tool_calls,
956                ..
957            } => {
958                let mut assistant_contents =
959                    vec![crate::completion::message::AssistantContent::Text(Text {
960                        text: content,
961                    })];
962                for tc in tool_calls {
963                    assistant_contents.push(
964                        crate::completion::message::AssistantContent::tool_call(
965                            tc.function.name.clone(),
966                            tc.function.name,
967                            tc.function.arguments,
968                        ),
969                    );
970                }
971                crate::completion::Message::Assistant {
972                    id: None,
973                    content: OneOrMany::many(assistant_contents).unwrap(),
974                }
975            }
976            // System and ToolResult are converted to User message as needed.
977            Message::System { content, .. } => crate::completion::Message::User {
978                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
979                    text: content,
980                })),
981            },
982            Message::ToolResult { name, content } => crate::completion::Message::User {
983                content: OneOrMany::one(message::UserContent::tool_result(
984                    name,
985                    OneOrMany::one(message::ToolResultContent::text(content)),
986                )),
987            },
988        }
989    }
990}
991
992impl Message {
993    /// Constructs a system message.
994    pub fn system(content: &str) -> Self {
995        Message::System {
996            content: content.to_owned(),
997            images: None,
998            name: None,
999        }
1000    }
1001}
1002
1003// ---------- Additional Message Types ----------
1004
1005impl From<crate::message::ToolCall> for ToolCall {
1006    fn from(tool_call: crate::message::ToolCall) -> Self {
1007        Self {
1008            r#type: ToolType::Function,
1009            function: Function {
1010                name: tool_call.function.name,
1011                arguments: tool_call.function.arguments,
1012            },
1013        }
1014    }
1015}
1016
1017#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1018pub struct SystemContent {
1019    #[serde(default)]
1020    r#type: SystemContentType,
1021    text: String,
1022}
1023
1024#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1025#[serde(rename_all = "lowercase")]
1026pub enum SystemContentType {
1027    #[default]
1028    Text,
1029}
1030
1031impl From<String> for SystemContent {
1032    fn from(s: String) -> Self {
1033        SystemContent {
1034            r#type: SystemContentType::default(),
1035            text: s,
1036        }
1037    }
1038}
1039
1040impl FromStr for SystemContent {
1041    type Err = std::convert::Infallible;
1042    fn from_str(s: &str) -> Result<Self, Self::Err> {
1043        Ok(SystemContent {
1044            r#type: SystemContentType::default(),
1045            text: s.to_string(),
1046        })
1047    }
1048}
1049
1050#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1051pub struct AssistantContent {
1052    pub text: String,
1053}
1054
1055impl FromStr for AssistantContent {
1056    type Err = std::convert::Infallible;
1057    fn from_str(s: &str) -> Result<Self, Self::Err> {
1058        Ok(AssistantContent { text: s.to_owned() })
1059    }
1060}
1061
1062#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1063#[serde(tag = "type", rename_all = "lowercase")]
1064pub enum UserContent {
1065    Text { text: String },
1066    Image { image_url: ImageUrl },
1067    // Audio variant removed as Ollama API does not support audio input.
1068}
1069
1070impl FromStr for UserContent {
1071    type Err = std::convert::Infallible;
1072    fn from_str(s: &str) -> Result<Self, Self::Err> {
1073        Ok(UserContent::Text { text: s.to_owned() })
1074    }
1075}
1076
1077#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1078pub struct ImageUrl {
1079    pub url: String,
1080    #[serde(default)]
1081    pub detail: ImageDetail,
1082}
1083
1084// =================================================================
1085// Tests
1086// =================================================================
1087
1088#[cfg(test)]
1089mod tests {
1090    use super::*;
1091    use serde_json::json;
1092
1093    // Test deserialization and conversion for the /api/chat endpoint.
1094    #[tokio::test]
1095    async fn test_chat_completion() {
1096        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1097        let sample_chat_response = json!({
1098            "model": "llama3.2",
1099            "created_at": "2023-08-04T19:22:45.499127Z",
1100            "message": {
1101                "role": "assistant",
1102                "content": "The sky is blue because of Rayleigh scattering.",
1103                "images": null,
1104                "tool_calls": [
1105                    {
1106                        "type": "function",
1107                        "function": {
1108                            "name": "get_current_weather",
1109                            "arguments": {
1110                                "location": "San Francisco, CA",
1111                                "format": "celsius"
1112                            }
1113                        }
1114                    }
1115                ]
1116            },
1117            "done": true,
1118            "total_duration": 8000000000u64,
1119            "load_duration": 6000000u64,
1120            "prompt_eval_count": 61u64,
1121            "prompt_eval_duration": 400000000u64,
1122            "eval_count": 468u64,
1123            "eval_duration": 7700000000u64
1124        });
1125        let sample_text = sample_chat_response.to_string();
1126
1127        let chat_resp: CompletionResponse =
1128            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1129        let conv: completion::CompletionResponse<CompletionResponse> =
1130            chat_resp.try_into().unwrap();
1131        assert!(
1132            !conv.choice.is_empty(),
1133            "Expected non-empty choice in chat response"
1134        );
1135    }
1136
1137    // Test conversion from provider Message to completion Message.
1138    #[test]
1139    fn test_message_conversion() {
1140        // Construct a provider Message (User variant with String content).
1141        let provider_msg = Message::User {
1142            content: "Test message".to_owned(),
1143            images: None,
1144            name: None,
1145        };
1146        // Convert it into a completion::Message.
1147        let comp_msg: crate::completion::Message = provider_msg.into();
1148        match comp_msg {
1149            crate::completion::Message::User { content } => {
1150                // Assume OneOrMany<T> has a method first() to access the first element.
1151                let first_content = content.first();
1152                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1153                match first_content {
1154                    crate::completion::message::UserContent::Text(text_struct) => {
1155                        assert_eq!(text_struct.text, "Test message");
1156                    }
1157                    _ => panic!("Expected text content in conversion"),
1158                }
1159            }
1160            _ => panic!("Conversion from provider Message to completion Message failed"),
1161        }
1162    }
1163
1164    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1165    #[test]
1166    fn test_tool_definition_conversion() {
1167        // Internal tool definition from the completion module.
1168        let internal_tool = crate::completion::ToolDefinition {
1169            name: "get_current_weather".to_owned(),
1170            description: "Get the current weather for a location".to_owned(),
1171            parameters: json!({
1172                "type": "object",
1173                "properties": {
1174                    "location": {
1175                        "type": "string",
1176                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1177                    },
1178                    "format": {
1179                        "type": "string",
1180                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1181                        "enum": ["celsius", "fahrenheit"]
1182                    }
1183                },
1184                "required": ["location", "format"]
1185            }),
1186        };
1187        // Convert internal tool to Ollama's tool definition.
1188        let ollama_tool: ToolDefinition = internal_tool.into();
1189        assert_eq!(ollama_tool.type_field, "function");
1190        assert_eq!(ollama_tool.function.name, "get_current_weather");
1191        assert_eq!(
1192            ollama_tool.function.description,
1193            "Get the current weather for a location"
1194        );
1195        // Check JSON fields in parameters.
1196        let params = &ollama_tool.function.parameters;
1197        assert_eq!(params["properties"]["location"]["type"], "string");
1198    }
1199
1200    // Test deserialization of chat response with thinking content
1201    #[tokio::test]
1202    async fn test_chat_completion_with_thinking() {
1203        let sample_response = json!({
1204            "model": "qwen-thinking",
1205            "created_at": "2023-08-04T19:22:45.499127Z",
1206            "message": {
1207                "role": "assistant",
1208                "content": "The answer is 42.",
1209                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1210                "images": null,
1211                "tool_calls": []
1212            },
1213            "done": true,
1214            "total_duration": 8000000000u64,
1215            "load_duration": 6000000u64,
1216            "prompt_eval_count": 61u64,
1217            "prompt_eval_duration": 400000000u64,
1218            "eval_count": 468u64,
1219            "eval_duration": 7700000000u64
1220        });
1221
1222        let chat_resp: CompletionResponse =
1223            serde_json::from_value(sample_response).expect("Failed to deserialize");
1224
1225        // Verify thinking field is present
1226        if let Message::Assistant {
1227            thinking, content, ..
1228        } = &chat_resp.message
1229        {
1230            assert_eq!(
1231                thinking.as_ref().unwrap(),
1232                "Let me think about this carefully. The question asks for the meaning of life..."
1233            );
1234            assert_eq!(content, "The answer is 42.");
1235        } else {
1236            panic!("Expected Assistant message");
1237        }
1238    }
1239
1240    // Test deserialization of chat response without thinking content
1241    #[tokio::test]
1242    async fn test_chat_completion_without_thinking() {
1243        let sample_response = json!({
1244            "model": "llama3.2",
1245            "created_at": "2023-08-04T19:22:45.499127Z",
1246            "message": {
1247                "role": "assistant",
1248                "content": "Hello!",
1249                "images": null,
1250                "tool_calls": []
1251            },
1252            "done": true,
1253            "total_duration": 8000000000u64,
1254            "load_duration": 6000000u64,
1255            "prompt_eval_count": 10u64,
1256            "prompt_eval_duration": 400000000u64,
1257            "eval_count": 5u64,
1258            "eval_duration": 7700000000u64
1259        });
1260
1261        let chat_resp: CompletionResponse =
1262            serde_json::from_value(sample_response).expect("Failed to deserialize");
1263
1264        // Verify thinking field is None when not provided
1265        if let Message::Assistant {
1266            thinking, content, ..
1267        } = &chat_resp.message
1268        {
1269            assert!(thinking.is_none());
1270            assert_eq!(content, "Hello!");
1271        } else {
1272            panic!("Expected Assistant message");
1273        }
1274    }
1275
1276    // Test deserialization of streaming response with thinking content
1277    #[test]
1278    fn test_streaming_response_with_thinking() {
1279        let sample_chunk = json!({
1280            "model": "qwen-thinking",
1281            "created_at": "2023-08-04T19:22:45.499127Z",
1282            "message": {
1283                "role": "assistant",
1284                "content": "",
1285                "thinking": "Analyzing the problem...",
1286                "images": null,
1287                "tool_calls": []
1288            },
1289            "done": false
1290        });
1291
1292        let chunk: CompletionResponse =
1293            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1294
1295        if let Message::Assistant {
1296            thinking, content, ..
1297        } = &chunk.message
1298        {
1299            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1300            assert_eq!(content, "");
1301        } else {
1302            panic!("Expected Assistant message");
1303        }
1304    }
1305
1306    // Test message conversion with thinking content
1307    #[test]
1308    fn test_message_conversion_with_thinking() {
1309        // Create an internal message with reasoning content
1310        let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1311
1312        let internal_msg = crate::message::Message::Assistant {
1313            id: None,
1314            content: crate::OneOrMany::many(vec![
1315                crate::message::AssistantContent::Reasoning(reasoning_content),
1316                crate::message::AssistantContent::Text(crate::message::Text {
1317                    text: "The answer is X".to_string(),
1318                }),
1319            ])
1320            .unwrap(),
1321        };
1322
1323        // Convert to provider Message
1324        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1325        assert_eq!(provider_msgs.len(), 1);
1326
1327        if let Message::Assistant {
1328            thinking, content, ..
1329        } = &provider_msgs[0]
1330        {
1331            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1332            assert_eq!(content, "The answer is X");
1333        } else {
1334            panic!("Expected Assistant message with thinking");
1335        }
1336    }
1337
1338    // Test empty thinking content is handled correctly
1339    #[test]
1340    fn test_empty_thinking_content() {
1341        let sample_response = json!({
1342            "model": "llama3.2",
1343            "created_at": "2023-08-04T19:22:45.499127Z",
1344            "message": {
1345                "role": "assistant",
1346                "content": "Response",
1347                "thinking": "",
1348                "images": null,
1349                "tool_calls": []
1350            },
1351            "done": true,
1352            "total_duration": 8000000000u64,
1353            "load_duration": 6000000u64,
1354            "prompt_eval_count": 10u64,
1355            "prompt_eval_duration": 400000000u64,
1356            "eval_count": 5u64,
1357            "eval_duration": 7700000000u64
1358        });
1359
1360        let chat_resp: CompletionResponse =
1361            serde_json::from_value(sample_response).expect("Failed to deserialize");
1362
1363        if let Message::Assistant {
1364            thinking, content, ..
1365        } = &chat_resp.message
1366        {
1367            // Empty string should still deserialize as Some("")
1368            assert_eq!(thinking.as_ref().unwrap(), "");
1369            assert_eq!(content, "Response");
1370        } else {
1371            panic!("Expected Assistant message");
1372        }
1373    }
1374
1375    // Test thinking with tool calls
1376    #[test]
1377    fn test_thinking_with_tool_calls() {
1378        let sample_response = json!({
1379            "model": "qwen-thinking",
1380            "created_at": "2023-08-04T19:22:45.499127Z",
1381            "message": {
1382                "role": "assistant",
1383                "content": "Let me check the weather.",
1384                "thinking": "User wants weather info, I should use the weather tool",
1385                "images": null,
1386                "tool_calls": [
1387                    {
1388                        "type": "function",
1389                        "function": {
1390                            "name": "get_weather",
1391                            "arguments": {
1392                                "location": "San Francisco"
1393                            }
1394                        }
1395                    }
1396                ]
1397            },
1398            "done": true,
1399            "total_duration": 8000000000u64,
1400            "load_duration": 6000000u64,
1401            "prompt_eval_count": 30u64,
1402            "prompt_eval_duration": 400000000u64,
1403            "eval_count": 50u64,
1404            "eval_duration": 7700000000u64
1405        });
1406
1407        let chat_resp: CompletionResponse =
1408            serde_json::from_value(sample_response).expect("Failed to deserialize");
1409
1410        if let Message::Assistant {
1411            thinking,
1412            content,
1413            tool_calls,
1414            ..
1415        } = &chat_resp.message
1416        {
1417            assert_eq!(
1418                thinking.as_ref().unwrap(),
1419                "User wants weather info, I should use the weather tool"
1420            );
1421            assert_eq!(content, "Let me check the weather.");
1422            assert_eq!(tool_calls.len(), 1);
1423            assert_eq!(tool_calls[0].function.name, "get_weather");
1424        } else {
1425            panic!("Expected Assistant message with thinking and tool calls");
1426        }
1427    }
1428
1429    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1430    #[test]
1431    fn test_completion_request_with_think_param() {
1432        use crate::OneOrMany;
1433        use crate::completion::Message as CompletionMessage;
1434        use crate::message::{Text, UserContent};
1435
1436        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1437        let completion_request = CompletionRequest {
1438            model: None,
1439            preamble: Some("You are a helpful assistant.".to_string()),
1440            chat_history: OneOrMany::one(CompletionMessage::User {
1441                content: OneOrMany::one(UserContent::Text(Text {
1442                    text: "What is 2 + 2?".to_string(),
1443                })),
1444            }),
1445            documents: vec![],
1446            tools: vec![],
1447            temperature: Some(0.7),
1448            max_tokens: Some(1024),
1449            tool_choice: None,
1450            additional_params: Some(json!({
1451                "think": true,
1452                "keep_alive": "-1m",
1453                "num_ctx": 4096
1454            })),
1455            output_schema: None,
1456        };
1457
1458        // Convert to OllamaCompletionRequest
1459        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1460            .expect("Failed to create Ollama request");
1461
1462        // Serialize to JSON
1463        let serialized =
1464            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1465
1466        // Assert equality with expected JSON
1467        // - "tools" is skipped when empty (skip_serializing_if)
1468        // - "think" should be a top-level boolean, NOT in options
1469        // - "keep_alive" should be a top-level string, NOT in options
1470        // - "num_ctx" should be in options (it's a model parameter)
1471        let expected = json!({
1472            "model": "qwen3:8b",
1473            "messages": [
1474                {
1475                    "role": "system",
1476                    "content": "You are a helpful assistant."
1477                },
1478                {
1479                    "role": "user",
1480                    "content": "What is 2 + 2?"
1481                }
1482            ],
1483            "temperature": 0.7,
1484            "stream": false,
1485            "think": true,
1486            "max_tokens": 1024,
1487            "keep_alive": "-1m",
1488            "options": {
1489                "temperature": 0.7,
1490                "num_ctx": 4096
1491            }
1492        });
1493
1494        assert_eq!(serialized, expected);
1495    }
1496
1497    // Test that `think` defaults to false when not specified
1498    #[test]
1499    fn test_completion_request_with_think_false_default() {
1500        use crate::OneOrMany;
1501        use crate::completion::Message as CompletionMessage;
1502        use crate::message::{Text, UserContent};
1503
1504        // Create a CompletionRequest WITHOUT "think" in additional_params
1505        let completion_request = CompletionRequest {
1506            model: None,
1507            preamble: Some("You are a helpful assistant.".to_string()),
1508            chat_history: OneOrMany::one(CompletionMessage::User {
1509                content: OneOrMany::one(UserContent::Text(Text {
1510                    text: "Hello!".to_string(),
1511                })),
1512            }),
1513            documents: vec![],
1514            tools: vec![],
1515            temperature: Some(0.5),
1516            max_tokens: None,
1517            tool_choice: None,
1518            additional_params: None,
1519            output_schema: None,
1520        };
1521
1522        // Convert to OllamaCompletionRequest
1523        let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1524            .expect("Failed to create Ollama request");
1525
1526        // Serialize to JSON
1527        let serialized =
1528            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1529
1530        // Assert that "think" defaults to false and "keep_alive" is not present
1531        let expected = json!({
1532            "model": "llama3.2",
1533            "messages": [
1534                {
1535                    "role": "system",
1536                    "content": "You are a helpful assistant."
1537                },
1538                {
1539                    "role": "user",
1540                    "content": "Hello!"
1541                }
1542            ],
1543            "temperature": 0.5,
1544            "stream": false,
1545            "think": false,
1546            "options": {
1547                "temperature": 0.5
1548            }
1549        });
1550
1551        assert_eq!(serialized, expected);
1552    }
1553
1554    #[test]
1555    fn test_completion_request_with_output_schema() {
1556        use crate::OneOrMany;
1557        use crate::completion::Message as CompletionMessage;
1558        use crate::message::{Text, UserContent};
1559
1560        let schema: schemars::Schema = serde_json::from_value(json!({
1561            "type": "object",
1562            "properties": {
1563                "age": { "type": "integer" },
1564                "available": { "type": "boolean" }
1565            },
1566            "required": ["age", "available"]
1567        }))
1568        .expect("Failed to parse schema");
1569
1570        let completion_request = CompletionRequest {
1571            model: Some("llama3.1".to_string()),
1572            preamble: None,
1573            chat_history: OneOrMany::one(CompletionMessage::User {
1574                content: OneOrMany::one(UserContent::Text(Text {
1575                    text: "How old is Ollama?".to_string(),
1576                })),
1577            }),
1578            documents: vec![],
1579            tools: vec![],
1580            temperature: None,
1581            max_tokens: None,
1582            tool_choice: None,
1583            additional_params: None,
1584            output_schema: Some(schema),
1585        };
1586
1587        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1588            .expect("Failed to create Ollama request");
1589
1590        let serialized =
1591            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1592
1593        let format = serialized
1594            .get("format")
1595            .expect("format field should be present");
1596        assert_eq!(
1597            *format,
1598            json!({
1599                "type": "object",
1600                "properties": {
1601                    "age": { "type": "integer" },
1602                    "available": { "type": "boolean" }
1603                },
1604                "required": ["age", "available"]
1605            })
1606        );
1607    }
1608
1609    #[test]
1610    fn test_completion_request_without_output_schema() {
1611        use crate::OneOrMany;
1612        use crate::completion::Message as CompletionMessage;
1613        use crate::message::{Text, UserContent};
1614
1615        let completion_request = CompletionRequest {
1616            model: Some("llama3.1".to_string()),
1617            preamble: None,
1618            chat_history: OneOrMany::one(CompletionMessage::User {
1619                content: OneOrMany::one(UserContent::Text(Text {
1620                    text: "Hello!".to_string(),
1621                })),
1622            }),
1623            documents: vec![],
1624            tools: vec![],
1625            temperature: None,
1626            max_tokens: None,
1627            tool_choice: None,
1628            additional_params: None,
1629            output_schema: None,
1630        };
1631
1632        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1633            .expect("Failed to create Ollama request");
1634
1635        let serialized =
1636            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1637
1638        assert!(
1639            serialized.get("format").is_none(),
1640            "format field should be absent when output_schema is None"
1641        );
1642    }
1643
1644    #[test]
1645    fn test_client_initialization() {
1646        let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1647        let _client_from_builder = crate::providers::ollama::Client::builder()
1648            .api_key(Nothing)
1649            .build()
1650            .expect("Client::builder() failed");
1651    }
1652}