rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::json_utils::merge_inplace;
47use crate::message::DocumentSourceKind;
48use crate::streaming::RawStreamingChoice;
49use crate::{
50    Embed, OneOrMany,
51    completion::{self, CompletionError, CompletionRequest},
52    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
53    impl_conversion_traits, json_utils, message,
54    message::{ImageDetail, Text},
55    streaming,
56};
57use async_stream::try_stream;
58use bytes::Bytes;
59use futures::StreamExt;
60use reqwest;
61// use reqwest_eventsource::{Event, RequestBuilderExt}; // (Not used currently as Ollama does not support SSE)
62use serde::{Deserialize, Serialize};
63use serde_json::{Value, json};
64use std::{convert::TryFrom, str::FromStr};
65use tracing::info_span;
66use tracing_futures::Instrument;
67// ---------- Main Client ----------
68
69const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
70
71pub struct ClientBuilder<'a, T = reqwest::Client> {
72    base_url: &'a str,
73    http_client: T,
74}
75
76impl<'a, T> ClientBuilder<'a, T>
77where
78    T: Default,
79{
80    #[allow(clippy::new_without_default)]
81    pub fn new() -> Self {
82        Self {
83            base_url: OLLAMA_API_BASE_URL,
84            http_client: Default::default(),
85        }
86    }
87}
88
89impl<'a, T> ClientBuilder<'a, T> {
90    pub fn new_with_client(http_client: T) -> Self {
91        Self {
92            base_url: OLLAMA_API_BASE_URL,
93            http_client,
94        }
95    }
96
97    pub fn base_url(mut self, base_url: &'a str) -> Self {
98        self.base_url = base_url;
99        self
100    }
101
102    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
103        ClientBuilder {
104            base_url: self.base_url,
105            http_client,
106        }
107    }
108
109    pub fn build(self) -> Client<T> {
110        Client {
111            base_url: self.base_url.into(),
112            http_client: self.http_client,
113        }
114    }
115}
116
117#[derive(Clone, Debug)]
118pub struct Client<T = reqwest::Client> {
119    base_url: String,
120    http_client: T,
121}
122
123impl<T> Client<T> {
124    fn req(&self, method: http_client::Method, path: &str) -> http_client::Builder {
125        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
126        http_client::Builder::new().method(method).uri(url)
127    }
128
129    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
130        self.req(http_client::Method::POST, path)
131    }
132
133    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
134        self.req(http_client::Method::GET, path)
135    }
136}
137
138impl Client<reqwest::Client> {
139    pub fn builder<'a>() -> ClientBuilder<'a, reqwest::Client> {
140        ClientBuilder::new()
141    }
142
143    pub fn new() -> Self {
144        Self::builder().build()
145    }
146
147    pub fn from_env() -> Self {
148        <Self as ProviderClient>::from_env()
149    }
150}
151
152impl Default for Client<reqwest::Client> {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl<T> ProviderClient for Client<T>
159where
160    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
161{
162    fn from_env() -> Self {
163        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
164        ClientBuilder::new().base_url(&api_base).build()
165    }
166
167    fn from_val(input: crate::client::ProviderValue) -> Self {
168        let crate::client::ProviderValue::Simple(_) = input else {
169            panic!("Incorrect provider value type")
170        };
171
172        ClientBuilder::new().build()
173    }
174}
175
176impl<T> CompletionClient for Client<T>
177where
178    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
179{
180    type CompletionModel = CompletionModel<T>;
181
182    fn completion_model(&self, model: &str) -> Self::CompletionModel {
183        CompletionModel::new(self.clone(), model)
184    }
185}
186
187impl<T> EmbeddingsClient for Client<T>
188where
189    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
190{
191    type EmbeddingModel = EmbeddingModel<T>;
192    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
193        EmbeddingModel::new(self.clone(), model, 0)
194    }
195    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
196        EmbeddingModel::new(self.clone(), model, ndims)
197    }
198    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
199        EmbeddingsBuilder::new(self.embedding_model(model))
200    }
201}
202
203impl<T> VerifyClient for Client<T>
204where
205    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
206{
207    #[cfg_attr(feature = "worker", worker::send)]
208    async fn verify(&self) -> Result<(), VerifyError> {
209        let req = self
210            .get("api/tags")
211            .body(http_client::NoBody)
212            .map_err(http_client::Error::from)?;
213
214        let response = HttpClientExt::send(&self.http_client, req).await?;
215
216        match response.status() {
217            reqwest::StatusCode::OK => Ok(()),
218            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
219            reqwest::StatusCode::INTERNAL_SERVER_ERROR
220            | reqwest::StatusCode::SERVICE_UNAVAILABLE
221            | reqwest::StatusCode::BAD_GATEWAY => {
222                let text = http_client::text(response).await?;
223                Err(VerifyError::ProviderError(text))
224            }
225            _ => {
226                //response.error_for_status()?;
227                Ok(())
228            }
229        }
230    }
231}
232
233impl_conversion_traits!(
234    AsTranscription,
235    AsImageGeneration,
236    AsAudioGeneration for Client<T>
237);
238
239// ---------- API Error and Response Structures ----------
240
241#[derive(Debug, Deserialize)]
242struct ApiErrorResponse {
243    message: String,
244}
245
246#[derive(Debug, Deserialize)]
247#[serde(untagged)]
248enum ApiResponse<T> {
249    Ok(T),
250    Err(ApiErrorResponse),
251}
252
253// ---------- Embedding API ----------
254
255pub const ALL_MINILM: &str = "all-minilm";
256pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
257
258#[derive(Debug, Serialize, Deserialize)]
259pub struct EmbeddingResponse {
260    pub model: String,
261    pub embeddings: Vec<Vec<f64>>,
262    #[serde(default)]
263    pub total_duration: Option<u64>,
264    #[serde(default)]
265    pub load_duration: Option<u64>,
266    #[serde(default)]
267    pub prompt_eval_count: Option<u64>,
268}
269
270impl From<ApiErrorResponse> for EmbeddingError {
271    fn from(err: ApiErrorResponse) -> Self {
272        EmbeddingError::ProviderError(err.message)
273    }
274}
275
276impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
277    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
278        match value {
279            ApiResponse::Ok(response) => Ok(response),
280            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
281        }
282    }
283}
284
285// ---------- Embedding Model ----------
286
287#[derive(Clone)]
288pub struct EmbeddingModel<T> {
289    client: Client<T>,
290    pub model: String,
291    ndims: usize,
292}
293
294impl<T> EmbeddingModel<T> {
295    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
296        Self {
297            client,
298            model: model.to_owned(),
299            ndims,
300        }
301    }
302}
303
304impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
305where
306    T: HttpClientExt + Clone + 'static,
307{
308    const MAX_DOCUMENTS: usize = 1024;
309    fn ndims(&self) -> usize {
310        self.ndims
311    }
312    #[cfg_attr(feature = "worker", worker::send)]
313    async fn embed_texts(
314        &self,
315        documents: impl IntoIterator<Item = String>,
316    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
317        let docs: Vec<String> = documents.into_iter().collect();
318
319        let body = serde_json::to_vec(&json!({
320            "model": self.model,
321            "input": docs
322        }))?;
323
324        let req = self
325            .client
326            .post("api/embed")
327            .header("Content-Type", "application/json")
328            .body(body)
329            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
330
331        let response = HttpClientExt::send(&self.client.http_client, req).await?;
332
333        if !response.status().is_success() {
334            let text = http_client::text(response).await?;
335            return Err(EmbeddingError::ProviderError(text));
336        }
337
338        let bytes: Vec<u8> = response.into_body().await?;
339
340        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
341
342        if api_resp.embeddings.len() != docs.len() {
343            return Err(EmbeddingError::ResponseError(
344                "Number of returned embeddings does not match input".into(),
345            ));
346        }
347        Ok(api_resp
348            .embeddings
349            .into_iter()
350            .zip(docs.into_iter())
351            .map(|(vec, document)| embeddings::Embedding { document, vec })
352            .collect())
353    }
354}
355
356// ---------- Completion API ----------
357
358pub const LLAMA3_2: &str = "llama3.2";
359pub const LLAVA: &str = "llava";
360pub const MISTRAL: &str = "mistral";
361
362#[derive(Debug, Serialize, Deserialize)]
363pub struct CompletionResponse {
364    pub model: String,
365    pub created_at: String,
366    pub message: Message,
367    pub done: bool,
368    #[serde(default)]
369    pub done_reason: Option<String>,
370    #[serde(default)]
371    pub total_duration: Option<u64>,
372    #[serde(default)]
373    pub load_duration: Option<u64>,
374    #[serde(default)]
375    pub prompt_eval_count: Option<u64>,
376    #[serde(default)]
377    pub prompt_eval_duration: Option<u64>,
378    #[serde(default)]
379    pub eval_count: Option<u64>,
380    #[serde(default)]
381    pub eval_duration: Option<u64>,
382}
383impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
384    type Error = CompletionError;
385    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
386        match resp.message {
387            // Process only if an assistant message is present.
388            Message::Assistant {
389                content,
390                thinking,
391                tool_calls,
392                ..
393            } => {
394                let mut assistant_contents = Vec::new();
395                // Add the assistant's text content if any.
396                if !content.is_empty() {
397                    assistant_contents.push(completion::AssistantContent::text(&content));
398                }
399                // Process tool_calls following Ollama's chat response definition.
400                // Each ToolCall has an id, a type, and a function field.
401                for tc in tool_calls.iter() {
402                    assistant_contents.push(completion::AssistantContent::tool_call(
403                        tc.function.name.clone(),
404                        tc.function.name.clone(),
405                        tc.function.arguments.clone(),
406                    ));
407                }
408                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
409                    CompletionError::ResponseError("No content provided".to_owned())
410                })?;
411                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
412                let completion_tokens = resp.eval_count.unwrap_or(0);
413
414                let raw_response = CompletionResponse {
415                    model: resp.model,
416                    created_at: resp.created_at,
417                    done: resp.done,
418                    done_reason: resp.done_reason,
419                    total_duration: resp.total_duration,
420                    load_duration: resp.load_duration,
421                    prompt_eval_count: resp.prompt_eval_count,
422                    prompt_eval_duration: resp.prompt_eval_duration,
423                    eval_count: resp.eval_count,
424                    eval_duration: resp.eval_duration,
425                    message: Message::Assistant {
426                        content,
427                        thinking,
428                        images: None,
429                        name: None,
430                        tool_calls,
431                    },
432                };
433
434                Ok(completion::CompletionResponse {
435                    choice,
436                    usage: Usage {
437                        input_tokens: prompt_tokens,
438                        output_tokens: completion_tokens,
439                        total_tokens: prompt_tokens + completion_tokens,
440                    },
441                    raw_response,
442                })
443            }
444            _ => Err(CompletionError::ResponseError(
445                "Chat response does not include an assistant message".into(),
446            )),
447        }
448    }
449}
450
451// ---------- Completion Model ----------
452
453#[derive(Clone)]
454pub struct CompletionModel<T> {
455    client: Client<T>,
456    pub model: String,
457}
458
459impl<T> CompletionModel<T> {
460    pub fn new(client: Client<T>, model: &str) -> Self {
461        Self {
462            client,
463            model: model.to_owned(),
464        }
465    }
466
467    fn create_completion_request(
468        &self,
469        completion_request: CompletionRequest,
470    ) -> Result<Value, CompletionError> {
471        if completion_request.tool_choice.is_some() {
472            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
473        }
474
475        // Build up the order of messages (context, chat_history)
476        let mut partial_history = vec![];
477        if let Some(docs) = completion_request.normalized_documents() {
478            partial_history.push(docs);
479        }
480        partial_history.extend(completion_request.chat_history);
481
482        // Initialize full history with preamble (or empty if non-existent)
483        let mut full_history: Vec<Message> = completion_request
484            .preamble
485            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
486
487        // Convert and extend the rest of the history
488        full_history.extend(
489            partial_history
490                .into_iter()
491                .map(|msg| msg.try_into())
492                .collect::<Result<Vec<Vec<Message>>, _>>()?
493                .into_iter()
494                .flatten()
495                .collect::<Vec<Message>>(),
496        );
497
498        let mut request_payload = json!({
499            "model": self.model,
500            "messages": full_history,
501            "stream": false,
502        });
503
504        // Convert internal prompt into a provider Message
505        let options = if let Some(mut extra) = completion_request.additional_params {
506            if extra.get("think").is_some() {
507                request_payload["think"] = extra["think"].take();
508            }
509            json_utils::merge(
510                json!({ "temperature": completion_request.temperature }),
511                extra,
512            )
513        } else {
514            json!({ "temperature": completion_request.temperature })
515        };
516
517        request_payload["options"] = options;
518
519        if !completion_request.tools.is_empty() {
520            request_payload["tools"] = json!(
521                completion_request
522                    .tools
523                    .into_iter()
524                    .map(|tool| tool.into())
525                    .collect::<Vec<ToolDefinition>>()
526            );
527        }
528
529        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
530
531        Ok(request_payload)
532    }
533}
534
535// ---------- CompletionModel Implementation ----------
536
537#[derive(Clone, Serialize, Deserialize, Debug)]
538pub struct StreamingCompletionResponse {
539    pub done_reason: Option<String>,
540    pub total_duration: Option<u64>,
541    pub load_duration: Option<u64>,
542    pub prompt_eval_count: Option<u64>,
543    pub prompt_eval_duration: Option<u64>,
544    pub eval_count: Option<u64>,
545    pub eval_duration: Option<u64>,
546}
547
548impl GetTokenUsage for StreamingCompletionResponse {
549    fn token_usage(&self) -> Option<crate::completion::Usage> {
550        let mut usage = crate::completion::Usage::new();
551        let input_tokens = self.prompt_eval_count.unwrap_or_default();
552        let output_tokens = self.eval_count.unwrap_or_default();
553        usage.input_tokens = input_tokens;
554        usage.output_tokens = output_tokens;
555        usage.total_tokens = input_tokens + output_tokens;
556
557        Some(usage)
558    }
559}
560
561impl<T> completion::CompletionModel for CompletionModel<T>
562where
563    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
564{
565    type Response = CompletionResponse;
566    type StreamingResponse = StreamingCompletionResponse;
567
568    #[cfg_attr(feature = "worker", worker::send)]
569    async fn completion(
570        &self,
571        completion_request: CompletionRequest,
572    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
573        let preamble = completion_request.preamble.clone();
574        let request = self.create_completion_request(completion_request)?;
575
576        let span = if tracing::Span::current().is_disabled() {
577            info_span!(
578                target: "rig::completions",
579                "chat",
580                gen_ai.operation.name = "chat",
581                gen_ai.provider.name = "ollama",
582                gen_ai.request.model = self.model,
583                gen_ai.system_instructions = preamble,
584                gen_ai.response.id = tracing::field::Empty,
585                gen_ai.response.model = tracing::field::Empty,
586                gen_ai.usage.output_tokens = tracing::field::Empty,
587                gen_ai.usage.input_tokens = tracing::field::Empty,
588                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
589                gen_ai.output.messages = tracing::field::Empty,
590            )
591        } else {
592            tracing::Span::current()
593        };
594
595        let body = serde_json::to_vec(&request)?;
596
597        let req = self
598            .client
599            .post("api/chat")
600            .header("Content-Type", "application/json")
601            .body(body)
602            .map_err(http_client::Error::from)?;
603
604        let async_block = async move {
605            let response = self.client.http_client.send::<_, Bytes>(req).await?;
606            let status = response.status();
607            let response_body = response.into_body().into_future().await?.to_vec();
608
609            if !status.is_success() {
610                return Err(CompletionError::ProviderError(
611                    String::from_utf8_lossy(&response_body).to_string(),
612                ));
613            }
614
615            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
616            let span = tracing::Span::current();
617            span.record("gen_ai.response.model_name", &response.model);
618            span.record(
619                "gen_ai.output.messages",
620                serde_json::to_string(&vec![&response.message]).unwrap(),
621            );
622            span.record(
623                "gen_ai.usage.input_tokens",
624                response.prompt_eval_count.unwrap_or_default(),
625            );
626            span.record(
627                "gen_ai.usage.output_tokens",
628                response.eval_count.unwrap_or_default(),
629            );
630
631            tracing::debug!(target: "rig", "Received response from Ollama: {}", serde_json::to_string_pretty(&response)?);
632
633            let response: completion::CompletionResponse<CompletionResponse> =
634                response.try_into()?;
635
636            Ok(response)
637        };
638
639        tracing::Instrument::instrument(async_block, span).await
640    }
641
642    #[cfg_attr(feature = "worker", worker::send)]
643    async fn stream(
644        &self,
645        request: CompletionRequest,
646    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
647    {
648        let preamble = request.preamble.clone();
649        let mut request = self.create_completion_request(request)?;
650        merge_inplace(&mut request, json!({"stream": true}));
651
652        let span = if tracing::Span::current().is_disabled() {
653            info_span!(
654                target: "rig::completions",
655                "chat_streaming",
656                gen_ai.operation.name = "chat_streaming",
657                gen_ai.provider.name = "ollama",
658                gen_ai.request.model = self.model,
659                gen_ai.system_instructions = preamble,
660                gen_ai.response.id = tracing::field::Empty,
661                gen_ai.response.model = self.model,
662                gen_ai.usage.output_tokens = tracing::field::Empty,
663                gen_ai.usage.input_tokens = tracing::field::Empty,
664                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
665                gen_ai.output.messages = tracing::field::Empty,
666            )
667        } else {
668            tracing::Span::current()
669        };
670
671        let body = serde_json::to_vec(&request)?;
672
673        let req = self
674            .client
675            .post("api/chat")
676            .header("Content-Type", "application/json")
677            .body(body)
678            .map_err(http_client::Error::from)?;
679
680        let response = self.client.http_client.send_streaming(req).await?;
681        let status = response.status();
682        let mut byte_stream = response.into_body();
683
684        if !status.is_success() {
685            return Err(CompletionError::ProviderError(format!(
686                "Got error status code trying to send a request to Ollama: {status}"
687            )));
688        }
689
690        let stream = try_stream! {
691            let span = tracing::Span::current();
692            let mut tool_calls_final = Vec::new();
693            let mut text_response = String::new();
694            let mut thinking_response = String::new();
695
696            while let Some(chunk) = byte_stream.next().await {
697                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
698
699                for line in bytes.split(|&b| b == b'\n') {
700                    if line.is_empty() {
701                        continue;
702                    }
703
704                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
705
706                    let response: CompletionResponse = serde_json::from_slice(line)?;
707
708                    if response.done {
709                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
710                        span.record("gen_ai.usage.output_tokens", response.eval_count);
711                        let message = Message::Assistant {
712                            content: text_response.clone(),
713                            thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
714                            images: None,
715                            name: None,
716                            tool_calls: tool_calls_final.clone()
717                        };
718                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
719                        yield RawStreamingChoice::FinalResponse(
720                            StreamingCompletionResponse {
721                                total_duration: response.total_duration,
722                                load_duration: response.load_duration,
723                                prompt_eval_count: response.prompt_eval_count,
724                                prompt_eval_duration: response.prompt_eval_duration,
725                                eval_count: response.eval_count,
726                                eval_duration: response.eval_duration,
727                                done_reason: response.done_reason,
728                            }
729                        );
730                        break;
731                    }
732
733                    if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
734                        if let Some(thinking_content) = thinking
735                            && !thinking_content.is_empty() {
736                            thinking_response += &thinking_content;
737                            yield RawStreamingChoice::Reasoning {
738                                reasoning: thinking_content,
739                                id: None,
740                                signature: None,
741                            };
742                        }
743
744                        if !content.is_empty() {
745                            text_response += &content;
746                            yield RawStreamingChoice::Message(content);
747                        }
748
749                        for tool_call in tool_calls {
750                            tool_calls_final.push(tool_call.clone());
751                            yield RawStreamingChoice::ToolCall {
752                                id: String::new(),
753                                name: tool_call.function.name,
754                                arguments: tool_call.function.arguments,
755                                call_id: None,
756                            };
757                        }
758                    }
759                }
760            }
761        }.instrument(span);
762
763        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
764            stream,
765        )))
766    }
767}
768
769// ---------- Tool Definition Conversion ----------
770
771/// Ollama-required tool definition format.
772#[derive(Clone, Debug, Deserialize, Serialize)]
773pub struct ToolDefinition {
774    #[serde(rename = "type")]
775    pub type_field: String, // Fixed as "function"
776    pub function: completion::ToolDefinition,
777}
778
779/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
780impl From<crate::completion::ToolDefinition> for ToolDefinition {
781    fn from(tool: crate::completion::ToolDefinition) -> Self {
782        ToolDefinition {
783            type_field: "function".to_owned(),
784            function: completion::ToolDefinition {
785                name: tool.name,
786                description: tool.description,
787                parameters: tool.parameters,
788            },
789        }
790    }
791}
792
793#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
794pub struct ToolCall {
795    #[serde(default, rename = "type")]
796    pub r#type: ToolType,
797    pub function: Function,
798}
799#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
800#[serde(rename_all = "lowercase")]
801pub enum ToolType {
802    #[default]
803    Function,
804}
805#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
806pub struct Function {
807    pub name: String,
808    pub arguments: Value,
809}
810
811// ---------- Provider Message Definition ----------
812
813#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
814#[serde(tag = "role", rename_all = "lowercase")]
815pub enum Message {
816    User {
817        content: String,
818        #[serde(skip_serializing_if = "Option::is_none")]
819        images: Option<Vec<String>>,
820        #[serde(skip_serializing_if = "Option::is_none")]
821        name: Option<String>,
822    },
823    Assistant {
824        #[serde(default)]
825        content: String,
826        #[serde(skip_serializing_if = "Option::is_none")]
827        thinking: Option<String>,
828        #[serde(skip_serializing_if = "Option::is_none")]
829        images: Option<Vec<String>>,
830        #[serde(skip_serializing_if = "Option::is_none")]
831        name: Option<String>,
832        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
833        tool_calls: Vec<ToolCall>,
834    },
835    System {
836        content: String,
837        #[serde(skip_serializing_if = "Option::is_none")]
838        images: Option<Vec<String>>,
839        #[serde(skip_serializing_if = "Option::is_none")]
840        name: Option<String>,
841    },
842    #[serde(rename = "tool")]
843    ToolResult {
844        #[serde(rename = "tool_name")]
845        name: String,
846        content: String,
847    },
848}
849
850/// -----------------------------
851/// Provider Message Conversions
852/// -----------------------------
853/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
854/// (Only User and Assistant variants are supported.)
855impl TryFrom<crate::message::Message> for Vec<Message> {
856    type Error = crate::message::MessageError;
857    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
858        use crate::message::Message as InternalMessage;
859        match internal_msg {
860            InternalMessage::User { content, .. } => {
861                let (tool_results, other_content): (Vec<_>, Vec<_>) =
862                    content.into_iter().partition(|content| {
863                        matches!(content, crate::message::UserContent::ToolResult(_))
864                    });
865
866                if !tool_results.is_empty() {
867                    tool_results
868                        .into_iter()
869                        .map(|content| match content {
870                            crate::message::UserContent::ToolResult(
871                                crate::message::ToolResult { id, content, .. },
872                            ) => {
873                                // Ollama expects a single string for tool results, so we concatenate
874                                let content_string = content
875                                    .into_iter()
876                                    .map(|content| match content {
877                                        crate::message::ToolResultContent::Text(text) => text.text,
878                                        _ => "[Non-text content]".to_string(),
879                                    })
880                                    .collect::<Vec<_>>()
881                                    .join("\n");
882
883                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
884                                    name: id,
885                                    content: content_string,
886                                })
887                            }
888                            _ => unreachable!(),
889                        })
890                        .collect::<Result<Vec<_>, _>>()
891                } else {
892                    // Ollama requires separate text content and images array
893                    let (texts, images) = other_content.into_iter().fold(
894                        (Vec::new(), Vec::new()),
895                        |(mut texts, mut images), content| {
896                            match content {
897                                crate::message::UserContent::Text(crate::message::Text {
898                                    text,
899                                }) => texts.push(text),
900                                crate::message::UserContent::Image(crate::message::Image {
901                                    data: DocumentSourceKind::Base64(data),
902                                    ..
903                                }) => images.push(data),
904                                crate::message::UserContent::Document(
905                                    crate::message::Document {
906                                        data:
907                                            DocumentSourceKind::Base64(data)
908                                            | DocumentSourceKind::String(data),
909                                        ..
910                                    },
911                                ) => texts.push(data),
912                                _ => {} // Audio not supported by Ollama
913                            }
914                            (texts, images)
915                        },
916                    );
917
918                    Ok(vec![Message::User {
919                        content: texts.join(" "),
920                        images: if images.is_empty() {
921                            None
922                        } else {
923                            Some(
924                                images
925                                    .into_iter()
926                                    .map(|x| x.to_string())
927                                    .collect::<Vec<String>>(),
928                            )
929                        },
930                        name: None,
931                    }])
932                }
933            }
934            InternalMessage::Assistant { content, .. } => {
935                let mut thinking: Option<String> = None;
936                let (text_content, tool_calls) = content.into_iter().fold(
937                    (Vec::new(), Vec::new()),
938                    |(mut texts, mut tools), content| {
939                        match content {
940                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
941                            crate::message::AssistantContent::ToolCall(tool_call) => {
942                                tools.push(tool_call)
943                            }
944                            crate::message::AssistantContent::Reasoning(
945                                crate::message::Reasoning { reasoning, .. },
946                            ) => {
947                                thinking =
948                                    Some(reasoning.first().cloned().unwrap_or(String::new()));
949                            }
950                        }
951                        (texts, tools)
952                    },
953                );
954
955                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
956                //  so either `content` or `tool_calls` will have some content.
957                Ok(vec![Message::Assistant {
958                    content: text_content.join(" "),
959                    thinking,
960                    images: None,
961                    name: None,
962                    tool_calls: tool_calls
963                        .into_iter()
964                        .map(|tool_call| tool_call.into())
965                        .collect::<Vec<_>>(),
966                }])
967            }
968        }
969    }
970}
971
972/// Conversion from provider Message to a completion message.
973/// This is needed so that responses can be converted back into chat history.
974impl From<Message> for crate::completion::Message {
975    fn from(msg: Message) -> Self {
976        match msg {
977            Message::User { content, .. } => crate::completion::Message::User {
978                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
979                    text: content,
980                })),
981            },
982            Message::Assistant {
983                content,
984                tool_calls,
985                ..
986            } => {
987                let mut assistant_contents =
988                    vec![crate::completion::message::AssistantContent::Text(Text {
989                        text: content,
990                    })];
991                for tc in tool_calls {
992                    assistant_contents.push(
993                        crate::completion::message::AssistantContent::tool_call(
994                            tc.function.name.clone(),
995                            tc.function.name,
996                            tc.function.arguments,
997                        ),
998                    );
999                }
1000                crate::completion::Message::Assistant {
1001                    id: None,
1002                    content: OneOrMany::many(assistant_contents).unwrap(),
1003                }
1004            }
1005            // System and ToolResult are converted to User message as needed.
1006            Message::System { content, .. } => crate::completion::Message::User {
1007                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
1008                    text: content,
1009                })),
1010            },
1011            Message::ToolResult { name, content } => crate::completion::Message::User {
1012                content: OneOrMany::one(message::UserContent::tool_result(
1013                    name,
1014                    OneOrMany::one(message::ToolResultContent::text(content)),
1015                )),
1016            },
1017        }
1018    }
1019}
1020
1021impl Message {
1022    /// Constructs a system message.
1023    pub fn system(content: &str) -> Self {
1024        Message::System {
1025            content: content.to_owned(),
1026            images: None,
1027            name: None,
1028        }
1029    }
1030}
1031
1032// ---------- Additional Message Types ----------
1033
1034impl From<crate::message::ToolCall> for ToolCall {
1035    fn from(tool_call: crate::message::ToolCall) -> Self {
1036        Self {
1037            r#type: ToolType::Function,
1038            function: Function {
1039                name: tool_call.function.name,
1040                arguments: tool_call.function.arguments,
1041            },
1042        }
1043    }
1044}
1045
1046#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1047pub struct SystemContent {
1048    #[serde(default)]
1049    r#type: SystemContentType,
1050    text: String,
1051}
1052
1053#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1054#[serde(rename_all = "lowercase")]
1055pub enum SystemContentType {
1056    #[default]
1057    Text,
1058}
1059
1060impl From<String> for SystemContent {
1061    fn from(s: String) -> Self {
1062        SystemContent {
1063            r#type: SystemContentType::default(),
1064            text: s,
1065        }
1066    }
1067}
1068
1069impl FromStr for SystemContent {
1070    type Err = std::convert::Infallible;
1071    fn from_str(s: &str) -> Result<Self, Self::Err> {
1072        Ok(SystemContent {
1073            r#type: SystemContentType::default(),
1074            text: s.to_string(),
1075        })
1076    }
1077}
1078
1079#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1080pub struct AssistantContent {
1081    pub text: String,
1082}
1083
1084impl FromStr for AssistantContent {
1085    type Err = std::convert::Infallible;
1086    fn from_str(s: &str) -> Result<Self, Self::Err> {
1087        Ok(AssistantContent { text: s.to_owned() })
1088    }
1089}
1090
1091#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1092#[serde(tag = "type", rename_all = "lowercase")]
1093pub enum UserContent {
1094    Text { text: String },
1095    Image { image_url: ImageUrl },
1096    // Audio variant removed as Ollama API does not support audio input.
1097}
1098
1099impl FromStr for UserContent {
1100    type Err = std::convert::Infallible;
1101    fn from_str(s: &str) -> Result<Self, Self::Err> {
1102        Ok(UserContent::Text { text: s.to_owned() })
1103    }
1104}
1105
1106#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1107pub struct ImageUrl {
1108    pub url: String,
1109    #[serde(default)]
1110    pub detail: ImageDetail,
1111}
1112
1113// =================================================================
1114// Tests
1115// =================================================================
1116
1117#[cfg(test)]
1118mod tests {
1119    use super::*;
1120    use serde_json::json;
1121
1122    // Test deserialization and conversion for the /api/chat endpoint.
1123    #[tokio::test]
1124    async fn test_chat_completion() {
1125        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1126        let sample_chat_response = json!({
1127            "model": "llama3.2",
1128            "created_at": "2023-08-04T19:22:45.499127Z",
1129            "message": {
1130                "role": "assistant",
1131                "content": "The sky is blue because of Rayleigh scattering.",
1132                "images": null,
1133                "tool_calls": [
1134                    {
1135                        "type": "function",
1136                        "function": {
1137                            "name": "get_current_weather",
1138                            "arguments": {
1139                                "location": "San Francisco, CA",
1140                                "format": "celsius"
1141                            }
1142                        }
1143                    }
1144                ]
1145            },
1146            "done": true,
1147            "total_duration": 8000000000u64,
1148            "load_duration": 6000000u64,
1149            "prompt_eval_count": 61u64,
1150            "prompt_eval_duration": 400000000u64,
1151            "eval_count": 468u64,
1152            "eval_duration": 7700000000u64
1153        });
1154        let sample_text = sample_chat_response.to_string();
1155
1156        let chat_resp: CompletionResponse =
1157            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1158        let conv: completion::CompletionResponse<CompletionResponse> =
1159            chat_resp.try_into().unwrap();
1160        assert!(
1161            !conv.choice.is_empty(),
1162            "Expected non-empty choice in chat response"
1163        );
1164    }
1165
1166    // Test conversion from provider Message to completion Message.
1167    #[test]
1168    fn test_message_conversion() {
1169        // Construct a provider Message (User variant with String content).
1170        let provider_msg = Message::User {
1171            content: "Test message".to_owned(),
1172            images: None,
1173            name: None,
1174        };
1175        // Convert it into a completion::Message.
1176        let comp_msg: crate::completion::Message = provider_msg.into();
1177        match comp_msg {
1178            crate::completion::Message::User { content } => {
1179                // Assume OneOrMany<T> has a method first() to access the first element.
1180                let first_content = content.first();
1181                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1182                match first_content {
1183                    crate::completion::message::UserContent::Text(text_struct) => {
1184                        assert_eq!(text_struct.text, "Test message");
1185                    }
1186                    _ => panic!("Expected text content in conversion"),
1187                }
1188            }
1189            _ => panic!("Conversion from provider Message to completion Message failed"),
1190        }
1191    }
1192
1193    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1194    #[test]
1195    fn test_tool_definition_conversion() {
1196        // Internal tool definition from the completion module.
1197        let internal_tool = crate::completion::ToolDefinition {
1198            name: "get_current_weather".to_owned(),
1199            description: "Get the current weather for a location".to_owned(),
1200            parameters: json!({
1201                "type": "object",
1202                "properties": {
1203                    "location": {
1204                        "type": "string",
1205                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1206                    },
1207                    "format": {
1208                        "type": "string",
1209                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1210                        "enum": ["celsius", "fahrenheit"]
1211                    }
1212                },
1213                "required": ["location", "format"]
1214            }),
1215        };
1216        // Convert internal tool to Ollama's tool definition.
1217        let ollama_tool: ToolDefinition = internal_tool.into();
1218        assert_eq!(ollama_tool.type_field, "function");
1219        assert_eq!(ollama_tool.function.name, "get_current_weather");
1220        assert_eq!(
1221            ollama_tool.function.description,
1222            "Get the current weather for a location"
1223        );
1224        // Check JSON fields in parameters.
1225        let params = &ollama_tool.function.parameters;
1226        assert_eq!(params["properties"]["location"]["type"], "string");
1227    }
1228
1229    // Test deserialization of chat response with thinking content
1230    #[tokio::test]
1231    async fn test_chat_completion_with_thinking() {
1232        let sample_response = json!({
1233            "model": "qwen-thinking",
1234            "created_at": "2023-08-04T19:22:45.499127Z",
1235            "message": {
1236                "role": "assistant",
1237                "content": "The answer is 42.",
1238                "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1239                "images": null,
1240                "tool_calls": []
1241            },
1242            "done": true,
1243            "total_duration": 8000000000u64,
1244            "load_duration": 6000000u64,
1245            "prompt_eval_count": 61u64,
1246            "prompt_eval_duration": 400000000u64,
1247            "eval_count": 468u64,
1248            "eval_duration": 7700000000u64
1249        });
1250
1251        let chat_resp: CompletionResponse =
1252            serde_json::from_value(sample_response).expect("Failed to deserialize");
1253
1254        // Verify thinking field is present
1255        if let Message::Assistant {
1256            thinking, content, ..
1257        } = &chat_resp.message
1258        {
1259            assert_eq!(
1260                thinking.as_ref().unwrap(),
1261                "Let me think about this carefully. The question asks for the meaning of life..."
1262            );
1263            assert_eq!(content, "The answer is 42.");
1264        } else {
1265            panic!("Expected Assistant message");
1266        }
1267    }
1268
1269    // Test deserialization of chat response without thinking content
1270    #[tokio::test]
1271    async fn test_chat_completion_without_thinking() {
1272        let sample_response = json!({
1273            "model": "llama3.2",
1274            "created_at": "2023-08-04T19:22:45.499127Z",
1275            "message": {
1276                "role": "assistant",
1277                "content": "Hello!",
1278                "images": null,
1279                "tool_calls": []
1280            },
1281            "done": true,
1282            "total_duration": 8000000000u64,
1283            "load_duration": 6000000u64,
1284            "prompt_eval_count": 10u64,
1285            "prompt_eval_duration": 400000000u64,
1286            "eval_count": 5u64,
1287            "eval_duration": 7700000000u64
1288        });
1289
1290        let chat_resp: CompletionResponse =
1291            serde_json::from_value(sample_response).expect("Failed to deserialize");
1292
1293        // Verify thinking field is None when not provided
1294        if let Message::Assistant {
1295            thinking, content, ..
1296        } = &chat_resp.message
1297        {
1298            assert!(thinking.is_none());
1299            assert_eq!(content, "Hello!");
1300        } else {
1301            panic!("Expected Assistant message");
1302        }
1303    }
1304
1305    // Test deserialization of streaming response with thinking content
1306    #[test]
1307    fn test_streaming_response_with_thinking() {
1308        let sample_chunk = json!({
1309            "model": "qwen-thinking",
1310            "created_at": "2023-08-04T19:22:45.499127Z",
1311            "message": {
1312                "role": "assistant",
1313                "content": "",
1314                "thinking": "Analyzing the problem...",
1315                "images": null,
1316                "tool_calls": []
1317            },
1318            "done": false
1319        });
1320
1321        let chunk: CompletionResponse =
1322            serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1323
1324        if let Message::Assistant {
1325            thinking, content, ..
1326        } = &chunk.message
1327        {
1328            assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1329            assert_eq!(content, "");
1330        } else {
1331            panic!("Expected Assistant message");
1332        }
1333    }
1334
1335    // Test message conversion with thinking content
1336    #[test]
1337    fn test_message_conversion_with_thinking() {
1338        // Create an internal message with reasoning content
1339        let reasoning_content = crate::message::Reasoning {
1340            id: None,
1341            reasoning: vec!["Step 1: Consider the problem".to_string()],
1342            signature: None,
1343        };
1344
1345        let internal_msg = crate::message::Message::Assistant {
1346            id: None,
1347            content: crate::OneOrMany::many(vec![
1348                crate::message::AssistantContent::Reasoning(reasoning_content),
1349                crate::message::AssistantContent::Text(crate::message::Text {
1350                    text: "The answer is X".to_string(),
1351                }),
1352            ])
1353            .unwrap(),
1354        };
1355
1356        // Convert to provider Message
1357        let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1358        assert_eq!(provider_msgs.len(), 1);
1359
1360        if let Message::Assistant {
1361            thinking, content, ..
1362        } = &provider_msgs[0]
1363        {
1364            assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1365            assert_eq!(content, "The answer is X");
1366        } else {
1367            panic!("Expected Assistant message with thinking");
1368        }
1369    }
1370
1371    // Test empty thinking content is handled correctly
1372    #[test]
1373    fn test_empty_thinking_content() {
1374        let sample_response = json!({
1375            "model": "llama3.2",
1376            "created_at": "2023-08-04T19:22:45.499127Z",
1377            "message": {
1378                "role": "assistant",
1379                "content": "Response",
1380                "thinking": "",
1381                "images": null,
1382                "tool_calls": []
1383            },
1384            "done": true,
1385            "total_duration": 8000000000u64,
1386            "load_duration": 6000000u64,
1387            "prompt_eval_count": 10u64,
1388            "prompt_eval_duration": 400000000u64,
1389            "eval_count": 5u64,
1390            "eval_duration": 7700000000u64
1391        });
1392
1393        let chat_resp: CompletionResponse =
1394            serde_json::from_value(sample_response).expect("Failed to deserialize");
1395
1396        if let Message::Assistant {
1397            thinking, content, ..
1398        } = &chat_resp.message
1399        {
1400            // Empty string should still deserialize as Some("")
1401            assert_eq!(thinking.as_ref().unwrap(), "");
1402            assert_eq!(content, "Response");
1403        } else {
1404            panic!("Expected Assistant message");
1405        }
1406    }
1407
1408    // Test thinking with tool calls
1409    #[test]
1410    fn test_thinking_with_tool_calls() {
1411        let sample_response = json!({
1412            "model": "qwen-thinking",
1413            "created_at": "2023-08-04T19:22:45.499127Z",
1414            "message": {
1415                "role": "assistant",
1416                "content": "Let me check the weather.",
1417                "thinking": "User wants weather info, I should use the weather tool",
1418                "images": null,
1419                "tool_calls": [
1420                    {
1421                        "type": "function",
1422                        "function": {
1423                            "name": "get_weather",
1424                            "arguments": {
1425                                "location": "San Francisco"
1426                            }
1427                        }
1428                    }
1429                ]
1430            },
1431            "done": true,
1432            "total_duration": 8000000000u64,
1433            "load_duration": 6000000u64,
1434            "prompt_eval_count": 30u64,
1435            "prompt_eval_duration": 400000000u64,
1436            "eval_count": 50u64,
1437            "eval_duration": 7700000000u64
1438        });
1439
1440        let chat_resp: CompletionResponse =
1441            serde_json::from_value(sample_response).expect("Failed to deserialize");
1442
1443        if let Message::Assistant {
1444            thinking,
1445            content,
1446            tool_calls,
1447            ..
1448        } = &chat_resp.message
1449        {
1450            assert_eq!(
1451                thinking.as_ref().unwrap(),
1452                "User wants weather info, I should use the weather tool"
1453            );
1454            assert_eq!(content, "Let me check the weather.");
1455            assert_eq!(tool_calls.len(), 1);
1456            assert_eq!(tool_calls[0].function.name, "get_weather");
1457        } else {
1458            panic!("Expected Assistant message with thinking and tool calls");
1459        }
1460    }
1461}