Skip to main content

rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust,ignore
5//! use rig::client::{Nothing, CompletionClient};
6//! use rig::completion::Prompt;
7//! use rig::providers::ollama;
8//!
9//! // Create a new Ollama client (defaults to http://localhost:11434)
10//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct
11//! let client: ollama::Client = ollama::Client::new(Nothing).unwrap();
12//!
13//! // Create an agent with a preamble
14//! let comedian_agent = client
15//!     .agent("qwen2.5:14b")
16//!     .preamble("You are a comedian here to entertain the user using humour and jokes.")
17//!     .build();
18//!
19//! // Prompt the agent and print the response
20//! let response = comedian_agent.prompt("Entertain me!").await?;
21//! println!("{response}");
22//!
23//! // Create an embedding model using the "all-minilm" model
24//! let emb_model = client.embedding_model("all-minilm", 384);
25//! let embeddings = emb_model.embed_texts(vec![
26//!     "Why is the sky blue?".to_owned(),
27//!     "Why is the grass green?".to_owned()
28//! ]).await?;
29//! println!("Embedding response: {:?}", embeddings);
30//!
31//! // Create an extractor if needed
32//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
33//! ```
34use crate::client::{
35    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42    OneOrMany,
43    completion::{self, CompletionError, CompletionRequest},
44    embeddings::{self, EmbeddingError},
45    json_utils, message,
46    message::{ImageDetail, Text},
47    streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57// ---------- Main Client ----------
58
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68    type Builder = OllamaBuilder;
69
70    const VERIFY_PATH: &'static str = "api/tags";
71
72    fn build<H>(
73        _: &crate::client::ClientBuilder<
74            Self::Builder,
75            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
76            H,
77        >,
78    ) -> http_client::Result<Self> {
79        Ok(Self)
80    }
81}
82
83impl<H> Capabilities<H> for OllamaExt {
84    type Completion = Capable<CompletionModel<H>>;
85    type Transcription = Nothing;
86    type Embeddings = Capable<EmbeddingModel<H>>;
87    type ModelListing = Nothing;
88    #[cfg(feature = "image")]
89    type ImageGeneration = Nothing;
90
91    #[cfg(feature = "audio")]
92    type AudioGeneration = Nothing;
93}
94
95impl DebugExt for OllamaExt {}
96
97impl ProviderBuilder for OllamaBuilder {
98    type Output = OllamaExt;
99    type ApiKey = Nothing;
100
101    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
102}
103
104pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
105pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
106
107impl ProviderClient for Client {
108    type Input = Nothing;
109
110    fn from_env() -> Self {
111        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
112
113        Self::builder()
114            .api_key(Nothing)
115            .base_url(&api_base)
116            .build()
117            .unwrap()
118    }
119
120    fn from_val(_: Self::Input) -> Self {
121        Self::builder().api_key(Nothing).build().unwrap()
122    }
123}
124
125// ---------- API Error and Response Structures ----------
126
127#[derive(Debug, Deserialize)]
128struct ApiErrorResponse {
129    message: String,
130}
131
132#[derive(Debug, Deserialize)]
133#[serde(untagged)]
134enum ApiResponse<T> {
135    Ok(T),
136    Err(ApiErrorResponse),
137}
138
139// ---------- Embedding API ----------
140
141pub const ALL_MINILM: &str = "all-minilm";
142pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
143
144fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
145    match identifier {
146        ALL_MINILM => Some(384),
147        NOMIC_EMBED_TEXT => Some(768),
148        _ => None,
149    }
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153pub struct EmbeddingResponse {
154    pub model: String,
155    pub embeddings: Vec<Vec<f64>>,
156    #[serde(default)]
157    pub total_duration: Option<u64>,
158    #[serde(default)]
159    pub load_duration: Option<u64>,
160    #[serde(default)]
161    pub prompt_eval_count: Option<u64>,
162}
163
164impl From<ApiErrorResponse> for EmbeddingError {
165    fn from(err: ApiErrorResponse) -> Self {
166        EmbeddingError::ProviderError(err.message)
167    }
168}
169
170impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
171    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
172        match value {
173            ApiResponse::Ok(response) => Ok(response),
174            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
175        }
176    }
177}
178
179// ---------- Embedding Model ----------
180
181#[derive(Clone)]
182pub struct EmbeddingModel<T = reqwest::Client> {
183    client: Client<T>,
184    pub model: String,
185    ndims: usize,
186}
187
188impl<T> EmbeddingModel<T> {
189    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
190        Self {
191            client,
192            model: model.into(),
193            ndims,
194        }
195    }
196
197    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
198        Self {
199            client,
200            model: model.into(),
201            ndims,
202        }
203    }
204}
205
206impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
207where
208    T: HttpClientExt + Clone + 'static,
209{
210    type Client = Client<T>;
211
212    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
213        let model = model.into();
214        let dims = dims
215            .or(model_dimensions_from_identifier(&model))
216            .unwrap_or_default();
217        Self::new(client.clone(), model, dims)
218    }
219
220    const MAX_DOCUMENTS: usize = 1024;
221    fn ndims(&self) -> usize {
222        self.ndims
223    }
224
225    async fn embed_texts(
226        &self,
227        documents: impl IntoIterator<Item = String>,
228    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
229        let docs: Vec<String> = documents.into_iter().collect();
230
231        let body = serde_json::to_vec(&json!({
232            "model": self.model,
233            "input": docs
234        }))?;
235
236        let req = self
237            .client
238            .post("api/embed")?
239            .body(body)
240            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
241
242        let response = self.client.send(req).await?;
243
244        if !response.status().is_success() {
245            let text = http_client::text(response).await?;
246            return Err(EmbeddingError::ProviderError(text));
247        }
248
249        let bytes: Vec<u8> = response.into_body().await?;
250
251        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
252
253        if api_resp.embeddings.len() != docs.len() {
254            return Err(EmbeddingError::ResponseError(
255                "Number of returned embeddings does not match input".into(),
256            ));
257        }
258        Ok(api_resp
259            .embeddings
260            .into_iter()
261            .zip(docs.into_iter())
262            .map(|(vec, document)| embeddings::Embedding { document, vec })
263            .collect())
264    }
265}
266
267// ---------- Completion API ----------
268
269pub const LLAMA3_2: &str = "llama3.2";
270pub const LLAVA: &str = "llava";
271pub const MISTRAL: &str = "mistral";
272
273#[derive(Debug, Serialize, Deserialize)]
274pub struct CompletionResponse {
275    pub model: String,
276    pub created_at: String,
277    pub message: Message,
278    pub done: bool,
279    #[serde(default)]
280    pub done_reason: Option<String>,
281    #[serde(default)]
282    pub total_duration: Option<u64>,
283    #[serde(default)]
284    pub load_duration: Option<u64>,
285    #[serde(default)]
286    pub prompt_eval_count: Option<u64>,
287    #[serde(default)]
288    pub prompt_eval_duration: Option<u64>,
289    #[serde(default)]
290    pub eval_count: Option<u64>,
291    #[serde(default)]
292    pub eval_duration: Option<u64>,
293}
294impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
295    type Error = CompletionError;
296    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
297        match resp.message {
298            // Process only if an assistant message is present.
299            Message::Assistant {
300                content,
301                thinking,
302                tool_calls,
303                ..
304            } => {
305                let mut assistant_contents = Vec::new();
306                // Add the assistant's text content if any.
307                if !content.is_empty() {
308                    assistant_contents.push(completion::AssistantContent::text(&content));
309                }
310                // Process tool_calls following Ollama's chat response definition.
311                // Each ToolCall has an id, a type, and a function field.
312                for tc in tool_calls.iter() {
313                    assistant_contents.push(completion::AssistantContent::tool_call(
314                        tc.function.name.clone(),
315                        tc.function.name.clone(),
316                        tc.function.arguments.clone(),
317                    ));
318                }
319                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
320                    CompletionError::ResponseError("No content provided".to_owned())
321                })?;
322                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
323                let completion_tokens = resp.eval_count.unwrap_or(0);
324
325                let raw_response = CompletionResponse {
326                    model: resp.model,
327                    created_at: resp.created_at,
328                    done: resp.done,
329                    done_reason: resp.done_reason,
330                    total_duration: resp.total_duration,
331                    load_duration: resp.load_duration,
332                    prompt_eval_count: resp.prompt_eval_count,
333                    prompt_eval_duration: resp.prompt_eval_duration,
334                    eval_count: resp.eval_count,
335                    eval_duration: resp.eval_duration,
336                    message: Message::Assistant {
337                        content,
338                        thinking,
339                        images: None,
340                        name: None,
341                        tool_calls,
342                    },
343                };
344
345                Ok(completion::CompletionResponse {
346                    choice,
347                    usage: Usage {
348                        input_tokens: prompt_tokens,
349                        output_tokens: completion_tokens,
350                        total_tokens: prompt_tokens + completion_tokens,
351                        cached_input_tokens: 0,
352                    },
353                    raw_response,
354                    message_id: None,
355                })
356            }
357            _ => Err(CompletionError::ResponseError(
358                "Chat response does not include an assistant message".into(),
359            )),
360        }
361    }
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365pub(super) struct OllamaCompletionRequest {
366    model: String,
367    pub messages: Vec<Message>,
368    #[serde(skip_serializing_if = "Option::is_none")]
369    temperature: Option<f64>,
370    #[serde(skip_serializing_if = "Vec::is_empty")]
371    tools: Vec<ToolDefinition>,
372    pub stream: bool,
373    think: bool,
374    #[serde(skip_serializing_if = "Option::is_none")]
375    max_tokens: Option<u64>,
376    #[serde(skip_serializing_if = "Option::is_none")]
377    keep_alive: Option<String>,
378    #[serde(skip_serializing_if = "Option::is_none")]
379    format: Option<schemars::Schema>,
380    options: serde_json::Value,
381}
382
383impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
384    type Error = CompletionError;
385
386    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
387        let model = req.model.clone().unwrap_or_else(|| model.to_string());
388        if req.tool_choice.is_some() {
389            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
390        }
391        // Build up the order of messages (context, chat_history, prompt)
392        let mut partial_history = vec![];
393        if let Some(docs) = req.normalized_documents() {
394            partial_history.push(docs);
395        }
396        partial_history.extend(req.chat_history);
397
398        // Add preamble to chat history (if available)
399        let mut full_history: Vec<Message> = match &req.preamble {
400            Some(preamble) => vec![Message::system(preamble)],
401            None => vec![],
402        };
403
404        // Convert and extend the rest of the history
405        full_history.extend(
406            partial_history
407                .into_iter()
408                .map(message::Message::try_into)
409                .collect::<Result<Vec<Vec<Message>>, _>>()?
410                .into_iter()
411                .flatten()
412                .collect::<Vec<_>>(),
413        );
414
415        let mut think = false;
416        let mut keep_alive: Option<String> = None;
417
418        let options = if let Some(mut extra) = req.additional_params {
419            // Extract top-level parameters that should not be in `options`
420            if let Some(obj) = extra.as_object_mut() {
421                // Extract `think` parameter
422                if let Some(think_val) = obj.remove("think") {
423                    think = think_val.as_bool().ok_or_else(|| {
424                        CompletionError::RequestError("`think` must be a bool".into())
425                    })?;
426                }
427
428                // Extract `keep_alive` parameter
429                if let Some(keep_alive_val) = obj.remove("keep_alive") {
430                    keep_alive = Some(
431                        keep_alive_val
432                            .as_str()
433                            .ok_or_else(|| {
434                                CompletionError::RequestError(
435                                    "`keep_alive` must be a string".into(),
436                                )
437                            })?
438                            .to_string(),
439                    );
440                }
441            }
442
443            json_utils::merge(json!({ "temperature": req.temperature }), extra)
444        } else {
445            json!({ "temperature": req.temperature })
446        };
447
448        Ok(Self {
449            model: model.to_string(),
450            messages: full_history,
451            temperature: req.temperature,
452            max_tokens: req.max_tokens,
453            stream: false,
454            think,
455            keep_alive,
456            format: req.output_schema,
457            tools: req
458                .tools
459                .clone()
460                .into_iter()
461                .map(ToolDefinition::from)
462                .collect::<Vec<_>>(),
463            options,
464        })
465    }
466}
467
468#[derive(Clone)]
469pub struct CompletionModel<T = reqwest::Client> {
470    client: Client<T>,
471    pub model: String,
472}
473
474impl<T> CompletionModel<T> {
475    pub fn new(client: Client<T>, model: &str) -> Self {
476        Self {
477            client,
478            model: model.to_owned(),
479        }
480    }
481}
482
483// ---------- CompletionModel Implementation ----------
484
485#[derive(Clone, Serialize, Deserialize, Debug)]
486pub struct StreamingCompletionResponse {
487    pub done_reason: Option<String>,
488    pub total_duration: Option<u64>,
489    pub load_duration: Option<u64>,
490    pub prompt_eval_count: Option<u64>,
491    pub prompt_eval_duration: Option<u64>,
492    pub eval_count: Option<u64>,
493    pub eval_duration: Option<u64>,
494}
495
496impl GetTokenUsage for StreamingCompletionResponse {
497    fn token_usage(&self) -> Option<crate::completion::Usage> {
498        let mut usage = crate::completion::Usage::new();
499        let input_tokens = self.prompt_eval_count.unwrap_or_default();
500        let output_tokens = self.eval_count.unwrap_or_default();
501        usage.input_tokens = input_tokens;
502        usage.output_tokens = output_tokens;
503        usage.total_tokens = input_tokens + output_tokens;
504
505        Some(usage)
506    }
507}
508
509impl<T> completion::CompletionModel for CompletionModel<T>
510where
511    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
512{
513    type Response = CompletionResponse;
514    type StreamingResponse = StreamingCompletionResponse;
515
516    type Client = Client<T>;
517
518    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
519        Self::new(client.clone(), model.into().as_str())
520    }
521
522    async fn completion(
523        &self,
524        completion_request: CompletionRequest,
525    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
526        let span = if tracing::Span::current().is_disabled() {
527            info_span!(
528                target: "rig::completions",
529                "chat",
530                gen_ai.operation.name = "chat",
531                gen_ai.provider.name = "ollama",
532                gen_ai.request.model = self.model,
533                gen_ai.system_instructions = tracing::field::Empty,
534                gen_ai.response.id = tracing::field::Empty,
535                gen_ai.response.model = tracing::field::Empty,
536                gen_ai.usage.output_tokens = tracing::field::Empty,
537                gen_ai.usage.input_tokens = tracing::field::Empty,
538            )
539        } else {
540            tracing::Span::current()
541        };
542
543        span.record("gen_ai.system_instructions", &completion_request.preamble);
544        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
545
546        if tracing::enabled!(tracing::Level::TRACE) {
547            tracing::trace!(target: "rig::completions",
548                "Ollama completion request: {}",
549                serde_json::to_string_pretty(&request)?
550            );
551        }
552
553        let body = serde_json::to_vec(&request)?;
554
555        let req = self
556            .client
557            .post("api/chat")?
558            .body(body)
559            .map_err(http_client::Error::from)?;
560
561        let async_block = async move {
562            let response = self.client.send::<_, Bytes>(req).await?;
563            let status = response.status();
564            let response_body = response.into_body().into_future().await?.to_vec();
565
566            if !status.is_success() {
567                return Err(CompletionError::ProviderError(
568                    String::from_utf8_lossy(&response_body).to_string(),
569                ));
570            }
571
572            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
573            let span = tracing::Span::current();
574            span.record("gen_ai.response.model_name", &response.model);
575            span.record(
576                "gen_ai.usage.input_tokens",
577                response.prompt_eval_count.unwrap_or_default(),
578            );
579            span.record(
580                "gen_ai.usage.output_tokens",
581                response.eval_count.unwrap_or_default(),
582            );
583
584            if tracing::enabled!(tracing::Level::TRACE) {
585                tracing::trace!(target: "rig::completions",
586                    "Ollama completion response: {}",
587                    serde_json::to_string_pretty(&response)?
588                );
589            }
590
591            let response: completion::CompletionResponse<CompletionResponse> =
592                response.try_into()?;
593
594            Ok(response)
595        };
596
597        tracing::Instrument::instrument(async_block, span).await
598    }
599
600    async fn stream(
601        &self,
602        request: CompletionRequest,
603    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
604    {
605        let span = if tracing::Span::current().is_disabled() {
606            info_span!(
607                target: "rig::completions",
608                "chat_streaming",
609                gen_ai.operation.name = "chat_streaming",
610                gen_ai.provider.name = "ollama",
611                gen_ai.request.model = self.model,
612                gen_ai.system_instructions = tracing::field::Empty,
613                gen_ai.response.id = tracing::field::Empty,
614                gen_ai.response.model = self.model,
615                gen_ai.usage.output_tokens = tracing::field::Empty,
616                gen_ai.usage.input_tokens = tracing::field::Empty,
617            )
618        } else {
619            tracing::Span::current()
620        };
621
622        span.record("gen_ai.system_instructions", &request.preamble);
623
624        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
625        request.stream = true;
626
627        if tracing::enabled!(tracing::Level::TRACE) {
628            tracing::trace!(target: "rig::completions",
629                "Ollama streaming completion request: {}",
630                serde_json::to_string_pretty(&request)?
631            );
632        }
633
634        let body = serde_json::to_vec(&request)?;
635
636        let req = self
637            .client
638            .post("api/chat")?
639            .body(body)
640            .map_err(http_client::Error::from)?;
641
642        let response = self.client.send_streaming(req).await?;
643        let status = response.status();
644        let mut byte_stream = response.into_body();
645
646        if !status.is_success() {
647            return Err(CompletionError::ProviderError(format!(
648                "Got error status code trying to send a request to Ollama: {status}"
649            )));
650        }
651
652        let stream = try_stream! {
653            let span = tracing::Span::current();
654            let mut tool_calls_final = Vec::new();
655            let mut text_response = String::new();
656            let mut thinking_response = String::new();
657
658            while let Some(chunk) = byte_stream.next().await {
659                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
660
661                for line in bytes.split(|&b| b == b'\n') {
662                    if line.is_empty() {
663                        continue;
664                    }
665
666                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
667
668                    let response: CompletionResponse = serde_json::from_slice(line)?;
669
670                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
671                        if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
672                            thinking_response += &thinking_content;
673                            yield RawStreamingChoice::ReasoningDelta {
674                                id: None,
675                                reasoning: thinking_content,
676                            };
677                        }
678
679                        if !content.is_empty() {
680                            text_response += &content;
681                            yield RawStreamingChoice::Message(content);
682                        }
683
684                        for tool_call in tool_calls {
685                            tool_calls_final.push(tool_call.clone());
686                            yield RawStreamingChoice::ToolCall(
687                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
688                            );
689                        }
690                    }
691
692                    if response.done {
693                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
694                        span.record("gen_ai.usage.output_tokens", response.eval_count);
695                        let message = Message::Assistant {
696                            content: text_response.clone(),
697                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
698                            images: None,
699                            name: None,
700                            tool_calls: tool_calls_final.clone()
701                        };
702                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
703                        yield RawStreamingChoice::FinalResponse(
704                            StreamingCompletionResponse {
705                                total_duration: response.total_duration,
706                                load_duration: response.load_duration,
707                                prompt_eval_count: response.prompt_eval_count,
708                                prompt_eval_duration: response.prompt_eval_duration,
709                                eval_count: response.eval_count,
710                                eval_duration: response.eval_duration,
711                                done_reason: response.done_reason,
712                            }
713                        );
714                        break;
715                    }
716                }
717            }
718        }.instrument(span);
719
720        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
721            stream,
722        )))
723    }
724}
725
726// ---------- Tool Definition Conversion ----------
727
728/// Ollama-required tool definition format.
729#[derive(Clone, Debug, Deserialize, Serialize)]
730pub struct ToolDefinition {
731    #[serde(rename = "type")]
732    pub type_field: String, // Fixed as "function"
733    pub function: completion::ToolDefinition,
734}
735
736/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
737impl From<crate::completion::ToolDefinition> for ToolDefinition {
738    fn from(tool: crate::completion::ToolDefinition) -> Self {
739        ToolDefinition {
740            type_field: "function".to_owned(),
741            function: completion::ToolDefinition {
742                name: tool.name,
743                description: tool.description,
744                parameters: tool.parameters,
745            },
746        }
747    }
748}
749
750#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
751pub struct ToolCall {
752    #[serde(default, rename = "type")]
753    pub r#type: ToolType,
754    pub function: Function,
755}
756#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
757#[serde(rename_all = "lowercase")]
758pub enum ToolType {
759    #[default]
760    Function,
761}
762#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
763pub struct Function {
764    pub name: String,
765    pub arguments: Value,
766}
767
768// ---------- Provider Message Definition ----------
769
770#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
771#[serde(tag = "role", rename_all = "lowercase")]
772pub enum Message {
773    User {
774        content: String,
775        #[serde(skip_serializing_if = "Option::is_none")]
776        images: Option<Vec<String>>,
777        #[serde(skip_serializing_if = "Option::is_none")]
778        name: Option<String>,
779    },
780    Assistant {
781        #[serde(default)]
782        content: String,
783        #[serde(skip_serializing_if = "Option::is_none")]
784        thinking: Option<String>,
785        #[serde(skip_serializing_if = "Option::is_none")]
786        images: Option<Vec<String>>,
787        #[serde(skip_serializing_if = "Option::is_none")]
788        name: Option<String>,
789        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
790        tool_calls: Vec<ToolCall>,
791    },
792    System {
793        content: String,
794        #[serde(skip_serializing_if = "Option::is_none")]
795        images: Option<Vec<String>>,
796        #[serde(skip_serializing_if = "Option::is_none")]
797        name: Option<String>,
798    },
799    #[serde(rename = "tool")]
800    ToolResult {
801        #[serde(rename = "tool_name")]
802        name: String,
803        content: String,
804    },
805}
806
807/// -----------------------------
808/// Provider Message Conversions
809/// -----------------------------
810/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
811/// (Only User and Assistant variants are supported.)
812impl TryFrom<crate::message::Message> for Vec<Message> {
813    type Error = crate::message::MessageError;
814    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
815        use crate::message::Message as InternalMessage;
816        match internal_msg {
817            InternalMessage::User { content, .. } => {
818                let (tool_results, other_content): (Vec<_>, Vec<_>) =
819                    content.into_iter().partition(|content| {
820                        matches!(content, crate::message::UserContent::ToolResult(_))
821                    });
822
823                if !tool_results.is_empty() {
824                    tool_results
825                        .into_iter()
826                        .map(|content| match content {
827                            crate::message::UserContent::ToolResult(
828                                crate::message::ToolResult { id, content, .. },
829                            ) => {
830                                // Ollama expects a single string for tool results, so we concatenate
831                                let content_string = content
832                                    .into_iter()
833                                    .map(|content| match content {
834                                        crate::message::ToolResultContent::Text(text) => text.text,
835                                        _ => "[Non-text content]".to_string(),
836                                    })
837                                    .collect::<Vec<_>>()
838                                    .join("\n");
839
840                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
841                                    name: id,
842                                    content: content_string,
843                                })
844                            }
845                            _ => unreachable!(),
846                        })
847                        .collect::<Result<Vec<_>, _>>()
848                } else {
849                    // Ollama requires separate text content and images array
850                    let (texts, images) = other_content.into_iter().fold(
851                        (Vec::new(), Vec::new()),
852                        |(mut texts, mut images), content| {
853                            match content {
854                                crate::message::UserContent::Text(crate::message::Text {
855                                    text,
856                                }) => texts.push(text),
857                                crate::message::UserContent::Image(crate::message::Image {
858                                    data: DocumentSourceKind::Base64(data),
859                                    ..
860                                }) => images.push(data),
861                                crate::message::UserContent::Document(
862                                    crate::message::Document {
863                                        data:
864                                            DocumentSourceKind::Base64(data)
865                                            | DocumentSourceKind::String(data),
866                                        ..
867                                    },
868                                ) => texts.push(data),
869                                _ => {} // Audio not supported by Ollama
870                            }
871                            (texts, images)
872                        },
873                    );
874
875                    Ok(vec![Message::User {
876                        content: texts.join(" "),
877                        images: if images.is_empty() {
878                            None
879                        } else {
880                            Some(
881                                images
882                                    .into_iter()
883                                    .map(|x| x.to_string())
884                                    .collect::<Vec<String>>(),
885                            )
886                        },
887                        name: None,
888                    }])
889                }
890            }
891            InternalMessage::Assistant { content, .. } => {
892                let mut thinking: Option<String> = None;
893                let mut text_content = Vec::new();
894                let mut tool_calls = Vec::new();
895
896                for content in content.into_iter() {
897                    match content {
898                        crate::message::AssistantContent::Text(text) => {
899                            text_content.push(text.text)
900                        }
901                        crate::message::AssistantContent::ToolCall(tool_call) => {
902                            tool_calls.push(tool_call)
903                        }
904                        crate::message::AssistantContent::Reasoning(reasoning) => {
905                            let display = reasoning.display_text();
906                            if !display.is_empty() {
907                                thinking = Some(display);
908                            }
909                        }
910                        crate::message::AssistantContent::Image(_) => {
911                            return Err(crate::message::MessageError::ConversionError(
912                                "Ollama currently doesn't support images.".into(),
913                            ));
914                        }
915                    }
916                }
917
918                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
919                //  so either `content` or `tool_calls` will have some content.
920                Ok(vec![Message::Assistant {
921                    content: text_content.join(" "),
922                    thinking,
923                    images: None,
924                    name: None,
925                    tool_calls: tool_calls
926                        .into_iter()
927                        .map(|tool_call| tool_call.into())
928                        .collect::<Vec<_>>(),
929                }])
930            }
931        }
932    }
933}
934
935/// Conversion from provider Message to a completion message.
936/// This is needed so that responses can be converted back into chat history.
937impl From<Message> for crate::completion::Message {
938    fn from(msg: Message) -> Self {
939        match msg {
940            Message::User { content, .. } => crate::completion::Message::User {
941                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
942                    text: content,
943                })),
944            },
945            Message::Assistant {
946                content,
947                tool_calls,
948                ..
949            } => {
950                let mut assistant_contents =
951                    vec![crate::completion::message::AssistantContent::Text(Text {
952                        text: content,
953                    })];
954                for tc in tool_calls {
955                    assistant_contents.push(
956                        crate::completion::message::AssistantContent::tool_call(
957                            tc.function.name.clone(),
958                            tc.function.name,
959                            tc.function.arguments,
960                        ),
961                    );
962                }
963                crate::completion::Message::Assistant {
964                    id: None,
965                    content: OneOrMany::many(assistant_contents).unwrap(),
966                }
967            }
968            // System and ToolResult are converted to User message as needed.
969            Message::System { content, .. } => crate::completion::Message::User {
970                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
971                    text: content,
972                })),
973            },
974            Message::ToolResult { name, content } => crate::completion::Message::User {
975                content: OneOrMany::one(message::UserContent::tool_result(
976                    name,
977                    OneOrMany::one(message::ToolResultContent::text(content)),
978                )),
979            },
980        }
981    }
982}
983
984impl Message {
985    /// Constructs a system message.
986    pub fn system(content: &str) -> Self {
987        Message::System {
988            content: content.to_owned(),
989            images: None,
990            name: None,
991        }
992    }
993}
994
995// ---------- Additional Message Types ----------
996
997impl From<crate::message::ToolCall> for ToolCall {
998    fn from(tool_call: crate::message::ToolCall) -> Self {
999        Self {
1000            r#type: ToolType::Function,
1001            function: Function {
1002                name: tool_call.function.name,
1003                arguments: tool_call.function.arguments,
1004            },
1005        }
1006    }
1007}
1008
1009#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1010pub struct SystemContent {
1011    #[serde(default)]
1012    r#type: SystemContentType,
1013    text: String,
1014}
1015
1016#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1017#[serde(rename_all = "lowercase")]
1018pub enum SystemContentType {
1019    #[default]
1020    Text,
1021}
1022
1023impl From<String> for SystemContent {
1024    fn from(s: String) -> Self {
1025        SystemContent {
1026            r#type: SystemContentType::default(),
1027            text: s,
1028        }
1029    }
1030}
1031
1032impl FromStr for SystemContent {
1033    type Err = std::convert::Infallible;
1034    fn from_str(s: &str) -> Result<Self, Self::Err> {
1035        Ok(SystemContent {
1036            r#type: SystemContentType::default(),
1037            text: s.to_string(),
1038        })
1039    }
1040}
1041
1042#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1043pub struct AssistantContent {
1044    pub text: String,
1045}
1046
1047impl FromStr for AssistantContent {
1048    type Err = std::convert::Infallible;
1049    fn from_str(s: &str) -> Result<Self, Self::Err> {
1050        Ok(AssistantContent { text: s.to_owned() })
1051    }
1052}
1053
1054#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1055#[serde(tag = "type", rename_all = "lowercase")]
1056pub enum UserContent {
1057    Text { text: String },
1058    Image { image_url: ImageUrl },
1059    // Audio variant removed as Ollama API does not support audio input.
1060}
1061
1062impl FromStr for UserContent {
1063    type Err = std::convert::Infallible;
1064    fn from_str(s: &str) -> Result<Self, Self::Err> {
1065        Ok(UserContent::Text { text: s.to_owned() })
1066    }
1067}
1068
1069#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1070pub struct ImageUrl {
1071    pub url: String,
1072    #[serde(default)]
1073    pub detail: ImageDetail,
1074}
1075
1076// =================================================================
1077// Tests
1078// =================================================================
1079
1080#[cfg(test)]
1081mod tests {
1082    use super::*;
1083    use serde_json::json;
1084
1085    // Test deserialization and conversion for the /api/chat endpoint.
1086    #[tokio::test]
1087    async fn test_chat_completion() {
1088        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1089        let sample_chat_response = json!({
1090            "model": "llama3.2",
1091            "created_at": "2023-08-04T19:22:45.499127Z",
1092            "message": {
1093                "role": "assistant",
1094                "content": "The sky is blue because of Rayleigh scattering.",
1095                "images": null,
1096                "tool_calls": [
1097                    {
1098                        "type": "function",
1099                        "function": {
1100                            "name": "get_current_weather",
1101                            "arguments": {
1102                                "location": "San Francisco, CA",
1103                                "format": "celsius"
1104                            }
1105                        }
1106                    }
1107                ]
1108            },
1109            "done": true,
1110            "total_duration": 8000000000u64,
1111            "load_duration": 6000000u64,
1112            "prompt_eval_count": 61u64,
1113            "prompt_eval_duration": 400000000u64,
1114            "eval_count": 468u64,
1115            "eval_duration": 7700000000u64
1116        });
1117        let sample_text = sample_chat_response.to_string();
1118
1119        let chat_resp: CompletionResponse =
1120            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1121        let conv: completion::CompletionResponse<CompletionResponse> =
1122            chat_resp.try_into().unwrap();
1123        assert!(
1124            !conv.choice.is_empty(),
1125            "Expected non-empty choice in chat response"
1126        );
1127    }
1128
1129    // Test conversion from provider Message to completion Message.
1130    #[test]
1131    fn test_message_conversion() {
1132        // Construct a provider Message (User variant with String content).
1133        let provider_msg = Message::User {
1134            content: "Test message".to_owned(),
1135            images: None,
1136            name: None,
1137        };
1138        // Convert it into a completion::Message.
1139        let comp_msg: crate::completion::Message = provider_msg.into();
1140        match comp_msg {
1141            crate::completion::Message::User { content } => {
1142                // Assume OneOrMany<T> has a method first() to access the first element.
1143                let first_content = content.first();
1144                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1145                match first_content {
1146                    crate::completion::message::UserContent::Text(text_struct) => {
1147                        assert_eq!(text_struct.text, "Test message");
1148                    }
1149                    _ => panic!("Expected text content in conversion"),
1150                }
1151            }
1152            _ => panic!("Conversion from provider Message to completion Message failed"),
1153        }
1154    }
1155
1156    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1157    #[test]
1158    fn test_tool_definition_conversion() {
1159        // Internal tool definition from the completion module.
1160        let internal_tool = crate::completion::ToolDefinition {
1161            name: "get_current_weather".to_owned(),
1162            description: "Get the current weather for a location".to_owned(),
1163            parameters: json!({
1164                "type": "object",
1165                "properties": {
1166                    "location": {
1167                        "type": "string",
1168                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1169                    },
1170                    "format": {
1171                        "type": "string",
1172                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1173                        "enum": ["celsius", "fahrenheit"]
1174                    }
1175                },
1176                "required": ["location", "format"]
1177            }),
1178        };
1179        // Convert internal tool to Ollama's tool definition.
1180        let ollama_tool: ToolDefinition = internal_tool.into();
1181        assert_eq!(ollama_tool.type_field, "function");
1182        assert_eq!(ollama_tool.function.name, "get_current_weather");
1183        assert_eq!(
1184            ollama_tool.function.description,
1185            "Get the current weather for a location"
1186        );
1187        // Check JSON fields in parameters.
1188        let params = &ollama_tool.function.parameters;
1189        assert_eq!(params["properties"]["location"]["type"], "string");
1190    }
1191
1192    // Test deserialization of chat response with thinking content
1193    #[tokio::test]
1194    async fn test_chat_completion_with_thinking() {
1195        let sample_response = json!({
1196            "model": "qwen-thinking",
1197            "created_at": "2023-08-04T19:22:45.499127Z",
1198            "message": {
1199                "role": "assistant",
1200                "content": "The answer is 42.",
1201                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1202                "images": null,
1203                "tool_calls": []
1204            },
1205            "done": true,
1206            "total_duration": 8000000000u64,
1207            "load_duration": 6000000u64,
1208            "prompt_eval_count": 61u64,
1209            "prompt_eval_duration": 400000000u64,
1210            "eval_count": 468u64,
1211            "eval_duration": 7700000000u64
1212        });
1213
1214        let chat_resp: CompletionResponse =
1215            serde_json::from_value(sample_response).expect("Failed to deserialize");
1216
1217        // Verify thinking field is present
1218        if let Message::Assistant {
1219            thinking, content, ..
1220        } = &chat_resp.message
1221        {
1222            assert_eq!(
1223                thinking.as_ref().unwrap(),
1224                "Let me think about this carefully. The question asks for the meaning of life..."
1225            );
1226            assert_eq!(content, "The answer is 42.");
1227        } else {
1228            panic!("Expected Assistant message");
1229        }
1230    }
1231
1232    // Test deserialization of chat response without thinking content
1233    #[tokio::test]
1234    async fn test_chat_completion_without_thinking() {
1235        let sample_response = json!({
1236            "model": "llama3.2",
1237            "created_at": "2023-08-04T19:22:45.499127Z",
1238            "message": {
1239                "role": "assistant",
1240                "content": "Hello!",
1241                "images": null,
1242                "tool_calls": []
1243            },
1244            "done": true,
1245            "total_duration": 8000000000u64,
1246            "load_duration": 6000000u64,
1247            "prompt_eval_count": 10u64,
1248            "prompt_eval_duration": 400000000u64,
1249            "eval_count": 5u64,
1250            "eval_duration": 7700000000u64
1251        });
1252
1253        let chat_resp: CompletionResponse =
1254            serde_json::from_value(sample_response).expect("Failed to deserialize");
1255
1256        // Verify thinking field is None when not provided
1257        if let Message::Assistant {
1258            thinking, content, ..
1259        } = &chat_resp.message
1260        {
1261            assert!(thinking.is_none());
1262            assert_eq!(content, "Hello!");
1263        } else {
1264            panic!("Expected Assistant message");
1265        }
1266    }
1267
1268    // Test deserialization of streaming response with thinking content
1269    #[test]
1270    fn test_streaming_response_with_thinking() {
1271        let sample_chunk = json!({
1272            "model": "qwen-thinking",
1273            "created_at": "2023-08-04T19:22:45.499127Z",
1274            "message": {
1275                "role": "assistant",
1276                "content": "",
1277                "thinking": "Analyzing the problem...",
1278                "images": null,
1279                "tool_calls": []
1280            },
1281            "done": false
1282        });
1283
1284        let chunk: CompletionResponse =
1285            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1286
1287        if let Message::Assistant {
1288            thinking, content, ..
1289        } = &chunk.message
1290        {
1291            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1292            assert_eq!(content, "");
1293        } else {
1294            panic!("Expected Assistant message");
1295        }
1296    }
1297
1298    // Test message conversion with thinking content
1299    #[test]
1300    fn test_message_conversion_with_thinking() {
1301        // Create an internal message with reasoning content
1302        let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1303
1304        let internal_msg = crate::message::Message::Assistant {
1305            id: None,
1306            content: crate::OneOrMany::many(vec![
1307                crate::message::AssistantContent::Reasoning(reasoning_content),
1308                crate::message::AssistantContent::Text(crate::message::Text {
1309                    text: "The answer is X".to_string(),
1310                }),
1311            ])
1312            .unwrap(),
1313        };
1314
1315        // Convert to provider Message
1316        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1317        assert_eq!(provider_msgs.len(), 1);
1318
1319        if let Message::Assistant {
1320            thinking, content, ..
1321        } = &provider_msgs[0]
1322        {
1323            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1324            assert_eq!(content, "The answer is X");
1325        } else {
1326            panic!("Expected Assistant message with thinking");
1327        }
1328    }
1329
1330    // Test empty thinking content is handled correctly
1331    #[test]
1332    fn test_empty_thinking_content() {
1333        let sample_response = json!({
1334            "model": "llama3.2",
1335            "created_at": "2023-08-04T19:22:45.499127Z",
1336            "message": {
1337                "role": "assistant",
1338                "content": "Response",
1339                "thinking": "",
1340                "images": null,
1341                "tool_calls": []
1342            },
1343            "done": true,
1344            "total_duration": 8000000000u64,
1345            "load_duration": 6000000u64,
1346            "prompt_eval_count": 10u64,
1347            "prompt_eval_duration": 400000000u64,
1348            "eval_count": 5u64,
1349            "eval_duration": 7700000000u64
1350        });
1351
1352        let chat_resp: CompletionResponse =
1353            serde_json::from_value(sample_response).expect("Failed to deserialize");
1354
1355        if let Message::Assistant {
1356            thinking, content, ..
1357        } = &chat_resp.message
1358        {
1359            // Empty string should still deserialize as Some("")
1360            assert_eq!(thinking.as_ref().unwrap(), "");
1361            assert_eq!(content, "Response");
1362        } else {
1363            panic!("Expected Assistant message");
1364        }
1365    }
1366
1367    // Test thinking with tool calls
1368    #[test]
1369    fn test_thinking_with_tool_calls() {
1370        let sample_response = json!({
1371            "model": "qwen-thinking",
1372            "created_at": "2023-08-04T19:22:45.499127Z",
1373            "message": {
1374                "role": "assistant",
1375                "content": "Let me check the weather.",
1376                "thinking": "User wants weather info, I should use the weather tool",
1377                "images": null,
1378                "tool_calls": [
1379                    {
1380                        "type": "function",
1381                        "function": {
1382                            "name": "get_weather",
1383                            "arguments": {
1384                                "location": "San Francisco"
1385                            }
1386                        }
1387                    }
1388                ]
1389            },
1390            "done": true,
1391            "total_duration": 8000000000u64,
1392            "load_duration": 6000000u64,
1393            "prompt_eval_count": 30u64,
1394            "prompt_eval_duration": 400000000u64,
1395            "eval_count": 50u64,
1396            "eval_duration": 7700000000u64
1397        });
1398
1399        let chat_resp: CompletionResponse =
1400            serde_json::from_value(sample_response).expect("Failed to deserialize");
1401
1402        if let Message::Assistant {
1403            thinking,
1404            content,
1405            tool_calls,
1406            ..
1407        } = &chat_resp.message
1408        {
1409            assert_eq!(
1410                thinking.as_ref().unwrap(),
1411                "User wants weather info, I should use the weather tool"
1412            );
1413            assert_eq!(content, "Let me check the weather.");
1414            assert_eq!(tool_calls.len(), 1);
1415            assert_eq!(tool_calls[0].function.name, "get_weather");
1416        } else {
1417            panic!("Expected Assistant message with thinking and tool calls");
1418        }
1419    }
1420
1421    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1422    #[test]
1423    fn test_completion_request_with_think_param() {
1424        use crate::OneOrMany;
1425        use crate::completion::Message as CompletionMessage;
1426        use crate::message::{Text, UserContent};
1427
1428        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1429        let completion_request = CompletionRequest {
1430            model: None,
1431            preamble: Some("You are a helpful assistant.".to_string()),
1432            chat_history: OneOrMany::one(CompletionMessage::User {
1433                content: OneOrMany::one(UserContent::Text(Text {
1434                    text: "What is 2 + 2?".to_string(),
1435                })),
1436            }),
1437            documents: vec![],
1438            tools: vec![],
1439            temperature: Some(0.7),
1440            max_tokens: Some(1024),
1441            tool_choice: None,
1442            additional_params: Some(json!({
1443                "think": true,
1444                "keep_alive": "-1m",
1445                "num_ctx": 4096
1446            })),
1447            output_schema: None,
1448        };
1449
1450        // Convert to OllamaCompletionRequest
1451        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1452            .expect("Failed to create Ollama request");
1453
1454        // Serialize to JSON
1455        let serialized =
1456            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1457
1458        // Assert equality with expected JSON
1459        // - "tools" is skipped when empty (skip_serializing_if)
1460        // - "think" should be a top-level boolean, NOT in options
1461        // - "keep_alive" should be a top-level string, NOT in options
1462        // - "num_ctx" should be in options (it's a model parameter)
1463        let expected = json!({
1464            "model": "qwen3:8b",
1465            "messages": [
1466                {
1467                    "role": "system",
1468                    "content": "You are a helpful assistant."
1469                },
1470                {
1471                    "role": "user",
1472                    "content": "What is 2 + 2?"
1473                }
1474            ],
1475            "temperature": 0.7,
1476            "stream": false,
1477            "think": true,
1478            "max_tokens": 1024,
1479            "keep_alive": "-1m",
1480            "options": {
1481                "temperature": 0.7,
1482                "num_ctx": 4096
1483            }
1484        });
1485
1486        assert_eq!(serialized, expected);
1487    }
1488
1489    // Test that `think` defaults to false when not specified
1490    #[test]
1491    fn test_completion_request_with_think_false_default() {
1492        use crate::OneOrMany;
1493        use crate::completion::Message as CompletionMessage;
1494        use crate::message::{Text, UserContent};
1495
1496        // Create a CompletionRequest WITHOUT "think" in additional_params
1497        let completion_request = CompletionRequest {
1498            model: None,
1499            preamble: Some("You are a helpful assistant.".to_string()),
1500            chat_history: OneOrMany::one(CompletionMessage::User {
1501                content: OneOrMany::one(UserContent::Text(Text {
1502                    text: "Hello!".to_string(),
1503                })),
1504            }),
1505            documents: vec![],
1506            tools: vec![],
1507            temperature: Some(0.5),
1508            max_tokens: None,
1509            tool_choice: None,
1510            additional_params: None,
1511            output_schema: None,
1512        };
1513
1514        // Convert to OllamaCompletionRequest
1515        let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1516            .expect("Failed to create Ollama request");
1517
1518        // Serialize to JSON
1519        let serialized =
1520            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1521
1522        // Assert that "think" defaults to false and "keep_alive" is not present
1523        let expected = json!({
1524            "model": "llama3.2",
1525            "messages": [
1526                {
1527                    "role": "system",
1528                    "content": "You are a helpful assistant."
1529                },
1530                {
1531                    "role": "user",
1532                    "content": "Hello!"
1533                }
1534            ],
1535            "temperature": 0.5,
1536            "stream": false,
1537            "think": false,
1538            "options": {
1539                "temperature": 0.5
1540            }
1541        });
1542
1543        assert_eq!(serialized, expected);
1544    }
1545
1546    #[test]
1547    fn test_completion_request_with_output_schema() {
1548        use crate::OneOrMany;
1549        use crate::completion::Message as CompletionMessage;
1550        use crate::message::{Text, UserContent};
1551
1552        let schema: schemars::Schema = serde_json::from_value(json!({
1553            "type": "object",
1554            "properties": {
1555                "age": { "type": "integer" },
1556                "available": { "type": "boolean" }
1557            },
1558            "required": ["age", "available"]
1559        }))
1560        .expect("Failed to parse schema");
1561
1562        let completion_request = CompletionRequest {
1563            model: Some("llama3.1".to_string()),
1564            preamble: None,
1565            chat_history: OneOrMany::one(CompletionMessage::User {
1566                content: OneOrMany::one(UserContent::Text(Text {
1567                    text: "How old is Ollama?".to_string(),
1568                })),
1569            }),
1570            documents: vec![],
1571            tools: vec![],
1572            temperature: None,
1573            max_tokens: None,
1574            tool_choice: None,
1575            additional_params: None,
1576            output_schema: Some(schema),
1577        };
1578
1579        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1580            .expect("Failed to create Ollama request");
1581
1582        let serialized =
1583            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1584
1585        let format = serialized
1586            .get("format")
1587            .expect("format field should be present");
1588        assert_eq!(
1589            *format,
1590            json!({
1591                "type": "object",
1592                "properties": {
1593                    "age": { "type": "integer" },
1594                    "available": { "type": "boolean" }
1595                },
1596                "required": ["age", "available"]
1597            })
1598        );
1599    }
1600
1601    #[test]
1602    fn test_completion_request_without_output_schema() {
1603        use crate::OneOrMany;
1604        use crate::completion::Message as CompletionMessage;
1605        use crate::message::{Text, UserContent};
1606
1607        let completion_request = CompletionRequest {
1608            model: Some("llama3.1".to_string()),
1609            preamble: None,
1610            chat_history: OneOrMany::one(CompletionMessage::User {
1611                content: OneOrMany::one(UserContent::Text(Text {
1612                    text: "Hello!".to_string(),
1613                })),
1614            }),
1615            documents: vec![],
1616            tools: vec![],
1617            temperature: None,
1618            max_tokens: None,
1619            tool_choice: None,
1620            additional_params: None,
1621            output_schema: None,
1622        };
1623
1624        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1625            .expect("Failed to create Ollama request");
1626
1627        let serialized =
1628            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1629
1630        assert!(
1631            serialized.get("format").is_none(),
1632            "format field should be absent when output_schema is None"
1633        );
1634    }
1635
1636    #[test]
1637    fn test_client_initialization() {
1638        let _client: crate::providers::ollama::Client =
1639            crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1640        let _client_from_builder: crate::providers::ollama::Client =
1641            crate::providers::ollama::Client::builder()
1642                .api_key(Nothing)
1643                .build()
1644                .expect("Client::builder() failed");
1645    }
1646}