Skip to main content

rig_core/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust,ignore
5//! use rig_core::client::{Nothing, CompletionClient};
6//! use rig_core::completion::Prompt;
7//! use rig_core::providers::ollama;
8//!
9//! // Create a new Ollama client (defaults to http://localhost:11434, no auth)
10//! let client = ollama::Client::new(Nothing).unwrap();
11//!
12//! // Or connect to a remote/proxied Ollama instance with authentication
13//! let client = ollama::Client::builder()
14//!     .api_key("my-secret-key")
15//!     .base_url("http://remote-ollama:11434")
16//!     .build()
17//!     .unwrap();
18//!
19//! // Create an agent with a preamble
20//! let comedian_agent = client
21//!     .agent("qwen2.5:14b")
22//!     .preamble("You are a comedian here to entertain the user using humour and jokes.")
23//!     .build();
24//!
25//! // Prompt the agent and print the response
26//! let response = comedian_agent.prompt("Entertain me!").await?;
27//! println!("{response}");
28//!
29//! // Create an embedding model using the "all-minilm" model
30//! let emb_model = client.embedding_model("all-minilm", 384);
31//! let embeddings = emb_model.embed_texts(vec![
32//!     "Why is the sky blue?".to_owned(),
33//!     "Why is the grass green?".to_owned()
34//! ]).await?;
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Create an extractor if needed
38//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
39//! ```
40use crate::client::{
41    self, ApiKey, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider, ProviderBuilder,
42    ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::model::{Model, ModelList, ModelListingError};
48use crate::streaming::RawStreamingChoice;
49use crate::{
50    OneOrMany,
51    completion::{self, CompletionError, CompletionRequest},
52    embeddings::{self, EmbeddingError},
53    json_utils, message,
54    message::{ImageDetail, Text},
55    streaming,
56    wasm_compat::{WasmCompatSend, WasmCompatSync},
57};
58use async_stream::try_stream;
59use bytes::Bytes;
60use futures::StreamExt;
61use serde::{Deserialize, Serialize};
62use serde_json::{Value, json};
63use std::{convert::TryFrom, str::FromStr};
64use tracing::info_span;
65use tracing_futures::Instrument;
66// ---------- Main Client ----------
67
68const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
69
70/// Optional API key for Ollama. By default Ollama requires no authentication,
71/// but proxied or secured deployments may require a Bearer token.
72#[derive(Debug, Default, Clone)]
73pub struct OllamaApiKey(Option<String>);
74
75impl ApiKey for OllamaApiKey {
76    fn into_header(
77        self,
78    ) -> Option<http_client::Result<(http::header::HeaderName, http::header::HeaderValue)>> {
79        self.0.map(http_client::make_auth_header)
80    }
81}
82
83impl From<Nothing> for OllamaApiKey {
84    fn from(_: Nothing) -> Self {
85        Self(None)
86    }
87}
88
89impl From<String> for OllamaApiKey {
90    fn from(key: String) -> Self {
91        if key.is_empty() {
92            Self(None)
93        } else {
94            Self(Some(key))
95        }
96    }
97}
98
99impl From<&str> for OllamaApiKey {
100    fn from(key: &str) -> Self {
101        if key.is_empty() {
102            Self(None)
103        } else {
104            Self(Some(key.to_owned()))
105        }
106    }
107}
108
109#[derive(Debug, Default, Clone, Copy)]
110pub struct OllamaExt;
111
112#[derive(Debug, Default, Clone, Copy)]
113pub struct OllamaBuilder;
114
115impl Provider for OllamaExt {
116    type Builder = OllamaBuilder;
117    const VERIFY_PATH: &'static str = "api/tags";
118}
119
120impl<H> Capabilities<H> for OllamaExt {
121    type Completion = Capable<CompletionModel<H>>;
122    type Transcription = Nothing;
123    type Embeddings = Capable<EmbeddingModel<H>>;
124    type ModelListing = Capable<OllamaModelLister<H>>;
125    #[cfg(feature = "image")]
126    type ImageGeneration = Nothing;
127
128    #[cfg(feature = "audio")]
129    type AudioGeneration = Nothing;
130    type Rerank = Nothing;
131}
132
133impl DebugExt for OllamaExt {}
134
135impl ProviderBuilder for OllamaBuilder {
136    type Extension<H>
137        = OllamaExt
138    where
139        H: HttpClientExt;
140    type ApiKey = OllamaApiKey;
141
142    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
143
144    fn build<H>(
145        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
146    ) -> http_client::Result<Self::Extension<H>>
147    where
148        H: HttpClientExt,
149    {
150        Ok(OllamaExt)
151    }
152}
153
154pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
155pub type ClientBuilder<H = crate::markers::Missing> =
156    client::ClientBuilder<OllamaBuilder, OllamaApiKey, H>;
157
158impl ProviderClient for Client {
159    type Input = OllamaApiKey;
160    type Error = crate::client::ProviderClientError;
161
162    fn from_env() -> Result<Self, Self::Error> {
163        let api_base = crate::client::optional_env_var("OLLAMA_API_BASE_URL")?
164            .unwrap_or_else(|| OLLAMA_API_BASE_URL.to_string());
165
166        let api_key = crate::client::optional_env_var("OLLAMA_API_KEY")?
167            .map(OllamaApiKey::from)
168            .unwrap_or_default();
169
170        Self::builder()
171            .api_key(api_key)
172            .base_url(&api_base)
173            .build()
174            .map_err(Into::into)
175    }
176
177    fn from_val(api_key: Self::Input) -> Result<Self, Self::Error> {
178        Self::builder().api_key(api_key).build().map_err(Into::into)
179    }
180}
181
182// ---------- API Error and Response Structures ----------
183
184#[derive(Debug, Deserialize)]
185struct ApiErrorResponse {
186    message: String,
187}
188
189#[derive(Debug, Deserialize)]
190#[serde(untagged)]
191enum ApiResponse<T> {
192    Ok(T),
193    Err(ApiErrorResponse),
194}
195
196// ---------- Embedding API ----------
197
198pub const ALL_MINILM: &str = "all-minilm";
199pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
200
201fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
202    match identifier {
203        ALL_MINILM => Some(384),
204        NOMIC_EMBED_TEXT => Some(768),
205        _ => None,
206    }
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210pub struct EmbeddingResponse {
211    pub model: String,
212    pub embeddings: Vec<Vec<f64>>,
213    #[serde(default)]
214    pub total_duration: Option<u64>,
215    #[serde(default)]
216    pub load_duration: Option<u64>,
217    #[serde(default)]
218    pub prompt_eval_count: Option<u64>,
219}
220
221impl From<ApiErrorResponse> for EmbeddingError {
222    fn from(err: ApiErrorResponse) -> Self {
223        EmbeddingError::ProviderError(err.message)
224    }
225}
226
227impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
228    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
229        match value {
230            ApiResponse::Ok(response) => Ok(response),
231            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
232        }
233    }
234}
235
236// ---------- Embedding Model ----------
237
238#[derive(Clone)]
239pub struct EmbeddingModel<T = reqwest::Client> {
240    client: Client<T>,
241    pub model: String,
242    ndims: usize,
243}
244
245impl<T> EmbeddingModel<T> {
246    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
247        Self {
248            client,
249            model: model.into(),
250            ndims,
251        }
252    }
253
254    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
255        Self {
256            client,
257            model: model.into(),
258            ndims,
259        }
260    }
261}
262
263impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
264where
265    T: HttpClientExt + Clone + 'static,
266{
267    type Client = Client<T>;
268
269    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
270        let model = model.into();
271        let dims = dims
272            .or(model_dimensions_from_identifier(&model))
273            .unwrap_or_default();
274        Self::new(client.clone(), model, dims)
275    }
276
277    const MAX_DOCUMENTS: usize = 1024;
278    fn ndims(&self) -> usize {
279        self.ndims
280    }
281
282    async fn embed_texts(
283        &self,
284        documents: impl IntoIterator<Item = String>,
285    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
286        let docs: Vec<String> = documents.into_iter().collect();
287
288        let body = serde_json::to_vec(&json!({
289            "model": self.model,
290            "input": docs
291        }))?;
292
293        let req = self
294            .client
295            .post("api/embed")?
296            .body(body)
297            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
298
299        let response = self.client.send::<_, Vec<u8>>(req).await?;
300
301        if !response.status().is_success() {
302            let text = http_client::text(response).await?;
303            return Err(EmbeddingError::ProviderError(text));
304        }
305
306        let bytes: Vec<u8> = response.into_body().await?;
307
308        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
309
310        if api_resp.embeddings.len() != docs.len() {
311            return Err(EmbeddingError::ResponseError(
312                "Number of returned embeddings does not match input".into(),
313            ));
314        }
315        Ok(api_resp
316            .embeddings
317            .into_iter()
318            .zip(docs.into_iter())
319            .map(|(vec, document)| embeddings::Embedding { document, vec })
320            .collect())
321    }
322}
323
324// ---------- Completion API ----------
325
326pub const LLAMA3_2: &str = "llama3.2";
327pub const LLAVA: &str = "llava";
328pub const MISTRAL: &str = "mistral";
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct CompletionResponse {
332    pub model: String,
333    pub created_at: String,
334    pub message: Message,
335    pub done: bool,
336    #[serde(default)]
337    pub done_reason: Option<String>,
338    #[serde(default)]
339    pub total_duration: Option<u64>,
340    #[serde(default)]
341    pub load_duration: Option<u64>,
342    #[serde(default)]
343    pub prompt_eval_count: Option<u64>,
344    #[serde(default)]
345    pub prompt_eval_duration: Option<u64>,
346    #[serde(default)]
347    pub eval_count: Option<u64>,
348    #[serde(default)]
349    pub eval_duration: Option<u64>,
350}
351impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
352    type Error = CompletionError;
353    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
354        match resp.message {
355            // Process only if an assistant message is present.
356            Message::Assistant {
357                content,
358                thinking,
359                tool_calls,
360                ..
361            } => {
362                let mut assistant_contents = Vec::new();
363                // Add the assistant's text content if any.
364                if !content.is_empty() {
365                    assistant_contents.push(completion::AssistantContent::text(&content));
366                }
367                // Process tool_calls following Ollama's chat response definition.
368                // Each ToolCall has an id, a type, and a function field.
369                for tc in tool_calls.iter() {
370                    assistant_contents.push(completion::AssistantContent::tool_call(
371                        tc.function.name.clone(),
372                        tc.function.name.clone(),
373                        tc.function.arguments.clone(),
374                    ));
375                }
376                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
377                    CompletionError::ResponseError("No content provided".to_owned())
378                })?;
379                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
380                let completion_tokens = resp.eval_count.unwrap_or(0);
381
382                let raw_response = CompletionResponse {
383                    model: resp.model,
384                    created_at: resp.created_at,
385                    done: resp.done,
386                    done_reason: resp.done_reason,
387                    total_duration: resp.total_duration,
388                    load_duration: resp.load_duration,
389                    prompt_eval_count: resp.prompt_eval_count,
390                    prompt_eval_duration: resp.prompt_eval_duration,
391                    eval_count: resp.eval_count,
392                    eval_duration: resp.eval_duration,
393                    message: Message::Assistant {
394                        content,
395                        thinking,
396                        images: None,
397                        name: None,
398                        tool_calls,
399                    },
400                };
401
402                Ok(completion::CompletionResponse {
403                    choice,
404                    usage: Usage {
405                        input_tokens: prompt_tokens,
406                        output_tokens: completion_tokens,
407                        total_tokens: prompt_tokens + completion_tokens,
408                        cached_input_tokens: 0,
409                        cache_creation_input_tokens: 0,
410                        tool_use_prompt_tokens: 0,
411                        reasoning_tokens: 0,
412                    },
413                    raw_response,
414                    message_id: None,
415                })
416            }
417            _ => Err(CompletionError::ResponseError(
418                "Chat response does not include an assistant message".into(),
419            )),
420        }
421    }
422}
423
424#[derive(Debug, Serialize, Deserialize)]
425pub(super) struct OllamaCompletionRequest {
426    model: String,
427    pub messages: Vec<Message>,
428    #[serde(skip_serializing_if = "Option::is_none")]
429    temperature: Option<f64>,
430    #[serde(skip_serializing_if = "Vec::is_empty")]
431    tools: Vec<ToolDefinition>,
432    pub stream: bool,
433    think: Think,
434    #[serde(skip_serializing_if = "Option::is_none")]
435    max_tokens: Option<u64>,
436    #[serde(skip_serializing_if = "Option::is_none")]
437    keep_alive: Option<String>,
438    #[serde(skip_serializing_if = "Option::is_none")]
439    format: Option<schemars::Schema>,
440    options: serde_json::Value,
441}
442
443impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
444    type Error = CompletionError;
445
446    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
447        let chat_history = req.chat_history_with_documents();
448        let model = req.model.clone().unwrap_or_else(|| model.to_string());
449        if req.tool_choice.is_some() {
450            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
451        }
452        // Build up the order of messages.
453        let mut partial_history = vec![];
454        partial_history.extend(chat_history);
455
456        // Add preamble to chat history (if available)
457        let mut full_history: Vec<Message> = match &req.preamble {
458            Some(preamble) => vec![Message::system(preamble)],
459            None => vec![],
460        };
461
462        // Convert and extend the rest of the history
463        full_history.extend(
464            partial_history
465                .into_iter()
466                .map(message::Message::try_into)
467                .collect::<Result<Vec<Vec<Message>>, _>>()?
468                .into_iter()
469                .flatten()
470                .collect::<Vec<_>>(),
471        );
472
473        let mut think = Think::Bool(false);
474        let mut keep_alive: Option<String> = None;
475
476        let options = if let Some(mut extra) = req.additional_params {
477            // Extract top-level parameters that should not be in `options`
478            if let Some(obj) = extra.as_object_mut() {
479                // Extract `think` parameter
480                if let Some(think_val) = obj.remove("think") {
481                    think = match think_val {
482                        Value::Bool(think) => Think::Bool(think),
483                        Value::String(think) => Think::Level(match think.to_lowercase().as_str() {
484                            "low" => Level::Low,
485                            "medium" => Level::Medium,
486                            "high" => Level::High,
487                            _ => {
488                                return Err(CompletionError::RequestError(
489                                    "`think` must be a 'low', 'medium', 'high', or bool".into(),
490                                ));
491                            }
492                        }),
493                        _ => {
494                            return Err(CompletionError::RequestError(
495                                "`think` must be a 'low', 'medium', 'high', or bool".into(),
496                            ));
497                        }
498                    };
499                }
500
501                // Extract `keep_alive` parameter
502                if let Some(keep_alive_val) = obj.remove("keep_alive") {
503                    keep_alive = Some(
504                        keep_alive_val
505                            .as_str()
506                            .ok_or_else(|| {
507                                CompletionError::RequestError(
508                                    "`keep_alive` must be a string".into(),
509                                )
510                            })?
511                            .to_string(),
512                    );
513                }
514            }
515
516            json_utils::merge(json!({ "temperature": req.temperature }), extra)
517        } else {
518            json!({ "temperature": req.temperature })
519        };
520
521        Ok(Self {
522            model: model.to_string(),
523            messages: full_history,
524            temperature: req.temperature,
525            max_tokens: req.max_tokens,
526            stream: false,
527            think,
528            keep_alive,
529            format: req.output_schema,
530            tools: req
531                .tools
532                .clone()
533                .into_iter()
534                .map(ToolDefinition::from)
535                .collect::<Vec<_>>(),
536            options,
537        })
538    }
539}
540
541#[derive(Clone)]
542pub struct CompletionModel<T = reqwest::Client> {
543    client: Client<T>,
544    pub model: String,
545}
546
547impl<T> CompletionModel<T> {
548    pub fn new(client: Client<T>, model: &str) -> Self {
549        Self {
550            client,
551            model: model.to_owned(),
552        }
553    }
554}
555
556#[derive(Debug, Clone, Serialize, Deserialize)]
557#[serde(untagged)]
558enum Think {
559    Bool(bool),
560    Level(Level),
561}
562
563#[derive(Debug, Clone, Serialize, Deserialize)]
564#[serde(rename_all = "lowercase")]
565enum Level {
566    Low,
567    Medium,
568    High,
569}
570
571// ---------- CompletionModel Implementation ----------
572
573#[derive(Clone, Serialize, Deserialize, Debug)]
574pub struct StreamingCompletionResponse {
575    pub done_reason: Option<String>,
576    pub total_duration: Option<u64>,
577    pub load_duration: Option<u64>,
578    pub prompt_eval_count: Option<u64>,
579    pub prompt_eval_duration: Option<u64>,
580    pub eval_count: Option<u64>,
581    pub eval_duration: Option<u64>,
582}
583
584impl GetTokenUsage for StreamingCompletionResponse {
585    fn token_usage(&self) -> crate::completion::Usage {
586        let mut usage = crate::completion::Usage::new();
587        let input_tokens = self.prompt_eval_count.unwrap_or_default();
588        let output_tokens = self.eval_count.unwrap_or_default();
589        usage.input_tokens = input_tokens;
590        usage.output_tokens = output_tokens;
591        usage.total_tokens = input_tokens + output_tokens;
592
593        usage
594    }
595}
596
597/// Reassembles newline-delimited JSON lines from a chunked HTTP byte stream.
598///
599/// `bytes_stream` makes no promises about chunk boundaries, so a single NDJSON
600/// line can be split across multiple chunks. `NdjsonBuffer` holds the trailing
601/// fragment between calls and yields only fully terminated lines.
602#[derive(Default)]
603struct NdjsonBuffer {
604    buf: Vec<u8>,
605}
606
607impl NdjsonBuffer {
608    fn new() -> Self {
609        Self::default()
610    }
611
612    /// Appends `chunk` to the buffer and returns any newly completed lines.
613    /// Empty lines are skipped; trailing partial data is retained for the next call.
614    fn decode(&mut self, chunk: &[u8]) -> Vec<Vec<u8>> {
615        self.buf.extend_from_slice(chunk);
616
617        let mut lines = Vec::new();
618        while let Some(pos) = self.buf.iter().position(|&b| b == b'\n') {
619            let mut line: Vec<u8> = self.buf.drain(..=pos).collect();
620            line.pop();
621            if !line.is_empty() {
622                lines.push(line);
623            }
624        }
625        lines
626    }
627}
628
629impl<T> completion::CompletionModel for CompletionModel<T>
630where
631    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
632{
633    type Response = CompletionResponse;
634    type StreamingResponse = StreamingCompletionResponse;
635
636    type Client = Client<T>;
637
638    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
639        Self::new(client.clone(), model.into().as_str())
640    }
641
642    async fn completion(
643        &self,
644        completion_request: CompletionRequest,
645    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
646        let span = if tracing::Span::current().is_disabled() {
647            info_span!(
648                target: "rig::completions",
649                "chat",
650                gen_ai.operation.name = "chat",
651                gen_ai.provider.name = "ollama",
652                gen_ai.request.model = self.model,
653                gen_ai.system_instructions = tracing::field::Empty,
654                gen_ai.response.id = tracing::field::Empty,
655                gen_ai.response.model = tracing::field::Empty,
656                gen_ai.usage.output_tokens = tracing::field::Empty,
657                gen_ai.usage.input_tokens = tracing::field::Empty,
658                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
659            )
660        } else {
661            tracing::Span::current()
662        };
663
664        span.record("gen_ai.system_instructions", &completion_request.preamble);
665        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
666
667        if tracing::enabled!(tracing::Level::TRACE) {
668            tracing::trace!(target: "rig::completions",
669                "Ollama completion request: {}",
670                serde_json::to_string_pretty(&request)?
671            );
672        }
673
674        let body = serde_json::to_vec(&request)?;
675
676        let req = self
677            .client
678            .post("api/chat")?
679            .body(body)
680            .map_err(http_client::Error::from)?;
681
682        let async_block = async move {
683            let response = self.client.send::<_, Bytes>(req).await?;
684            let status = response.status();
685            let response_body = response.into_body().into_future().await?.to_vec();
686
687            if !status.is_success() {
688                return Err(CompletionError::ProviderError(
689                    String::from_utf8_lossy(&response_body).to_string(),
690                ));
691            }
692
693            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
694            let span = tracing::Span::current();
695            span.record("gen_ai.response.model", &response.model);
696            span.record(
697                "gen_ai.usage.input_tokens",
698                response.prompt_eval_count.unwrap_or_default(),
699            );
700            span.record(
701                "gen_ai.usage.output_tokens",
702                response.eval_count.unwrap_or_default(),
703            );
704
705            if tracing::enabled!(tracing::Level::TRACE) {
706                tracing::trace!(target: "rig::completions",
707                    "Ollama completion response: {}",
708                    serde_json::to_string_pretty(&response)?
709                );
710            }
711
712            let response: completion::CompletionResponse<CompletionResponse> =
713                response.try_into()?;
714
715            Ok(response)
716        };
717
718        tracing::Instrument::instrument(async_block, span).await
719    }
720
721    async fn stream(
722        &self,
723        request: CompletionRequest,
724    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
725    {
726        let span = if tracing::Span::current().is_disabled() {
727            info_span!(
728                target: "rig::completions",
729                "chat_streaming",
730                gen_ai.operation.name = "chat_streaming",
731                gen_ai.provider.name = "ollama",
732                gen_ai.request.model = self.model,
733                gen_ai.system_instructions = tracing::field::Empty,
734                gen_ai.response.id = tracing::field::Empty,
735                gen_ai.response.model = self.model,
736                gen_ai.usage.output_tokens = tracing::field::Empty,
737                gen_ai.usage.input_tokens = tracing::field::Empty,
738                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
739            )
740        } else {
741            tracing::Span::current()
742        };
743
744        span.record("gen_ai.system_instructions", &request.preamble);
745
746        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
747        request.stream = true;
748
749        if tracing::enabled!(tracing::Level::TRACE) {
750            tracing::trace!(target: "rig::completions",
751                "Ollama streaming completion request: {}",
752                serde_json::to_string_pretty(&request)?
753            );
754        }
755
756        let body = serde_json::to_vec(&request)?;
757
758        let req = self
759            .client
760            .post("api/chat")?
761            .body(body)
762            .map_err(http_client::Error::from)?;
763
764        let response = self.client.send_streaming(req).await?;
765        let status = response.status();
766        let mut byte_stream = response.into_body();
767
768        if !status.is_success() {
769            return Err(CompletionError::ProviderError(format!(
770                "Got error status code trying to send a request to Ollama: {status}"
771            )));
772        }
773
774        let stream = try_stream! {
775            let span = tracing::Span::current();
776            let mut tool_calls_final = Vec::new();
777            let mut text_response = String::new();
778            let mut thinking_response = String::new();
779            let mut line_buf = NdjsonBuffer::new();
780
781            while let Some(chunk) = byte_stream.next().await {
782                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
783
784                for line in line_buf.decode(&bytes) {
785                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(&line));
786
787                    let response: CompletionResponse = serde_json::from_slice(&line)?;
788
789                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
790                        if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
791                            thinking_response += &thinking_content;
792                            yield RawStreamingChoice::ReasoningDelta {
793                                id: None,
794                                reasoning: thinking_content,
795                            };
796                        }
797
798                        if !content.is_empty() {
799                            text_response += &content;
800                            yield RawStreamingChoice::Message(content);
801                        }
802
803                        for tool_call in tool_calls {
804                            tool_calls_final.push(tool_call.clone());
805                            yield RawStreamingChoice::ToolCall(
806                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
807                            );
808                        }
809                    }
810
811                    if response.done {
812                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
813                        span.record("gen_ai.usage.output_tokens", response.eval_count);
814                        let message = Message::Assistant {
815                            content: text_response.clone(),
816                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
817                            images: None,
818                            name: None,
819                            tool_calls: tool_calls_final.clone()
820                        };
821                        if let Ok(serialized_message) = serde_json::to_string(&vec![message]) {
822                            span.record("gen_ai.output.messages", serialized_message);
823                        }
824                        yield RawStreamingChoice::FinalResponse(
825                            StreamingCompletionResponse {
826                                total_duration: response.total_duration,
827                                load_duration: response.load_duration,
828                                prompt_eval_count: response.prompt_eval_count,
829                                prompt_eval_duration: response.prompt_eval_duration,
830                                eval_count: response.eval_count,
831                                eval_duration: response.eval_duration,
832                                done_reason: response.done_reason,
833                            }
834                        );
835                        break;
836                    }
837                }
838            }
839        }.instrument(span);
840
841        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
842            stream,
843        )))
844    }
845}
846
847// ---------- Model Listing  ----------
848
849#[derive(Debug, Deserialize)]
850struct ListModelsResponse {
851    models: Vec<ListModelEntry>,
852}
853
854#[derive(Debug, Deserialize)]
855struct ListModelEntry {
856    name: String,
857    model: String,
858}
859
860impl From<ListModelEntry> for Model {
861    fn from(value: ListModelEntry) -> Self {
862        Model::new(value.model, value.name)
863    }
864}
865
866/// [`ModelLister`] implementation for the Ollama API (`GET /api/tags`).
867#[derive(Clone)]
868pub struct OllamaModelLister<H = reqwest::Client> {
869    client: Client<H>,
870}
871
872impl<H> ModelLister<H> for OllamaModelLister<H>
873where
874    H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
875{
876    type Client = Client<H>;
877
878    fn new(client: Self::Client) -> Self {
879        Self { client }
880    }
881
882    async fn list_all(&self) -> Result<ModelList, ModelListingError> {
883        let path = "/api/tags";
884        let req = self.client.get(path)?.body(http_client::NoBody)?;
885        let response = self.client.send::<_, Vec<u8>>(req).await?;
886
887        if !response.status().is_success() {
888            let status_code = response.status().as_u16();
889            let body = response.into_body().await?;
890            return Err(ModelListingError::api_error_with_context(
891                "Ollama",
892                path,
893                status_code,
894                &body,
895            ));
896        }
897
898        let body = response.into_body().await?;
899        let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
900            ModelListingError::parse_error_with_context("Ollama", path, &error, &body)
901        })?;
902        let models = api_resp.models.into_iter().map(Model::from).collect();
903
904        Ok(ModelList::new(models))
905    }
906}
907
908// ---------- Tool Definition Conversion ----------
909
910/// Ollama-required tool definition format.
911#[derive(Clone, Debug, Deserialize, Serialize)]
912pub struct ToolDefinition {
913    #[serde(rename = "type")]
914    pub type_field: String, // Fixed as "function"
915    pub function: completion::ToolDefinition,
916}
917
918/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
919impl From<crate::completion::ToolDefinition> for ToolDefinition {
920    fn from(tool: crate::completion::ToolDefinition) -> Self {
921        ToolDefinition {
922            type_field: "function".to_owned(),
923            function: completion::ToolDefinition {
924                name: tool.name,
925                description: tool.description,
926                parameters: tool.parameters,
927            },
928        }
929    }
930}
931
932#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
933pub struct ToolCall {
934    #[serde(default, rename = "type")]
935    pub r#type: ToolType,
936    pub function: Function,
937}
938#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
939#[serde(rename_all = "lowercase")]
940pub enum ToolType {
941    #[default]
942    Function,
943}
944#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
945pub struct Function {
946    pub name: String,
947    pub arguments: Value,
948}
949
950// ---------- Provider Message Definition ----------
951
952#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
953#[serde(tag = "role", rename_all = "lowercase")]
954pub enum Message {
955    User {
956        content: String,
957        #[serde(skip_serializing_if = "Option::is_none")]
958        images: Option<Vec<String>>,
959        #[serde(skip_serializing_if = "Option::is_none")]
960        name: Option<String>,
961    },
962    Assistant {
963        #[serde(default)]
964        content: String,
965        #[serde(skip_serializing_if = "Option::is_none")]
966        thinking: Option<String>,
967        #[serde(skip_serializing_if = "Option::is_none")]
968        images: Option<Vec<String>>,
969        #[serde(skip_serializing_if = "Option::is_none")]
970        name: Option<String>,
971        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
972        tool_calls: Vec<ToolCall>,
973    },
974    System {
975        content: String,
976        #[serde(skip_serializing_if = "Option::is_none")]
977        images: Option<Vec<String>>,
978        #[serde(skip_serializing_if = "Option::is_none")]
979        name: Option<String>,
980    },
981    #[serde(rename = "tool")]
982    ToolResult {
983        #[serde(rename = "tool_name")]
984        name: String,
985        content: String,
986    },
987}
988
989/// -----------------------------
990/// Provider Message Conversions
991/// -----------------------------
992/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
993/// (Only User and Assistant variants are supported.)
994impl TryFrom<crate::message::Message> for Vec<Message> {
995    type Error = crate::message::MessageError;
996    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
997        use crate::message::Message as InternalMessage;
998        match internal_msg {
999            InternalMessage::System { content } => Ok(vec![Message::System {
1000                content,
1001                images: None,
1002                name: None,
1003            }]),
1004            InternalMessage::User { content, .. } => {
1005                let (tool_results, other_content): (Vec<_>, Vec<_>) =
1006                    content.into_iter().partition(|content| {
1007                        matches!(content, crate::message::UserContent::ToolResult(_))
1008                    });
1009
1010                if !tool_results.is_empty() {
1011                    tool_results
1012                        .into_iter()
1013                        .map(|content| match content {
1014                            crate::message::UserContent::ToolResult(
1015                                crate::message::ToolResult { id, content, .. },
1016                            ) => {
1017                                // Ollama expects a single string for tool results, so we concatenate
1018                                let content_string = content
1019                                    .into_iter()
1020                                    .map(|content| match content {
1021                                        crate::message::ToolResultContent::Text(text) => text.text,
1022                                        _ => "[Non-text content]".to_string(),
1023                                    })
1024                                    .collect::<Vec<_>>()
1025                                    .join("\n");
1026
1027                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
1028                                    name: id,
1029                                    content: content_string,
1030                                })
1031                            }
1032                            _ => Err(crate::message::MessageError::ConversionError(
1033                                "expected tool result content while converting Ollama input".into(),
1034                            )),
1035                        })
1036                        .collect::<Result<Vec<_>, _>>()
1037                } else {
1038                    // Ollama requires separate text content and images array
1039                    let (texts, images) = other_content.into_iter().fold(
1040                        (Vec::new(), Vec::new()),
1041                        |(mut texts, mut images), content| {
1042                            match content {
1043                                crate::message::UserContent::Text(crate::message::Text {
1044                                    text,
1045                                    ..
1046                                }) => texts.push(text),
1047                                crate::message::UserContent::Image(crate::message::Image {
1048                                    data: DocumentSourceKind::Base64(data),
1049                                    ..
1050                                }) => images.push(data),
1051                                crate::message::UserContent::Document(
1052                                    crate::message::Document {
1053                                        data:
1054                                            DocumentSourceKind::Base64(data)
1055                                            | DocumentSourceKind::String(data),
1056                                        ..
1057                                    },
1058                                ) => texts.push(data),
1059                                _ => {} // Audio not supported by Ollama
1060                            }
1061                            (texts, images)
1062                        },
1063                    );
1064
1065                    Ok(vec![Message::User {
1066                        content: texts.join(" "),
1067                        images: if images.is_empty() {
1068                            None
1069                        } else {
1070                            Some(
1071                                images
1072                                    .into_iter()
1073                                    .map(|x| x.to_string())
1074                                    .collect::<Vec<String>>(),
1075                            )
1076                        },
1077                        name: None,
1078                    }])
1079                }
1080            }
1081            InternalMessage::Assistant { content, .. } => {
1082                let mut thinking: Option<String> = None;
1083                let mut text_content = Vec::new();
1084                let mut tool_calls = Vec::new();
1085
1086                for content in content.into_iter() {
1087                    match content {
1088                        crate::message::AssistantContent::Text(text) => {
1089                            text_content.push(text.text)
1090                        }
1091                        crate::message::AssistantContent::ToolCall(tool_call) => {
1092                            tool_calls.push(tool_call)
1093                        }
1094                        crate::message::AssistantContent::Reasoning(reasoning) => {
1095                            let display = reasoning.display_text();
1096                            if !display.is_empty() {
1097                                thinking = Some(display);
1098                            }
1099                        }
1100                        crate::message::AssistantContent::Image(_) => {
1101                            return Err(crate::message::MessageError::ConversionError(
1102                                "Ollama currently doesn't support images.".into(),
1103                            ));
1104                        }
1105                    }
1106                }
1107
1108                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
1109                //  so either `content` or `tool_calls` will have some content.
1110                Ok(vec![Message::Assistant {
1111                    content: text_content.join(" "),
1112                    thinking,
1113                    images: None,
1114                    name: None,
1115                    tool_calls: tool_calls
1116                        .into_iter()
1117                        .map(|tool_call| tool_call.into())
1118                        .collect::<Vec<_>>(),
1119                }])
1120            }
1121        }
1122    }
1123}
1124
1125/// Conversion from provider Message to a completion message.
1126/// This is needed so that responses can be converted back into chat history.
1127impl From<Message> for crate::completion::Message {
1128    fn from(msg: Message) -> Self {
1129        match msg {
1130            Message::User { content, .. } => crate::completion::Message::User {
1131                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text::new(
1132                    content,
1133                ))),
1134            },
1135            Message::Assistant {
1136                content,
1137                tool_calls,
1138                ..
1139            } => {
1140                let mut assistant_contents =
1141                    vec![crate::completion::message::AssistantContent::Text(
1142                        Text::new(content),
1143                    )];
1144                for tc in tool_calls {
1145                    assistant_contents.push(
1146                        crate::completion::message::AssistantContent::tool_call(
1147                            tc.function.name.clone(),
1148                            tc.function.name,
1149                            tc.function.arguments,
1150                        ),
1151                    );
1152                }
1153                let content =
1154                    OneOrMany::from_iter_optional(assistant_contents).unwrap_or_else(|| {
1155                        OneOrMany::one(crate::completion::message::AssistantContent::Text(
1156                            Text::new(String::new()),
1157                        ))
1158                    });
1159
1160                crate::completion::Message::Assistant { id: None, content }
1161            }
1162            // System and ToolResult are converted to User message as needed.
1163            Message::System { content, .. } => crate::completion::Message::User {
1164                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text::new(
1165                    content,
1166                ))),
1167            },
1168            Message::ToolResult { name, content } => crate::completion::Message::User {
1169                content: OneOrMany::one(message::UserContent::tool_result(
1170                    name,
1171                    OneOrMany::one(message::ToolResultContent::text(content)),
1172                )),
1173            },
1174        }
1175    }
1176}
1177
1178impl Message {
1179    /// Constructs a system message.
1180    pub fn system(content: &str) -> Self {
1181        Message::System {
1182            content: content.to_owned(),
1183            images: None,
1184            name: None,
1185        }
1186    }
1187}
1188
1189// ---------- Additional Message Types ----------
1190
1191impl From<crate::message::ToolCall> for ToolCall {
1192    fn from(tool_call: crate::message::ToolCall) -> Self {
1193        Self {
1194            r#type: ToolType::Function,
1195            function: Function {
1196                name: tool_call.function.name,
1197                arguments: tool_call.function.arguments,
1198            },
1199        }
1200    }
1201}
1202
1203#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1204pub struct SystemContent {
1205    #[serde(default)]
1206    r#type: SystemContentType,
1207    text: String,
1208}
1209
1210#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1211#[serde(rename_all = "lowercase")]
1212pub enum SystemContentType {
1213    #[default]
1214    Text,
1215}
1216
1217impl From<String> for SystemContent {
1218    fn from(s: String) -> Self {
1219        SystemContent {
1220            r#type: SystemContentType::default(),
1221            text: s,
1222        }
1223    }
1224}
1225
1226impl FromStr for SystemContent {
1227    type Err = std::convert::Infallible;
1228    fn from_str(s: &str) -> Result<Self, Self::Err> {
1229        Ok(SystemContent {
1230            r#type: SystemContentType::default(),
1231            text: s.to_string(),
1232        })
1233    }
1234}
1235
1236#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1237pub struct AssistantContent {
1238    pub text: String,
1239}
1240
1241impl FromStr for AssistantContent {
1242    type Err = std::convert::Infallible;
1243    fn from_str(s: &str) -> Result<Self, Self::Err> {
1244        Ok(AssistantContent { text: s.to_owned() })
1245    }
1246}
1247
1248#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1249#[serde(tag = "type", rename_all = "lowercase")]
1250pub enum UserContent {
1251    Text { text: String },
1252    Image { image_url: ImageUrl },
1253    // Audio variant removed as Ollama API does not support audio input.
1254}
1255
1256impl FromStr for UserContent {
1257    type Err = std::convert::Infallible;
1258    fn from_str(s: &str) -> Result<Self, Self::Err> {
1259        Ok(UserContent::Text { text: s.to_owned() })
1260    }
1261}
1262
1263#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1264pub struct ImageUrl {
1265    pub url: String,
1266    #[serde(default)]
1267    pub detail: ImageDetail,
1268}
1269
1270// =================================================================
1271// Tests
1272// =================================================================
1273
1274#[cfg(test)]
1275mod tests {
1276    use super::*;
1277    use serde_json::json;
1278
1279    // Test deserialization and conversion for the /api/chat endpoint.
1280    #[tokio::test]
1281    async fn test_chat_completion() {
1282        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1283        let sample_chat_response = json!({
1284            "model": "llama3.2",
1285            "created_at": "2023-08-04T19:22:45.499127Z",
1286            "message": {
1287                "role": "assistant",
1288                "content": "The sky is blue because of Rayleigh scattering.",
1289                "images": null,
1290                "tool_calls": [
1291                    {
1292                        "type": "function",
1293                        "function": {
1294                            "name": "get_current_weather",
1295                            "arguments": {
1296                                "location": "San Francisco, CA",
1297                                "format": "celsius"
1298                            }
1299                        }
1300                    }
1301                ]
1302            },
1303            "done": true,
1304            "total_duration": 8000000000u64,
1305            "load_duration": 6000000u64,
1306            "prompt_eval_count": 61u64,
1307            "prompt_eval_duration": 400000000u64,
1308            "eval_count": 468u64,
1309            "eval_duration": 7700000000u64
1310        });
1311        let sample_text = sample_chat_response.to_string();
1312
1313        let chat_resp: CompletionResponse =
1314            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1315        let conv: completion::CompletionResponse<CompletionResponse> =
1316            chat_resp.try_into().unwrap();
1317        assert!(
1318            !conv.choice.is_empty(),
1319            "Expected non-empty choice in chat response"
1320        );
1321    }
1322
1323    // Test conversion from provider Message to completion Message.
1324    #[test]
1325    fn test_message_conversion() {
1326        // Construct a provider Message (User variant with String content).
1327        let provider_msg = Message::User {
1328            content: "Test message".to_owned(),
1329            images: None,
1330            name: None,
1331        };
1332        // Convert it into a completion::Message.
1333        let comp_msg: crate::completion::Message = provider_msg.into();
1334        match comp_msg {
1335            crate::completion::Message::User { content } => {
1336                // Assume OneOrMany<T> has a method first() to access the first element.
1337                let first_content = content.first();
1338                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1339                match first_content {
1340                    crate::completion::message::UserContent::Text(text_struct) => {
1341                        assert_eq!(text_struct.text, "Test message");
1342                    }
1343                    _ => panic!("Expected text content in conversion"),
1344                }
1345            }
1346            _ => panic!("Conversion from provider Message to completion Message failed"),
1347        }
1348    }
1349
1350    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1351    #[test]
1352    fn test_tool_definition_conversion() {
1353        // Internal tool definition from the completion module.
1354        let internal_tool = crate::completion::ToolDefinition {
1355            name: "get_current_weather".to_owned(),
1356            description: "Get the current weather for a location".to_owned(),
1357            parameters: json!({
1358                "type": "object",
1359                "properties": {
1360                    "location": {
1361                        "type": "string",
1362                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1363                    },
1364                    "format": {
1365                        "type": "string",
1366                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1367                        "enum": ["celsius", "fahrenheit"]
1368                    }
1369                },
1370                "required": ["location", "format"]
1371            }),
1372        };
1373        // Convert internal tool to Ollama's tool definition.
1374        let ollama_tool: ToolDefinition = internal_tool.into();
1375        assert_eq!(ollama_tool.type_field, "function");
1376        assert_eq!(ollama_tool.function.name, "get_current_weather");
1377        assert_eq!(
1378            ollama_tool.function.description,
1379            "Get the current weather for a location"
1380        );
1381        // Check JSON fields in parameters.
1382        let params = &ollama_tool.function.parameters;
1383        assert_eq!(params["properties"]["location"]["type"], "string");
1384    }
1385
1386    // Test deserialization of chat response with thinking content
1387    #[tokio::test]
1388    async fn test_chat_completion_with_thinking() {
1389        let sample_response = json!({
1390            "model": "qwen-thinking",
1391            "created_at": "2023-08-04T19:22:45.499127Z",
1392            "message": {
1393                "role": "assistant",
1394                "content": "The answer is 42.",
1395                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1396                "images": null,
1397                "tool_calls": []
1398            },
1399            "done": true,
1400            "total_duration": 8000000000u64,
1401            "load_duration": 6000000u64,
1402            "prompt_eval_count": 61u64,
1403            "prompt_eval_duration": 400000000u64,
1404            "eval_count": 468u64,
1405            "eval_duration": 7700000000u64
1406        });
1407
1408        let chat_resp: CompletionResponse =
1409            serde_json::from_value(sample_response).expect("Failed to deserialize");
1410
1411        // Verify thinking field is present
1412        if let Message::Assistant {
1413            thinking, content, ..
1414        } = &chat_resp.message
1415        {
1416            assert_eq!(
1417                thinking.as_ref().unwrap(),
1418                "Let me think about this carefully. The question asks for the meaning of life..."
1419            );
1420            assert_eq!(content, "The answer is 42.");
1421        } else {
1422            panic!("Expected Assistant message");
1423        }
1424    }
1425
1426    // Test deserialization of chat response without thinking content
1427    #[tokio::test]
1428    async fn test_chat_completion_without_thinking() {
1429        let sample_response = json!({
1430            "model": "llama3.2",
1431            "created_at": "2023-08-04T19:22:45.499127Z",
1432            "message": {
1433                "role": "assistant",
1434                "content": "Hello!",
1435                "images": null,
1436                "tool_calls": []
1437            },
1438            "done": true,
1439            "total_duration": 8000000000u64,
1440            "load_duration": 6000000u64,
1441            "prompt_eval_count": 10u64,
1442            "prompt_eval_duration": 400000000u64,
1443            "eval_count": 5u64,
1444            "eval_duration": 7700000000u64
1445        });
1446
1447        let chat_resp: CompletionResponse =
1448            serde_json::from_value(sample_response).expect("Failed to deserialize");
1449
1450        // Verify thinking field is None when not provided
1451        if let Message::Assistant {
1452            thinking, content, ..
1453        } = &chat_resp.message
1454        {
1455            assert!(thinking.is_none());
1456            assert_eq!(content, "Hello!");
1457        } else {
1458            panic!("Expected Assistant message");
1459        }
1460    }
1461
1462    // Test deserialization of streaming response with thinking content
1463    #[test]
1464    fn test_streaming_response_with_thinking() {
1465        let sample_chunk = json!({
1466            "model": "qwen-thinking",
1467            "created_at": "2023-08-04T19:22:45.499127Z",
1468            "message": {
1469                "role": "assistant",
1470                "content": "",
1471                "thinking": "Analyzing the problem...",
1472                "images": null,
1473                "tool_calls": []
1474            },
1475            "done": false
1476        });
1477
1478        let chunk: CompletionResponse =
1479            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1480
1481        if let Message::Assistant {
1482            thinking, content, ..
1483        } = &chunk.message
1484        {
1485            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1486            assert_eq!(content, "");
1487        } else {
1488            panic!("Expected Assistant message");
1489        }
1490    }
1491
1492    // Test message conversion with thinking content
1493    #[test]
1494    fn test_message_conversion_with_thinking() {
1495        // Create an internal message with reasoning content
1496        let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1497
1498        let internal_msg = crate::message::Message::Assistant {
1499            id: None,
1500            content: crate::OneOrMany::many(vec![
1501                crate::message::AssistantContent::Reasoning(reasoning_content),
1502                crate::message::AssistantContent::Text(crate::message::Text::new(
1503                    "The answer is X".to_string(),
1504                )),
1505            ])
1506            .unwrap(),
1507        };
1508
1509        // Convert to provider Message
1510        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1511        assert_eq!(provider_msgs.len(), 1);
1512
1513        if let Message::Assistant {
1514            thinking, content, ..
1515        } = &provider_msgs[0]
1516        {
1517            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1518            assert_eq!(content, "The answer is X");
1519        } else {
1520            panic!("Expected Assistant message with thinking");
1521        }
1522    }
1523
1524    // Test empty thinking content is handled correctly
1525    #[test]
1526    fn test_empty_thinking_content() {
1527        let sample_response = json!({
1528            "model": "llama3.2",
1529            "created_at": "2023-08-04T19:22:45.499127Z",
1530            "message": {
1531                "role": "assistant",
1532                "content": "Response",
1533                "thinking": "",
1534                "images": null,
1535                "tool_calls": []
1536            },
1537            "done": true,
1538            "total_duration": 8000000000u64,
1539            "load_duration": 6000000u64,
1540            "prompt_eval_count": 10u64,
1541            "prompt_eval_duration": 400000000u64,
1542            "eval_count": 5u64,
1543            "eval_duration": 7700000000u64
1544        });
1545
1546        let chat_resp: CompletionResponse =
1547            serde_json::from_value(sample_response).expect("Failed to deserialize");
1548
1549        if let Message::Assistant {
1550            thinking, content, ..
1551        } = &chat_resp.message
1552        {
1553            // Empty string should still deserialize as Some("")
1554            assert_eq!(thinking.as_ref().unwrap(), "");
1555            assert_eq!(content, "Response");
1556        } else {
1557            panic!("Expected Assistant message");
1558        }
1559    }
1560
1561    // Test thinking with tool calls
1562    #[test]
1563    fn test_thinking_with_tool_calls() {
1564        let sample_response = json!({
1565            "model": "qwen-thinking",
1566            "created_at": "2023-08-04T19:22:45.499127Z",
1567            "message": {
1568                "role": "assistant",
1569                "content": "Let me check the weather.",
1570                "thinking": "User wants weather info, I should use the weather tool",
1571                "images": null,
1572                "tool_calls": [
1573                    {
1574                        "type": "function",
1575                        "function": {
1576                            "name": "get_weather",
1577                            "arguments": {
1578                                "location": "San Francisco"
1579                            }
1580                        }
1581                    }
1582                ]
1583            },
1584            "done": true,
1585            "total_duration": 8000000000u64,
1586            "load_duration": 6000000u64,
1587            "prompt_eval_count": 30u64,
1588            "prompt_eval_duration": 400000000u64,
1589            "eval_count": 50u64,
1590            "eval_duration": 7700000000u64
1591        });
1592
1593        let chat_resp: CompletionResponse =
1594            serde_json::from_value(sample_response).expect("Failed to deserialize");
1595
1596        if let Message::Assistant {
1597            thinking,
1598            content,
1599            tool_calls,
1600            ..
1601        } = &chat_resp.message
1602        {
1603            assert_eq!(
1604                thinking.as_ref().unwrap(),
1605                "User wants weather info, I should use the weather tool"
1606            );
1607            assert_eq!(content, "Let me check the weather.");
1608            assert_eq!(tool_calls.len(), 1);
1609            assert_eq!(tool_calls[0].function.name, "get_weather");
1610        } else {
1611            panic!("Expected Assistant message with thinking and tool calls");
1612        }
1613    }
1614
1615    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1616    #[test]
1617    fn test_completion_request_with_think_param() {
1618        use crate::OneOrMany;
1619        use crate::completion::Message as CompletionMessage;
1620        use crate::message::{Text, UserContent};
1621
1622        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1623        let completion_request = CompletionRequest {
1624            model: None,
1625            preamble: Some("You are a helpful assistant.".to_string()),
1626            chat_history: OneOrMany::one(CompletionMessage::User {
1627                content: OneOrMany::one(UserContent::Text(Text::new("What is 2 + 2?".to_string()))),
1628            }),
1629            documents: vec![],
1630            tools: vec![],
1631            temperature: Some(0.7),
1632            max_tokens: Some(1024),
1633            tool_choice: None,
1634            additional_params: Some(json!({
1635                "think": true,
1636                "keep_alive": "-1m",
1637                "num_ctx": 4096
1638            })),
1639            output_schema: None,
1640        };
1641
1642        // Convert to OllamaCompletionRequest
1643        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1644            .expect("Failed to create Ollama request");
1645
1646        // Serialize to JSON
1647        let serialized =
1648            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1649
1650        // Assert equality with expected JSON
1651        // - "tools" is skipped when empty (skip_serializing_if)
1652        // - "think" should be a top-level boolean, NOT in options
1653        // - "keep_alive" should be a top-level string, NOT in options
1654        // - "num_ctx" should be in options (it's a model parameter)
1655        let expected = json!({
1656            "model": "qwen3:8b",
1657            "messages": [
1658                {
1659                    "role": "system",
1660                    "content": "You are a helpful assistant."
1661                },
1662                {
1663                    "role": "user",
1664                    "content": "What is 2 + 2?"
1665                }
1666            ],
1667            "temperature": 0.7,
1668            "stream": false,
1669            "think": true,
1670            "max_tokens": 1024,
1671            "keep_alive": "-1m",
1672            "options": {
1673                "temperature": 0.7,
1674                "num_ctx": 4096
1675            }
1676        });
1677
1678        assert_eq!(serialized, expected);
1679    }
1680
1681    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1682    #[test]
1683    fn test_completion_request_with_level_low_think_param() {
1684        use crate::OneOrMany;
1685        use crate::completion::Message as CompletionMessage;
1686        use crate::message::{Text, UserContent};
1687
1688        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1689        let completion_request = CompletionRequest {
1690            model: None,
1691            preamble: Some("You are a helpful assistant.".to_string()),
1692            chat_history: OneOrMany::one(CompletionMessage::User {
1693                content: OneOrMany::one(UserContent::Text(Text::new("What is 2 + 2?".to_string()))),
1694            }),
1695            documents: vec![],
1696            tools: vec![],
1697            temperature: Some(0.7),
1698            max_tokens: Some(1024),
1699            tool_choice: None,
1700            additional_params: Some(json!({
1701                "think": "low",
1702                "keep_alive": "-1m",
1703                "num_ctx": 4096
1704            })),
1705            output_schema: None,
1706        };
1707
1708        // Convert to OllamaCompletionRequest
1709        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1710            .expect("Failed to create Ollama request");
1711
1712        // Serialize to JSON
1713        let serialized =
1714            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1715
1716        // Assert equality with expected JSON
1717        // - "tools" is skipped when empty (skip_serializing_if)
1718        // - "think" should be a top-level boolean, NOT in options
1719        // - "keep_alive" should be a top-level string, NOT in options
1720        // - "num_ctx" should be in options (it's a model parameter)
1721        let expected = json!({
1722            "model": "qwen3:8b",
1723            "messages": [
1724                {
1725                    "role": "system",
1726                    "content": "You are a helpful assistant."
1727                },
1728                {
1729                    "role": "user",
1730                    "content": "What is 2 + 2?"
1731                }
1732            ],
1733            "temperature": 0.7,
1734            "stream": false,
1735            "think": "low",
1736            "max_tokens": 1024,
1737            "keep_alive": "-1m",
1738            "options": {
1739                "temperature": 0.7,
1740                "num_ctx": 4096
1741            }
1742        });
1743
1744        assert_eq!(serialized, expected);
1745    }
1746
1747    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1748    #[test]
1749    fn test_completion_request_with_level_medium_think_param() {
1750        use crate::OneOrMany;
1751        use crate::completion::Message as CompletionMessage;
1752        use crate::message::{Text, UserContent};
1753
1754        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1755        let completion_request = CompletionRequest {
1756            model: None,
1757            preamble: Some("You are a helpful assistant.".to_string()),
1758            chat_history: OneOrMany::one(CompletionMessage::User {
1759                content: OneOrMany::one(UserContent::Text(Text::new("What is 2 + 2?".to_string()))),
1760            }),
1761            documents: vec![],
1762            tools: vec![],
1763            temperature: Some(0.7),
1764            max_tokens: Some(1024),
1765            tool_choice: None,
1766            additional_params: Some(json!({
1767                "think": "medium",
1768                "keep_alive": "-1m",
1769                "num_ctx": 4096
1770            })),
1771            output_schema: None,
1772        };
1773
1774        // Convert to OllamaCompletionRequest
1775        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1776            .expect("Failed to create Ollama request");
1777
1778        // Serialize to JSON
1779        let serialized =
1780            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1781
1782        // Assert equality with expected JSON
1783        // - "tools" is skipped when empty (skip_serializing_if)
1784        // - "think" should be a top-level boolean, NOT in options
1785        // - "keep_alive" should be a top-level string, NOT in options
1786        // - "num_ctx" should be in options (it's a model parameter)
1787        let expected = json!({
1788            "model": "qwen3:8b",
1789            "messages": [
1790                {
1791                    "role": "system",
1792                    "content": "You are a helpful assistant."
1793                },
1794                {
1795                    "role": "user",
1796                    "content": "What is 2 + 2?"
1797                }
1798            ],
1799            "temperature": 0.7,
1800            "stream": false,
1801            "think": "medium",
1802            "max_tokens": 1024,
1803            "keep_alive": "-1m",
1804            "options": {
1805                "temperature": 0.7,
1806                "num_ctx": 4096
1807            }
1808        });
1809
1810        assert_eq!(serialized, expected);
1811    }
1812
1813    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1814    #[test]
1815    fn test_completion_request_with_level_high_think_param() {
1816        use crate::OneOrMany;
1817        use crate::completion::Message as CompletionMessage;
1818        use crate::message::{Text, UserContent};
1819
1820        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1821        let completion_request = CompletionRequest {
1822            model: None,
1823            preamble: Some("You are a helpful assistant.".to_string()),
1824            chat_history: OneOrMany::one(CompletionMessage::User {
1825                content: OneOrMany::one(UserContent::Text(Text::new("What is 2 + 2?".to_string()))),
1826            }),
1827            documents: vec![],
1828            tools: vec![],
1829            temperature: Some(0.7),
1830            max_tokens: Some(1024),
1831            tool_choice: None,
1832            additional_params: Some(json!({
1833                "think": "high",
1834                "keep_alive": "-1m",
1835                "num_ctx": 4096
1836            })),
1837            output_schema: None,
1838        };
1839
1840        // Convert to OllamaCompletionRequest
1841        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1842            .expect("Failed to create Ollama request");
1843
1844        // Serialize to JSON
1845        let serialized =
1846            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1847
1848        // Assert equality with expected JSON
1849        // - "tools" is skipped when empty (skip_serializing_if)
1850        // - "think" should be a top-level boolean, NOT in options
1851        // - "keep_alive" should be a top-level string, NOT in options
1852        // - "num_ctx" should be in options (it's a model parameter)
1853        let expected = json!({
1854            "model": "qwen3:8b",
1855            "messages": [
1856                {
1857                    "role": "system",
1858                    "content": "You are a helpful assistant."
1859                },
1860                {
1861                    "role": "user",
1862                    "content": "What is 2 + 2?"
1863                }
1864            ],
1865            "temperature": 0.7,
1866            "stream": false,
1867            "think": "high",
1868            "max_tokens": 1024,
1869            "keep_alive": "-1m",
1870            "options": {
1871                "temperature": 0.7,
1872                "num_ctx": 4096
1873            }
1874        });
1875
1876        assert_eq!(serialized, expected);
1877    }
1878
1879    // Test that `think` and `keep_alive` are extracted as top-level params, not in `options`
1880    #[test]
1881    fn test_completion_request_with_level_invalid_think_param() {
1882        use crate::OneOrMany;
1883        use crate::completion::Message as CompletionMessage;
1884        use crate::message::{Text, UserContent};
1885
1886        // Create a CompletionRequest with "think": true, "keep_alive", and "num_ctx" in additional_params
1887        let completion_request = CompletionRequest {
1888            model: None,
1889            preamble: Some("You are a helpful assistant.".to_string()),
1890            chat_history: OneOrMany::one(CompletionMessage::User {
1891                content: OneOrMany::one(UserContent::Text(Text::new("What is 2 + 2?".to_string()))),
1892            }),
1893            documents: vec![],
1894            tools: vec![],
1895            temperature: Some(0.7),
1896            max_tokens: Some(1024),
1897            tool_choice: None,
1898            additional_params: Some(json!({
1899                "think": "invalid",
1900                "keep_alive": "-1m",
1901                "num_ctx": 4096
1902            })),
1903            output_schema: None,
1904        };
1905
1906        // Convert to OllamaCompletionRequest
1907        let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request));
1908
1909        assert!(ollama_request.is_err())
1910    }
1911
1912    // Test that `think` defaults to false when not specified
1913    #[test]
1914    fn test_completion_request_with_think_false_default() {
1915        use crate::OneOrMany;
1916        use crate::completion::Message as CompletionMessage;
1917        use crate::message::{Text, UserContent};
1918
1919        // Create a CompletionRequest WITHOUT "think" in additional_params
1920        let completion_request = CompletionRequest {
1921            model: None,
1922            preamble: Some("You are a helpful assistant.".to_string()),
1923            chat_history: OneOrMany::one(CompletionMessage::User {
1924                content: OneOrMany::one(UserContent::Text(Text::new("Hello!".to_string()))),
1925            }),
1926            documents: vec![],
1927            tools: vec![],
1928            temperature: Some(0.5),
1929            max_tokens: None,
1930            tool_choice: None,
1931            additional_params: None,
1932            output_schema: None,
1933        };
1934
1935        // Convert to OllamaCompletionRequest
1936        let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1937            .expect("Failed to create Ollama request");
1938
1939        // Serialize to JSON
1940        let serialized =
1941            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1942
1943        // Assert that "think" defaults to false and "keep_alive" is not present
1944        let expected = json!({
1945            "model": "llama3.2",
1946            "messages": [
1947                {
1948                    "role": "system",
1949                    "content": "You are a helpful assistant."
1950                },
1951                {
1952                    "role": "user",
1953                    "content": "Hello!"
1954                }
1955            ],
1956            "temperature": 0.5,
1957            "stream": false,
1958            "think": false,
1959            "options": {
1960                "temperature": 0.5
1961            }
1962        });
1963
1964        assert_eq!(serialized, expected);
1965    }
1966
1967    #[test]
1968    fn test_completion_request_with_output_schema() {
1969        use crate::OneOrMany;
1970        use crate::completion::Message as CompletionMessage;
1971        use crate::message::{Text, UserContent};
1972
1973        let schema: schemars::Schema = serde_json::from_value(json!({
1974            "type": "object",
1975            "properties": {
1976                "age": { "type": "integer" },
1977                "available": { "type": "boolean" }
1978            },
1979            "required": ["age", "available"]
1980        }))
1981        .expect("Failed to parse schema");
1982
1983        let completion_request = CompletionRequest {
1984            model: Some("llama3.1".to_string()),
1985            preamble: None,
1986            chat_history: OneOrMany::one(CompletionMessage::User {
1987                content: OneOrMany::one(UserContent::Text(Text::new(
1988                    "How old is Ollama?".to_string(),
1989                ))),
1990            }),
1991            documents: vec![],
1992            tools: vec![],
1993            temperature: None,
1994            max_tokens: None,
1995            tool_choice: None,
1996            additional_params: None,
1997            output_schema: Some(schema),
1998        };
1999
2000        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
2001            .expect("Failed to create Ollama request");
2002
2003        let serialized =
2004            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
2005
2006        let format = serialized
2007            .get("format")
2008            .expect("format field should be present");
2009        assert_eq!(
2010            *format,
2011            json!({
2012                "type": "object",
2013                "properties": {
2014                    "age": { "type": "integer" },
2015                    "available": { "type": "boolean" }
2016                },
2017                "required": ["age", "available"]
2018            })
2019        );
2020    }
2021
2022    #[test]
2023    fn test_completion_request_without_output_schema() {
2024        use crate::OneOrMany;
2025        use crate::completion::Message as CompletionMessage;
2026        use crate::message::{Text, UserContent};
2027
2028        let completion_request = CompletionRequest {
2029            model: Some("llama3.1".to_string()),
2030            preamble: None,
2031            chat_history: OneOrMany::one(CompletionMessage::User {
2032                content: OneOrMany::one(UserContent::Text(Text::new("Hello!".to_string()))),
2033            }),
2034            documents: vec![],
2035            tools: vec![],
2036            temperature: None,
2037            max_tokens: None,
2038            tool_choice: None,
2039            additional_params: None,
2040            output_schema: None,
2041        };
2042
2043        let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
2044            .expect("Failed to create Ollama request");
2045
2046        let serialized =
2047            serde_json::to_value(&ollama_request).expect("Failed to serialize request");
2048
2049        assert!(
2050            serialized.get("format").is_none(),
2051            "format field should be absent when output_schema is None"
2052        );
2053    }
2054
2055    #[test]
2056    fn test_client_initialization() {
2057        let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
2058        let _client_from_builder = crate::providers::ollama::Client::builder()
2059            .api_key(Nothing)
2060            .build()
2061            .expect("Client::builder() failed");
2062    }
2063
2064    #[test]
2065    fn ndjson_buffer_returns_complete_lines_in_single_chunk() {
2066        let mut buf = NdjsonBuffer::new();
2067        let lines = buf.decode(b"{\"a\":1}\n{\"b\":2}\n");
2068        assert_eq!(lines, vec![b"{\"a\":1}".to_vec(), b"{\"b\":2}".to_vec()]);
2069    }
2070
2071    #[test]
2072    fn ndjson_buffer_reassembles_line_split_across_chunks() {
2073        let mut buf = NdjsonBuffer::new();
2074
2075        assert!(buf.decode(b"{\"model\":\"llama\",\"mes").is_empty());
2076
2077        let lines = buf.decode(b"sage\":\"hi\"}\n{\"done\"");
2078        assert_eq!(
2079            lines,
2080            vec![b"{\"model\":\"llama\",\"message\":\"hi\"}".to_vec()]
2081        );
2082
2083        let lines = buf.decode(b":true}\n");
2084        assert_eq!(lines, vec![b"{\"done\":true}".to_vec()]);
2085    }
2086
2087    #[test]
2088    fn ndjson_buffer_skips_blank_lines() {
2089        let mut buf = NdjsonBuffer::new();
2090        let lines = buf.decode(b"\n{\"a\":1}\n\n");
2091        assert_eq!(lines, vec![b"{\"a\":1}".to_vec()]);
2092    }
2093
2094    #[test]
2095    fn ndjson_buffer_retains_unterminated_trailing_data() {
2096        let mut buf = NdjsonBuffer::new();
2097        let lines = buf.decode(b"{\"a\":1}\n{\"b\":2");
2098        assert_eq!(lines, vec![b"{\"a\":1}".to_vec()]);
2099        let lines = buf.decode(b"}\n");
2100        assert_eq!(lines, vec![b"{\"b\":2}".to_vec()]);
2101    }
2102
2103    #[test]
2104    fn ndjson_buffer_handles_empty_chunk() {
2105        let mut buf = NdjsonBuffer::new();
2106        assert!(buf.decode(b"").is_empty());
2107
2108        buf.decode(b"{\"a\":1");
2109        assert!(buf.decode(b"").is_empty());
2110
2111        let lines = buf.decode(b"}\n");
2112        assert_eq!(lines, vec![b"{\"a\":1}".to_vec()]);
2113    }
2114
2115    #[test]
2116    fn ndjson_buffer_handles_multi_byte_utf8_split_across_chunks() {
2117        // `\n` (0x0A) cannot appear inside any UTF-8 continuation byte, so a
2118        // byte-wise newline scan is always safe — but verify explicitly that a
2119        // multi-byte sequence reassembles correctly when split across chunks.
2120        let mut buf = NdjsonBuffer::new();
2121        assert!(buf.decode(&[0xd0]).is_empty());
2122        assert!(buf.decode(&[0xb8, 0xd0, 0xb7, 0xd0]).is_empty());
2123        assert!(
2124            buf.decode(&[
2125                0xb2, 0xd0, 0xb5, 0xd1, 0x81, 0xd1, 0x82, 0xd0, 0xbd, 0xd0, 0xb8
2126            ])
2127            .is_empty()
2128        );
2129
2130        let lines = buf.decode(b"\n");
2131        assert_eq!(lines.len(), 1);
2132        assert_eq!(std::str::from_utf8(&lines[0]).unwrap(), "известни");
2133    }
2134
2135    #[test]
2136    fn ndjson_buffer_yields_parseable_chunks_when_split_arbitrarily() {
2137        let original = concat!(
2138            "{\"model\":\"llama3.2\",\"message\":{\"role\":\"assistant\",\"content\":\"hi\"},\"done\":false}\n",
2139            "{\"model\":\"llama3.2\",\"message\":{\"role\":\"assistant\",\"content\":\"\"},\"done\":true}\n",
2140        );
2141
2142        let mut buf = NdjsonBuffer::new();
2143        let mut received = Vec::new();
2144        for byte in original.as_bytes() {
2145            for line in buf.decode(std::slice::from_ref(byte)) {
2146                let parsed: serde_json::Value =
2147                    serde_json::from_slice(&line).expect("each drained line must be valid JSON");
2148                received.push(parsed);
2149            }
2150        }
2151
2152        assert_eq!(received.len(), 2);
2153        assert_eq!(received[0]["message"]["content"], "hi");
2154        assert_eq!(received[1]["done"], true);
2155    }
2156}