rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::completion::Usage;
43use crate::json_utils::merge_inplace;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46    Embed, OneOrMany,
47    completion::{self, CompletionError, CompletionRequest},
48    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49    impl_conversion_traits, json_utils, message,
50    message::{ImageDetail, Text},
51    streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::{convert::TryFrom, str::FromStr};
59// ---------- Main Client ----------
60
61const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
62
63#[derive(Clone, Debug)]
64pub struct Client {
65    base_url: String,
66    http_client: reqwest::Client,
67}
68
69impl Default for Client {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl Client {
76    pub fn new() -> Self {
77        Self::from_url(OLLAMA_API_BASE_URL)
78    }
79    pub fn from_url(base_url: &str) -> Self {
80        Self {
81            base_url: base_url.to_owned(),
82            http_client: reqwest::Client::builder()
83                .build()
84                .expect("Ollama reqwest client should build"),
85        }
86    }
87
88    /// Use your own `reqwest::Client`.
89    /// The required headers will be automatically attached upon trying to make a request.
90    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
91        self.http_client = client;
92
93        self
94    }
95
96    fn post(&self, path: &str) -> reqwest::RequestBuilder {
97        let url = format!("{}/{}", self.base_url, path);
98        self.http_client.post(url)
99    }
100}
101
102impl ProviderClient for Client {
103    fn from_env() -> Self
104    where
105        Self: Sized,
106    {
107        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
108        Self::from_url(&api_base)
109    }
110
111    fn from_val(input: crate::client::ProviderValue) -> Self {
112        let crate::client::ProviderValue::Simple(_) = input else {
113            panic!("Incorrect provider value type")
114        };
115
116        Self::new()
117    }
118}
119
120impl CompletionClient for Client {
121    type CompletionModel = CompletionModel;
122
123    fn completion_model(&self, model: &str) -> CompletionModel {
124        CompletionModel::new(self.clone(), model)
125    }
126}
127
128impl EmbeddingsClient for Client {
129    type EmbeddingModel = EmbeddingModel;
130    fn embedding_model(&self, model: &str) -> EmbeddingModel {
131        EmbeddingModel::new(self.clone(), model, 0)
132    }
133    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
134        EmbeddingModel::new(self.clone(), model, ndims)
135    }
136    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
137        EmbeddingsBuilder::new(self.embedding_model(model))
138    }
139}
140
141impl_conversion_traits!(
142    AsTranscription,
143    AsImageGeneration,
144    AsAudioGeneration for Client
145);
146
147// ---------- API Error and Response Structures ----------
148
149#[derive(Debug, Deserialize)]
150struct ApiErrorResponse {
151    message: String,
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156enum ApiResponse<T> {
157    Ok(T),
158    Err(ApiErrorResponse),
159}
160
161// ---------- Embedding API ----------
162
163pub const ALL_MINILM: &str = "all-minilm";
164pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
165
166#[derive(Debug, Serialize, Deserialize)]
167pub struct EmbeddingResponse {
168    pub model: String,
169    pub embeddings: Vec<Vec<f64>>,
170    #[serde(default)]
171    pub total_duration: Option<u64>,
172    #[serde(default)]
173    pub load_duration: Option<u64>,
174    #[serde(default)]
175    pub prompt_eval_count: Option<u64>,
176}
177
178impl From<ApiErrorResponse> for EmbeddingError {
179    fn from(err: ApiErrorResponse) -> Self {
180        EmbeddingError::ProviderError(err.message)
181    }
182}
183
184impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
185    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
186        match value {
187            ApiResponse::Ok(response) => Ok(response),
188            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
189        }
190    }
191}
192
193// ---------- Embedding Model ----------
194
195#[derive(Clone)]
196pub struct EmbeddingModel {
197    client: Client,
198    pub model: String,
199    ndims: usize,
200}
201
202impl EmbeddingModel {
203    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
204        Self {
205            client,
206            model: model.to_owned(),
207            ndims,
208        }
209    }
210}
211
212impl embeddings::EmbeddingModel for EmbeddingModel {
213    const MAX_DOCUMENTS: usize = 1024;
214    fn ndims(&self) -> usize {
215        self.ndims
216    }
217    #[cfg_attr(feature = "worker", worker::send)]
218    async fn embed_texts(
219        &self,
220        documents: impl IntoIterator<Item = String>,
221    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
222        let docs: Vec<String> = documents.into_iter().collect();
223        let payload = json!({
224            "model": self.model,
225            "input": docs,
226        });
227        let response = self
228            .client
229            .post("api/embed")
230            .json(&payload)
231            .send()
232            .await
233            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
234        if response.status().is_success() {
235            let api_resp: EmbeddingResponse = response
236                .json()
237                .await
238                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
239            if api_resp.embeddings.len() != docs.len() {
240                return Err(EmbeddingError::ResponseError(
241                    "Number of returned embeddings does not match input".into(),
242                ));
243            }
244            Ok(api_resp
245                .embeddings
246                .into_iter()
247                .zip(docs.into_iter())
248                .map(|(vec, document)| embeddings::Embedding { document, vec })
249                .collect())
250        } else {
251            Err(EmbeddingError::ProviderError(response.text().await?))
252        }
253    }
254}
255
256// ---------- Completion API ----------
257
258pub const LLAMA3_2: &str = "llama3.2";
259pub const LLAVA: &str = "llava";
260pub const MISTRAL: &str = "mistral";
261
262#[derive(Debug, Serialize, Deserialize)]
263pub struct CompletionResponse {
264    pub model: String,
265    pub created_at: String,
266    pub message: Message,
267    pub done: bool,
268    #[serde(default)]
269    pub done_reason: Option<String>,
270    #[serde(default)]
271    pub total_duration: Option<u64>,
272    #[serde(default)]
273    pub load_duration: Option<u64>,
274    #[serde(default)]
275    pub prompt_eval_count: Option<u64>,
276    #[serde(default)]
277    pub prompt_eval_duration: Option<u64>,
278    #[serde(default)]
279    pub eval_count: Option<u64>,
280    #[serde(default)]
281    pub eval_duration: Option<u64>,
282}
283impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
284    type Error = CompletionError;
285    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
286        match resp.message {
287            // Process only if an assistant message is present.
288            Message::Assistant {
289                content,
290                tool_calls,
291                ..
292            } => {
293                let mut assistant_contents = Vec::new();
294                // Add the assistant's text content if any.
295                if !content.is_empty() {
296                    assistant_contents.push(completion::AssistantContent::text(&content));
297                }
298                // Process tool_calls following Ollama's chat response definition.
299                // Each ToolCall has an id, a type, and a function field.
300                for tc in tool_calls.iter() {
301                    assistant_contents.push(completion::AssistantContent::tool_call(
302                        tc.function.name.clone(),
303                        tc.function.name.clone(),
304                        tc.function.arguments.clone(),
305                    ));
306                }
307                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
308                    CompletionError::ResponseError("No content provided".to_owned())
309                })?;
310                let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
311                let completion_tokens = resp.eval_count.unwrap_or(0);
312
313                let raw_response = CompletionResponse {
314                    model: resp.model,
315                    created_at: resp.created_at,
316                    done: resp.done,
317                    done_reason: resp.done_reason,
318                    total_duration: resp.total_duration,
319                    load_duration: resp.load_duration,
320                    prompt_eval_count: resp.prompt_eval_count,
321                    prompt_eval_duration: resp.prompt_eval_duration,
322                    eval_count: resp.eval_count,
323                    eval_duration: resp.eval_duration,
324                    message: Message::Assistant {
325                        content,
326                        images: None,
327                        name: None,
328                        tool_calls,
329                    },
330                };
331
332                Ok(completion::CompletionResponse {
333                    choice,
334                    usage: Usage {
335                        input_tokens: prompt_tokens,
336                        output_tokens: completion_tokens,
337                        total_tokens: prompt_tokens + completion_tokens,
338                    },
339                    raw_response,
340                })
341            }
342            _ => Err(CompletionError::ResponseError(
343                "Chat response does not include an assistant message".into(),
344            )),
345        }
346    }
347}
348
349// ---------- Completion Model ----------
350
351#[derive(Clone)]
352pub struct CompletionModel {
353    client: Client,
354    pub model: String,
355}
356
357impl CompletionModel {
358    pub fn new(client: Client, model: &str) -> Self {
359        Self {
360            client,
361            model: model.to_owned(),
362        }
363    }
364
365    fn create_completion_request(
366        &self,
367        completion_request: CompletionRequest,
368    ) -> Result<Value, CompletionError> {
369        // Build up the order of messages (context, chat_history)
370        let mut partial_history = vec![];
371        if let Some(docs) = completion_request.normalized_documents() {
372            partial_history.push(docs);
373        }
374        partial_history.extend(completion_request.chat_history);
375
376        // Initialize full history with preamble (or empty if non-existent)
377        let mut full_history: Vec<Message> = completion_request
378            .preamble
379            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
380
381        // Convert and extend the rest of the history
382        full_history.extend(
383            partial_history
384                .into_iter()
385                .map(|msg| msg.try_into())
386                .collect::<Result<Vec<Vec<Message>>, _>>()?
387                .into_iter()
388                .flatten()
389                .collect::<Vec<Message>>(),
390        );
391
392        // Convert internal prompt into a provider Message
393        let options = if let Some(extra) = completion_request.additional_params {
394            json_utils::merge(
395                json!({ "temperature": completion_request.temperature }),
396                extra,
397            )
398        } else {
399            json!({ "temperature": completion_request.temperature })
400        };
401
402        let mut request_payload = json!({
403            "model": self.model,
404            "messages": full_history,
405            "options": options,
406            "stream": false,
407        });
408        if !completion_request.tools.is_empty() {
409            request_payload["tools"] = json!(
410                completion_request
411                    .tools
412                    .into_iter()
413                    .map(|tool| tool.into())
414                    .collect::<Vec<ToolDefinition>>()
415            );
416        }
417
418        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
419
420        Ok(request_payload)
421    }
422}
423
424// ---------- CompletionModel Implementation ----------
425
426#[derive(Clone)]
427pub struct StreamingCompletionResponse {
428    pub done_reason: Option<String>,
429    pub total_duration: Option<u64>,
430    pub load_duration: Option<u64>,
431    pub prompt_eval_count: Option<u64>,
432    pub prompt_eval_duration: Option<u64>,
433    pub eval_count: Option<u64>,
434    pub eval_duration: Option<u64>,
435}
436impl completion::CompletionModel for CompletionModel {
437    type Response = CompletionResponse;
438    type StreamingResponse = StreamingCompletionResponse;
439
440    #[cfg_attr(feature = "worker", worker::send)]
441    async fn completion(
442        &self,
443        completion_request: CompletionRequest,
444    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
445        let request_payload = self.create_completion_request(completion_request)?;
446
447        let response = self
448            .client
449            .post("api/chat")
450            .json(&request_payload)
451            .send()
452            .await
453            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
454        if response.status().is_success() {
455            let text = response
456                .text()
457                .await
458                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
459            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
460            let chat_resp: CompletionResponse = serde_json::from_str(&text)
461                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
462            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
463            Ok(conv)
464        } else {
465            let err_text = response
466                .text()
467                .await
468                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
469            Err(CompletionError::ProviderError(err_text))
470        }
471    }
472
473    #[cfg_attr(feature = "worker", worker::send)]
474    async fn stream(
475        &self,
476        request: CompletionRequest,
477    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
478    {
479        let mut request_payload = self.create_completion_request(request)?;
480        merge_inplace(&mut request_payload, json!({"stream": true}));
481
482        let response = self
483            .client
484            .post("api/chat")
485            .json(&request_payload)
486            .send()
487            .await
488            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
489
490        if !response.status().is_success() {
491            let err_text = response
492                .text()
493                .await
494                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
495            return Err(CompletionError::ProviderError(err_text));
496        }
497
498        let stream = Box::pin(stream! {
499            let mut stream = response.bytes_stream();
500            while let Some(chunk_result) = stream.next().await {
501                let chunk = match chunk_result {
502                    Ok(c) => c,
503                    Err(e) => {
504                        yield Err(CompletionError::from(e));
505                        break;
506                    }
507                };
508
509                let text = match String::from_utf8(chunk.to_vec()) {
510                    Ok(t) => t,
511                    Err(e) => {
512                        yield Err(CompletionError::ResponseError(e.to_string()));
513                        break;
514                    }
515                };
516
517
518                for line in text.lines() {
519                    let line = line.to_string();
520
521                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
522                        continue;
523                    };
524
525                    match response.message {
526                        Message::Assistant{ content, tool_calls, .. } => {
527                            if !content.is_empty() {
528                                yield Ok(RawStreamingChoice::Message(content))
529                            }
530
531                            for tool_call in tool_calls.iter() {
532                                let function = tool_call.function.clone();
533
534                                yield Ok(RawStreamingChoice::ToolCall {
535                                    id: "".to_string(),
536                                    name: function.name,
537                                    arguments: function.arguments,
538                                    call_id: None
539                                });
540                            }
541                        }
542                        _ => {
543                            continue;
544                        }
545                    }
546
547                    if response.done {
548                        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
549                            total_duration: response.total_duration,
550                            load_duration: response.load_duration,
551                            prompt_eval_count: response.prompt_eval_count,
552                            prompt_eval_duration: response.prompt_eval_duration,
553                            eval_count: response.eval_count,
554                            eval_duration: response.eval_duration,
555                            done_reason: response.done_reason,
556                        }));
557                    }
558                }
559            }
560        });
561
562        Ok(streaming::StreamingCompletionResponse::stream(stream))
563    }
564}
565
566// ---------- Tool Definition Conversion ----------
567
568/// Ollama-required tool definition format.
569#[derive(Clone, Debug, Deserialize, Serialize)]
570pub struct ToolDefinition {
571    #[serde(rename = "type")]
572    pub type_field: String, // Fixed as "function"
573    pub function: completion::ToolDefinition,
574}
575
576/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
577impl From<crate::completion::ToolDefinition> for ToolDefinition {
578    fn from(tool: crate::completion::ToolDefinition) -> Self {
579        ToolDefinition {
580            type_field: "function".to_owned(),
581            function: completion::ToolDefinition {
582                name: tool.name,
583                description: tool.description,
584                parameters: tool.parameters,
585            },
586        }
587    }
588}
589
590#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
591pub struct ToolCall {
592    // pub id: String,
593    #[serde(default, rename = "type")]
594    pub r#type: ToolType,
595    pub function: Function,
596}
597#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
598#[serde(rename_all = "lowercase")]
599pub enum ToolType {
600    #[default]
601    Function,
602}
603#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
604pub struct Function {
605    pub name: String,
606    pub arguments: Value,
607}
608
609// ---------- Provider Message Definition ----------
610
611#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
612#[serde(tag = "role", rename_all = "lowercase")]
613pub enum Message {
614    User {
615        content: String,
616        #[serde(skip_serializing_if = "Option::is_none")]
617        images: Option<Vec<String>>,
618        #[serde(skip_serializing_if = "Option::is_none")]
619        name: Option<String>,
620    },
621    Assistant {
622        #[serde(default)]
623        content: String,
624        #[serde(skip_serializing_if = "Option::is_none")]
625        images: Option<Vec<String>>,
626        #[serde(skip_serializing_if = "Option::is_none")]
627        name: Option<String>,
628        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
629        tool_calls: Vec<ToolCall>,
630    },
631    System {
632        content: String,
633        #[serde(skip_serializing_if = "Option::is_none")]
634        images: Option<Vec<String>>,
635        #[serde(skip_serializing_if = "Option::is_none")]
636        name: Option<String>,
637    },
638    #[serde(rename = "tool")]
639    ToolResult {
640        #[serde(rename = "tool_name")]
641        name: String,
642        content: String,
643    },
644}
645
646/// -----------------------------
647/// Provider Message Conversions
648/// -----------------------------
649/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
650/// (Only User and Assistant variants are supported.)
651impl TryFrom<crate::message::Message> for Vec<Message> {
652    type Error = crate::message::MessageError;
653    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
654        use crate::message::Message as InternalMessage;
655        match internal_msg {
656            InternalMessage::User { content, .. } => {
657                let (tool_results, other_content): (Vec<_>, Vec<_>) =
658                    content.into_iter().partition(|content| {
659                        matches!(content, crate::message::UserContent::ToolResult(_))
660                    });
661
662                if !tool_results.is_empty() {
663                    tool_results
664                        .into_iter()
665                        .map(|content| match content {
666                            crate::message::UserContent::ToolResult(
667                                crate::message::ToolResult { id, content, .. },
668                            ) => {
669                                // Ollama expects a single string for tool results, so we concatenate
670                                let content_string = content
671                                    .into_iter()
672                                    .map(|content| match content {
673                                        crate::message::ToolResultContent::Text(text) => text.text,
674                                        _ => "[Non-text content]".to_string(),
675                                    })
676                                    .collect::<Vec<_>>()
677                                    .join("\n");
678
679                                Ok::<_, crate::message::MessageError>(Message::ToolResult {
680                                    name: id,
681                                    content: content_string,
682                                })
683                            }
684                            _ => unreachable!(),
685                        })
686                        .collect::<Result<Vec<_>, _>>()
687                } else {
688                    // Ollama requires separate text content and images array
689                    let (texts, images) = other_content.into_iter().fold(
690                        (Vec::new(), Vec::new()),
691                        |(mut texts, mut images), content| {
692                            match content {
693                                crate::message::UserContent::Text(crate::message::Text {
694                                    text,
695                                }) => texts.push(text),
696                                crate::message::UserContent::Image(crate::message::Image {
697                                    data,
698                                    ..
699                                }) => images.push(data),
700                                _ => {} // Audio/Document not supported by Ollama
701                            }
702                            (texts, images)
703                        },
704                    );
705
706                    Ok(vec![Message::User {
707                        content: texts.join(" "),
708                        images: if images.is_empty() {
709                            None
710                        } else {
711                            Some(images)
712                        },
713                        name: None,
714                    }])
715                }
716            }
717            InternalMessage::Assistant { content, .. } => {
718                let (text_content, tool_calls) = content.into_iter().fold(
719                    (Vec::new(), Vec::new()),
720                    |(mut texts, mut tools), content| {
721                        match content {
722                            crate::message::AssistantContent::Text(text) => texts.push(text.text),
723                            crate::message::AssistantContent::ToolCall(tool_call) => {
724                                tools.push(tool_call)
725                            }
726                        }
727                        (texts, tools)
728                    },
729                );
730
731                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
732                //  so either `content` or `tool_calls` will have some content.
733                Ok(vec![Message::Assistant {
734                    content: text_content.join(" "),
735                    images: None,
736                    name: None,
737                    tool_calls: tool_calls
738                        .into_iter()
739                        .map(|tool_call| tool_call.into())
740                        .collect::<Vec<_>>(),
741                }])
742            }
743        }
744    }
745}
746
747/// Conversion from provider Message to a completion message.
748/// This is needed so that responses can be converted back into chat history.
749impl From<Message> for crate::completion::Message {
750    fn from(msg: Message) -> Self {
751        match msg {
752            Message::User { content, .. } => crate::completion::Message::User {
753                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
754                    text: content,
755                })),
756            },
757            Message::Assistant {
758                content,
759                tool_calls,
760                ..
761            } => {
762                let mut assistant_contents =
763                    vec![crate::completion::message::AssistantContent::Text(Text {
764                        text: content,
765                    })];
766                for tc in tool_calls {
767                    assistant_contents.push(
768                        crate::completion::message::AssistantContent::tool_call(
769                            tc.function.name.clone(),
770                            tc.function.name,
771                            tc.function.arguments,
772                        ),
773                    );
774                }
775                crate::completion::Message::Assistant {
776                    id: None,
777                    content: OneOrMany::many(assistant_contents).unwrap(),
778                }
779            }
780            // System and ToolResult are converted to User message as needed.
781            Message::System { content, .. } => crate::completion::Message::User {
782                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
783                    text: content,
784                })),
785            },
786            Message::ToolResult { name, content } => crate::completion::Message::User {
787                content: OneOrMany::one(message::UserContent::tool_result(
788                    name,
789                    OneOrMany::one(message::ToolResultContent::text(content)),
790                )),
791            },
792        }
793    }
794}
795
796impl Message {
797    /// Constructs a system message.
798    pub fn system(content: &str) -> Self {
799        Message::System {
800            content: content.to_owned(),
801            images: None,
802            name: None,
803        }
804    }
805}
806
807// ---------- Additional Message Types ----------
808
809impl From<crate::message::ToolCall> for ToolCall {
810    fn from(tool_call: crate::message::ToolCall) -> Self {
811        Self {
812            r#type: ToolType::Function,
813            function: Function {
814                name: tool_call.function.name,
815                arguments: tool_call.function.arguments,
816            },
817        }
818    }
819}
820
821#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
822pub struct SystemContent {
823    #[serde(default)]
824    r#type: SystemContentType,
825    text: String,
826}
827
828#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
829#[serde(rename_all = "lowercase")]
830pub enum SystemContentType {
831    #[default]
832    Text,
833}
834
835impl From<String> for SystemContent {
836    fn from(s: String) -> Self {
837        SystemContent {
838            r#type: SystemContentType::default(),
839            text: s,
840        }
841    }
842}
843
844impl FromStr for SystemContent {
845    type Err = std::convert::Infallible;
846    fn from_str(s: &str) -> Result<Self, Self::Err> {
847        Ok(SystemContent {
848            r#type: SystemContentType::default(),
849            text: s.to_string(),
850        })
851    }
852}
853
854#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
855pub struct AssistantContent {
856    pub text: String,
857}
858
859impl FromStr for AssistantContent {
860    type Err = std::convert::Infallible;
861    fn from_str(s: &str) -> Result<Self, Self::Err> {
862        Ok(AssistantContent { text: s.to_owned() })
863    }
864}
865
866#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
867#[serde(tag = "type", rename_all = "lowercase")]
868pub enum UserContent {
869    Text { text: String },
870    Image { image_url: ImageUrl },
871    // Audio variant removed as Ollama API does not support audio input.
872}
873
874impl FromStr for UserContent {
875    type Err = std::convert::Infallible;
876    fn from_str(s: &str) -> Result<Self, Self::Err> {
877        Ok(UserContent::Text { text: s.to_owned() })
878    }
879}
880
881#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
882pub struct ImageUrl {
883    pub url: String,
884    #[serde(default)]
885    pub detail: ImageDetail,
886}
887
888// =================================================================
889// Tests
890// =================================================================
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895    use serde_json::json;
896
897    // Test deserialization and conversion for the /api/chat endpoint.
898    #[tokio::test]
899    async fn test_chat_completion() {
900        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
901        let sample_chat_response = json!({
902            "model": "llama3.2",
903            "created_at": "2023-08-04T19:22:45.499127Z",
904            "message": {
905                "role": "assistant",
906                "content": "The sky is blue because of Rayleigh scattering.",
907                "images": null,
908                "tool_calls": [
909                    {
910                        "type": "function",
911                        "function": {
912                            "name": "get_current_weather",
913                            "arguments": {
914                                "location": "San Francisco, CA",
915                                "format": "celsius"
916                            }
917                        }
918                    }
919                ]
920            },
921            "done": true,
922            "total_duration": 8000000000u64,
923            "load_duration": 6000000u64,
924            "prompt_eval_count": 61u64,
925            "prompt_eval_duration": 400000000u64,
926            "eval_count": 468u64,
927            "eval_duration": 7700000000u64
928        });
929        let sample_text = sample_chat_response.to_string();
930
931        let chat_resp: CompletionResponse =
932            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
933        let conv: completion::CompletionResponse<CompletionResponse> =
934            chat_resp.try_into().unwrap();
935        assert!(
936            !conv.choice.is_empty(),
937            "Expected non-empty choice in chat response"
938        );
939    }
940
941    // Test conversion from provider Message to completion Message.
942    #[test]
943    fn test_message_conversion() {
944        // Construct a provider Message (User variant with String content).
945        let provider_msg = Message::User {
946            content: "Test message".to_owned(),
947            images: None,
948            name: None,
949        };
950        // Convert it into a completion::Message.
951        let comp_msg: crate::completion::Message = provider_msg.into();
952        match comp_msg {
953            crate::completion::Message::User { content } => {
954                // Assume OneOrMany<T> has a method first() to access the first element.
955                let first_content = content.first();
956                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
957                match first_content {
958                    crate::completion::message::UserContent::Text(text_struct) => {
959                        assert_eq!(text_struct.text, "Test message");
960                    }
961                    _ => panic!("Expected text content in conversion"),
962                }
963            }
964            _ => panic!("Conversion from provider Message to completion Message failed"),
965        }
966    }
967
968    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
969    #[test]
970    fn test_tool_definition_conversion() {
971        // Internal tool definition from the completion module.
972        let internal_tool = crate::completion::ToolDefinition {
973            name: "get_current_weather".to_owned(),
974            description: "Get the current weather for a location".to_owned(),
975            parameters: json!({
976                "type": "object",
977                "properties": {
978                    "location": {
979                        "type": "string",
980                        "description": "The location to get the weather for, e.g. San Francisco, CA"
981                    },
982                    "format": {
983                        "type": "string",
984                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
985                        "enum": ["celsius", "fahrenheit"]
986                    }
987                },
988                "required": ["location", "format"]
989            }),
990        };
991        // Convert internal tool to Ollama's tool definition.
992        let ollama_tool: ToolDefinition = internal_tool.into();
993        assert_eq!(ollama_tool.type_field, "function");
994        assert_eq!(ollama_tool.function.name, "get_current_weather");
995        assert_eq!(
996            ollama_tool.function.description,
997            "Get the current weather for a location"
998        );
999        // Check JSON fields in parameters.
1000        let params = &ollama_tool.function.parameters;
1001        assert_eq!(params["properties"]["location"]["type"], "string");
1002    }
1003}