Skip to main content

rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust,ignore
5//! use rig::client::{Nothing, CompletionClient};
6//! use rig::completion::Prompt;
7//! use rig::providers::ollama;
8//!
9//! // Create a new Ollama client (defaults to http://localhost:11434)
10//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct
11//! let client = ollama::Client::new(Nothing).unwrap();
12//!
13//! // Create an agent with a preamble
14//! let comedian_agent = client
15//!     .agent("qwen2.5:14b")
16//!     .preamble("You are a comedian here to entertain the user using humour and jokes.")
17//!     .build();
18//!
19//! // Prompt the agent and print the response
20//! let response = comedian_agent.prompt("Entertain me!").await?;
21//! println!("{response}");
22//!
23//! // Create an embedding model using the "all-minilm" model
24//! let emb_model = client.embedding_model("all-minilm", 384);
25//! let embeddings = emb_model.embed_texts(vec![
26//!     "Why is the sky blue?".to_owned(),
27//!     "Why is the grass green?".to_owned()
28//! ]).await?;
29//! println!("Embedding response: {:?}", embeddings);
30//!
31//! // Create an extractor if needed
32//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
33//! ```
34use crate::client::{
35    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42    OneOrMany,
43    completion::{self, CompletionError, CompletionRequest},
44    embeddings::{self, EmbeddingError},
45    json_utils, message,
46    message::{ImageDetail, Text},
47    streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57// ---------- Main Client ----------
58
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68    type Builder = OllamaBuilder;
69    const VERIFY_PATH: &'static str = "api/tags";
70}
71
72impl<H> Capabilities<H> for OllamaExt {
73    type Completion = Capable<CompletionModel<H>>;
74    type Transcription = Nothing;
75    type Embeddings = Capable<EmbeddingModel<H>>;
76    type ModelListing = Nothing;
77    #[cfg(feature = "image")]
78    type ImageGeneration = Nothing;
79
80    #[cfg(feature = "audio")]
81    type AudioGeneration = Nothing;
82}
83
84impl DebugExt for OllamaExt {}
85
86impl ProviderBuilder for OllamaBuilder {
87    type Extension<H>
88        = OllamaExt
89    where
90        H: HttpClientExt;
91    type ApiKey = Nothing;
92
93    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
94
95    fn build<H>(
96        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97    ) -> http_client::Result<Self::Extension<H>>
98    where
99        H: HttpClientExt,
100    {
101        Ok(OllamaExt)
102    }
103}
104
105pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
106pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
107
108impl ProviderClient for Client {
109    type Input = Nothing;
110
111    fn from_env() -> Self {
112        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
113
114        Self::builder()
115            .api_key(Nothing)
116            .base_url(&api_base)
117            .build()
118            .unwrap()
119    }
120
121    fn from_val(_: Self::Input) -> Self {
122        Self::builder().api_key(Nothing).build().unwrap()
123    }
124}
125
126// ---------- API Error and Response Structures ----------
127
128#[derive(Debug, Deserialize)]
129struct ApiErrorResponse {
130    message: String,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum ApiResponse<T> {
136    Ok(T),
137    Err(ApiErrorResponse),
138}
139
140// ---------- Embedding API ----------
141
142pub const ALL_MINILM: &str = "all-minilm";
143pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
144
145fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
146    match identifier {
147        ALL_MINILM => Some(384),
148        NOMIC_EMBED_TEXT => Some(768),
149        _ => None,
150    }
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct EmbeddingResponse {
155    pub model: String,
156    pub embeddings: Vec<Vec<f64>>,
157    #[serde(default)]
158    pub total_duration: Option<u64>,
159    #[serde(default)]
160    pub load_duration: Option<u64>,
161    #[serde(default)]
162    pub prompt_eval_count: Option<u64>,
163}
164
165impl From<ApiErrorResponse> for EmbeddingError {
166    fn from(err: ApiErrorResponse) -> Self {
167        EmbeddingError::ProviderError(err.message)
168    }
169}
170
171impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
172    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
173        match value {
174            ApiResponse::Ok(response) => Ok(response),
175            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
176        }
177    }
178}
179
180// ---------- Embedding Model ----------
181
182#[derive(Clone)]
183pub struct EmbeddingModel<T = reqwest::Client> {
184    client: Client<T>,
185    pub model: String,
186    ndims: usize,
187}
188
189impl<T> EmbeddingModel<T> {
190    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
191        Self {
192            client,
193            model: model.into(),
194            ndims,
195        }
196    }
197
198    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
199        Self {
200            client,
201            model: model.into(),
202            ndims,
203        }
204    }
205}
206
207impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
208where
209    T: HttpClientExt + Clone + 'static,
210{
211    type Client = Client<T>;
212
213    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
214        let model = model.into();
215        let dims = dims
216            .or(model_dimensions_from_identifier(&model))
217            .unwrap_or_default();
218        Self::new(client.clone(), model, dims)
219    }
220
221    const MAX_DOCUMENTS: usize = 1024;
222    fn ndims(&self) -> usize {
223        self.ndims
224    }
225
226    async fn embed_texts(
227        &self,
228        documents: impl IntoIterator<Item = String>,
229    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
230        let docs: Vec<String> = documents.into_iter().collect();
231
232        let body = serde_json::to_vec(&json!({
233            "model": self.model,
234            "input": docs
235        }))?;
236
237        let req = self
238            .client
239            .post("api/embed")?
240            .body(body)
241            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
242
243        let response = self.client.send(req).await?;
244
245        if !response.status().is_success() {
246            let text = http_client::text(response).await?;
247            return Err(EmbeddingError::ProviderError(text));
248        }
249
250        let bytes: Vec<u8> = response.into_body().await?;
251
252        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
253
254        if api_resp.embeddings.len() != docs.len() {
255            return Err(EmbeddingError::ResponseError(
256                "Number of returned embeddings does not match input".into(),
257            ));
258        }
259        Ok(api_resp
260            .embeddings
261            .into_iter()
262            .zip(docs.into_iter())
263            .map(|(vec, document)| embeddings::Embedding { document, vec })
264            .collect())
265    }
266}
267
268// ---------- Completion API ----------
269
270pub const LLAMA3_2: &str = "llama3.2";
271pub const LLAVA: &str = "llava";
272pub const MISTRAL: &str = "mistral";
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct CompletionResponse {
276    pub model: String,
277    pub created_at: String,
278    pub message: Message,
279    pub done: bool,
280    #[serde(default)]
281    pub done_reason: Option<String>,
282    #[serde(default)]
283    pub total_duration: Option<u64>,
284    #[serde(default)]
285    pub load_duration: Option<u64>,
286    #[serde(default)]
287    pub prompt_eval_count: Option<u64>,
288    #[serde(default)]
289    pub prompt_eval_duration: Option<u64>,
290    #[serde(default)]
291    pub eval_count: Option<u64>,
292    #[serde(default)]
293    pub eval_duration: Option<u64>,
294}
295impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
296    type Error = CompletionError;
297    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
298        match resp.message {
299            // Process only if an assistant message is present.
300            Message::Assistant {
301                content,
302                thinking,
303                tool_calls,
304                ..
305            } => {
306                let mut assistant_contents = Vec::new();
307                // Add the assistant's text content if any.
308                if !content.is_empty() {
309                    assistant_contents.push(completion::AssistantContent::text(&content));
310                }
311                // Process tool_calls following Ollama's chat response definition.
312                // Each ToolCall has an id, a type, and a function field.
313                for tc in tool_calls.iter() {
314                    assistant_contents.push(completion::AssistantContent::tool_call(
315                        tc.function.name.clone(),
316                        tc.function.name.clone(),
317                        tc.function.arguments.clone(),
318                    ));
319                }
320                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
321                    CompletionError::ResponseError("No content provided".to_owned())
322                })?;
323                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
324                let completion_tokens = resp.eval_count.unwrap_or(0);
325
326                let raw_response = CompletionResponse {
327                    model: resp.model,
328                    created_at: resp.created_at,
329                    done: resp.done,
330                    done_reason: resp.done_reason,
331                    total_duration: resp.total_duration,
332                    load_duration: resp.load_duration,
333                    prompt_eval_count: resp.prompt_eval_count,
334                    prompt_eval_duration: resp.prompt_eval_duration,
335                    eval_count: resp.eval_count,
336                    eval_duration: resp.eval_duration,
337                    message: Message::Assistant {
338                        content,
339                        thinking,
340                        images: None,
341                        name: None,
342                        tool_calls,
343                    },
344                };
345
346                Ok(completion::CompletionResponse {
347                    choice,
348                    usage: Usage {
349                        input_tokens: prompt_tokens,
350                        output_tokens: completion_tokens,
351                        total_tokens: prompt_tokens + completion_tokens,
352                        cached_input_tokens: 0,
353                    },
354                    raw_response,
355                    message_id: None,
356                })
357            }
358            _ => Err(CompletionError::ResponseError(
359                "Chat response does not include an assistant message".into(),
360            )),
361        }
362    }
363}
364
365#[derive(Debug, Serialize, Deserialize)]
366pub(super) struct OllamaCompletionRequest {
367    model: String,
368    pub messages: Vec<Message>,
369    #[serde(skip_serializing_if = "Option::is_none")]
370    temperature: Option<f64>,
371    #[serde(skip_serializing_if = "Vec::is_empty")]
372    tools: Vec<ToolDefinition>,
373    pub stream: bool,
374    think: bool,
375    #[serde(skip_serializing_if = "Option::is_none")]
376    max_tokens: Option<u64>,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    keep_alive: Option<String>,
379    #[serde(skip_serializing_if = "Option::is_none")]
380    format: Option<schemars::Schema>,
381    options: serde_json::Value,
382}
383
384impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
385    type Error = CompletionError;
386
387    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
388        let model = req.model.clone().unwrap_or_else(|| model.to_string());
389        if req.tool_choice.is_some() {
390            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
391        }
392        // Build up the order of messages (context, chat_history, prompt)
393        let mut partial_history = vec![];
394        if let Some(docs) = req.normalized_documents() {
395            partial_history.push(docs);
396        }
397        partial_history.extend(req.chat_history);
398
399        // Add preamble to chat history (if available)
400        let mut full_history: Vec<Message> = match &req.preamble {
401            Some(preamble) => vec![Message::system(preamble)],
402            None => vec![],
403        };
404
405        // Convert and extend the rest of the history
406        full_history.extend(
407            partial_history
408                .into_iter()
409                .map(message::Message::try_into)
410                .collect::<Result<Vec<Vec<Message>>, _>>()?
411                .into_iter()
412                .flatten()
413                .collect::<Vec<_>>(),
414        );
415
416        let mut think = false;
417        let mut keep_alive: Option<String> = None;
418
419        let options = if let Some(mut extra) = req.additional_params {
420            // Extract top-level parameters that should not be in `options`
421            if let Some(obj) = extra.as_object_mut() {
422                // Extract `think` parameter
423                if let Some(think_val) = obj.remove("think") {
424                    think = think_val.as_bool().ok_or_else(|| {
425                        CompletionError::RequestError("`think` must be a bool".into())
426                    })?;
427                }
428
429                // Extract `keep_alive` parameter
430                if let Some(keep_alive_val) = obj.remove("keep_alive") {
431                    keep_alive = Some(
432                        keep_alive_val
433                            .as_str()
434                            .ok_or_else(|| {
435                                CompletionError::RequestError(
436                                    "`keep_alive` must be a string".into(),
437                                )
438                            })?
439                            .to_string(),
440                    );
441                }
442            }
443
444            json_utils::merge(json!({ "temperature": req.temperature }), extra)
445        } else {
446            json!({ "temperature": req.temperature })
447        };
448
449        Ok(Self {
450            model: model.to_string(),
451            messages: full_history,
452            temperature: req.temperature,
453            max_tokens: req.max_tokens,
454            stream: false,
455            think,
456            keep_alive,
457            format: req.output_schema,
458            tools: req
459                .tools
460                .clone()
461                .into_iter()
462                .map(ToolDefinition::from)
463                .collect::<Vec<_>>(),
464            options,
465        })
466    }
467}
468
469#[derive(Clone)]
470pub struct CompletionModel<T = reqwest::Client> {
471    client: Client<T>,
472    pub model: String,
473}
474
475impl<T> CompletionModel<T> {
476    pub fn new(client: Client<T>, model: &str) -> Self {
477        Self {
478            client,
479            model: model.to_owned(),
480        }
481    }
482}
483
484// ---------- CompletionModel Implementation ----------
485
486#[derive(Clone, Serialize, Deserialize, Debug)]
487pub struct StreamingCompletionResponse {
488    pub done_reason: Option<String>,
489    pub total_duration: Option<u64>,
490    pub load_duration: Option<u64>,
491    pub prompt_eval_count: Option<u64>,
492    pub prompt_eval_duration: Option<u64>,
493    pub eval_count: Option<u64>,
494    pub eval_duration: Option<u64>,
495}
496
497impl GetTokenUsage for StreamingCompletionResponse {
498    fn token_usage(&self) -> Option<crate::completion::Usage> {
499        let mut usage = crate::completion::Usage::new();
500        let input_tokens = self.prompt_eval_count.unwrap_or_default();
501        let output_tokens = self.eval_count.unwrap_or_default();
502        usage.input_tokens = input_tokens;
503        usage.output_tokens = output_tokens;
504        usage.total_tokens = input_tokens + output_tokens;
505
506        Some(usage)
507    }
508}
509
510impl<T> completion::CompletionModel for CompletionModel<T>
511where
512    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
513{
514    type Response = CompletionResponse;
515    type StreamingResponse = StreamingCompletionResponse;
516
517    type Client = Client<T>;
518
519    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
520        Self::new(client.clone(), model.into().as_str())
521    }
522
523    async fn completion(
524        &self,
525        completion_request: CompletionRequest,
526    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
527        let span = if tracing::Span::current().is_disabled() {
528            info_span!(
529                target: "rig::completions",
530                "chat",
531                gen_ai.operation.name = "chat",
532                gen_ai.provider.name = "ollama",
533                gen_ai.request.model = self.model,
534                gen_ai.system_instructions = tracing::field::Empty,
535                gen_ai.response.id = tracing::field::Empty,
536                gen_ai.response.model = tracing::field::Empty,
537                gen_ai.usage.output_tokens = tracing::field::Empty,
538                gen_ai.usage.input_tokens = tracing::field::Empty,
539            )
540        } else {
541            tracing::Span::current()
542        };
543
544        span.record("gen_ai.system_instructions", &completion_request.preamble);
545        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
546
547        if tracing::enabled!(tracing::Level::TRACE) {
548            tracing::trace!(target: "rig::completions",
549                "Ollama completion request: {}",
550                serde_json::to_string_pretty(&request)?
551            );
552        }
553
554        let body = serde_json::to_vec(&request)?;
555
556        let req = self
557            .client
558            .post("api/chat")?
559            .body(body)
560            .map_err(http_client::Error::from)?;
561
562        let async_block = async move {
563            let response = self.client.send::<_, Bytes>(req).await?;
564            let status = response.status();
565            let response_body = response.into_body().into_future().await?.to_vec();
566
567            if !status.is_success() {
568                return Err(CompletionError::ProviderError(
569                    String::from_utf8_lossy(&response_body).to_string(),
570                ));
571            }
572
573            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
574            let span = tracing::Span::current();
575            span.record("gen_ai.response.model_name", &response.model);
576            span.record(
577                "gen_ai.usage.input_tokens",
578                response.prompt_eval_count.unwrap_or_default(),
579            );
580            span.record(
581                "gen_ai.usage.output_tokens",
582                response.eval_count.unwrap_or_default(),
583            );
584
585            if tracing::enabled!(tracing::Level::TRACE) {
586                tracing::trace!(target: "rig::completions",
587                    "Ollama completion response: {}",
588                    serde_json::to_string_pretty(&response)?
589                );
590            }
591
592            let response: completion::CompletionResponse<CompletionResponse> =
593                response.try_into()?;
594
595            Ok(response)
596        };
597
598        tracing::Instrument::instrument(async_block, span).await
599    }
600
601    async fn stream(
602        &self,
603        request: CompletionRequest,
604    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
605    {
606        let span = if tracing::Span::current().is_disabled() {
607            info_span!(
608                target: "rig::completions",
609                "chat_streaming",
610                gen_ai.operation.name = "chat_streaming",
611                gen_ai.provider.name = "ollama",
612                gen_ai.request.model = self.model,
613                gen_ai.system_instructions = tracing::field::Empty,
614                gen_ai.response.id = tracing::field::Empty,
615                gen_ai.response.model = self.model,
616                gen_ai.usage.output_tokens = tracing::field::Empty,
617                gen_ai.usage.input_tokens = tracing::field::Empty,
618            )
619        } else {
620            tracing::Span::current()
621        };
622
623        span.record("gen_ai.system_instructions", &request.preamble);
624
625        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
626        request.stream = true;
627
628        if tracing::enabled!(tracing::Level::TRACE) {
629            tracing::trace!(target: "rig::completions",
630                "Ollama streaming completion request: {}",
631                serde_json::to_string_pretty(&request)?
632            );
633        }
634
635        let body = serde_json::to_vec(&request)?;
636
637        let req = self
638            .client
639            .post("api/chat")?
640            .body(body)
641            .map_err(http_client::Error::from)?;
642
643        let response = self.client.send_streaming(req).await?;
644        let status = response.status();
645        let mut byte_stream = response.into_body();
646
647        if !status.is_success() {
648            return Err(CompletionError::ProviderError(format!(
649                "Got error status code trying to send a request to Ollama: {status}"
650            )));
651        }
652
653        let stream = try_stream! {
654            let span = tracing::Span::current();
655            let mut tool_calls_final = Vec::new();
656            let mut text_response = String::new();
657            let mut thinking_response = String::new();
658
659            while let Some(chunk) = byte_stream.next().await {
660                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
661
662                for line in bytes.split(|&b| b == b'\n') {
663                    if line.is_empty() {
664                        continue;
665                    }
666
667                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
668
669                    let response: CompletionResponse = serde_json::from_slice(line)?;
670
671                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
672                        if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
673                            thinking_response += &thinking_content;
674                            yield RawStreamingChoice::ReasoningDelta {
675                                id: None,
676                                reasoning: thinking_content,
677                            };
678                        }
679
680                        if !content.is_empty() {
681                            text_response += &content;
682                            yield RawStreamingChoice::Message(content);
683                        }
684
685                        for tool_call in tool_calls {
686                            tool_calls_final.push(tool_call.clone());
687                            yield RawStreamingChoice::ToolCall(
688                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
689                            );
690                        }
691                    }
692
693                    if response.done {
694                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
695                        span.record("gen_ai.usage.output_tokens", response.eval_count);
696                        let message = Message::Assistant {
697                            content: text_response.clone(),
698                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
699                            images: None,
700                            name: None,
701                            tool_calls: tool_calls_final.clone()
702                        };
703                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
704                        yield RawStreamingChoice::FinalResponse(
705                            StreamingCompletionResponse {
706                                total_duration: response.total_duration,
707                                load_duration: response.load_duration,
708                                prompt_eval_count: response.prompt_eval_count,
709                                prompt_eval_duration: response.prompt_eval_duration,
710                                eval_count: response.eval_count,
711                                eval_duration: response.eval_duration,
712                                done_reason: response.done_reason,
713                            }
714                        );
715                        break;
716                    }
717                }
718            }
719        }.instrument(span);
720
721        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
722            stream,
723        )))
724    }
725}
726
727// ---------- Tool Definition Conversion ----------
728
729/// Ollama-required tool definition format.
730#[derive(Clone, Debug, Deserialize, Serialize)]
731pub struct ToolDefinition {
732    #[serde(rename = "type")]
733    pub type_field: String, // Fixed as "function"
734    pub function: completion::ToolDefinition,
735}
736
737/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
738impl From<crate::completion::ToolDefinition> for ToolDefinition {
739    fn from(tool: crate::completion::ToolDefinition) -> Self {
740        ToolDefinition {
741            type_field: "function".to_owned(),
742            function: completion::ToolDefinition {
743                name: tool.name,
744                description: tool.description,
745                parameters: tool.parameters,
746            },
747        }
748    }
749}
750
751#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
752pub struct ToolCall {
753    #[serde(default, rename = "type")]
754    pub r#type: ToolType,
755    pub function: Function,
756}
757#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
758#[serde(rename_all = "lowercase")]
759pub enum ToolType {
760    #[default]
761    Function,
762}
763#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
764pub struct Function {
765    pub name: String,
766    pub arguments: Value,
767}
768
769// ---------- Provider Message Definition ----------
770
771#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
772#[serde(tag = "role", rename_all = "lowercase")]
773pub enum Message {
774    User {
775        content: String,
776        #[serde(skip_serializing_if = "Option::is_none")]
777        images: Option<Vec<String>>,
778        #[serde(skip_serializing_if = "Option::is_none")]
779        name: Option<String>,
780    },
781    Assistant {
782        #[serde(default)]
783        content: String,
784        #[serde(skip_serializing_if = "Option::is_none")]
785        thinking: Option<String>,
786        #[serde(skip_serializing_if = "Option::is_none")]
787        images: Option<Vec<String>>,
788        #[serde(skip_serializing_if = "Option::is_none")]
789        name: Option<String>,
790        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
791        tool_calls: Vec<ToolCall>,
792    },
793    System {
794        content: String,
795        #[serde(skip_serializing_if = "Option::is_none")]
796        images: Option<Vec<String>>,
797        #[serde(skip_serializing_if = "Option::is_none")]
798        name: Option<String>,
799    },
800    #[serde(rename = "tool")]
801    ToolResult {
802        #[serde(rename = "tool_name")]
803        name: String,
804        content: String,
805    },
806}
807
808/// -----------------------------
809/// Provider Message Conversions
810/// -----------------------------
811/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
812/// (Only User and Assistant variants are supported.)
813impl TryFrom<crate::message::Message> for Vec<Message> {
814    type Error = crate::message::MessageError;
815    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
816        use crate::message::Message as InternalMessage;
817        match internal_msg {
818            InternalMessage::User { content, .. } => {
819                let (tool_results, other_content): (Vec<_>, Vec<_>) =
820                    content.into_iter().partition(|content| {
821                        matches!(content, crate::message::UserContent::ToolResult(_))
822                    });
823
824                if !tool_results.is_empty() {
825                    tool_results
826                        .into_iter()
827                        .map(|content| match content {
828                            crate::message::UserContent::ToolResult(
829                                crate::message::ToolResult { id, content, .. },
830                            ) => {
831                                // Ollama expects a single string for tool results, so we concatenate
832                                let content_string = content
833                                    .into_iter()
834                                    .map(|content| match content {
835                                        crate::message::ToolResultContent::Text(text) => text.text,
836                                        _ => "[Non-text content]".to_string(),
837                                    })
838                                    .collect::<Vec<_>>()
839                                    .join("\n");
840
841                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
842                                    name: id,
843                                    content: content_string,
844                                })
845                            }
846                            _ => unreachable!(),
847                        })
848                        .collect::<Result<Vec<_>, _>>()
849                } else {
850                    // Ollama requires separate text content and images array
851                    let (texts, images) = other_content.into_iter().fold(
852                        (Vec::new(), Vec::new()),
853                        |(mut texts, mut images), content| {
854                            match content {
855                                crate::message::UserContent::Text(crate::message::Text {
856                                    text,
857                                }) => texts.push(text),
858                                crate::message::UserContent::Image(crate::message::Image {
859                                    data: DocumentSourceKind::Base64(data),
860                                    ..
861                                }) => images.push(data),
862                                crate::message::UserContent::Document(
863                                    crate::message::Document {
864                                        data:
865                                            DocumentSourceKind::Base64(data)
866                                            | DocumentSourceKind::String(data),
867                                        ..
868                                    },
869                                ) => texts.push(data),
870                                _ => {} // Audio not supported by Ollama
871                            }
872                            (texts, images)
873                        },
874                    );
875
876                    Ok(vec![Message::User {
877                        content: texts.join(" "),
878                        images: if images.is_empty() {
879                            None
880                        } else {
881                            Some(
882                                images
883                                    .into_iter()
884                                    .map(|x| x.to_string())
885                                    .collect::<Vec<String>>(),
886                            )
887                        },
888                        name: None,
889                    }])
890                }
891            }
892            InternalMessage::Assistant { content, .. } => {
893                let mut thinking: Option<String> = None;
894                let mut text_content = Vec::new();
895                let mut tool_calls = Vec::new();
896
897                for content in content.into_iter() {
898                    match content {
899                        crate::message::AssistantContent::Text(text) => {
900                            text_content.push(text.text)
901                        }
902                        crate::message::AssistantContent::ToolCall(tool_call) => {
903                            tool_calls.push(tool_call)
904                        }
905                        crate::message::AssistantContent::Reasoning(reasoning) => {
906                            let display = reasoning.display_text();
907                            if !display.is_empty() {
908                                thinking = Some(display);
909                            }
910                        }
911                        crate::message::AssistantContent::Image(_) => {
912                            return Err(crate::message::MessageError::ConversionError(
913                                "Ollama currently doesn't support images.".into(),
914                            ));
915                        }
916                    }
917                }
918
919                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
920                //  so either `content` or `tool_calls` will have some content.
921                Ok(vec![Message::Assistant {
922                    content: text_content.join(" "),
923                    thinking,
924                    images: None,
925                    name: None,
926                    tool_calls: tool_calls
927                        .into_iter()
928                        .map(|tool_call| tool_call.into())
929                        .collect::<Vec<_>>(),
930                }])
931            }
932        }
933    }
934}
935
936/// Conversion from provider Message to a completion message.
937/// This is needed so that responses can be converted back into chat history.
938impl From<Message> for crate::completion::Message {
939    fn from(msg: Message) -> Self {
940        match msg {
941            Message::User { content, .. } => crate::completion::Message::User {
942                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
943                    text: content,
944                })),
945            },
946            Message::Assistant {
947                content,
948                tool_calls,
949                ..
950            } => {
951                let mut assistant_contents =
952                    vec![crate::completion::message::AssistantContent::Text(Text {
953                        text: content,
954                    })];
955                for tc in tool_calls {
956                    assistant_contents.push(
957                        crate::completion::message::AssistantContent::tool_call(
958                            tc.function.name.clone(),
959                            tc.function.name,
960                            tc.function.arguments,
961                        ),
962                    );
963                }
964                crate::completion::Message::Assistant {
965                    id: None,
966                    content: OneOrMany::many(assistant_contents).unwrap(),
967                }
968            }
969            // System and ToolResult are converted to User message as needed.
970            Message::System { content, .. } => crate::completion::Message::User {
971                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
972                    text: content,
973                })),
974            },
975            Message::ToolResult { name, content } => crate::completion::Message::User {
976                content: OneOrMany::one(message::UserContent::tool_result(
977                    name,
978                    OneOrMany::one(message::ToolResultContent::text(content)),
979                )),
980            },
981        }
982    }
983}
984
985impl Message {
986    /// Constructs a system message.
987    pub fn system(content: &str) -> Self {
988        Message::System {
989            content: content.to_owned(),
990            images: None,
991            name: None,
992        }
993    }
994}
995
996// ---------- Additional Message Types ----------
997
998impl From<crate::message::ToolCall> for ToolCall {
999    fn from(tool_call: crate::message::ToolCall) -> Self {
1000        Self {
1001            r#type: ToolType::Function,
1002            function: Function {
1003                name: tool_call.function.name,
1004                arguments: tool_call.function.arguments,
1005            },
1006        }
1007    }
1008}
1009
1010#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1011pub struct SystemContent {
1012    #[serde(default)]
1013    r#type: SystemContentType,
1014    text: String,
1015}
1016
1017#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1018#[serde(rename_all = "lowercase")]
1019pub enum SystemContentType {
1020    #[default]
1021    Text,
1022}
1023
1024impl From<String> for SystemContent {
1025    fn from(s: String) -> Self {
1026        SystemContent {
1027            r#type: SystemContentType::default(),
1028            text: s,
1029        }
1030    }
1031}
1032
1033impl FromStr for SystemContent {
1034    type Err = std::convert::Infallible;
1035    fn from_str(s: &str) -> Result<Self, Self::Err> {
1036        Ok(SystemContent {
1037            r#type: SystemContentType::default(),
1038            text: s.to_string(),
1039        })
1040    }
1041}
1042
1043#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1044pub struct AssistantContent {
1045    pub text: String,
1046}
1047
1048impl FromStr for AssistantContent {
1049    type Err = std::convert::Infallible;
1050    fn from_str(s: &str) -> Result<Self, Self::Err> {
1051        Ok(AssistantContent { text: s.to_owned() })
1052    }
1053}
1054
1055#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1056#[serde(tag = "type", rename_all = "lowercase")]
1057pub enum UserContent {
1058    Text { text: String },
1059    Image { image_url: ImageUrl },
1060    // Audio variant removed as Ollama API does not support audio input.
1061}
1062
1063impl FromStr for UserContent {
1064    type Err = std::convert::Infallible;
1065    fn from_str(s: &str) -> Result<Self, Self::Err> {
1066        Ok(UserContent::Text { text: s.to_owned() })
1067    }
1068}
1069
1070#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1071pub struct ImageUrl {
1072    pub url: String,
1073    #[serde(default)]
1074    pub detail: ImageDetail,
1075}
1076
1077// =================================================================
1078// Tests
1079// =================================================================
1080
1081#[cfg(test)]
1082mod tests {
1083    use super::*;
1084    use serde_json::json;
1085
1086    // Test deserialization and conversion for the /api/chat endpoint.
1087    #[tokio::test]
1088    async fn test_chat_completion() {
1089        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1090        let sample_chat_response = json!({
1091            "model": "llama3.2",
1092            "created_at": "2023-08-04T19:22:45.499127Z",
1093            "message": {
1094                "role": "assistant",
1095                "content": "The sky is blue because of Rayleigh scattering.",
1096                "images": null,
1097                "tool_calls": [
1098                    {
1099                        "type": "function",
1100                        "function": {
1101                            "name": "get_current_weather",
1102                            "arguments": {
1103                                "location": "San Francisco, CA",
1104                                "format": "celsius"
1105                            }
1106                        }
1107                    }
1108                ]
1109            },
1110            "done": true,
1111            "total_duration": 8000000000u64,
1112            "load_duration": 6000000u64,
1113            "prompt_eval_count": 61u64,
1114            "prompt_eval_duration": 400000000u64,
1115            "eval_count": 468u64,
1116            "eval_duration": 7700000000u64
1117        });
1118        let sample_text = sample_chat_response.to_string();
1119
1120        let chat_resp: CompletionResponse =
1121            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1122        let conv: completion::CompletionResponse<CompletionResponse> =
1123            chat_resp.try_into().unwrap();
1124        assert!(
1125            !conv.choice.is_empty(),
1126            "Expected non-empty choice in chat response"
1127        );
1128    }
1129
1130    // Test conversion from provider Message to completion Message.
1131    #[test]
1132    fn test_message_conversion() {
1133        // Construct a provider Message (User variant with String content).
1134        let provider_msg = Message::User {
1135            content: "Test message".to_owned(),
1136            images: None,
1137            name: None,
1138        };
1139        // Convert it into a completion::Message.
1140        let comp_msg: crate::completion::Message = provider_msg.into();
1141        match comp_msg {
1142            crate::completion::Message::User { content } => {
1143                // Assume OneOrMany<T> has a method first() to access the first element.
1144                let first_content = content.first();
1145                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1146                match first_content {
1147                    crate::completion::message::UserContent::Text(text_struct) => {
1148                        assert_eq!(text_struct.text, "Test message");
1149                    }
1150                    _ => panic!("Expected text content in conversion"),
1151                }
1152            }
1153            _ => panic!("Conversion from provider Message to completion Message failed"),
1154        }
1155    }
1156
1157    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1158    #[test]
1159    fn test_tool_definition_conversion() {
1160        // Internal tool definition from the completion module.
1161        let internal_tool = crate::completion::ToolDefinition {
1162            name: "get_current_weather".to_owned(),
1163            description: "Get the current weather for a location".to_owned(),
1164            parameters: json!({
1165                "type": "object",
1166                "properties": {
1167                    "location": {
1168                        "type": "string",
1169                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1170                    },
1171                    "format": {
1172                        "type": "string",
1173                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1174                        "enum": ["celsius", "fahrenheit"]
1175                    }
1176                },
1177                "required": ["location", "format"]
1178            }),
1179        };
1180        // Convert internal tool to Ollama's tool definition.
1181        let ollama_tool: ToolDefinition = internal_tool.into();
1182        assert_eq!(ollama_tool.type_field, "function");
1183        assert_eq!(ollama_tool.function.name, "get_current_weather");
1184        assert_eq!(
1185            ollama_tool.function.description,
1186            "Get the current weather for a location"
1187        );
1188        // Check JSON fields in parameters.
1189        let params = &ollama_tool.function.parameters;
1190        assert_eq!(params["properties"]["location"]["type"], "string");
1191    }
1192
1193    // Test deserialization of chat response with thinking content
1194    #[tokio::test]
1195    async fn test_chat_completion_with_thinking() {
1196        let sample_response = json!({
1197            "model": "qwen-thinking",
1198            "created_at": "2023-08-04T19:22:45.499127Z",
1199            "message": {
1200                "role": "assistant",
1201                "content": "The answer is 42.",
1202                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1203                "images": null,
1204                "tool_calls": []
1205            },
1206            "done": true,
1207            "total_duration": 8000000000u64,
1208            "load_duration": 6000000u64,
1209            "prompt_eval_count": 61u64,
1210            "prompt_eval_duration": 400000000u64,
1211            "eval_count": 468u64,
1212            "eval_duration": 7700000000u64
1213        });
1214
1215        let chat_resp: CompletionResponse =
1216            serde_json::from_value(sample_response).expect("Failed to deserialize");
1217
1218        // Verify thinking field is present
1219        if let Message::Assistant {
1220            thinking, content, ..
1221        } = &chat_resp.message
1222        {
1223            assert_eq!(
1224                thinking.as_ref().unwrap(),
1225                "Let me think about this carefully. The question asks for the meaning of life..."
1226            );
1227            assert_eq!(content, "The answer is 42.");
1228        } else {
1229            panic!("Expected Assistant message");
1230        }
1231    }
1232
1233    // Test deserialization of chat response without thinking content
1234    #[tokio::test]
1235    async fn test_chat_completion_without_thinking() {
1236        let sample_response = json!({
1237            "model": "llama3.2",
1238            "created_at": "2023-08-04T19:22:45.499127Z",
1239            "message": {
1240                "role": "assistant",
1241                "content": "Hello!",
1242                "images": null,
1243                "tool_calls": []
1244            },
1245            "done": true,
1246            "total_duration": 8000000000u64,
1247            "load_duration": 6000000u64,
1248            "prompt_eval_count": 10u64,
1249            "prompt_eval_duration": 400000000u64,
1250            "eval_count": 5u64,
1251            "eval_duration": 7700000000u64
1252        });
1253
1254        let chat_resp: CompletionResponse =
1255            serde_json::from_value(sample_response).expect("Failed to deserialize");
1256
1257        // Verify thinking field is None when not provided
1258        if let Message::Assistant {
1259            thinking, content, ..
1260        } = &chat_resp.message
1261        {
1262            assert!(thinking.is_none());
1263            assert_eq!(content, "Hello!");
1264        } else {
1265            panic!("Expected Assistant message");
1266        }
1267    }
1268
1269    // Test deserialization of streaming response with thinking content
1270    #[test]
1271    fn test_streaming_response_with_thinking() {
1272        let sample_chunk = json!({
1273            "model": "qwen-thinking",
1274            "created_at": "2023-08-04T19:22:45.499127Z",
1275            "message": {
1276                "role": "assistant",
1277                "content": "",
1278                "thinking": "Analyzing the problem...",
1279                "images": null,
1280                "tool_calls": []
1281            },
1282            "done": false
1283        });
1284
1285        let chunk: CompletionResponse =
1286            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1287
1288        if let Message::Assistant {
1289            thinking, content, ..
1290        } = &chunk.message
1291        {
1292            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1293            assert_eq!(content, "");
1294        } else {
1295            panic!("Expected Assistant message");
1296        }
1297    }
1298
1299    // Test message conversion with thinking content
1300    #[test]
1301    fn test_message_conversion_with_thinking() {
1302        // Create an internal message with reasoning content
1303        let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1304
1305        let internal_msg = crate::message::Message::Assistant {
1306            id: None,
1307            content: crate::OneOrMany::many(vec![
1308                crate::message::AssistantContent::Reasoning(reasoning_content),
1309                crate::message::AssistantContent::Text(crate::message::Text {
1310                    text: "The answer is X".to_string(),
1311                }),
1312            ])
1313            .unwrap(),
1314        };
1315
1316        // Convert to provider Message
1317        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1318        assert_eq!(provider_msgs.len(), 1);
1319
1320        if let Message::Assistant {
1321            thinking, content, ..
1322        } = &provider_msgs[0]
1323        {
1324            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1325            assert_eq!(content, "The answer is X");
1326        } else {
1327            panic!("Expected Assistant message with thinking");
1328        }
1329    }
1330
1331    // Test empty thinking content is handled correctly
1332    #[test]
1333    fn test_empty_thinking_content() {
1334        let sample_response = json!({
1335            "model": "llama3.2",
1336            "created_at": "2023-08-04T19:22:45.499127Z",
1337            "message": {
1338                "role": "assistant",
1339                "content": "Response",
1340                "thinking": "",
1341                "images": null,
1342                "tool_calls": []
1343            },
1344            "done": true,
1345            "total_duration": 8000000000u64,
1346            "load_duration": 6000000u64,
1347            "prompt_eval_count": 10u64,
1348            "prompt_eval_duration": 400000000u64,
1349            "eval_count": 5u64,
1350            "eval_duration": 7700000000u64
1351        });
1352
1353        let chat_resp: CompletionResponse =
1354            serde_json::from_value(sample_response).expect("Failed to deserialize");
1355
1356        if let Message::Assistant {
1357            thinking, content, ..
1358        } = &chat_resp.message
1359        {
1360            // Empty string should still deserialize as Some("")
1361            assert_eq!(thinking.as_ref().unwrap(), "");
1362            assert_eq!(content, "Response");
1363        } else {
1364            panic!("Expected Assistant message");
1365        }
1366    }
1367
1368    // Test thinking with tool calls
1369    #[test]
1370    fn test_thinking_with_tool_calls() {
1371        let sample_response = json!({
1372            "model": "qwen-thinking",
1373            "created_at": "2023-08-04T19:22:45.499127Z",
1374            "message": {
1375                "role": "assistant",
1376                "content": "Let me check the weather.",
1377                "thinking": "User wants weather info, I should use the weather tool",
1378                "images": null,
1379                "tool_calls": [
1380                    {
1381                        "type": "function",
1382                        "function": {
1383                            "name": "get_weather",
1384                            "arguments": {
1385                                "location": "San Francisco"
1386                            }
1387                        }
1388                    }
1389                ]
1390            },
1391            "done": true,
1392            "total_duration": 8000000000u64,
1393            "load_duration": 6000000u64,
1394            "prompt_eval_count": 30u64,
1395            "prompt_eval_duration": 400000000u64,
1396            "eval_count": 50u64,
1397            "eval_duration": 7700000000u64
1398        });
1399
1400        let chat_resp: CompletionResponse =
1401            serde_json::from_value(sample_response).expect("Failed to deserialize");
1402
1403        if let Message::Assistant {
1404            thinking,
1405            content,
1406            tool_calls,
1407            ..
1408        } = &chat_resp.message
1409        {
1410            assert_eq!(
1411                thinking.as_ref().unwrap(),
1412                "User wants weather info, I should use the weather tool"
1413            );
1414            assert_eq!(content, "Let me check the weather.");
1415            assert_eq!(tool_calls.len(), 1);
1416            assert_eq!(tool_calls[0].function.name, "get_weather");
1417        } else {
1418            panic!("Expected Assistant message with thinking and tool calls");
1419        }
1420    }
1421
1422    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1423    #[test]
1424    fn test_completion_request_with_think_param() {
1425        use crate::OneOrMany;
1426        use crate::completion::Message as CompletionMessage;
1427        use crate::message::{Text, UserContent};
1428
1429        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1430        let completion_request = CompletionRequest {
1431            model: None,
1432            preamble: Some("You are a helpful assistant.".to_string()),
1433            chat_history: OneOrMany::one(CompletionMessage::User {
1434                content: OneOrMany::one(UserContent::Text(Text {
1435                    text: "What is 2 + 2?".to_string(),
1436                })),
1437            }),
1438            documents: vec![],
1439            tools: vec![],
1440            temperature: Some(0.7),
1441            max_tokens: Some(1024),
1442            tool_choice: None,
1443            additional_params: Some(json!({
1444                "think": true,
1445                "keep_alive": "-1m",
1446                "num_ctx": 4096
1447            })),
1448            output_schema: None,
1449        };
1450
1451        // Convert to OllamaCompletionRequest
1452        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1453            .expect("Failed to create Ollama request");
1454
1455        // Serialize to JSON
1456        let serialized =
1457            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1458
1459        // Assert equality with expected JSON
1460        // - "tools" is skipped when empty (skip_serializing_if)
1461        // - "think" should be a top-level boolean, NOT in options
1462        // - "keep_alive" should be a top-level string, NOT in options
1463        // - "num_ctx" should be in options (it's a model parameter)
1464        let expected = json!({
1465            "model": "qwen3:8b",
1466            "messages": [
1467                {
1468                    "role": "system",
1469                    "content": "You are a helpful assistant."
1470                },
1471                {
1472                    "role": "user",
1473                    "content": "What is 2 + 2?"
1474                }
1475            ],
1476            "temperature": 0.7,
1477            "stream": false,
1478            "think": true,
1479            "max_tokens": 1024,
1480            "keep_alive": "-1m",
1481            "options": {
1482                "temperature": 0.7,
1483                "num_ctx": 4096
1484            }
1485        });
1486
1487        assert_eq!(serialized, expected);
1488    }
1489
1490    // Test that `think` defaults to false when not specified
1491    #[test]
1492    fn test_completion_request_with_think_false_default() {
1493        use crate::OneOrMany;
1494        use crate::completion::Message as CompletionMessage;
1495        use crate::message::{Text, UserContent};
1496
1497        // Create a CompletionRequest WITHOUT "think" in additional_params
1498        let completion_request = CompletionRequest {
1499            model: None,
1500            preamble: Some("You are a helpful assistant.".to_string()),
1501            chat_history: OneOrMany::one(CompletionMessage::User {
1502                content: OneOrMany::one(UserContent::Text(Text {
1503                    text: "Hello!".to_string(),
1504                })),
1505            }),
1506            documents: vec![],
1507            tools: vec![],
1508            temperature: Some(0.5),
1509            max_tokens: None,
1510            tool_choice: None,
1511            additional_params: None,
1512            output_schema: None,
1513        };
1514
1515        // Convert to OllamaCompletionRequest
1516        let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1517            .expect("Failed to create Ollama request");
1518
1519        // Serialize to JSON
1520        let serialized =
1521            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1522
1523        // Assert that "think" defaults to false and "keep_alive" is not present
1524        let expected = json!({
1525            "model": "llama3.2",
1526            "messages": [
1527                {
1528                    "role": "system",
1529                    "content": "You are a helpful assistant."
1530                },
1531                {
1532                    "role": "user",
1533                    "content": "Hello!"
1534                }
1535            ],
1536            "temperature": 0.5,
1537            "stream": false,
1538            "think": false,
1539            "options": {
1540                "temperature": 0.5
1541            }
1542        });
1543
1544        assert_eq!(serialized, expected);
1545    }
1546
1547    #[test]
1548    fn test_completion_request_with_output_schema() {
1549        use crate::OneOrMany;
1550        use crate::completion::Message as CompletionMessage;
1551        use crate::message::{Text, UserContent};
1552
1553        let schema: schemars::Schema = serde_json::from_value(json!({
1554            "type": "object",
1555            "properties": {
1556                "age": { "type": "integer" },
1557                "available": { "type": "boolean" }
1558            },
1559            "required": ["age", "available"]
1560        }))
1561        .expect("Failed to parse schema");
1562
1563        let completion_request = CompletionRequest {
1564            model: Some("llama3.1".to_string()),
1565            preamble: None,
1566            chat_history: OneOrMany::one(CompletionMessage::User {
1567                content: OneOrMany::one(UserContent::Text(Text {
1568                    text: "How old is Ollama?".to_string(),
1569                })),
1570            }),
1571            documents: vec![],
1572            tools: vec![],
1573            temperature: None,
1574            max_tokens: None,
1575            tool_choice: None,
1576            additional_params: None,
1577            output_schema: Some(schema),
1578        };
1579
1580        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1581            .expect("Failed to create Ollama request");
1582
1583        let serialized =
1584            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1585
1586        let format = serialized
1587            .get("format")
1588            .expect("format field should be present");
1589        assert_eq!(
1590            *format,
1591            json!({
1592                "type": "object",
1593                "properties": {
1594                    "age": { "type": "integer" },
1595                    "available": { "type": "boolean" }
1596                },
1597                "required": ["age", "available"]
1598            })
1599        );
1600    }
1601
1602    #[test]
1603    fn test_completion_request_without_output_schema() {
1604        use crate::OneOrMany;
1605        use crate::completion::Message as CompletionMessage;
1606        use crate::message::{Text, UserContent};
1607
1608        let completion_request = CompletionRequest {
1609            model: Some("llama3.1".to_string()),
1610            preamble: None,
1611            chat_history: OneOrMany::one(CompletionMessage::User {
1612                content: OneOrMany::one(UserContent::Text(Text {
1613                    text: "Hello!".to_string(),
1614                })),
1615            }),
1616            documents: vec![],
1617            tools: vec![],
1618            temperature: None,
1619            max_tokens: None,
1620            tool_choice: None,
1621            additional_params: None,
1622            output_schema: None,
1623        };
1624
1625        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1626            .expect("Failed to create Ollama request");
1627
1628        let serialized =
1629            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1630
1631        assert!(
1632            serialized.get("format").is_none(),
1633            "format field should be absent when output_schema is None"
1634        );
1635    }
1636
1637    #[test]
1638    fn test_client_initialization() {
1639        let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1640        let _client_from_builder = crate::providers::ollama::Client::builder()
1641            .api_key(Nothing)
1642            .build()
1643            .expect("Client::builder() failed");
1644    }
1645}