rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{
42    CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::json_utils::merge_inplace;
47use crate::message::DocumentSourceKind;
48use crate::streaming::RawStreamingChoice;
49use crate::{
50    Embed, OneOrMany,
51    completion::{self, CompletionError, CompletionRequest},
52    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
53    impl_conversion_traits, json_utils, message,
54    message::{ImageDetail, Text},
55    streaming,
56};
57use async_stream::try_stream;
58use futures::StreamExt;
59use reqwest;
60// use reqwest_eventsource::{Event, RequestBuilderExt}; // (Not used currently as Ollama does not support SSE)
61use serde::{Deserialize, Serialize};
62use serde_json::{Value, json};
63use std::{convert::TryFrom, str::FromStr};
64use tracing::info_span;
65use tracing_futures::Instrument;
66// ---------- Main Client ----------
67
68const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
69
70pub struct ClientBuilder<'a, T = reqwest::Client> {
71    base_url: &'a str,
72    http_client: T,
73}
74
75impl<'a, T> ClientBuilder<'a, T>
76where
77    T: Default,
78{
79    #[allow(clippy::new_without_default)]
80    pub fn new() -> Self {
81        Self {
82            base_url: OLLAMA_API_BASE_URL,
83            http_client: Default::default(),
84        }
85    }
86}
87
88impl<'a, T> ClientBuilder<'a, T> {
89    pub fn base_url(mut self, base_url: &'a str) -> Self {
90        self.base_url = base_url;
91        self
92    }
93
94    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
95        ClientBuilder {
96            base_url: self.base_url,
97            http_client,
98        }
99    }
100
101    pub fn build(self) -> Client<T> {
102        Client {
103            base_url: self.base_url.into(),
104            http_client: self.http_client,
105        }
106    }
107}
108
109#[derive(Clone, Debug)]
110pub struct Client<T = reqwest::Client> {
111    base_url: String,
112    http_client: T,
113}
114
115impl<T> Default for Client<T>
116where
117    T: Default,
118{
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl<T> Client<T>
125where
126    T: Default,
127{
128    /// Create a new Ollama client builder.
129    ///
130    /// # Example
131    /// ```
132    /// use rig::providers::ollama::{ClientBuilder, self};
133    ///
134    /// // Initialize the Ollama client
135    /// let client = Client::builder()
136    ///    .build()
137    /// ```
138    pub fn builder<'a>() -> ClientBuilder<'a, T> {
139        ClientBuilder::new()
140    }
141
142    /// Create a new Ollama client. For more control, use the `builder` method.
143    ///
144    /// # Panics
145    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
146    pub fn new() -> Self {
147        Self::builder().build()
148    }
149}
150
151impl<T> Client<T> {
152    fn req(&self, method: http_client::Method, path: &str) -> http_client::Builder {
153        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
154        http_client::Builder::new().method(method).uri(url)
155    }
156
157    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
158        self.req(http_client::Method::POST, path)
159    }
160
161    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
162        self.req(http_client::Method::GET, path)
163    }
164}
165
166impl Client<reqwest::Client> {
167    fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
168        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
169        self.http_client.post(url)
170    }
171}
172
173impl ProviderClient for Client<reqwest::Client> {
174    fn from_env() -> Self {
175        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
176        Self::builder().base_url(&api_base).build()
177    }
178
179    fn from_val(input: crate::client::ProviderValue) -> Self {
180        let crate::client::ProviderValue::Simple(_) = input else {
181            panic!("Incorrect provider value type")
182        };
183
184        Self::new()
185    }
186}
187
188impl CompletionClient for Client<reqwest::Client> {
189    type CompletionModel = CompletionModel<reqwest::Client>;
190
191    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
192        CompletionModel::new(self.clone(), model)
193    }
194}
195
196impl EmbeddingsClient for Client<reqwest::Client> {
197    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
198    fn embedding_model(&self, model: &str) -> EmbeddingModel<reqwest::Client> {
199        EmbeddingModel::new(self.clone(), model, 0)
200    }
201    fn embedding_model_with_ndims(
202        &self,
203        model: &str,
204        ndims: usize,
205    ) -> EmbeddingModel<reqwest::Client> {
206        EmbeddingModel::new(self.clone(), model, ndims)
207    }
208    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
209        EmbeddingsBuilder::new(self.embedding_model(model))
210    }
211}
212
213impl VerifyClient for Client<reqwest::Client> {
214    #[cfg_attr(feature = "worker", worker::send)]
215    async fn verify(&self) -> Result<(), VerifyError> {
216        let req = self
217            .get("api/tags")
218            .body(http_client::NoBody)
219            .map_err(http_client::Error::from)?;
220
221        let response = HttpClientExt::send(&self.http_client, req).await?;
222
223        match response.status() {
224            reqwest::StatusCode::OK => Ok(()),
225            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
226            reqwest::StatusCode::INTERNAL_SERVER_ERROR
227            | reqwest::StatusCode::SERVICE_UNAVAILABLE
228            | reqwest::StatusCode::BAD_GATEWAY => {
229                let text = http_client::text(response).await?;
230                Err(VerifyError::ProviderError(text))
231            }
232            _ => {
233                //response.error_for_status()?;
234                Ok(())
235            }
236        }
237    }
238}
239
240impl_conversion_traits!(
241    AsTranscription,
242    AsImageGeneration,
243    AsAudioGeneration for Client<T>
244);
245
246// ---------- API Error and Response Structures ----------
247
248#[derive(Debug, Deserialize)]
249struct ApiErrorResponse {
250    message: String,
251}
252
253#[derive(Debug, Deserialize)]
254#[serde(untagged)]
255enum ApiResponse<T> {
256    Ok(T),
257    Err(ApiErrorResponse),
258}
259
260// ---------- Embedding API ----------
261
262pub const ALL_MINILM: &str = "all-minilm";
263pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
264
265#[derive(Debug, Serialize, Deserialize)]
266pub struct EmbeddingResponse {
267    pub model: String,
268    pub embeddings: Vec<Vec<f64>>,
269    #[serde(default)]
270    pub total_duration: Option<u64>,
271    #[serde(default)]
272    pub load_duration: Option<u64>,
273    #[serde(default)]
274    pub prompt_eval_count: Option<u64>,
275}
276
277impl From<ApiErrorResponse> for EmbeddingError {
278    fn from(err: ApiErrorResponse) -> Self {
279        EmbeddingError::ProviderError(err.message)
280    }
281}
282
283impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
284    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
285        match value {
286            ApiResponse::Ok(response) => Ok(response),
287            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
288        }
289    }
290}
291
292// ---------- Embedding Model ----------
293
294#[derive(Clone)]
295pub struct EmbeddingModel<T> {
296    client: Client<T>,
297    pub model: String,
298    ndims: usize,
299}
300
301impl<T> EmbeddingModel<T> {
302    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
303        Self {
304            client,
305            model: model.to_owned(),
306            ndims,
307        }
308    }
309}
310
311impl embeddings::EmbeddingModel for EmbeddingModel<reqwest::Client> {
312    const MAX_DOCUMENTS: usize = 1024;
313    fn ndims(&self) -> usize {
314        self.ndims
315    }
316    #[cfg_attr(feature = "worker", worker::send)]
317    async fn embed_texts(
318        &self,
319        documents: impl IntoIterator<Item = String>,
320    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
321        let docs: Vec<String> = documents.into_iter().collect();
322
323        let body = serde_json::to_vec(&json!({
324            "model": self.model,
325            "input": docs
326        }))?;
327
328        let req = self
329            .client
330            .post("api/embed")
331            .header("Content-Type", "application/json")
332            .body(body)
333            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
334
335        let response = HttpClientExt::send(&self.client.http_client, req).await?;
336
337        if !response.status().is_success() {
338            let text = http_client::text(response).await?;
339            return Err(EmbeddingError::ProviderError(text));
340        }
341
342        let bytes: Vec<u8> = response.into_body().await?;
343
344        let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
345
346        if api_resp.embeddings.len() != docs.len() {
347            return Err(EmbeddingError::ResponseError(
348                "Number of returned embeddings does not match input".into(),
349            ));
350        }
351        Ok(api_resp
352            .embeddings
353            .into_iter()
354            .zip(docs.into_iter())
355            .map(|(vec, document)| embeddings::Embedding { document, vec })
356            .collect())
357    }
358}
359
360// ---------- Completion API ----------
361
362pub const LLAMA3_2: &str = "llama3.2";
363pub const LLAVA: &str = "llava";
364pub const MISTRAL: &str = "mistral";
365
366#[derive(Debug, Serialize, Deserialize)]
367pub struct CompletionResponse {
368    pub model: String,
369    pub created_at: String,
370    pub message: Message,
371    pub done: bool,
372    #[serde(default)]
373    pub done_reason: Option<String>,
374    #[serde(default)]
375    pub total_duration: Option<u64>,
376    #[serde(default)]
377    pub load_duration: Option<u64>,
378    #[serde(default)]
379    pub prompt_eval_count: Option<u64>,
380    #[serde(default)]
381    pub prompt_eval_duration: Option<u64>,
382    #[serde(default)]
383    pub eval_count: Option<u64>,
384    #[serde(default)]
385    pub eval_duration: Option<u64>,
386}
387impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
388    type Error = CompletionError;
389    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
390        match resp.message {
391            // Process only if an assistant message is present.
392            Message::Assistant {
393                content,
394                thinking,
395                tool_calls,
396                ..
397            } => {
398                let mut assistant_contents = Vec::new();
399                // Add the assistant's text content if any.
400                if !content.is_empty() {
401                    assistant_contents.push(completion::AssistantContent::text(&content));
402                }
403                // Process tool_calls following Ollama's chat response definition.
404                // Each ToolCall has an id, a type, and a function field.
405                for tc in tool_calls.iter() {
406                    assistant_contents.push(completion::AssistantContent::tool_call(
407                        tc.function.name.clone(),
408                        tc.function.name.clone(),
409                        tc.function.arguments.clone(),
410                    ));
411                }
412                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
413                    CompletionError::ResponseError("No content provided".to_owned())
414                })?;
415                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
416                let completion_tokens = resp.eval_count.unwrap_or(0);
417
418                let raw_response = CompletionResponse {
419                    model: resp.model,
420                    created_at: resp.created_at,
421                    done: resp.done,
422                    done_reason: resp.done_reason,
423                    total_duration: resp.total_duration,
424                    load_duration: resp.load_duration,
425                    prompt_eval_count: resp.prompt_eval_count,
426                    prompt_eval_duration: resp.prompt_eval_duration,
427                    eval_count: resp.eval_count,
428                    eval_duration: resp.eval_duration,
429                    message: Message::Assistant {
430                        content,
431                        thinking,
432                        images: None,
433                        name: None,
434                        tool_calls,
435                    },
436                };
437
438                Ok(completion::CompletionResponse {
439                    choice,
440                    usage: Usage {
441                        input_tokens: prompt_tokens,
442                        output_tokens: completion_tokens,
443                        total_tokens: prompt_tokens + completion_tokens,
444                    },
445                    raw_response,
446                })
447            }
448            _ => Err(CompletionError::ResponseError(
449                "Chat response does not include an assistant message".into(),
450            )),
451        }
452    }
453}
454
455// ---------- Completion Model ----------
456
457#[derive(Clone)]
458pub struct CompletionModel<T> {
459    client: Client<T>,
460    pub model: String,
461}
462
463impl<T> CompletionModel<T> {
464    pub fn new(client: Client<T>, model: &str) -> Self {
465        Self {
466            client,
467            model: model.to_owned(),
468        }
469    }
470
471    fn create_completion_request(
472        &self,
473        completion_request: CompletionRequest,
474    ) -> Result<Value, CompletionError> {
475        if completion_request.tool_choice.is_some() {
476            tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
477        }
478
479        // Build up the order of messages (context, chat_history)
480        let mut partial_history = vec![];
481        if let Some(docs) = completion_request.normalized_documents() {
482            partial_history.push(docs);
483        }
484        partial_history.extend(completion_request.chat_history);
485
486        // Initialize full history with preamble (or empty if non-existent)
487        let mut full_history: Vec<Message> = completion_request
488            .preamble
489            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
490
491        // Convert and extend the rest of the history
492        full_history.extend(
493            partial_history
494                .into_iter()
495                .map(|msg| msg.try_into())
496                .collect::<Result<Vec<Vec<Message>>, _>>()?
497                .into_iter()
498                .flatten()
499                .collect::<Vec<Message>>(),
500        );
501
502        // Convert internal prompt into a provider Message
503        let options = if let Some(extra) = completion_request.additional_params {
504            json_utils::merge(
505                json!({ "temperature": completion_request.temperature }),
506                extra,
507            )
508        } else {
509            json!({ "temperature": completion_request.temperature })
510        };
511
512        let mut request_payload = json!({
513            "model": self.model,
514            "messages": full_history,
515            "options": options,
516            "stream": false,
517        });
518        if !completion_request.tools.is_empty() {
519            request_payload["tools"] = json!(
520                completion_request
521                    .tools
522                    .into_iter()
523                    .map(|tool| tool.into())
524                    .collect::<Vec<ToolDefinition>>()
525            );
526        }
527
528        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
529
530        Ok(request_payload)
531    }
532}
533
534// ---------- CompletionModel Implementation ----------
535
536#[derive(Clone, Serialize, Deserialize, Debug)]
537pub struct StreamingCompletionResponse {
538    pub done_reason: Option<String>,
539    pub total_duration: Option<u64>,
540    pub load_duration: Option<u64>,
541    pub prompt_eval_count: Option<u64>,
542    pub prompt_eval_duration: Option<u64>,
543    pub eval_count: Option<u64>,
544    pub eval_duration: Option<u64>,
545}
546
547impl GetTokenUsage for StreamingCompletionResponse {
548    fn token_usage(&self) -> Option<crate::completion::Usage> {
549        let mut usage = crate::completion::Usage::new();
550        let input_tokens = self.prompt_eval_count.unwrap_or_default();
551        let output_tokens = self.eval_count.unwrap_or_default();
552        usage.input_tokens = input_tokens;
553        usage.output_tokens = output_tokens;
554        usage.total_tokens = input_tokens + output_tokens;
555
556        Some(usage)
557    }
558}
559
560impl completion::CompletionModel for CompletionModel<reqwest::Client> {
561    type Response = CompletionResponse;
562    type StreamingResponse = StreamingCompletionResponse;
563
564    #[cfg_attr(feature = "worker", worker::send)]
565    async fn completion(
566        &self,
567        completion_request: CompletionRequest,
568    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
569        let preamble = completion_request.preamble.clone();
570        let request = self.create_completion_request(completion_request)?;
571
572        let span = if tracing::Span::current().is_disabled() {
573            info_span!(
574                target: "rig::completions",
575                "chat",
576                gen_ai.operation.name = "chat",
577                gen_ai.provider.name = "ollama",
578                gen_ai.request.model = self.model,
579                gen_ai.system_instructions = preamble,
580                gen_ai.response.id = tracing::field::Empty,
581                gen_ai.response.model = tracing::field::Empty,
582                gen_ai.usage.output_tokens = tracing::field::Empty,
583                gen_ai.usage.input_tokens = tracing::field::Empty,
584                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
585                gen_ai.output.messages = tracing::field::Empty,
586            )
587        } else {
588            tracing::Span::current()
589        };
590
591        let async_block = async move {
592            let response = self
593                .client
594                .reqwest_post("api/chat")
595                .json(&request)
596                .send()
597                .await
598                .map_err(|e| http_client::Error::Instance(e.into()))?;
599
600            if !response.status().is_success() {
601                return Err(CompletionError::ProviderError(
602                    response
603                        .text()
604                        .await
605                        .map_err(|e| http_client::Error::Instance(e.into()))?,
606                ));
607            }
608
609            let bytes = response
610                .bytes()
611                .await
612                .map_err(|e| http_client::Error::Instance(e.into()))?;
613
614            tracing::debug!(target: "rig", "Received response from Ollama: {}", String::from_utf8_lossy(&bytes));
615
616            let response: CompletionResponse = serde_json::from_slice(&bytes)?;
617            let span = tracing::Span::current();
618            span.record("gen_ai.response.model_name", &response.model);
619            span.record(
620                "gen_ai.output.messages",
621                serde_json::to_string(&vec![&response.message]).unwrap(),
622            );
623            span.record(
624                "gen_ai.usage.input_tokens",
625                response.prompt_eval_count.unwrap_or_default(),
626            );
627            span.record(
628                "gen_ai.usage.output_tokens",
629                response.eval_count.unwrap_or_default(),
630            );
631
632            let response: completion::CompletionResponse<CompletionResponse> =
633                response.try_into()?;
634
635            Ok(response)
636        };
637
638        tracing::Instrument::instrument(async_block, span).await
639    }
640
641    #[cfg_attr(feature = "worker", worker::send)]
642    async fn stream(
643        &self,
644        request: CompletionRequest,
645    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
646    {
647        let preamble = request.preamble.clone();
648        let mut request = self.create_completion_request(request)?;
649        merge_inplace(&mut request, json!({"stream": true}));
650
651        let span = if tracing::Span::current().is_disabled() {
652            info_span!(
653                target: "rig::completions",
654                "chat_streaming",
655                gen_ai.operation.name = "chat_streaming",
656                gen_ai.provider.name = "ollama",
657                gen_ai.request.model = self.model,
658                gen_ai.system_instructions = preamble,
659                gen_ai.response.id = tracing::field::Empty,
660                gen_ai.response.model = self.model,
661                gen_ai.usage.output_tokens = tracing::field::Empty,
662                gen_ai.usage.input_tokens = tracing::field::Empty,
663                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
664                gen_ai.output.messages = tracing::field::Empty,
665            )
666        } else {
667            tracing::Span::current()
668        };
669
670        let response = self
671            .client
672            .reqwest_post("api/chat")
673            .json(&request)
674            .send()
675            .await
676            .map_err(|e| http_client::Error::Instance(e.into()))?;
677
678        if !response.status().is_success() {
679            return Err(CompletionError::ProviderError(
680                response
681                    .text()
682                    .await
683                    .map_err(|e| http_client::Error::Instance(e.into()))?,
684            ));
685        }
686
687        let stream = try_stream! {
688            let span = tracing::Span::current();
689            let mut byte_stream = response.bytes_stream();
690            let mut tool_calls_final = Vec::new();
691            let mut text_response = String::new();
692
693            while let Some(chunk) = byte_stream.next().await {
694                let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
695
696                for line in bytes.split(|&b| b == b'\n') {
697                    if line.is_empty() {
698                        continue;
699                    }
700
701                    tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
702
703                    let response: CompletionResponse = serde_json::from_slice(line)?;
704
705                    if response.done {
706                        span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
707                        span.record("gen_ai.usage.output_tokens", response.eval_count);
708                        let message = Message::Assistant {
709                            content: text_response.clone(),
710                            thinking: None,
711                            images: None,
712                            name: None,
713                            tool_calls: tool_calls_final.clone()
714                        };
715                        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
716                        yield RawStreamingChoice::FinalResponse(
717                            StreamingCompletionResponse {
718                                total_duration: response.total_duration,
719                                load_duration: response.load_duration,
720                                prompt_eval_count: response.prompt_eval_count,
721                                prompt_eval_duration: response.prompt_eval_duration,
722                                eval_count: response.eval_count,
723                                eval_duration: response.eval_duration,
724                                done_reason: response.done_reason,
725                            }
726                        );
727                        break;
728                    }
729
730                    if let Message::Assistant { content, tool_calls, .. } = response.message {
731                        if !content.is_empty() {
732                            text_response += &content;
733                            yield RawStreamingChoice::Message(content);
734                        }
735                        for tool_call in tool_calls {
736                            tool_calls_final.push(tool_call.clone());
737                            yield RawStreamingChoice::ToolCall {
738                                id: String::new(),
739                                name: tool_call.function.name,
740                                arguments: tool_call.function.arguments,
741                                call_id: None,
742                            };
743                        }
744                    }
745                }
746            }
747        }.instrument(span);
748
749        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
750            stream,
751        )))
752    }
753}
754
755// ---------- Tool Definition Conversion ----------
756
757/// Ollama-required tool definition format.
758#[derive(Clone, Debug, Deserialize, Serialize)]
759pub struct ToolDefinition {
760    #[serde(rename = "type")]
761    pub type_field: String, // Fixed as "function"
762    pub function: completion::ToolDefinition,
763}
764
765/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
766impl From<crate::completion::ToolDefinition> for ToolDefinition {
767    fn from(tool: crate::completion::ToolDefinition) -> Self {
768        ToolDefinition {
769            type_field: "function".to_owned(),
770            function: completion::ToolDefinition {
771                name: tool.name,
772                description: tool.description,
773                parameters: tool.parameters,
774            },
775        }
776    }
777}
778
779#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
780pub struct ToolCall {
781    #[serde(default, rename = "type")]
782    pub r#type: ToolType,
783    pub function: Function,
784}
785#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
786#[serde(rename_all = "lowercase")]
787pub enum ToolType {
788    #[default]
789    Function,
790}
791#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
792pub struct Function {
793    pub name: String,
794    pub arguments: Value,
795}
796
797// ---------- Provider Message Definition ----------
798
799#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
800#[serde(tag = "role", rename_all = "lowercase")]
801pub enum Message {
802    User {
803        content: String,
804        #[serde(skip_serializing_if = "Option::is_none")]
805        images: Option<Vec<String>>,
806        #[serde(skip_serializing_if = "Option::is_none")]
807        name: Option<String>,
808    },
809    Assistant {
810        #[serde(default)]
811        content: String,
812        #[serde(skip_serializing_if = "Option::is_none")]
813        thinking: Option<String>,
814        #[serde(skip_serializing_if = "Option::is_none")]
815        images: Option<Vec<String>>,
816        #[serde(skip_serializing_if = "Option::is_none")]
817        name: Option<String>,
818        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
819        tool_calls: Vec<ToolCall>,
820    },
821    System {
822        content: String,
823        #[serde(skip_serializing_if = "Option::is_none")]
824        images: Option<Vec<String>>,
825        #[serde(skip_serializing_if = "Option::is_none")]
826        name: Option<String>,
827    },
828    #[serde(rename = "tool")]
829    ToolResult {
830        #[serde(rename = "tool_name")]
831        name: String,
832        content: String,
833    },
834}
835
836/// -----------------------------
837/// Provider Message Conversions
838/// -----------------------------
839/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
840/// (Only User and Assistant variants are supported.)
841impl TryFrom<crate::message::Message> for Vec<Message> {
842    type Error = crate::message::MessageError;
843    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
844        use crate::message::Message as InternalMessage;
845        match internal_msg {
846            InternalMessage::User { content, .. } => {
847                let (tool_results, other_content): (Vec<_>, Vec<_>) =
848                    content.into_iter().partition(|content| {
849                        matches!(content, crate::message::UserContent::ToolResult(_))
850                    });
851
852                if !tool_results.is_empty() {
853                    tool_results
854                        .into_iter()
855                        .map(|content| match content {
856                            crate::message::UserContent::ToolResult(
857                                crate::message::ToolResult { id, content, .. },
858                            ) => {
859                                // Ollama expects a single string for tool results, so we concatenate
860                                let content_string = content
861                                    .into_iter()
862                                    .map(|content| match content {
863                                        crate::message::ToolResultContent::Text(text) => text.text,
864                                        _ => "[Non-text content]".to_string(),
865                                    })
866                                    .collect::<Vec<_>>()
867                                    .join("\n");
868
869                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
870                                    name: id,
871                                    content: content_string,
872                                })
873                            }
874                            _ => unreachable!(),
875                        })
876                        .collect::<Result<Vec<_>, _>>()
877                } else {
878                    // Ollama requires separate text content and images array
879                    let (texts, images) = other_content.into_iter().fold(
880                        (Vec::new(), Vec::new()),
881                        |(mut texts, mut images), content| {
882                            match content {
883                                crate::message::UserContent::Text(crate::message::Text {
884                                    text,
885                                }) => texts.push(text),
886                                crate::message::UserContent::Image(crate::message::Image {
887                                    data: DocumentSourceKind::Base64(data),
888                                    ..
889                                }) => images.push(data),
890                                crate::message::UserContent::Document(
891                                    crate::message::Document {
892                                        data:
893                                            DocumentSourceKind::Base64(data)
894                                            | DocumentSourceKind::String(data),
895                                        ..
896                                    },
897                                ) => texts.push(data),
898                                _ => {} // Audio not supported by Ollama
899                            }
900                            (texts, images)
901                        },
902                    );
903
904                    Ok(vec![Message::User {
905                        content: texts.join(" "),
906                        images: if images.is_empty() {
907                            None
908                        } else {
909                            Some(
910                                images
911                                    .into_iter()
912                                    .map(|x| x.to_string())
913                                    .collect::<Vec<String>>(),
914                            )
915                        },
916                        name: None,
917                    }])
918                }
919            }
920            InternalMessage::Assistant { content, .. } => {
921                let mut thinking: Option<String> = None;
922                let (text_content, tool_calls) = content.into_iter().fold(
923                    (Vec::new(), Vec::new()),
924                    |(mut texts, mut tools), content| {
925                        match content {
926                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
927                            crate::message::AssistantContent::ToolCall(tool_call) => {
928                                tools.push(tool_call)
929                            }
930                            crate::message::AssistantContent::Reasoning(
931                                crate::message::Reasoning { reasoning, .. },
932                            ) => {
933                                thinking =
934                                    Some(reasoning.first().cloned().unwrap_or(String::new()));
935                            }
936                        }
937                        (texts, tools)
938                    },
939                );
940
941                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
942                //  so either `content` or `tool_calls` will have some content.
943                Ok(vec![Message::Assistant {
944                    content: text_content.join(" "),
945                    thinking,
946                    images: None,
947                    name: None,
948                    tool_calls: tool_calls
949                        .into_iter()
950                        .map(|tool_call| tool_call.into())
951                        .collect::<Vec<_>>(),
952                }])
953            }
954        }
955    }
956}
957
958/// Conversion from provider Message to a completion message.
959/// This is needed so that responses can be converted back into chat history.
960impl From<Message> for crate::completion::Message {
961    fn from(msg: Message) -> Self {
962        match msg {
963            Message::User { content, .. } => crate::completion::Message::User {
964                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
965                    text: content,
966                })),
967            },
968            Message::Assistant {
969                content,
970                tool_calls,
971                ..
972            } => {
973                let mut assistant_contents =
974                    vec![crate::completion::message::AssistantContent::Text(Text {
975                        text: content,
976                    })];
977                for tc in tool_calls {
978                    assistant_contents.push(
979                        crate::completion::message::AssistantContent::tool_call(
980                            tc.function.name.clone(),
981                            tc.function.name,
982                            tc.function.arguments,
983                        ),
984                    );
985                }
986                crate::completion::Message::Assistant {
987                    id: None,
988                    content: OneOrMany::many(assistant_contents).unwrap(),
989                }
990            }
991            // System and ToolResult are converted to User message as needed.
992            Message::System { content, .. } => crate::completion::Message::User {
993                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
994                    text: content,
995                })),
996            },
997            Message::ToolResult { name, content } => crate::completion::Message::User {
998                content: OneOrMany::one(message::UserContent::tool_result(
999                    name,
1000                    OneOrMany::one(message::ToolResultContent::text(content)),
1001                )),
1002            },
1003        }
1004    }
1005}
1006
1007impl Message {
1008    /// Constructs a system message.
1009    pub fn system(content: &str) -> Self {
1010        Message::System {
1011            content: content.to_owned(),
1012            images: None,
1013            name: None,
1014        }
1015    }
1016}
1017
1018// ---------- Additional Message Types ----------
1019
1020impl From<crate::message::ToolCall> for ToolCall {
1021    fn from(tool_call: crate::message::ToolCall) -> Self {
1022        Self {
1023            r#type: ToolType::Function,
1024            function: Function {
1025                name: tool_call.function.name,
1026                arguments: tool_call.function.arguments,
1027            },
1028        }
1029    }
1030}
1031
1032#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1033pub struct SystemContent {
1034    #[serde(default)]
1035    r#type: SystemContentType,
1036    text: String,
1037}
1038
1039#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1040#[serde(rename_all = "lowercase")]
1041pub enum SystemContentType {
1042    #[default]
1043    Text,
1044}
1045
1046impl From<String> for SystemContent {
1047    fn from(s: String) -> Self {
1048        SystemContent {
1049            r#type: SystemContentType::default(),
1050            text: s,
1051        }
1052    }
1053}
1054
1055impl FromStr for SystemContent {
1056    type Err = std::convert::Infallible;
1057    fn from_str(s: &str) -> Result<Self, Self::Err> {
1058        Ok(SystemContent {
1059            r#type: SystemContentType::default(),
1060            text: s.to_string(),
1061        })
1062    }
1063}
1064
1065#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1066pub struct AssistantContent {
1067    pub text: String,
1068}
1069
1070impl FromStr for AssistantContent {
1071    type Err = std::convert::Infallible;
1072    fn from_str(s: &str) -> Result<Self, Self::Err> {
1073        Ok(AssistantContent { text: s.to_owned() })
1074    }
1075}
1076
1077#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1078#[serde(tag = "type", rename_all = "lowercase")]
1079pub enum UserContent {
1080    Text { text: String },
1081    Image { image_url: ImageUrl },
1082    // Audio variant removed as Ollama API does not support audio input.
1083}
1084
1085impl FromStr for UserContent {
1086    type Err = std::convert::Infallible;
1087    fn from_str(s: &str) -> Result<Self, Self::Err> {
1088        Ok(UserContent::Text { text: s.to_owned() })
1089    }
1090}
1091
1092#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1093pub struct ImageUrl {
1094    pub url: String,
1095    #[serde(default)]
1096    pub detail: ImageDetail,
1097}
1098
1099// =================================================================
1100// Tests
1101// =================================================================
1102
1103#[cfg(test)]
1104mod tests {
1105    use super::*;
1106    use serde_json::json;
1107
1108    // Test deserialization and conversion for the /api/chat endpoint.
1109    #[tokio::test]
1110    async fn test_chat_completion() {
1111        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
1112        let sample_chat_response = json!({
1113            "model": "llama3.2",
1114            "created_at": "2023-08-04T19:22:45.499127Z",
1115            "message": {
1116                "role": "assistant",
1117                "content": "The sky is blue because of Rayleigh scattering.",
1118                "images": null,
1119                "tool_calls": [
1120                    {
1121                        "type": "function",
1122                        "function": {
1123                            "name": "get_current_weather",
1124                            "arguments": {
1125                                "location": "San Francisco, CA",
1126                                "format": "celsius"
1127                            }
1128                        }
1129                    }
1130                ]
1131            },
1132            "done": true,
1133            "total_duration": 8000000000u64,
1134            "load_duration": 6000000u64,
1135            "prompt_eval_count": 61u64,
1136            "prompt_eval_duration": 400000000u64,
1137            "eval_count": 468u64,
1138            "eval_duration": 7700000000u64
1139        });
1140        let sample_text = sample_chat_response.to_string();
1141
1142        let chat_resp: CompletionResponse =
1143            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1144        let conv: completion::CompletionResponse<CompletionResponse> =
1145            chat_resp.try_into().unwrap();
1146        assert!(
1147            !conv.choice.is_empty(),
1148            "Expected non-empty choice in chat response"
1149        );
1150    }
1151
1152    // Test conversion from provider Message to completion Message.
1153    #[test]
1154    fn test_message_conversion() {
1155        // Construct a provider Message (User variant with String content).
1156        let provider_msg = Message::User {
1157            content: "Test message".to_owned(),
1158            images: None,
1159            name: None,
1160        };
1161        // Convert it into a completion::Message.
1162        let comp_msg: crate::completion::Message = provider_msg.into();
1163        match comp_msg {
1164            crate::completion::Message::User { content } => {
1165                // Assume OneOrMany<T> has a method first() to access the first element.
1166                let first_content = content.first();
1167                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
1168                match first_content {
1169                    crate::completion::message::UserContent::Text(text_struct) => {
1170                        assert_eq!(text_struct.text, "Test message");
1171                    }
1172                    _ => panic!("Expected text content in conversion"),
1173                }
1174            }
1175            _ => panic!("Conversion from provider Message to completion Message failed"),
1176        }
1177    }
1178
1179    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
1180    #[test]
1181    fn test_tool_definition_conversion() {
1182        // Internal tool definition from the completion module.
1183        let internal_tool = crate::completion::ToolDefinition {
1184            name: "get_current_weather".to_owned(),
1185            description: "Get the current weather for a location".to_owned(),
1186            parameters: json!({
1187                "type": "object",
1188                "properties": {
1189                    "location": {
1190                        "type": "string",
1191                        "description": "The location to get the weather for, e.g. San Francisco, CA"
1192                    },
1193                    "format": {
1194                        "type": "string",
1195                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1196                        "enum": ["celsius", "fahrenheit"]
1197                    }
1198                },
1199                "required": ["location", "format"]
1200            }),
1201        };
1202        // Convert internal tool to Ollama's tool definition.
1203        let ollama_tool: ToolDefinition = internal_tool.into();
1204        assert_eq!(ollama_tool.type_field, "function");
1205        assert_eq!(ollama_tool.function.name, "get_current_weather");
1206        assert_eq!(
1207            ollama_tool.function.description,
1208            "Get the current weather for a location"
1209        );
1210        // Check JSON fields in parameters.
1211        let params = &ollama_tool.function.parameters;
1212        assert_eq!(params["properties"]["location"]["type"], "string");
1213    }
1214}