rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49    OneOrMany,
50    completion::{self, CompletionError, CompletionRequest},
51    embeddings::{self, EmbeddingError},
52    json_utils, message,
53    message::{ImageDetail, Text},
54    streaming,
55};
56use async_stream::try_stream;
57use bytes::Bytes;
58use futures::StreamExt;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use tracing::info_span;
63use tracing_futures::Instrument;
64// ---------- Main Client ----------
65
66const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68#[derive(Debug, Default, Clone, Copy)]
69pub struct OllamaExt;
70
71#[derive(Debug, Default, Clone, Copy)]
72pub struct OllamaBuilder;
73
74impl Provider for OllamaExt {
75    type Builder = OllamaBuilder;
76
77    const VERIFY_PATH: &'static str = "api/tags";
78
79    fn build<H>(
80        _: &crate::client::ClientBuilder<
81            Self::Builder,
82            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
83            H,
84        >,
85    ) -> http_client::Result<Self> {
86        Ok(Self)
87    }
88}
89
90impl<H> Capabilities<H> for OllamaExt {
91    type Completion = Capable<CompletionModel<H>>;
92    type Transcription = Nothing;
93    type Embeddings = Capable<EmbeddingModel<H>>;
94    #[cfg(feature = "image")]
95    type ImageGeneration = Nothing;
96
97    #[cfg(feature = "audio")]
98    type AudioGeneration = Nothing;
99}
100
101impl DebugExt for OllamaExt {}
102
103impl ProviderBuilder for OllamaBuilder {
104    type Output = OllamaExt;
105    type ApiKey = Nothing;
106
107    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
108}
109
110pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
111pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
112
113impl ProviderClient for Client {
114    type Input = Nothing;
115
116    fn from_env() -> Self {
117        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
118
119        Self::builder()
120            .api_key(Nothing)
121            .base_url(&api_base)
122            .build()
123            .unwrap()
124    }
125
126    fn from_val(_: Self::Input) -> Self {
127        Self::builder().api_key(Nothing).build().unwrap()
128    }
129}
130
131// ---------- API Error and Response Structures ----------
132
133#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135    message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141    Ok(T),
142    Err(ApiErrorResponse),
143}
144
145// ---------- Embedding API ----------
146
147pub const ALL_MINILM: &str = "all-minilm";
148pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct EmbeddingResponse {
152    pub model: String,
153    pub embeddings: Vec<Vec<f64>>,
154    #[serde(default)]
155    pub total_duration: Option<u64>,
156    #[serde(default)]
157    pub load_duration: Option<u64>,
158    #[serde(default)]
159    pub prompt_eval_count: Option<u64>,
160}
161
162impl From<ApiErrorResponse> for EmbeddingError {
163    fn from(err: ApiErrorResponse) -> Self {
164        EmbeddingError::ProviderError(err.message)
165    }
166}
167
168impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
169    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
170        match value {
171            ApiResponse::Ok(response) => Ok(response),
172            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
173        }
174    }
175}
176
177// ---------- Embedding Model ----------
178
179#[derive(Clone)]
180pub struct EmbeddingModel<T> {
181    client: Client<T>,
182    pub model: String,
183    ndims: usize,
184}
185
186impl<T> EmbeddingModel<T> {
187    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
188        Self {
189            client,
190            model: model.into(),
191            ndims,
192        }
193    }
194
195    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
196        Self {
197            client,
198            model: model.into(),
199            ndims,
200        }
201    }
202}
203
204impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
205where
206    T: HttpClientExt + Clone + 'static,
207{
208    type Client = Client<T>;
209
210    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
211        Self::new(client.clone(), model, dims.unwrap())
212    }
213
214    const MAX_DOCUMENTS: usize = 1024;
215    fn ndims(&self) -> usize {
216        self.ndims
217    }
218    #[cfg_attr(feature = "worker", worker::send)]
219    async fn embed_texts(
220        &self,
221        documents: impl IntoIterator<Item = String>,
222    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223        let docs: Vec<String> = documents.into_iter().collect();
224
225        let body = serde_json::to_vec(&json!({
226            "model": self.model,
227            "input": docs
228        }))?;
229
230        let req = self
231            .client
232            .post("api/embed")?
233            .body(body)
234            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
235
236        let response = HttpClientExt::send(self.client.http_client(), req).await?;
237
238        if !response.status().is_success() {
239            let text = http_client::text(response).await?;
240            return Err(EmbeddingError::ProviderError(text));
241        }
242
243        let bytes: Vec<u8> = response.into_body().await?;
244
245        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
246
247        if api_resp.embeddings.len() != docs.len() {
248            return Err(EmbeddingError::ResponseError(
249                "Number of returned embeddings does not match input".into(),
250            ));
251        }
252        Ok(api_resp
253            .embeddings
254            .into_iter()
255            .zip(docs.into_iter())
256            .map(|(vec, document)| embeddings::Embedding { document, vec })
257            .collect())
258    }
259}
260
261// ---------- Completion API ----------
262
263pub const LLAMA3_2: &str = "llama3.2";
264pub const LLAVA: &str = "llava";
265pub const MISTRAL: &str = "mistral";
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct CompletionResponse {
269    pub model: String,
270    pub created_at: String,
271    pub message: Message,
272    pub done: bool,
273    #[serde(default)]
274    pub done_reason: Option<String>,
275    #[serde(default)]
276    pub total_duration: Option<u64>,
277    #[serde(default)]
278    pub load_duration: Option<u64>,
279    #[serde(default)]
280    pub prompt_eval_count: Option<u64>,
281    #[serde(default)]
282    pub prompt_eval_duration: Option<u64>,
283    #[serde(default)]
284    pub eval_count: Option<u64>,
285    #[serde(default)]
286    pub eval_duration: Option<u64>,
287}
288impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
289    type Error = CompletionError;
290    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
291        match resp.message {
292            // Process only if an assistant message is present.
293            Message::Assistant {
294                content,
295                thinking,
296                tool_calls,
297                ..
298            } => {
299                let mut assistant_contents = Vec::new();
300                // Add the assistant's text content if any.
301                if !content.is_empty() {
302                    assistant_contents.push(completion::AssistantContent::text(&content));
303                }
304                // Process tool_calls following Ollama's chat response definition.
305                // Each ToolCall has an id, a type, and a function field.
306                for tc in tool_calls.iter() {
307                    assistant_contents.push(completion::AssistantContent::tool_call(
308                        tc.function.name.clone(),
309                        tc.function.name.clone(),
310                        tc.function.arguments.clone(),
311                    ));
312                }
313                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
314                    CompletionError::ResponseError("No content provided".to_owned())
315                })?;
316                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
317                let completion_tokens = resp.eval_count.unwrap_or(0);
318
319                let raw_response = CompletionResponse {
320                    model: resp.model,
321                    created_at: resp.created_at,
322                    done: resp.done,
323                    done_reason: resp.done_reason,
324                    total_duration: resp.total_duration,
325                    load_duration: resp.load_duration,
326                    prompt_eval_count: resp.prompt_eval_count,
327                    prompt_eval_duration: resp.prompt_eval_duration,
328                    eval_count: resp.eval_count,
329                    eval_duration: resp.eval_duration,
330                    message: Message::Assistant {
331                        content,
332                        thinking,
333                        images: None,
334                        name: None,
335                        tool_calls,
336                    },
337                };
338
339                Ok(completion::CompletionResponse {
340                    choice,
341                    usage: Usage {
342                        input_tokens: prompt_tokens,
343                        output_tokens: completion_tokens,
344                        total_tokens: prompt_tokens + completion_tokens,
345                    },
346                    raw_response,
347                })
348            }
349            _ => Err(CompletionError::ResponseError(
350                "Chat response does not include an assistant message".into(),
351            )),
352        }
353    }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357pub(super) struct OllamaCompletionRequest {
358    model: String,
359    pub messages: Vec<Message>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    temperature: Option<f64>,
362    #[serde(skip_serializing_if = "Vec::is_empty")]
363    tools: Vec<ToolDefinition>,
364    pub stream: bool,
365    think: bool,
366    #[serde(skip_serializing_if = "Option::is_none")]
367    max_tokens: Option<u64>,
368    options: serde_json::Value,
369}
370
371impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
372    type Error = CompletionError;
373
374    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
375        if req.tool_choice.is_some() {
376            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
377        }
378        // Build up the order of messages (context, chat_history, prompt)
379        let mut partial_history = vec![];
380        if let Some(docs) = req.normalized_documents() {
381            partial_history.push(docs);
382        }
383        partial_history.extend(req.chat_history);
384
385        // Add preamble to chat history (if available)
386        let mut full_history: Vec<Message> = match &req.preamble {
387            Some(preamble) => vec![Message::system(preamble)],
388            None => vec![],
389        };
390
391        // Convert and extend the rest of the history
392        full_history.extend(
393            partial_history
394                .into_iter()
395                .map(message::Message::try_into)
396                .collect::<Result<Vec<Vec<Message>>, _>>()?
397                .into_iter()
398                .flatten()
399                .collect::<Vec<_>>(),
400        );
401
402        let mut think = false;
403
404        // TODO: Fix this up to include the full range of ollama options
405        let options = if let Some(mut extra) = req.additional_params {
406            if extra.get("think").is_some() {
407                think = extra["think"].take().as_bool().ok_or_else(|| {
408                    CompletionError::RequestError("`think` must be a bool".into())
409                })?;
410            }
411            json_utils::merge(json!({ "temperature": req.temperature }), extra)
412        } else {
413            json!({ "temperature": req.temperature })
414        };
415
416        Ok(Self {
417            model: model.to_string(),
418            messages: full_history,
419            temperature: req.temperature,
420            max_tokens: req.max_tokens,
421            stream: false,
422            think,
423            tools: req
424                .tools
425                .clone()
426                .into_iter()
427                .map(ToolDefinition::from)
428                .collect::<Vec<_>>(),
429            options,
430        })
431    }
432}
433
434#[derive(Clone)]
435pub struct CompletionModel<T = reqwest::Client> {
436    client: Client<T>,
437    pub model: String,
438}
439
440impl<T> CompletionModel<T> {
441    pub fn new(client: Client<T>, model: &str) -> Self {
442        Self {
443            client,
444            model: model.to_owned(),
445        }
446    }
447}
448
449// ---------- CompletionModel Implementation ----------
450
451#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct StreamingCompletionResponse {
453    pub done_reason: Option<String>,
454    pub total_duration: Option<u64>,
455    pub load_duration: Option<u64>,
456    pub prompt_eval_count: Option<u64>,
457    pub prompt_eval_duration: Option<u64>,
458    pub eval_count: Option<u64>,
459    pub eval_duration: Option<u64>,
460}
461
462impl GetTokenUsage for StreamingCompletionResponse {
463    fn token_usage(&self) -> Option<crate::completion::Usage> {
464        let mut usage = crate::completion::Usage::new();
465        let input_tokens = self.prompt_eval_count.unwrap_or_default();
466        let output_tokens = self.eval_count.unwrap_or_default();
467        usage.input_tokens = input_tokens;
468        usage.output_tokens = output_tokens;
469        usage.total_tokens = input_tokens + output_tokens;
470
471        Some(usage)
472    }
473}
474
475impl<T> completion::CompletionModel for CompletionModel<T>
476where
477    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
478{
479    type Response = CompletionResponse;
480    type StreamingResponse = StreamingCompletionResponse;
481
482    type Client = Client<T>;
483
484    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
485        Self::new(client.clone(), model.into().as_str())
486    }
487
488    #[cfg_attr(feature = "worker", worker::send)]
489    async fn completion(
490        &self,
491        completion_request: CompletionRequest,
492    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
493        let preamble = completion_request.preamble.clone();
494        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
495
496        let span = if tracing::Span::current().is_disabled() {
497            info_span!(
498                target: "rig::completions",
499                "chat",
500                gen_ai.operation.name = "chat",
501                gen_ai.provider.name = "ollama",
502                gen_ai.request.model = self.model,
503                gen_ai.system_instructions = preamble,
504                gen_ai.response.id = tracing::field::Empty,
505                gen_ai.response.model = tracing::field::Empty,
506                gen_ai.usage.output_tokens = tracing::field::Empty,
507                gen_ai.usage.input_tokens = tracing::field::Empty,
508                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
509                gen_ai.output.messages = tracing::field::Empty,
510            )
511        } else {
512            tracing::Span::current()
513        };
514
515        let body = serde_json::to_vec(&request)?;
516
517        let req = self
518            .client
519            .post("api/chat")?
520            .body(body)
521            .map_err(http_client::Error::from)?;
522
523        let async_block = async move {
524            let response = self.client.send::<_, Bytes>(req).await?;
525            let status = response.status();
526            let response_body = response.into_body().into_future().await?.to_vec();
527
528            if !status.is_success() {
529                return Err(CompletionError::ProviderError(
530                    String::from_utf8_lossy(&response_body).to_string(),
531                ));
532            }
533
534            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
535            let span = tracing::Span::current();
536            span.record("gen_ai.response.model_name", &response.model);
537            span.record(
538                "gen_ai.output.messages",
539                serde_json::to_string(&vec![&response.message]).unwrap(),
540            );
541            span.record(
542                "gen_ai.usage.input_tokens",
543                response.prompt_eval_count.unwrap_or_default(),
544            );
545            span.record(
546                "gen_ai.usage.output_tokens",
547                response.eval_count.unwrap_or_default(),
548            );
549
550            tracing::trace!(
551                target: "rig::completions",
552                "Ollama completion response: {}",
553                serde_json::to_string_pretty(&response)?
554            );
555
556            let response: completion::CompletionResponse<CompletionResponse> =
557                response.try_into()?;
558
559            Ok(response)
560        };
561
562        tracing::Instrument::instrument(async_block, span).await
563    }
564
565    #[cfg_attr(feature = "worker", worker::send)]
566    async fn stream(
567        &self,
568        request: CompletionRequest,
569    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
570    {
571        let preamble = request.preamble.clone();
572        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
573        request.stream = true;
574
575        let span = if tracing::Span::current().is_disabled() {
576            info_span!(
577                target: "rig::completions",
578                "chat_streaming",
579                gen_ai.operation.name = "chat_streaming",
580                gen_ai.provider.name = "ollama",
581                gen_ai.request.model = self.model,
582                gen_ai.system_instructions = preamble,
583                gen_ai.response.id = tracing::field::Empty,
584                gen_ai.response.model = self.model,
585                gen_ai.usage.output_tokens = tracing::field::Empty,
586                gen_ai.usage.input_tokens = tracing::field::Empty,
587                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
588                gen_ai.output.messages = tracing::field::Empty,
589            )
590        } else {
591            tracing::Span::current()
592        };
593
594        let body = serde_json::to_vec(&request)?;
595
596        let req = self
597            .client
598            .post("api/chat")?
599            .body(body)
600            .map_err(http_client::Error::from)?;
601
602        let response = self.client.http_client().send_streaming(req).await?;
603        let status = response.status();
604        let mut byte_stream = response.into_body();
605
606        if !status.is_success() {
607            return Err(CompletionError::ProviderError(format!(
608                "Got error status code trying to send a request to Ollama: {status}"
609            )));
610        }
611
612        let stream = try_stream! {
613            let span = tracing::Span::current();
614            let mut tool_calls_final = Vec::new();
615            let mut text_response = String::new();
616            let mut thinking_response = String::new();
617
618            while let Some(chunk) = byte_stream.next().await {
619                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
620
621                for line in bytes.split(|&b| b == b'\n') {
622                    if line.is_empty() {
623                        continue;
624                    }
625
626                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
627
628                    let response: CompletionResponse = serde_json::from_slice(line)?;
629
630                    if response.done {
631                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
632                        span.record("gen_ai.usage.output_tokens", response.eval_count);
633                        let message = Message::Assistant {
634                            content: text_response.clone(),
635                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
636                            images: None,
637                            name: None,
638                            tool_calls: tool_calls_final.clone()
639                        };
640                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
641                        yield RawStreamingChoice::FinalResponse(
642                            StreamingCompletionResponse {
643                                total_duration: response.total_duration,
644                                load_duration: response.load_duration,
645                                prompt_eval_count: response.prompt_eval_count,
646                                prompt_eval_duration: response.prompt_eval_duration,
647                                eval_count: response.eval_count,
648                                eval_duration: response.eval_duration,
649                                done_reason: response.done_reason,
650                            }
651                        );
652                        break;
653                    }
654
655                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
656                        if let Some(thinking_content) = thinking
657                            && !thinking_content.is_empty() {
658                            thinking_response += &thinking_content;
659                            yield RawStreamingChoice::Reasoning {
660                                reasoning: thinking_content,
661                                id: None,
662                                signature: None,
663                            };
664                        }
665
666                        if !content.is_empty() {
667                            text_response += &content;
668                            yield RawStreamingChoice::Message(content);
669                        }
670
671                        for tool_call in tool_calls {
672                            tool_calls_final.push(tool_call.clone());
673                            yield RawStreamingChoice::ToolCall {
674                                id: String::new(),
675                                name: tool_call.function.name,
676                                arguments: tool_call.function.arguments,
677                                call_id: None,
678                            };
679                        }
680                    }
681                }
682            }
683        }.instrument(span);
684
685        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
686            stream,
687        )))
688    }
689}
690
691// ---------- Tool Definition Conversion ----------
692
693/// Ollama-required tool definition format.
694#[derive(Clone, Debug, Deserialize, Serialize)]
695pub struct ToolDefinition {
696    #[serde(rename = "type")]
697    pub type_field: String, // Fixed as "function"
698    pub function: completion::ToolDefinition,
699}
700
701/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
702impl From<crate::completion::ToolDefinition> for ToolDefinition {
703    fn from(tool: crate::completion::ToolDefinition) -> Self {
704        ToolDefinition {
705            type_field: "function".to_owned(),
706            function: completion::ToolDefinition {
707                name: tool.name,
708                description: tool.description,
709                parameters: tool.parameters,
710            },
711        }
712    }
713}
714
715#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
716pub struct ToolCall {
717    #[serde(default, rename = "type")]
718    pub r#type: ToolType,
719    pub function: Function,
720}
721#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
722#[serde(rename_all = "lowercase")]
723pub enum ToolType {
724    #[default]
725    Function,
726}
727#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
728pub struct Function {
729    pub name: String,
730    pub arguments: Value,
731}
732
733// ---------- Provider Message Definition ----------
734
735#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
736#[serde(tag = "role", rename_all = "lowercase")]
737pub enum Message {
738    User {
739        content: String,
740        #[serde(skip_serializing_if = "Option::is_none")]
741        images: Option<Vec<String>>,
742        #[serde(skip_serializing_if = "Option::is_none")]
743        name: Option<String>,
744    },
745    Assistant {
746        #[serde(default)]
747        content: String,
748        #[serde(skip_serializing_if = "Option::is_none")]
749        thinking: Option<String>,
750        #[serde(skip_serializing_if = "Option::is_none")]
751        images: Option<Vec<String>>,
752        #[serde(skip_serializing_if = "Option::is_none")]
753        name: Option<String>,
754        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
755        tool_calls: Vec<ToolCall>,
756    },
757    System {
758        content: String,
759        #[serde(skip_serializing_if = "Option::is_none")]
760        images: Option<Vec<String>>,
761        #[serde(skip_serializing_if = "Option::is_none")]
762        name: Option<String>,
763    },
764    #[serde(rename = "tool")]
765    ToolResult {
766        #[serde(rename = "tool_name")]
767        name: String,
768        content: String,
769    },
770}
771
772/// -----------------------------
773/// Provider Message Conversions
774/// -----------------------------
775/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
776/// (Only User and Assistant variants are supported.)
777impl TryFrom<crate::message::Message> for Vec<Message> {
778    type Error = crate::message::MessageError;
779    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
780        use crate::message::Message as InternalMessage;
781        match internal_msg {
782            InternalMessage::User { content, .. } => {
783                let (tool_results, other_content): (Vec<_>, Vec<_>) =
784                    content.into_iter().partition(|content| {
785                        matches!(content, crate::message::UserContent::ToolResult(_))
786                    });
787
788                if !tool_results.is_empty() {
789                    tool_results
790                        .into_iter()
791                        .map(|content| match content {
792                            crate::message::UserContent::ToolResult(
793                                crate::message::ToolResult { id, content, .. },
794                            ) => {
795                                // Ollama expects a single string for tool results, so we concatenate
796                                let content_string = content
797                                    .into_iter()
798                                    .map(|content| match content {
799                                        crate::message::ToolResultContent::Text(text) => text.text,
800                                        _ => "[Non-text content]".to_string(),
801                                    })
802                                    .collect::<Vec<_>>()
803                                    .join("\n");
804
805                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
806                                    name: id,
807                                    content: content_string,
808                                })
809                            }
810                            _ => unreachable!(),
811                        })
812                        .collect::<Result<Vec<_>, _>>()
813                } else {
814                    // Ollama requires separate text content and images array
815                    let (texts, images) = other_content.into_iter().fold(
816                        (Vec::new(), Vec::new()),
817                        |(mut texts, mut images), content| {
818                            match content {
819                                crate::message::UserContent::Text(crate::message::Text {
820                                    text,
821                                }) => texts.push(text),
822                                crate::message::UserContent::Image(crate::message::Image {
823                                    data: DocumentSourceKind::Base64(data),
824                                    ..
825                                }) => images.push(data),
826                                crate::message::UserContent::Document(
827                                    crate::message::Document {
828                                        data:
829                                            DocumentSourceKind::Base64(data)
830                                            | DocumentSourceKind::String(data),
831                                        ..
832                                    },
833                                ) => texts.push(data),
834                                _ => {} // Audio not supported by Ollama
835                            }
836                            (texts, images)
837                        },
838                    );
839
840                    Ok(vec![Message::User {
841                        content: texts.join(" "),
842                        images: if images.is_empty() {
843                            None
844                        } else {
845                            Some(
846                                images
847                                    .into_iter()
848                                    .map(|x| x.to_string())
849                                    .collect::<Vec<String>>(),
850                            )
851                        },
852                        name: None,
853                    }])
854                }
855            }
856            InternalMessage::Assistant { content, .. } => {
857                let mut thinking: Option<String> = None;
858                let mut text_content = Vec::new();
859                let mut tool_calls = Vec::new();
860
861                for content in content.into_iter() {
862                    match content {
863                        crate::message::AssistantContent::Text(text) => {
864                            text_content.push(text.text)
865                        }
866                        crate::message::AssistantContent::ToolCall(tool_call) => {
867                            tool_calls.push(tool_call)
868                        }
869                        crate::message::AssistantContent::Reasoning(
870                            crate::message::Reasoning { reasoning, .. },
871                        ) => {
872                            thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
873                        }
874                        crate::message::AssistantContent::Image(_) => {
875                            return Err(crate::message::MessageError::ConversionError(
876                                "Ollama currently doesn't support images.".into(),
877                            ));
878                        }
879                    }
880                }
881
882                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
883                //  so either `content` or `tool_calls` will have some content.
884                Ok(vec![Message::Assistant {
885                    content: text_content.join(" "),
886                    thinking,
887                    images: None,
888                    name: None,
889                    tool_calls: tool_calls
890                        .into_iter()
891                        .map(|tool_call| tool_call.into())
892                        .collect::<Vec<_>>(),
893                }])
894            }
895        }
896    }
897}
898
899/// Conversion from provider Message to a completion message.
900/// This is needed so that responses can be converted back into chat history.
901impl From<Message> for crate::completion::Message {
902    fn from(msg: Message) -> Self {
903        match msg {
904            Message::User { content, .. } => crate::completion::Message::User {
905                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
906                    text: content,
907                })),
908            },
909            Message::Assistant {
910                content,
911                tool_calls,
912                ..
913            } => {
914                let mut assistant_contents =
915                    vec![crate::completion::message::AssistantContent::Text(Text {
916                        text: content,
917                    })];
918                for tc in tool_calls {
919                    assistant_contents.push(
920                        crate::completion::message::AssistantContent::tool_call(
921                            tc.function.name.clone(),
922                            tc.function.name,
923                            tc.function.arguments,
924                        ),
925                    );
926                }
927                crate::completion::Message::Assistant {
928                    id: None,
929                    content: OneOrMany::many(assistant_contents).unwrap(),
930                }
931            }
932            // System and ToolResult are converted to User message as needed.
933            Message::System { content, .. } => crate::completion::Message::User {
934                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
935                    text: content,
936                })),
937            },
938            Message::ToolResult { name, content } => crate::completion::Message::User {
939                content: OneOrMany::one(message::UserContent::tool_result(
940                    name,
941                    OneOrMany::one(message::ToolResultContent::text(content)),
942                )),
943            },
944        }
945    }
946}
947
948impl Message {
949    /// Constructs a system message.
950    pub fn system(content: &str) -> Self {
951        Message::System {
952            content: content.to_owned(),
953            images: None,
954            name: None,
955        }
956    }
957}
958
959// ---------- Additional Message Types ----------
960
961impl From<crate::message::ToolCall> for ToolCall {
962    fn from(tool_call: crate::message::ToolCall) -> Self {
963        Self {
964            r#type: ToolType::Function,
965            function: Function {
966                name: tool_call.function.name,
967                arguments: tool_call.function.arguments,
968            },
969        }
970    }
971}
972
973#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
974pub struct SystemContent {
975    #[serde(default)]
976    r#type: SystemContentType,
977    text: String,
978}
979
980#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
981#[serde(rename_all = "lowercase")]
982pub enum SystemContentType {
983    #[default]
984    Text,
985}
986
987impl From<String> for SystemContent {
988    fn from(s: String) -> Self {
989        SystemContent {
990            r#type: SystemContentType::default(),
991            text: s,
992        }
993    }
994}
995
996impl FromStr for SystemContent {
997    type Err = std::convert::Infallible;
998    fn from_str(s: &str) -> Result<Self, Self::Err> {
999        Ok(SystemContent {
1000            r#type: SystemContentType::default(),
1001            text: s.to_string(),
1002        })
1003    }
1004}
1005
1006#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1007pub struct AssistantContent {
1008    pub text: String,
1009}
1010
1011impl FromStr for AssistantContent {
1012    type Err = std::convert::Infallible;
1013    fn from_str(s: &str) -> Result<Self, Self::Err> {
1014        Ok(AssistantContent { text: s.to_owned() })
1015    }
1016}
1017
1018#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1019#[serde(tag = "type", rename_all = "lowercase")]
1020pub enum UserContent {
1021    Text { text: String },
1022    Image { image_url: ImageUrl },
1023    // Audio variant removed as Ollama API does not support audio input.
1024}
1025
1026impl FromStr for UserContent {
1027    type Err = std::convert::Infallible;
1028    fn from_str(s: &str) -> Result<Self, Self::Err> {
1029        Ok(UserContent::Text { text: s.to_owned() })
1030    }
1031}
1032
1033#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1034pub struct ImageUrl {
1035    pub url: String,
1036    #[serde(default)]
1037    pub detail: ImageDetail,
1038}
1039
1040// =================================================================
1041// Tests
1042// =================================================================
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047    use serde_json::json;
1048
1049    // Test deserialization and conversion for the /api/chat endpoint.
1050    #[tokio::test]
1051    async fn test_chat_completion() {
1052        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1053        let sample_chat_response = json!({
1054            "model": "llama3.2",
1055            "created_at": "2023-08-04T19:22:45.499127Z",
1056            "message": {
1057                "role": "assistant",
1058                "content": "The sky is blue because of Rayleigh scattering.",
1059                "images": null,
1060                "tool_calls": [
1061                    {
1062                        "type": "function",
1063                        "function": {
1064                            "name": "get_current_weather",
1065                            "arguments": {
1066                                "location": "San Francisco, CA",
1067                                "format": "celsius"
1068                            }
1069                        }
1070                    }
1071                ]
1072            },
1073            "done": true,
1074            "total_duration": 8000000000u64,
1075            "load_duration": 6000000u64,
1076            "prompt_eval_count": 61u64,
1077            "prompt_eval_duration": 400000000u64,
1078            "eval_count": 468u64,
1079            "eval_duration": 7700000000u64
1080        });
1081        let sample_text = sample_chat_response.to_string();
1082
1083        let chat_resp: CompletionResponse =
1084            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1085        let conv: completion::CompletionResponse<CompletionResponse> =
1086            chat_resp.try_into().unwrap();
1087        assert!(
1088            !conv.choice.is_empty(),
1089            "Expected non-empty choice in chat response"
1090        );
1091    }
1092
1093    // Test conversion from provider Message to completion Message.
1094    #[test]
1095    fn test_message_conversion() {
1096        // Construct a provider Message (User variant with String content).
1097        let provider_msg = Message::User {
1098            content: "Test message".to_owned(),
1099            images: None,
1100            name: None,
1101        };
1102        // Convert it into a completion::Message.
1103        let comp_msg: crate::completion::Message = provider_msg.into();
1104        match comp_msg {
1105            crate::completion::Message::User { content } => {
1106                // Assume OneOrMany<T> has a method first() to access the first element.
1107                let first_content = content.first();
1108                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1109                match first_content {
1110                    crate::completion::message::UserContent::Text(text_struct) => {
1111                        assert_eq!(text_struct.text, "Test message");
1112                    }
1113                    _ => panic!("Expected text content in conversion"),
1114                }
1115            }
1116            _ => panic!("Conversion from provider Message to completion Message failed"),
1117        }
1118    }
1119
1120    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1121    #[test]
1122    fn test_tool_definition_conversion() {
1123        // Internal tool definition from the completion module.
1124        let internal_tool = crate::completion::ToolDefinition {
1125            name: "get_current_weather".to_owned(),
1126            description: "Get the current weather for a location".to_owned(),
1127            parameters: json!({
1128                "type": "object",
1129                "properties": {
1130                    "location": {
1131                        "type": "string",
1132                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1133                    },
1134                    "format": {
1135                        "type": "string",
1136                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1137                        "enum": ["celsius", "fahrenheit"]
1138                    }
1139                },
1140                "required": ["location", "format"]
1141            }),
1142        };
1143        // Convert internal tool to Ollama's tool definition.
1144        let ollama_tool: ToolDefinition = internal_tool.into();
1145        assert_eq!(ollama_tool.type_field, "function");
1146        assert_eq!(ollama_tool.function.name, "get_current_weather");
1147        assert_eq!(
1148            ollama_tool.function.description,
1149            "Get the current weather for a location"
1150        );
1151        // Check JSON fields in parameters.
1152        let params = &ollama_tool.function.parameters;
1153        assert_eq!(params["properties"]["location"]["type"], "string");
1154    }
1155
1156    // Test deserialization of chat response with thinking content
1157    #[tokio::test]
1158    async fn test_chat_completion_with_thinking() {
1159        let sample_response = json!({
1160            "model": "qwen-thinking",
1161            "created_at": "2023-08-04T19:22:45.499127Z",
1162            "message": {
1163                "role": "assistant",
1164                "content": "The answer is 42.",
1165                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1166                "images": null,
1167                "tool_calls": []
1168            },
1169            "done": true,
1170            "total_duration": 8000000000u64,
1171            "load_duration": 6000000u64,
1172            "prompt_eval_count": 61u64,
1173            "prompt_eval_duration": 400000000u64,
1174            "eval_count": 468u64,
1175            "eval_duration": 7700000000u64
1176        });
1177
1178        let chat_resp: CompletionResponse =
1179            serde_json::from_value(sample_response).expect("Failed to deserialize");
1180
1181        // Verify thinking field is present
1182        if let Message::Assistant {
1183            thinking, content, ..
1184        } = &chat_resp.message
1185        {
1186            assert_eq!(
1187                thinking.as_ref().unwrap(),
1188                "Let me think about this carefully. The question asks for the meaning of life..."
1189            );
1190            assert_eq!(content, "The answer is 42.");
1191        } else {
1192            panic!("Expected Assistant message");
1193        }
1194    }
1195
1196    // Test deserialization of chat response without thinking content
1197    #[tokio::test]
1198    async fn test_chat_completion_without_thinking() {
1199        let sample_response = json!({
1200            "model": "llama3.2",
1201            "created_at": "2023-08-04T19:22:45.499127Z",
1202            "message": {
1203                "role": "assistant",
1204                "content": "Hello!",
1205                "images": null,
1206                "tool_calls": []
1207            },
1208            "done": true,
1209            "total_duration": 8000000000u64,
1210            "load_duration": 6000000u64,
1211            "prompt_eval_count": 10u64,
1212            "prompt_eval_duration": 400000000u64,
1213            "eval_count": 5u64,
1214            "eval_duration": 7700000000u64
1215        });
1216
1217        let chat_resp: CompletionResponse =
1218            serde_json::from_value(sample_response).expect("Failed to deserialize");
1219
1220        // Verify thinking field is None when not provided
1221        if let Message::Assistant {
1222            thinking, content, ..
1223        } = &chat_resp.message
1224        {
1225            assert!(thinking.is_none());
1226            assert_eq!(content, "Hello!");
1227        } else {
1228            panic!("Expected Assistant message");
1229        }
1230    }
1231
1232    // Test deserialization of streaming response with thinking content
1233    #[test]
1234    fn test_streaming_response_with_thinking() {
1235        let sample_chunk = json!({
1236            "model": "qwen-thinking",
1237            "created_at": "2023-08-04T19:22:45.499127Z",
1238            "message": {
1239                "role": "assistant",
1240                "content": "",
1241                "thinking": "Analyzing the problem...",
1242                "images": null,
1243                "tool_calls": []
1244            },
1245            "done": false
1246        });
1247
1248        let chunk: CompletionResponse =
1249            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1250
1251        if let Message::Assistant {
1252            thinking, content, ..
1253        } = &chunk.message
1254        {
1255            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1256            assert_eq!(content, "");
1257        } else {
1258            panic!("Expected Assistant message");
1259        }
1260    }
1261
1262    // Test message conversion with thinking content
1263    #[test]
1264    fn test_message_conversion_with_thinking() {
1265        // Create an internal message with reasoning content
1266        let reasoning_content = crate::message::Reasoning {
1267            id: None,
1268            reasoning: vec!["Step 1: Consider the problem".to_string()],
1269            signature: None,
1270        };
1271
1272        let internal_msg = crate::message::Message::Assistant {
1273            id: None,
1274            content: crate::OneOrMany::many(vec![
1275                crate::message::AssistantContent::Reasoning(reasoning_content),
1276                crate::message::AssistantContent::Text(crate::message::Text {
1277                    text: "The answer is X".to_string(),
1278                }),
1279            ])
1280            .unwrap(),
1281        };
1282
1283        // Convert to provider Message
1284        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1285        assert_eq!(provider_msgs.len(), 1);
1286
1287        if let Message::Assistant {
1288            thinking, content, ..
1289        } = &provider_msgs[0]
1290        {
1291            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1292            assert_eq!(content, "The answer is X");
1293        } else {
1294            panic!("Expected Assistant message with thinking");
1295        }
1296    }
1297
1298    // Test empty thinking content is handled correctly
1299    #[test]
1300    fn test_empty_thinking_content() {
1301        let sample_response = json!({
1302            "model": "llama3.2",
1303            "created_at": "2023-08-04T19:22:45.499127Z",
1304            "message": {
1305                "role": "assistant",
1306                "content": "Response",
1307                "thinking": "",
1308                "images": null,
1309                "tool_calls": []
1310            },
1311            "done": true,
1312            "total_duration": 8000000000u64,
1313            "load_duration": 6000000u64,
1314            "prompt_eval_count": 10u64,
1315            "prompt_eval_duration": 400000000u64,
1316            "eval_count": 5u64,
1317            "eval_duration": 7700000000u64
1318        });
1319
1320        let chat_resp: CompletionResponse =
1321            serde_json::from_value(sample_response).expect("Failed to deserialize");
1322
1323        if let Message::Assistant {
1324            thinking, content, ..
1325        } = &chat_resp.message
1326        {
1327            // Empty string should still deserialize as Some("")
1328            assert_eq!(thinking.as_ref().unwrap(), "");
1329            assert_eq!(content, "Response");
1330        } else {
1331            panic!("Expected Assistant message");
1332        }
1333    }
1334
1335    // Test thinking with tool calls
1336    #[test]
1337    fn test_thinking_with_tool_calls() {
1338        let sample_response = json!({
1339            "model": "qwen-thinking",
1340            "created_at": "2023-08-04T19:22:45.499127Z",
1341            "message": {
1342                "role": "assistant",
1343                "content": "Let me check the weather.",
1344                "thinking": "User wants weather info, I should use the weather tool",
1345                "images": null,
1346                "tool_calls": [
1347                    {
1348                        "type": "function",
1349                        "function": {
1350                            "name": "get_weather",
1351                            "arguments": {
1352                                "location": "San Francisco"
1353                            }
1354                        }
1355                    }
1356                ]
1357            },
1358            "done": true,
1359            "total_duration": 8000000000u64,
1360            "load_duration": 6000000u64,
1361            "prompt_eval_count": 30u64,
1362            "prompt_eval_duration": 400000000u64,
1363            "eval_count": 50u64,
1364            "eval_duration": 7700000000u64
1365        });
1366
1367        let chat_resp: CompletionResponse =
1368            serde_json::from_value(sample_response).expect("Failed to deserialize");
1369
1370        if let Message::Assistant {
1371            thinking,
1372            content,
1373            tool_calls,
1374            ..
1375        } = &chat_resp.message
1376        {
1377            assert_eq!(
1378                thinking.as_ref().unwrap(),
1379                "User wants weather info, I should use the weather tool"
1380            );
1381            assert_eq!(content, "Let me check the weather.");
1382            assert_eq!(tool_calls.len(), 1);
1383            assert_eq!(tool_calls[0].function.name, "get_weather");
1384        } else {
1385            panic!("Expected Assistant message with thinking and tool calls");
1386        }
1387    }
1388}