rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::json_utils::merge_inplace;
43use crate::message::MessageError;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46    Embed, OneOrMany,
47    completion::{self, CompletionError, CompletionRequest},
48    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49    impl_conversion_traits, json_utils, message,
50    message::{ImageDetail, Text},
51    streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::convert::Infallible;
59use std::{convert::TryFrom, str::FromStr};
60// ---------- Main Client ----------
61
62const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64#[derive(Clone, Debug)]
65pub struct Client {
66    base_url: String,
67    http_client: reqwest::Client,
68}
69
70impl Default for Client {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Client {
77    pub fn new() -> Self {
78        Self::from_url(OLLAMA_API_BASE_URL)
79    }
80    pub fn from_url(base_url: &str) -> Self {
81        Self {
82            base_url: base_url.to_owned(),
83            http_client: reqwest::Client::builder()
84                .build()
85                .expect("Ollama reqwest client should build"),
86        }
87    }
88
89    /// Use your own `reqwest::Client`.
90    /// The required headers will be automatically attached upon trying to make a request.
91    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
92        self.http_client = client;
93
94        self
95    }
96
97    fn post(&self, path: &str) -> reqwest::RequestBuilder {
98        let url = format!("{}/{}", self.base_url, path);
99        self.http_client.post(url)
100    }
101}
102
103impl ProviderClient for Client {
104    fn from_env() -> Self
105    where
106        Self: Sized,
107    {
108        let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
109        Self::from_url(&api_base)
110    }
111}
112
113impl CompletionClient for Client {
114    type CompletionModel = CompletionModel;
115
116    fn completion_model(&self, model: &str) -> CompletionModel {
117        CompletionModel::new(self.clone(), model)
118    }
119}
120
121impl EmbeddingsClient for Client {
122    type EmbeddingModel = EmbeddingModel;
123    fn embedding_model(&self, model: &str) -> EmbeddingModel {
124        EmbeddingModel::new(self.clone(), model, 0)
125    }
126    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
127        EmbeddingModel::new(self.clone(), model, ndims)
128    }
129    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
130        EmbeddingsBuilder::new(self.embedding_model(model))
131    }
132}
133
134impl_conversion_traits!(
135    AsTranscription,
136    AsImageGeneration,
137    AsAudioGeneration for Client
138);
139
140// ---------- API Error and Response Structures ----------
141
142#[derive(Debug, Deserialize)]
143struct ApiErrorResponse {
144    message: String,
145}
146
147#[derive(Debug, Deserialize)]
148#[serde(untagged)]
149enum ApiResponse<T> {
150    Ok(T),
151    Err(ApiErrorResponse),
152}
153
154// ---------- Embedding API ----------
155
156pub const ALL_MINILM: &str = "all-minilm";
157pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
158
159#[derive(Debug, Serialize, Deserialize)]
160pub struct EmbeddingResponse {
161    pub model: String,
162    pub embeddings: Vec<Vec<f64>>,
163    #[serde(default)]
164    pub total_duration: Option<u64>,
165    #[serde(default)]
166    pub load_duration: Option<u64>,
167    #[serde(default)]
168    pub prompt_eval_count: Option<u64>,
169}
170
171impl From<ApiErrorResponse> for EmbeddingError {
172    fn from(err: ApiErrorResponse) -> Self {
173        EmbeddingError::ProviderError(err.message)
174    }
175}
176
177impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
178    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
179        match value {
180            ApiResponse::Ok(response) => Ok(response),
181            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
182        }
183    }
184}
185
186// ---------- Embedding Model ----------
187
188#[derive(Clone)]
189pub struct EmbeddingModel {
190    client: Client,
191    pub model: String,
192    ndims: usize,
193}
194
195impl EmbeddingModel {
196    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
197        Self {
198            client,
199            model: model.to_owned(),
200            ndims,
201        }
202    }
203}
204
205impl embeddings::EmbeddingModel for EmbeddingModel {
206    const MAX_DOCUMENTS: usize = 1024;
207    fn ndims(&self) -> usize {
208        self.ndims
209    }
210    #[cfg_attr(feature = "worker", worker::send)]
211    async fn embed_texts(
212        &self,
213        documents: impl IntoIterator<Item = String>,
214    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
215        let docs: Vec<String> = documents.into_iter().collect();
216        let payload = json!({
217            "model": self.model,
218            "input": docs,
219        });
220        let response = self
221            .client
222            .post("api/embed")
223            .json(&payload)
224            .send()
225            .await
226            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
227        if response.status().is_success() {
228            let api_resp: EmbeddingResponse = response
229                .json()
230                .await
231                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
232            if api_resp.embeddings.len() != docs.len() {
233                return Err(EmbeddingError::ResponseError(
234                    "Number of returned embeddings does not match input".into(),
235                ));
236            }
237            Ok(api_resp
238                .embeddings
239                .into_iter()
240                .zip(docs.into_iter())
241                .map(|(vec, document)| embeddings::Embedding { document, vec })
242                .collect())
243        } else {
244            Err(EmbeddingError::ProviderError(response.text().await?))
245        }
246    }
247}
248
249// ---------- Completion API ----------
250
251pub const LLAMA3_2: &str = "llama3.2";
252pub const LLAVA: &str = "llava";
253pub const MISTRAL: &str = "mistral";
254
255#[derive(Debug, Serialize, Deserialize)]
256pub struct CompletionResponse {
257    pub model: String,
258    pub created_at: String,
259    pub message: Message,
260    pub done: bool,
261    #[serde(default)]
262    pub done_reason: Option<String>,
263    #[serde(default)]
264    pub total_duration: Option<u64>,
265    #[serde(default)]
266    pub load_duration: Option<u64>,
267    #[serde(default)]
268    pub prompt_eval_count: Option<u64>,
269    #[serde(default)]
270    pub prompt_eval_duration: Option<u64>,
271    #[serde(default)]
272    pub eval_count: Option<u64>,
273    #[serde(default)]
274    pub eval_duration: Option<u64>,
275}
276impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
277    type Error = CompletionError;
278    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
279        match resp.message {
280            // Process only if an assistant message is present.
281            Message::Assistant {
282                content,
283                tool_calls,
284                ..
285            } => {
286                let mut assistant_contents = Vec::new();
287                // Add the assistant's text content if any.
288                if !content.is_empty() {
289                    assistant_contents.push(completion::AssistantContent::text(&content));
290                }
291                // Process tool_calls following Ollama's chat response definition.
292                // Each ToolCall has an id, a type, and a function field.
293                for tc in tool_calls.iter() {
294                    assistant_contents.push(completion::AssistantContent::tool_call(
295                        tc.function.name.clone(),
296                        tc.function.name.clone(),
297                        tc.function.arguments.clone(),
298                    ));
299                }
300                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
301                    CompletionError::ResponseError("No content provided".to_owned())
302                })?;
303                let raw_response = CompletionResponse {
304                    model: resp.model,
305                    created_at: resp.created_at,
306                    done: resp.done,
307                    done_reason: resp.done_reason,
308                    total_duration: resp.total_duration,
309                    load_duration: resp.load_duration,
310                    prompt_eval_count: resp.prompt_eval_count,
311                    prompt_eval_duration: resp.prompt_eval_duration,
312                    eval_count: resp.eval_count,
313                    eval_duration: resp.eval_duration,
314                    message: Message::Assistant {
315                        content,
316                        images: None,
317                        name: None,
318                        tool_calls,
319                    },
320                };
321                Ok(completion::CompletionResponse {
322                    choice,
323                    raw_response,
324                })
325            }
326            _ => Err(CompletionError::ResponseError(
327                "Chat response does not include an assistant message".into(),
328            )),
329        }
330    }
331}
332
333// ---------- Completion Model ----------
334
335#[derive(Clone)]
336pub struct CompletionModel {
337    client: Client,
338    pub model: String,
339}
340
341impl CompletionModel {
342    pub fn new(client: Client, model: &str) -> Self {
343        Self {
344            client,
345            model: model.to_owned(),
346        }
347    }
348
349    fn create_completion_request(
350        &self,
351        completion_request: CompletionRequest,
352    ) -> Result<Value, CompletionError> {
353        // Build up the order of messages (context, chat_history)
354        let mut partial_history = vec![];
355        if let Some(docs) = completion_request.normalized_documents() {
356            partial_history.push(docs);
357        }
358        partial_history.extend(completion_request.chat_history);
359
360        // Initialize full history with preamble (or empty if non-existent)
361        let mut full_history: Vec<Message> = completion_request
362            .preamble
363            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
364
365        // Convert and extend the rest of the history
366        full_history.extend(
367            partial_history
368                .into_iter()
369                .map(|msg| msg.try_into())
370                .collect::<Result<Vec<Message>, _>>()?,
371        );
372
373        // Convert internal prompt into a provider Message
374        let options = if let Some(extra) = completion_request.additional_params {
375            json_utils::merge(
376                json!({ "temperature": completion_request.temperature }),
377                extra,
378            )
379        } else {
380            json!({ "temperature": completion_request.temperature })
381        };
382
383        let mut request_payload = json!({
384            "model": self.model,
385            "messages": full_history,
386            "options": options,
387            "stream": false,
388        });
389        if !completion_request.tools.is_empty() {
390            request_payload["tools"] = json!(
391                completion_request
392                    .tools
393                    .into_iter()
394                    .map(|tool| tool.into())
395                    .collect::<Vec<ToolDefinition>>()
396            );
397        }
398
399        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
400
401        Ok(request_payload)
402    }
403}
404
405// ---------- CompletionModel Implementation ----------
406
407#[derive(Clone)]
408pub struct StreamingCompletionResponse {
409    pub done_reason: Option<String>,
410    pub total_duration: Option<u64>,
411    pub load_duration: Option<u64>,
412    pub prompt_eval_count: Option<u64>,
413    pub prompt_eval_duration: Option<u64>,
414    pub eval_count: Option<u64>,
415    pub eval_duration: Option<u64>,
416}
417impl completion::CompletionModel for CompletionModel {
418    type Response = CompletionResponse;
419    type StreamingResponse = StreamingCompletionResponse;
420
421    #[cfg_attr(feature = "worker", worker::send)]
422    async fn completion(
423        &self,
424        completion_request: CompletionRequest,
425    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
426        let request_payload = self.create_completion_request(completion_request)?;
427
428        let response = self
429            .client
430            .post("api/chat")
431            .json(&request_payload)
432            .send()
433            .await
434            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
435        if response.status().is_success() {
436            let text = response
437                .text()
438                .await
439                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
440            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
441            let chat_resp: CompletionResponse = serde_json::from_str(&text)
442                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
443            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
444            Ok(conv)
445        } else {
446            let err_text = response
447                .text()
448                .await
449                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
450            Err(CompletionError::ProviderError(err_text))
451        }
452    }
453
454    #[cfg_attr(feature = "worker", worker::send)]
455    async fn stream(
456        &self,
457        request: CompletionRequest,
458    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
459    {
460        let mut request_payload = self.create_completion_request(request)?;
461        merge_inplace(&mut request_payload, json!({"stream": true}));
462
463        let response = self
464            .client
465            .post("api/chat")
466            .json(&request_payload)
467            .send()
468            .await
469            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
470
471        if !response.status().is_success() {
472            let err_text = response
473                .text()
474                .await
475                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
476            return Err(CompletionError::ProviderError(err_text));
477        }
478
479        let stream = Box::pin(stream! {
480            let mut stream = response.bytes_stream();
481            while let Some(chunk_result) = stream.next().await {
482                let chunk = match chunk_result {
483                    Ok(c) => c,
484                    Err(e) => {
485                        yield Err(CompletionError::from(e));
486                        break;
487                    }
488                };
489
490                let text = match String::from_utf8(chunk.to_vec()) {
491                    Ok(t) => t,
492                    Err(e) => {
493                        yield Err(CompletionError::ResponseError(e.to_string()));
494                        break;
495                    }
496                };
497
498
499                for line in text.lines() {
500                    let line = line.to_string();
501
502                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
503                        continue;
504                    };
505
506                    match response.message {
507                        Message::Assistant{ content, tool_calls, .. } => {
508                            if !content.is_empty() {
509                                yield Ok(RawStreamingChoice::Message(content))
510                            }
511
512                            for tool_call in tool_calls.iter() {
513                                let function = tool_call.function.clone();
514
515                                yield Ok(RawStreamingChoice::ToolCall {
516                                    id: "".to_string(),
517                                    name: function.name,
518                                    arguments: function.arguments,
519                                    call_id: None
520                                });
521                            }
522                        }
523                        _ => {
524                            continue;
525                        }
526                    }
527
528                    if response.done {
529                        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
530                            total_duration: response.total_duration,
531                            load_duration: response.load_duration,
532                            prompt_eval_count: response.prompt_eval_count,
533                            prompt_eval_duration: response.prompt_eval_duration,
534                            eval_count: response.eval_count,
535                            eval_duration: response.eval_duration,
536                            done_reason: response.done_reason,
537                        }));
538                    }
539                }
540            }
541        });
542
543        Ok(streaming::StreamingCompletionResponse::stream(stream))
544    }
545}
546
547// ---------- Tool Definition Conversion ----------
548
549/// Ollama-required tool definition format.
550#[derive(Clone, Debug, Deserialize, Serialize)]
551pub struct ToolDefinition {
552    #[serde(rename = "type")]
553    pub type_field: String, // Fixed as "function"
554    pub function: completion::ToolDefinition,
555}
556
557/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
558impl From<crate::completion::ToolDefinition> for ToolDefinition {
559    fn from(tool: crate::completion::ToolDefinition) -> Self {
560        ToolDefinition {
561            type_field: "function".to_owned(),
562            function: completion::ToolDefinition {
563                name: tool.name,
564                description: tool.description,
565                parameters: tool.parameters,
566            },
567        }
568    }
569}
570
571#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
572pub struct ToolCall {
573    // pub id: String,
574    #[serde(default, rename = "type")]
575    pub r#type: ToolType,
576    pub function: Function,
577}
578#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
579#[serde(rename_all = "lowercase")]
580pub enum ToolType {
581    #[default]
582    Function,
583}
584#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
585pub struct Function {
586    pub name: String,
587    pub arguments: Value,
588}
589
590// ---------- Provider Message Definition ----------
591
592#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
593#[serde(tag = "role", rename_all = "lowercase")]
594pub enum Message {
595    User {
596        content: String,
597        #[serde(skip_serializing_if = "Option::is_none")]
598        images: Option<Vec<String>>,
599        #[serde(skip_serializing_if = "Option::is_none")]
600        name: Option<String>,
601    },
602    Assistant {
603        #[serde(default)]
604        content: String,
605        #[serde(skip_serializing_if = "Option::is_none")]
606        images: Option<Vec<String>>,
607        #[serde(skip_serializing_if = "Option::is_none")]
608        name: Option<String>,
609        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
610        tool_calls: Vec<ToolCall>,
611    },
612    System {
613        content: String,
614        #[serde(skip_serializing_if = "Option::is_none")]
615        images: Option<Vec<String>>,
616        #[serde(skip_serializing_if = "Option::is_none")]
617        name: Option<String>,
618    },
619    #[serde(rename = "tool")]
620    ToolResult { name: String, content: String },
621}
622
623/// -----------------------------
624/// Provider Message Conversions
625/// -----------------------------
626/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
627/// (Only User and Assistant variants are supported.)
628impl TryFrom<crate::message::Message> for Message {
629    type Error = crate::message::MessageError;
630    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
631        use crate::message::Message as InternalMessage;
632        match internal_msg {
633            InternalMessage::User { content, .. } => {
634                let mut texts = Vec::new();
635                let mut images = Vec::new();
636                for uc in content.into_iter() {
637                    match uc {
638                        crate::message::UserContent::Text(t) => texts.push(t.text),
639                        crate::message::UserContent::Image(img) => images.push(img.data),
640                        crate::message::UserContent::ToolResult(result) => {
641                            let content = result
642                                .content
643                                .into_iter()
644                                .map(ToolResultContent::try_from)
645                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
646
647                            let content = OneOrMany::many(content).map_err(|x| {
648                                MessageError::ConversionError(format!(
649                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
650                                ))
651                            })?;
652
653                            return Ok(Message::ToolResult {
654                                name: result.id,
655                                content: content.first().text,
656                            });
657                        }
658                        _ => {} // Audio variant removed since Ollama API does not support it.
659                    }
660                }
661                let content_str = texts.join(" ");
662                let images_opt = if images.is_empty() {
663                    None
664                } else {
665                    Some(images)
666                };
667                Ok(Message::User {
668                    content: content_str,
669                    images: images_opt,
670                    name: None,
671                })
672            }
673            InternalMessage::Assistant { content, .. } => {
674                let mut texts = Vec::new();
675                let mut tool_calls = Vec::new();
676                for ac in content.into_iter() {
677                    match ac {
678                        crate::message::AssistantContent::Text(t) => texts.push(t.text),
679                        crate::message::AssistantContent::ToolCall(tc) => {
680                            tool_calls.push(ToolCall {
681                                r#type: ToolType::Function, // Assuming internal tool call provides these fields
682                                function: Function {
683                                    name: tc.function.name,
684                                    arguments: tc.function.arguments,
685                                },
686                            });
687                        }
688                    }
689                }
690                let content_str = texts.join(" ");
691                Ok(Message::Assistant {
692                    content: content_str,
693                    images: None,
694                    name: None,
695                    tool_calls,
696                })
697            }
698        }
699    }
700}
701
702/// Conversion from provider Message to a completion message.
703/// This is needed so that responses can be converted back into chat history.
704impl From<Message> for crate::completion::Message {
705    fn from(msg: Message) -> Self {
706        match msg {
707            Message::User { content, .. } => crate::completion::Message::User {
708                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
709                    text: content,
710                })),
711            },
712            Message::Assistant {
713                content,
714                tool_calls,
715                ..
716            } => {
717                let mut assistant_contents =
718                    vec![crate::completion::message::AssistantContent::Text(Text {
719                        text: content,
720                    })];
721                for tc in tool_calls {
722                    assistant_contents.push(
723                        crate::completion::message::AssistantContent::tool_call(
724                            tc.function.name.clone(),
725                            tc.function.name,
726                            tc.function.arguments,
727                        ),
728                    );
729                }
730                crate::completion::Message::Assistant {
731                    id: None,
732                    content: OneOrMany::many(assistant_contents).unwrap(),
733                }
734            }
735            // System and ToolResult are converted to User message as needed.
736            Message::System { content, .. } => crate::completion::Message::User {
737                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
738                    text: content,
739                })),
740            },
741            Message::ToolResult { name, content } => crate::completion::Message::User {
742                content: OneOrMany::one(message::UserContent::tool_result(
743                    name,
744                    OneOrMany::one(message::ToolResultContent::Text(Text { text: content })),
745                )),
746            },
747        }
748    }
749}
750
751impl Message {
752    /// Constructs a system message.
753    pub fn system(content: &str) -> Self {
754        Message::System {
755            content: content.to_owned(),
756            images: None,
757            name: None,
758        }
759    }
760}
761
762// ---------- Additional Message Types ----------
763
764#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
765pub struct ToolResultContent {
766    text: String,
767}
768
769impl TryFrom<crate::message::ToolResultContent> for ToolResultContent {
770    type Error = MessageError;
771    fn try_from(value: crate::message::ToolResultContent) -> Result<Self, Self::Error> {
772        let crate::message::ToolResultContent::Text(Text { text }) = value else {
773            return Err(MessageError::ConversionError(
774                "Non-text tool results not supported".into(),
775            ));
776        };
777
778        Ok(Self { text })
779    }
780}
781
782impl FromStr for ToolResultContent {
783    type Err = Infallible;
784
785    fn from_str(s: &str) -> Result<Self, Self::Err> {
786        Ok(s.to_owned().into())
787    }
788}
789
790impl From<String> for ToolResultContent {
791    fn from(s: String) -> Self {
792        ToolResultContent { text: s }
793    }
794}
795
796#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
797pub struct SystemContent {
798    #[serde(default)]
799    r#type: SystemContentType,
800    text: String,
801}
802
803#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
804#[serde(rename_all = "lowercase")]
805pub enum SystemContentType {
806    #[default]
807    Text,
808}
809
810impl From<String> for SystemContent {
811    fn from(s: String) -> Self {
812        SystemContent {
813            r#type: SystemContentType::default(),
814            text: s,
815        }
816    }
817}
818
819impl FromStr for SystemContent {
820    type Err = std::convert::Infallible;
821    fn from_str(s: &str) -> Result<Self, Self::Err> {
822        Ok(SystemContent {
823            r#type: SystemContentType::default(),
824            text: s.to_string(),
825        })
826    }
827}
828
829#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
830pub struct AssistantContent {
831    pub text: String,
832}
833
834impl FromStr for AssistantContent {
835    type Err = std::convert::Infallible;
836    fn from_str(s: &str) -> Result<Self, Self::Err> {
837        Ok(AssistantContent { text: s.to_owned() })
838    }
839}
840
841#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
842#[serde(tag = "type", rename_all = "lowercase")]
843pub enum UserContent {
844    Text { text: String },
845    Image { image_url: ImageUrl },
846    // Audio variant removed as Ollama API does not support audio input.
847}
848
849impl FromStr for UserContent {
850    type Err = std::convert::Infallible;
851    fn from_str(s: &str) -> Result<Self, Self::Err> {
852        Ok(UserContent::Text { text: s.to_owned() })
853    }
854}
855
856#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
857pub struct ImageUrl {
858    pub url: String,
859    #[serde(default)]
860    pub detail: ImageDetail,
861}
862
863// =================================================================
864// Tests
865// =================================================================
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870    use serde_json::json;
871
872    // Test deserialization and conversion for the /api/chat endpoint.
873    #[tokio::test]
874    async fn test_chat_completion() {
875        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
876        let sample_chat_response = json!({
877            "model": "llama3.2",
878            "created_at": "2023-08-04T19:22:45.499127Z",
879            "message": {
880                "role": "assistant",
881                "content": "The sky is blue because of Rayleigh scattering.",
882                "images": null,
883                "tool_calls": [
884                    {
885                        "type": "function",
886                        "function": {
887                            "name": "get_current_weather",
888                            "arguments": {
889                                "location": "San Francisco, CA",
890                                "format": "celsius"
891                            }
892                        }
893                    }
894                ]
895            },
896            "done": true,
897            "total_duration": 8000000000u64,
898            "load_duration": 6000000u64,
899            "prompt_eval_count": 61u64,
900            "prompt_eval_duration": 400000000u64,
901            "eval_count": 468u64,
902            "eval_duration": 7700000000u64
903        });
904        let sample_text = sample_chat_response.to_string();
905
906        let chat_resp: CompletionResponse =
907            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
908        let conv: completion::CompletionResponse<CompletionResponse> =
909            chat_resp.try_into().unwrap();
910        assert!(
911            !conv.choice.is_empty(),
912            "Expected non-empty choice in chat response"
913        );
914    }
915
916    // Test conversion from provider Message to completion Message.
917    #[test]
918    fn test_message_conversion() {
919        // Construct a provider Message (User variant with String content).
920        let provider_msg = Message::User {
921            content: "Test message".to_owned(),
922            images: None,
923            name: None,
924        };
925        // Convert it into a completion::Message.
926        let comp_msg: crate::completion::Message = provider_msg.into();
927        match comp_msg {
928            crate::completion::Message::User { content } => {
929                // Assume OneOrMany<T> has a method first() to access the first element.
930                let first_content = content.first();
931                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
932                match first_content {
933                    crate::completion::message::UserContent::Text(text_struct) => {
934                        assert_eq!(text_struct.text, "Test message");
935                    }
936                    _ => panic!("Expected text content in conversion"),
937                }
938            }
939            _ => panic!("Conversion from provider Message to completion Message failed"),
940        }
941    }
942
943    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
944    #[test]
945    fn test_tool_definition_conversion() {
946        // Internal tool definition from the completion module.
947        let internal_tool = crate::completion::ToolDefinition {
948            name: "get_current_weather".to_owned(),
949            description: "Get the current weather for a location".to_owned(),
950            parameters: json!({
951                "type": "object",
952                "properties": {
953                    "location": {
954                        "type": "string",
955                        "description": "The location to get the weather for, e.g. San Francisco, CA"
956                    },
957                    "format": {
958                        "type": "string",
959                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
960                        "enum": ["celsius", "fahrenheit"]
961                    }
962                },
963                "required": ["location", "format"]
964            }),
965        };
966        // Convert internal tool to Ollama's tool definition.
967        let ollama_tool: ToolDefinition = internal_tool.into();
968        assert_eq!(ollama_tool.type_field, "function");
969        assert_eq!(ollama_tool.function.name, "get_current_weather");
970        assert_eq!(
971            ollama_tool.function.description,
972            "Get the current weather for a location"
973        );
974        // Check JSON fields in parameters.
975        let params = &ollama_tool.function.parameters;
976        assert_eq!(params["properties"]["location"]["type"], "string");
977    }
978}