rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49    OneOrMany,
50    completion::{self, CompletionError, CompletionRequest},
51    embeddings::{self, EmbeddingError},
52    json_utils, message,
53    message::{ImageDetail, Text},
54    streaming,
55};
56use async_stream::try_stream;
57use bytes::Bytes;
58use futures::StreamExt;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use tracing::info_span;
63use tracing_futures::Instrument;
64// ---------- Main Client ----------
65
66const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68#[derive(Debug, Default, Clone, Copy)]
69pub struct OllamaExt;
70
71#[derive(Debug, Default, Clone, Copy)]
72pub struct OllamaBuilder;
73
74impl Provider for OllamaExt {
75    type Builder = OllamaBuilder;
76
77    const VERIFY_PATH: &'static str = "api/tags";
78
79    fn build<H>(
80        _: &crate::client::ClientBuilder<
81            Self::Builder,
82            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
83            H,
84        >,
85    ) -> http_client::Result<Self> {
86        Ok(Self)
87    }
88}
89
90impl<H> Capabilities<H> for OllamaExt {
91    type Completion = Capable<CompletionModel<H>>;
92    type Transcription = Nothing;
93    type Embeddings = Capable<EmbeddingModel<H>>;
94    #[cfg(feature = "image")]
95    type ImageGeneration = Nothing;
96
97    #[cfg(feature = "audio")]
98    type AudioGeneration = Nothing;
99}
100
101impl DebugExt for OllamaExt {}
102
103impl ProviderBuilder for OllamaBuilder {
104    type Output = OllamaExt;
105    type ApiKey = Nothing;
106
107    const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
108}
109
110pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
111pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
112
113impl ProviderClient for Client {
114    type Input = Nothing;
115
116    fn from_env() -> Self {
117        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
118
119        Self::builder()
120            .api_key(Nothing)
121            .base_url(&api_base)
122            .build()
123            .unwrap()
124    }
125
126    fn from_val(_: Self::Input) -> Self {
127        Self::builder().api_key(Nothing).build().unwrap()
128    }
129}
130
131// ---------- API Error and Response Structures ----------
132
133#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135    message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141    Ok(T),
142    Err(ApiErrorResponse),
143}
144
145// ---------- Embedding API ----------
146
147pub const ALL_MINILM: &str = "all-minilm";
148pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct EmbeddingResponse {
152    pub model: String,
153    pub embeddings: Vec<Vec<f64>>,
154    #[serde(default)]
155    pub total_duration: Option<u64>,
156    #[serde(default)]
157    pub load_duration: Option<u64>,
158    #[serde(default)]
159    pub prompt_eval_count: Option<u64>,
160}
161
162impl From<ApiErrorResponse> for EmbeddingError {
163    fn from(err: ApiErrorResponse) -> Self {
164        EmbeddingError::ProviderError(err.message)
165    }
166}
167
168impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
169    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
170        match value {
171            ApiResponse::Ok(response) => Ok(response),
172            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
173        }
174    }
175}
176
177// ---------- Embedding Model ----------
178
179#[derive(Clone)]
180pub struct EmbeddingModel<T> {
181    client: Client<T>,
182    pub model: String,
183    ndims: usize,
184}
185
186impl<T> EmbeddingModel<T> {
187    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
188        Self {
189            client,
190            model: model.into(),
191            ndims,
192        }
193    }
194
195    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
196        Self {
197            client,
198            model: model.into(),
199            ndims,
200        }
201    }
202}
203
204impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
205where
206    T: HttpClientExt + Clone + 'static,
207{
208    type Client = Client<T>;
209
210    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
211        Self::new(client.clone(), model, dims.unwrap())
212    }
213
214    const MAX_DOCUMENTS: usize = 1024;
215    fn ndims(&self) -> usize {
216        self.ndims
217    }
218
219    async fn embed_texts(
220        &self,
221        documents: impl IntoIterator<Item = String>,
222    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223        let docs: Vec<String> = documents.into_iter().collect();
224
225        let body = serde_json::to_vec(&json!({
226            "model": self.model,
227            "input": docs
228        }))?;
229
230        let req = self
231            .client
232            .post("api/embed")?
233            .body(body)
234            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
235
236        let response = self.client.send(req).await?;
237
238        if !response.status().is_success() {
239            let text = http_client::text(response).await?;
240            return Err(EmbeddingError::ProviderError(text));
241        }
242
243        let bytes: Vec<u8> = response.into_body().await?;
244
245        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
246
247        if api_resp.embeddings.len() != docs.len() {
248            return Err(EmbeddingError::ResponseError(
249                "Number of returned embeddings does not match input".into(),
250            ));
251        }
252        Ok(api_resp
253            .embeddings
254            .into_iter()
255            .zip(docs.into_iter())
256            .map(|(vec, document)| embeddings::Embedding { document, vec })
257            .collect())
258    }
259}
260
261// ---------- Completion API ----------
262
263pub const LLAMA3_2: &str = "llama3.2";
264pub const LLAVA: &str = "llava";
265pub const MISTRAL: &str = "mistral";
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct CompletionResponse {
269    pub model: String,
270    pub created_at: String,
271    pub message: Message,
272    pub done: bool,
273    #[serde(default)]
274    pub done_reason: Option<String>,
275    #[serde(default)]
276    pub total_duration: Option<u64>,
277    #[serde(default)]
278    pub load_duration: Option<u64>,
279    #[serde(default)]
280    pub prompt_eval_count: Option<u64>,
281    #[serde(default)]
282    pub prompt_eval_duration: Option<u64>,
283    #[serde(default)]
284    pub eval_count: Option<u64>,
285    #[serde(default)]
286    pub eval_duration: Option<u64>,
287}
288impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
289    type Error = CompletionError;
290    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
291        match resp.message {
292            // Process only if an assistant message is present.
293            Message::Assistant {
294                content,
295                thinking,
296                tool_calls,
297                ..
298            } => {
299                let mut assistant_contents = Vec::new();
300                // Add the assistant's text content if any.
301                if !content.is_empty() {
302                    assistant_contents.push(completion::AssistantContent::text(&content));
303                }
304                // Process tool_calls following Ollama's chat response definition.
305                // Each ToolCall has an id, a type, and a function field.
306                for tc in tool_calls.iter() {
307                    assistant_contents.push(completion::AssistantContent::tool_call(
308                        tc.function.name.clone(),
309                        tc.function.name.clone(),
310                        tc.function.arguments.clone(),
311                    ));
312                }
313                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
314                    CompletionError::ResponseError("No content provided".to_owned())
315                })?;
316                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
317                let completion_tokens = resp.eval_count.unwrap_or(0);
318
319                let raw_response = CompletionResponse {
320                    model: resp.model,
321                    created_at: resp.created_at,
322                    done: resp.done,
323                    done_reason: resp.done_reason,
324                    total_duration: resp.total_duration,
325                    load_duration: resp.load_duration,
326                    prompt_eval_count: resp.prompt_eval_count,
327                    prompt_eval_duration: resp.prompt_eval_duration,
328                    eval_count: resp.eval_count,
329                    eval_duration: resp.eval_duration,
330                    message: Message::Assistant {
331                        content,
332                        thinking,
333                        images: None,
334                        name: None,
335                        tool_calls,
336                    },
337                };
338
339                Ok(completion::CompletionResponse {
340                    choice,
341                    usage: Usage {
342                        input_tokens: prompt_tokens,
343                        output_tokens: completion_tokens,
344                        total_tokens: prompt_tokens + completion_tokens,
345                    },
346                    raw_response,
347                })
348            }
349            _ => Err(CompletionError::ResponseError(
350                "Chat response does not include an assistant message".into(),
351            )),
352        }
353    }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357pub(super) struct OllamaCompletionRequest {
358    model: String,
359    pub messages: Vec<Message>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    temperature: Option<f64>,
362    #[serde(skip_serializing_if = "Vec::is_empty")]
363    tools: Vec<ToolDefinition>,
364    pub stream: bool,
365    think: bool,
366    #[serde(skip_serializing_if = "Option::is_none")]
367    max_tokens: Option<u64>,
368    options: serde_json::Value,
369}
370
371impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
372    type Error = CompletionError;
373
374    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
375        if req.tool_choice.is_some() {
376            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
377        }
378        // Build up the order of messages (context, chat_history, prompt)
379        let mut partial_history = vec![];
380        if let Some(docs) = req.normalized_documents() {
381            partial_history.push(docs);
382        }
383        partial_history.extend(req.chat_history);
384
385        // Add preamble to chat history (if available)
386        let mut full_history: Vec<Message> = match &req.preamble {
387            Some(preamble) => vec![Message::system(preamble)],
388            None => vec![],
389        };
390
391        // Convert and extend the rest of the history
392        full_history.extend(
393            partial_history
394                .into_iter()
395                .map(message::Message::try_into)
396                .collect::<Result<Vec<Vec<Message>>, _>>()?
397                .into_iter()
398                .flatten()
399                .collect::<Vec<_>>(),
400        );
401
402        let mut think = false;
403
404        // TODO: Fix this up to include the full range of ollama options
405        let options = if let Some(mut extra) = req.additional_params {
406            if extra.get("think").is_some() {
407                think = extra["think"].take().as_bool().ok_or_else(|| {
408                    CompletionError::RequestError("`think` must be a bool".into())
409                })?;
410            }
411            json_utils::merge(json!({ "temperature": req.temperature }), extra)
412        } else {
413            json!({ "temperature": req.temperature })
414        };
415
416        Ok(Self {
417            model: model.to_string(),
418            messages: full_history,
419            temperature: req.temperature,
420            max_tokens: req.max_tokens,
421            stream: false,
422            think,
423            tools: req
424                .tools
425                .clone()
426                .into_iter()
427                .map(ToolDefinition::from)
428                .collect::<Vec<_>>(),
429            options,
430        })
431    }
432}
433
434#[derive(Clone)]
435pub struct CompletionModel<T = reqwest::Client> {
436    client: Client<T>,
437    pub model: String,
438}
439
440impl<T> CompletionModel<T> {
441    pub fn new(client: Client<T>, model: &str) -> Self {
442        Self {
443            client,
444            model: model.to_owned(),
445        }
446    }
447}
448
449// ---------- CompletionModel Implementation ----------
450
451#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct StreamingCompletionResponse {
453    pub done_reason: Option<String>,
454    pub total_duration: Option<u64>,
455    pub load_duration: Option<u64>,
456    pub prompt_eval_count: Option<u64>,
457    pub prompt_eval_duration: Option<u64>,
458    pub eval_count: Option<u64>,
459    pub eval_duration: Option<u64>,
460}
461
462impl GetTokenUsage for StreamingCompletionResponse {
463    fn token_usage(&self) -> Option<crate::completion::Usage> {
464        let mut usage = crate::completion::Usage::new();
465        let input_tokens = self.prompt_eval_count.unwrap_or_default();
466        let output_tokens = self.eval_count.unwrap_or_default();
467        usage.input_tokens = input_tokens;
468        usage.output_tokens = output_tokens;
469        usage.total_tokens = input_tokens + output_tokens;
470
471        Some(usage)
472    }
473}
474
475impl<T> completion::CompletionModel for CompletionModel<T>
476where
477    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
478{
479    type Response = CompletionResponse;
480    type StreamingResponse = StreamingCompletionResponse;
481
482    type Client = Client<T>;
483
484    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
485        Self::new(client.clone(), model.into().as_str())
486    }
487
488    async fn completion(
489        &self,
490        completion_request: CompletionRequest,
491    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
492        let span = if tracing::Span::current().is_disabled() {
493            info_span!(
494                target: "rig::completions",
495                "chat",
496                gen_ai.operation.name = "chat",
497                gen_ai.provider.name = "ollama",
498                gen_ai.request.model = self.model,
499                gen_ai.system_instructions = tracing::field::Empty,
500                gen_ai.response.id = tracing::field::Empty,
501                gen_ai.response.model = tracing::field::Empty,
502                gen_ai.usage.output_tokens = tracing::field::Empty,
503                gen_ai.usage.input_tokens = tracing::field::Empty,
504            )
505        } else {
506            tracing::Span::current()
507        };
508
509        span.record("gen_ai.system_instructions", &completion_request.preamble);
510        let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
511
512        if tracing::enabled!(tracing::Level::TRACE) {
513            tracing::trace!(target: "rig::completions",
514                "Ollama completion request: {}",
515                serde_json::to_string_pretty(&request)?
516            );
517        }
518
519        let body = serde_json::to_vec(&request)?;
520
521        let req = self
522            .client
523            .post("api/chat")?
524            .body(body)
525            .map_err(http_client::Error::from)?;
526
527        let async_block = async move {
528            let response = self.client.send::<_, Bytes>(req).await?;
529            let status = response.status();
530            let response_body = response.into_body().into_future().await?.to_vec();
531
532            if !status.is_success() {
533                return Err(CompletionError::ProviderError(
534                    String::from_utf8_lossy(&response_body).to_string(),
535                ));
536            }
537
538            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
539            let span = tracing::Span::current();
540            span.record("gen_ai.response.model_name", &response.model);
541            span.record(
542                "gen_ai.usage.input_tokens",
543                response.prompt_eval_count.unwrap_or_default(),
544            );
545            span.record(
546                "gen_ai.usage.output_tokens",
547                response.eval_count.unwrap_or_default(),
548            );
549
550            if tracing::enabled!(tracing::Level::TRACE) {
551                tracing::trace!(target: "rig::completions",
552                    "Ollama completion response: {}",
553                    serde_json::to_string_pretty(&response)?
554                );
555            }
556
557            let response: completion::CompletionResponse<CompletionResponse> =
558                response.try_into()?;
559
560            Ok(response)
561        };
562
563        tracing::Instrument::instrument(async_block, span).await
564    }
565
566    async fn stream(
567        &self,
568        request: CompletionRequest,
569    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
570    {
571        let span = if tracing::Span::current().is_disabled() {
572            info_span!(
573                target: "rig::completions",
574                "chat_streaming",
575                gen_ai.operation.name = "chat_streaming",
576                gen_ai.provider.name = "ollama",
577                gen_ai.request.model = self.model,
578                gen_ai.system_instructions = tracing::field::Empty,
579                gen_ai.response.id = tracing::field::Empty,
580                gen_ai.response.model = self.model,
581                gen_ai.usage.output_tokens = tracing::field::Empty,
582                gen_ai.usage.input_tokens = tracing::field::Empty,
583            )
584        } else {
585            tracing::Span::current()
586        };
587
588        span.record("gen_ai.system_instructions", &request.preamble);
589
590        let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
591        request.stream = true;
592
593        if tracing::enabled!(tracing::Level::TRACE) {
594            tracing::trace!(target: "rig::completions",
595                "Ollama streaming completion request: {}",
596                serde_json::to_string_pretty(&request)?
597            );
598        }
599
600        let body = serde_json::to_vec(&request)?;
601
602        let req = self
603            .client
604            .post("api/chat")?
605            .body(body)
606            .map_err(http_client::Error::from)?;
607
608        let response = self.client.send_streaming(req).await?;
609        let status = response.status();
610        let mut byte_stream = response.into_body();
611
612        if !status.is_success() {
613            return Err(CompletionError::ProviderError(format!(
614                "Got error status code trying to send a request to Ollama: {status}"
615            )));
616        }
617
618        let stream = try_stream! {
619            let span = tracing::Span::current();
620            let mut tool_calls_final = Vec::new();
621            let mut text_response = String::new();
622            let mut thinking_response = String::new();
623
624            while let Some(chunk) = byte_stream.next().await {
625                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
626
627                for line in bytes.split(|&b| b == b'\n') {
628                    if line.is_empty() {
629                        continue;
630                    }
631
632                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
633
634                    let response: CompletionResponse = serde_json::from_slice(line)?;
635
636                    if response.done {
637                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
638                        span.record("gen_ai.usage.output_tokens", response.eval_count);
639                        let message = Message::Assistant {
640                            content: text_response.clone(),
641                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
642                            images: None,
643                            name: None,
644                            tool_calls: tool_calls_final.clone()
645                        };
646                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
647                        yield RawStreamingChoice::FinalResponse(
648                            StreamingCompletionResponse {
649                                total_duration: response.total_duration,
650                                load_duration: response.load_duration,
651                                prompt_eval_count: response.prompt_eval_count,
652                                prompt_eval_duration: response.prompt_eval_duration,
653                                eval_count: response.eval_count,
654                                eval_duration: response.eval_duration,
655                                done_reason: response.done_reason,
656                            }
657                        );
658                        break;
659                    }
660
661                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
662                        if let Some(thinking_content) = thinking
663                            && !thinking_content.is_empty() {
664                            thinking_response += &thinking_content;
665                            yield RawStreamingChoice::ReasoningDelta {
666                                id: None,
667                                reasoning: thinking_content,
668                            };
669                        }
670
671                        if !content.is_empty() {
672                            text_response += &content;
673                            yield RawStreamingChoice::Message(content);
674                        }
675
676                        for tool_call in tool_calls {
677                            tool_calls_final.push(tool_call.clone());
678                            yield RawStreamingChoice::ToolCall(
679                                crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
680                            );
681                        }
682                    }
683                }
684            }
685        }.instrument(span);
686
687        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
688            stream,
689        )))
690    }
691}
692
693// ---------- Tool Definition Conversion ----------
694
695/// Ollama-required tool definition format.
696#[derive(Clone, Debug, Deserialize, Serialize)]
697pub struct ToolDefinition {
698    #[serde(rename = "type")]
699    pub type_field: String, // Fixed as "function"
700    pub function: completion::ToolDefinition,
701}
702
703/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
704impl From<crate::completion::ToolDefinition> for ToolDefinition {
705    fn from(tool: crate::completion::ToolDefinition) -> Self {
706        ToolDefinition {
707            type_field: "function".to_owned(),
708            function: completion::ToolDefinition {
709                name: tool.name,
710                description: tool.description,
711                parameters: tool.parameters,
712            },
713        }
714    }
715}
716
717#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
718pub struct ToolCall {
719    #[serde(default, rename = "type")]
720    pub r#type: ToolType,
721    pub function: Function,
722}
723#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
724#[serde(rename_all = "lowercase")]
725pub enum ToolType {
726    #[default]
727    Function,
728}
729#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
730pub struct Function {
731    pub name: String,
732    pub arguments: Value,
733}
734
735// ---------- Provider Message Definition ----------
736
737#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
738#[serde(tag = "role", rename_all = "lowercase")]
739pub enum Message {
740    User {
741        content: String,
742        #[serde(skip_serializing_if = "Option::is_none")]
743        images: Option<Vec<String>>,
744        #[serde(skip_serializing_if = "Option::is_none")]
745        name: Option<String>,
746    },
747    Assistant {
748        #[serde(default)]
749        content: String,
750        #[serde(skip_serializing_if = "Option::is_none")]
751        thinking: Option<String>,
752        #[serde(skip_serializing_if = "Option::is_none")]
753        images: Option<Vec<String>>,
754        #[serde(skip_serializing_if = "Option::is_none")]
755        name: Option<String>,
756        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
757        tool_calls: Vec<ToolCall>,
758    },
759    System {
760        content: String,
761        #[serde(skip_serializing_if = "Option::is_none")]
762        images: Option<Vec<String>>,
763        #[serde(skip_serializing_if = "Option::is_none")]
764        name: Option<String>,
765    },
766    #[serde(rename = "tool")]
767    ToolResult {
768        #[serde(rename = "tool_name")]
769        name: String,
770        content: String,
771    },
772}
773
774/// -----------------------------
775/// Provider Message Conversions
776/// -----------------------------
777/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
778/// (Only User and Assistant variants are supported.)
779impl TryFrom<crate::message::Message> for Vec<Message> {
780    type Error = crate::message::MessageError;
781    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
782        use crate::message::Message as InternalMessage;
783        match internal_msg {
784            InternalMessage::User { content, .. } => {
785                let (tool_results, other_content): (Vec<_>, Vec<_>) =
786                    content.into_iter().partition(|content| {
787                        matches!(content, crate::message::UserContent::ToolResult(_))
788                    });
789
790                if !tool_results.is_empty() {
791                    tool_results
792                        .into_iter()
793                        .map(|content| match content {
794                            crate::message::UserContent::ToolResult(
795                                crate::message::ToolResult { id, content, .. },
796                            ) => {
797                                // Ollama expects a single string for tool results, so we concatenate
798                                let content_string = content
799                                    .into_iter()
800                                    .map(|content| match content {
801                                        crate::message::ToolResultContent::Text(text) => text.text,
802                                        _ => "[Non-text content]".to_string(),
803                                    })
804                                    .collect::<Vec<_>>()
805                                    .join("\n");
806
807                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
808                                    name: id,
809                                    content: content_string,
810                                })
811                            }
812                            _ => unreachable!(),
813                        })
814                        .collect::<Result<Vec<_>, _>>()
815                } else {
816                    // Ollama requires separate text content and images array
817                    let (texts, images) = other_content.into_iter().fold(
818                        (Vec::new(), Vec::new()),
819                        |(mut texts, mut images), content| {
820                            match content {
821                                crate::message::UserContent::Text(crate::message::Text {
822                                    text,
823                                }) => texts.push(text),
824                                crate::message::UserContent::Image(crate::message::Image {
825                                    data: DocumentSourceKind::Base64(data),
826                                    ..
827                                }) => images.push(data),
828                                crate::message::UserContent::Document(
829                                    crate::message::Document {
830                                        data:
831                                            DocumentSourceKind::Base64(data)
832                                            | DocumentSourceKind::String(data),
833                                        ..
834                                    },
835                                ) => texts.push(data),
836                                _ => {} // Audio not supported by Ollama
837                            }
838                            (texts, images)
839                        },
840                    );
841
842                    Ok(vec![Message::User {
843                        content: texts.join(" "),
844                        images: if images.is_empty() {
845                            None
846                        } else {
847                            Some(
848                                images
849                                    .into_iter()
850                                    .map(|x| x.to_string())
851                                    .collect::<Vec<String>>(),
852                            )
853                        },
854                        name: None,
855                    }])
856                }
857            }
858            InternalMessage::Assistant { content, .. } => {
859                let mut thinking: Option<String> = None;
860                let mut text_content = Vec::new();
861                let mut tool_calls = Vec::new();
862
863                for content in content.into_iter() {
864                    match content {
865                        crate::message::AssistantContent::Text(text) => {
866                            text_content.push(text.text)
867                        }
868                        crate::message::AssistantContent::ToolCall(tool_call) => {
869                            tool_calls.push(tool_call)
870                        }
871                        crate::message::AssistantContent::Reasoning(
872                            crate::message::Reasoning { reasoning, .. },
873                        ) => {
874                            thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
875                        }
876                        crate::message::AssistantContent::Image(_) => {
877                            return Err(crate::message::MessageError::ConversionError(
878                                "Ollama currently doesn't support images.".into(),
879                            ));
880                        }
881                    }
882                }
883
884                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
885                //  so either `content` or `tool_calls` will have some content.
886                Ok(vec![Message::Assistant {
887                    content: text_content.join(" "),
888                    thinking,
889                    images: None,
890                    name: None,
891                    tool_calls: tool_calls
892                        .into_iter()
893                        .map(|tool_call| tool_call.into())
894                        .collect::<Vec<_>>(),
895                }])
896            }
897        }
898    }
899}
900
901/// Conversion from provider Message to a completion message.
902/// This is needed so that responses can be converted back into chat history.
903impl From<Message> for crate::completion::Message {
904    fn from(msg: Message) -> Self {
905        match msg {
906            Message::User { content, .. } => crate::completion::Message::User {
907                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
908                    text: content,
909                })),
910            },
911            Message::Assistant {
912                content,
913                tool_calls,
914                ..
915            } => {
916                let mut assistant_contents =
917                    vec![crate::completion::message::AssistantContent::Text(Text {
918                        text: content,
919                    })];
920                for tc in tool_calls {
921                    assistant_contents.push(
922                        crate::completion::message::AssistantContent::tool_call(
923                            tc.function.name.clone(),
924                            tc.function.name,
925                            tc.function.arguments,
926                        ),
927                    );
928                }
929                crate::completion::Message::Assistant {
930                    id: None,
931                    content: OneOrMany::many(assistant_contents).unwrap(),
932                }
933            }
934            // System and ToolResult are converted to User message as needed.
935            Message::System { content, .. } => crate::completion::Message::User {
936                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
937                    text: content,
938                })),
939            },
940            Message::ToolResult { name, content } => crate::completion::Message::User {
941                content: OneOrMany::one(message::UserContent::tool_result(
942                    name,
943                    OneOrMany::one(message::ToolResultContent::text(content)),
944                )),
945            },
946        }
947    }
948}
949
950impl Message {
951    /// Constructs a system message.
952    pub fn system(content: &str) -> Self {
953        Message::System {
954            content: content.to_owned(),
955            images: None,
956            name: None,
957        }
958    }
959}
960
961// ---------- Additional Message Types ----------
962
963impl From<crate::message::ToolCall> for ToolCall {
964    fn from(tool_call: crate::message::ToolCall) -> Self {
965        Self {
966            r#type: ToolType::Function,
967            function: Function {
968                name: tool_call.function.name,
969                arguments: tool_call.function.arguments,
970            },
971        }
972    }
973}
974
975#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
976pub struct SystemContent {
977    #[serde(default)]
978    r#type: SystemContentType,
979    text: String,
980}
981
982#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
983#[serde(rename_all = "lowercase")]
984pub enum SystemContentType {
985    #[default]
986    Text,
987}
988
989impl From<String> for SystemContent {
990    fn from(s: String) -> Self {
991        SystemContent {
992            r#type: SystemContentType::default(),
993            text: s,
994        }
995    }
996}
997
998impl FromStr for SystemContent {
999    type Err = std::convert::Infallible;
1000    fn from_str(s: &str) -> Result<Self, Self::Err> {
1001        Ok(SystemContent {
1002            r#type: SystemContentType::default(),
1003            text: s.to_string(),
1004        })
1005    }
1006}
1007
1008#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1009pub struct AssistantContent {
1010    pub text: String,
1011}
1012
1013impl FromStr for AssistantContent {
1014    type Err = std::convert::Infallible;
1015    fn from_str(s: &str) -> Result<Self, Self::Err> {
1016        Ok(AssistantContent { text: s.to_owned() })
1017    }
1018}
1019
1020#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1021#[serde(tag = "type", rename_all = "lowercase")]
1022pub enum UserContent {
1023    Text { text: String },
1024    Image { image_url: ImageUrl },
1025    // Audio variant removed as Ollama API does not support audio input.
1026}
1027
1028impl FromStr for UserContent {
1029    type Err = std::convert::Infallible;
1030    fn from_str(s: &str) -> Result<Self, Self::Err> {
1031        Ok(UserContent::Text { text: s.to_owned() })
1032    }
1033}
1034
1035#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1036pub struct ImageUrl {
1037    pub url: String,
1038    #[serde(default)]
1039    pub detail: ImageDetail,
1040}
1041
1042// =================================================================
1043// Tests
1044// =================================================================
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049    use serde_json::json;
1050
1051    // Test deserialization and conversion for the /api/chat endpoint.
1052    #[tokio::test]
1053    async fn test_chat_completion() {
1054        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1055        let sample_chat_response = json!({
1056            "model": "llama3.2",
1057            "created_at": "2023-08-04T19:22:45.499127Z",
1058            "message": {
1059                "role": "assistant",
1060                "content": "The sky is blue because of Rayleigh scattering.",
1061                "images": null,
1062                "tool_calls": [
1063                    {
1064                        "type": "function",
1065                        "function": {
1066                            "name": "get_current_weather",
1067                            "arguments": {
1068                                "location": "San Francisco, CA",
1069                                "format": "celsius"
1070                            }
1071                        }
1072                    }
1073                ]
1074            },
1075            "done": true,
1076            "total_duration": 8000000000u64,
1077            "load_duration": 6000000u64,
1078            "prompt_eval_count": 61u64,
1079            "prompt_eval_duration": 400000000u64,
1080            "eval_count": 468u64,
1081            "eval_duration": 7700000000u64
1082        });
1083        let sample_text = sample_chat_response.to_string();
1084
1085        let chat_resp: CompletionResponse =
1086            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1087        let conv: completion::CompletionResponse<CompletionResponse> =
1088            chat_resp.try_into().unwrap();
1089        assert!(
1090            !conv.choice.is_empty(),
1091            "Expected non-empty choice in chat response"
1092        );
1093    }
1094
1095    // Test conversion from provider Message to completion Message.
1096    #[test]
1097    fn test_message_conversion() {
1098        // Construct a provider Message (User variant with String content).
1099        let provider_msg = Message::User {
1100            content: "Test message".to_owned(),
1101            images: None,
1102            name: None,
1103        };
1104        // Convert it into a completion::Message.
1105        let comp_msg: crate::completion::Message = provider_msg.into();
1106        match comp_msg {
1107            crate::completion::Message::User { content } => {
1108                // Assume OneOrMany<T> has a method first() to access the first element.
1109                let first_content = content.first();
1110                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1111                match first_content {
1112                    crate::completion::message::UserContent::Text(text_struct) => {
1113                        assert_eq!(text_struct.text, "Test message");
1114                    }
1115                    _ => panic!("Expected text content in conversion"),
1116                }
1117            }
1118            _ => panic!("Conversion from provider Message to completion Message failed"),
1119        }
1120    }
1121
1122    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1123    #[test]
1124    fn test_tool_definition_conversion() {
1125        // Internal tool definition from the completion module.
1126        let internal_tool = crate::completion::ToolDefinition {
1127            name: "get_current_weather".to_owned(),
1128            description: "Get the current weather for a location".to_owned(),
1129            parameters: json!({
1130                "type": "object",
1131                "properties": {
1132                    "location": {
1133                        "type": "string",
1134                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1135                    },
1136                    "format": {
1137                        "type": "string",
1138                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1139                        "enum": ["celsius", "fahrenheit"]
1140                    }
1141                },
1142                "required": ["location", "format"]
1143            }),
1144        };
1145        // Convert internal tool to Ollama's tool definition.
1146        let ollama_tool: ToolDefinition = internal_tool.into();
1147        assert_eq!(ollama_tool.type_field, "function");
1148        assert_eq!(ollama_tool.function.name, "get_current_weather");
1149        assert_eq!(
1150            ollama_tool.function.description,
1151            "Get the current weather for a location"
1152        );
1153        // Check JSON fields in parameters.
1154        let params = &ollama_tool.function.parameters;
1155        assert_eq!(params["properties"]["location"]["type"], "string");
1156    }
1157
1158    // Test deserialization of chat response with thinking content
1159    #[tokio::test]
1160    async fn test_chat_completion_with_thinking() {
1161        let sample_response = json!({
1162            "model": "qwen-thinking",
1163            "created_at": "2023-08-04T19:22:45.499127Z",
1164            "message": {
1165                "role": "assistant",
1166                "content": "The answer is 42.",
1167                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1168                "images": null,
1169                "tool_calls": []
1170            },
1171            "done": true,
1172            "total_duration": 8000000000u64,
1173            "load_duration": 6000000u64,
1174            "prompt_eval_count": 61u64,
1175            "prompt_eval_duration": 400000000u64,
1176            "eval_count": 468u64,
1177            "eval_duration": 7700000000u64
1178        });
1179
1180        let chat_resp: CompletionResponse =
1181            serde_json::from_value(sample_response).expect("Failed to deserialize");
1182
1183        // Verify thinking field is present
1184        if let Message::Assistant {
1185            thinking, content, ..
1186        } = &chat_resp.message
1187        {
1188            assert_eq!(
1189                thinking.as_ref().unwrap(),
1190                "Let me think about this carefully. The question asks for the meaning of life..."
1191            );
1192            assert_eq!(content, "The answer is 42.");
1193        } else {
1194            panic!("Expected Assistant message");
1195        }
1196    }
1197
1198    // Test deserialization of chat response without thinking content
1199    #[tokio::test]
1200    async fn test_chat_completion_without_thinking() {
1201        let sample_response = json!({
1202            "model": "llama3.2",
1203            "created_at": "2023-08-04T19:22:45.499127Z",
1204            "message": {
1205                "role": "assistant",
1206                "content": "Hello!",
1207                "images": null,
1208                "tool_calls": []
1209            },
1210            "done": true,
1211            "total_duration": 8000000000u64,
1212            "load_duration": 6000000u64,
1213            "prompt_eval_count": 10u64,
1214            "prompt_eval_duration": 400000000u64,
1215            "eval_count": 5u64,
1216            "eval_duration": 7700000000u64
1217        });
1218
1219        let chat_resp: CompletionResponse =
1220            serde_json::from_value(sample_response).expect("Failed to deserialize");
1221
1222        // Verify thinking field is None when not provided
1223        if let Message::Assistant {
1224            thinking, content, ..
1225        } = &chat_resp.message
1226        {
1227            assert!(thinking.is_none());
1228            assert_eq!(content, "Hello!");
1229        } else {
1230            panic!("Expected Assistant message");
1231        }
1232    }
1233
1234    // Test deserialization of streaming response with thinking content
1235    #[test]
1236    fn test_streaming_response_with_thinking() {
1237        let sample_chunk = json!({
1238            "model": "qwen-thinking",
1239            "created_at": "2023-08-04T19:22:45.499127Z",
1240            "message": {
1241                "role": "assistant",
1242                "content": "",
1243                "thinking": "Analyzing the problem...",
1244                "images": null,
1245                "tool_calls": []
1246            },
1247            "done": false
1248        });
1249
1250        let chunk: CompletionResponse =
1251            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1252
1253        if let Message::Assistant {
1254            thinking, content, ..
1255        } = &chunk.message
1256        {
1257            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1258            assert_eq!(content, "");
1259        } else {
1260            panic!("Expected Assistant message");
1261        }
1262    }
1263
1264    // Test message conversion with thinking content
1265    #[test]
1266    fn test_message_conversion_with_thinking() {
1267        // Create an internal message with reasoning content
1268        let reasoning_content = crate::message::Reasoning {
1269            id: None,
1270            reasoning: vec!["Step 1: Consider the problem".to_string()],
1271            signature: None,
1272        };
1273
1274        let internal_msg = crate::message::Message::Assistant {
1275            id: None,
1276            content: crate::OneOrMany::many(vec![
1277                crate::message::AssistantContent::Reasoning(reasoning_content),
1278                crate::message::AssistantContent::Text(crate::message::Text {
1279                    text: "The answer is X".to_string(),
1280                }),
1281            ])
1282            .unwrap(),
1283        };
1284
1285        // Convert to provider Message
1286        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1287        assert_eq!(provider_msgs.len(), 1);
1288
1289        if let Message::Assistant {
1290            thinking, content, ..
1291        } = &provider_msgs[0]
1292        {
1293            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1294            assert_eq!(content, "The answer is X");
1295        } else {
1296            panic!("Expected Assistant message with thinking");
1297        }
1298    }
1299
1300    // Test empty thinking content is handled correctly
1301    #[test]
1302    fn test_empty_thinking_content() {
1303        let sample_response = json!({
1304            "model": "llama3.2",
1305            "created_at": "2023-08-04T19:22:45.499127Z",
1306            "message": {
1307                "role": "assistant",
1308                "content": "Response",
1309                "thinking": "",
1310                "images": null,
1311                "tool_calls": []
1312            },
1313            "done": true,
1314            "total_duration": 8000000000u64,
1315            "load_duration": 6000000u64,
1316            "prompt_eval_count": 10u64,
1317            "prompt_eval_duration": 400000000u64,
1318            "eval_count": 5u64,
1319            "eval_duration": 7700000000u64
1320        });
1321
1322        let chat_resp: CompletionResponse =
1323            serde_json::from_value(sample_response).expect("Failed to deserialize");
1324
1325        if let Message::Assistant {
1326            thinking, content, ..
1327        } = &chat_resp.message
1328        {
1329            // Empty string should still deserialize as Some("")
1330            assert_eq!(thinking.as_ref().unwrap(), "");
1331            assert_eq!(content, "Response");
1332        } else {
1333            panic!("Expected Assistant message");
1334        }
1335    }
1336
1337    // Test thinking with tool calls
1338    #[test]
1339    fn test_thinking_with_tool_calls() {
1340        let sample_response = json!({
1341            "model": "qwen-thinking",
1342            "created_at": "2023-08-04T19:22:45.499127Z",
1343            "message": {
1344                "role": "assistant",
1345                "content": "Let me check the weather.",
1346                "thinking": "User wants weather info, I should use the weather tool",
1347                "images": null,
1348                "tool_calls": [
1349                    {
1350                        "type": "function",
1351                        "function": {
1352                            "name": "get_weather",
1353                            "arguments": {
1354                                "location": "San Francisco"
1355                            }
1356                        }
1357                    }
1358                ]
1359            },
1360            "done": true,
1361            "total_duration": 8000000000u64,
1362            "load_duration": 6000000u64,
1363            "prompt_eval_count": 30u64,
1364            "prompt_eval_duration": 400000000u64,
1365            "eval_count": 50u64,
1366            "eval_duration": 7700000000u64
1367        });
1368
1369        let chat_resp: CompletionResponse =
1370            serde_json::from_value(sample_response).expect("Failed to deserialize");
1371
1372        if let Message::Assistant {
1373            thinking,
1374            content,
1375            tool_calls,
1376            ..
1377        } = &chat_resp.message
1378        {
1379            assert_eq!(
1380                thinking.as_ref().unwrap(),
1381                "User wants weather info, I should use the weather tool"
1382            );
1383            assert_eq!(content, "Let me check the weather.");
1384            assert_eq!(tool_calls.len(), 1);
1385            assert_eq!(tool_calls[0].function.name, "get_weather");
1386        } else {
1387            panic!("Expected Assistant message with thinking and tool calls");
1388        }
1389    }
1390}