rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::json_utils::merge_inplace;
43use crate::message::MessageError;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46    completion::{self, CompletionError, CompletionRequest},
47    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
48    impl_conversion_traits, json_utils, message,
49    message::{ImageDetail, Text},
50    streaming, Embed, OneOrMany,
51};
52use async_stream::stream;
53use futures::StreamExt;
54use reqwest;
55use serde::{Deserialize, Serialize};
56use serde_json::{json, Value};
57use std::convert::Infallible;
58use std::{convert::TryFrom, str::FromStr};
59// ---------- Main Client ----------
60
61const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
62
63#[derive(Clone, Debug)]
64pub struct Client {
65    base_url: String,
66    http_client: reqwest::Client,
67}
68
69impl Default for Client {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl Client {
76    pub fn new() -> Self {
77        Self::from_url(OLLAMA_API_BASE_URL)
78    }
79    pub fn from_url(base_url: &str) -> Self {
80        Self {
81            base_url: base_url.to_owned(),
82            http_client: reqwest::Client::builder()
83                .build()
84                .expect("Ollama reqwest client should build"),
85        }
86    }
87    fn post(&self, path: &str) -> reqwest::RequestBuilder {
88        let url = format!("{}/{}", self.base_url, path);
89        self.http_client.post(url)
90    }
91}
92
93impl ProviderClient for Client {
94    fn from_env() -> Self
95    where
96        Self: Sized,
97    {
98        Client::default()
99    }
100}
101
102impl CompletionClient for Client {
103    type CompletionModel = CompletionModel;
104
105    fn completion_model(&self, model: &str) -> CompletionModel {
106        CompletionModel::new(self.clone(), model)
107    }
108}
109
110impl EmbeddingsClient for Client {
111    type EmbeddingModel = EmbeddingModel;
112    fn embedding_model(&self, model: &str) -> EmbeddingModel {
113        EmbeddingModel::new(self.clone(), model, 0)
114    }
115    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
116        EmbeddingModel::new(self.clone(), model, ndims)
117    }
118    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
119        EmbeddingsBuilder::new(self.embedding_model(model))
120    }
121}
122
123impl_conversion_traits!(
124    AsTranscription,
125    AsImageGeneration,
126    AsAudioGeneration for Client
127);
128
129// ---------- API Error and Response Structures ----------
130
131#[derive(Debug, Deserialize)]
132struct ApiErrorResponse {
133    message: String,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde(untagged)]
138enum ApiResponse<T> {
139    Ok(T),
140    Err(ApiErrorResponse),
141}
142
143// ---------- Embedding API ----------
144
145pub const ALL_MINILM: &str = "all-minilm";
146pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct EmbeddingResponse {
150    pub model: String,
151    pub embeddings: Vec<Vec<f64>>,
152    #[serde(default)]
153    pub total_duration: Option<u64>,
154    #[serde(default)]
155    pub load_duration: Option<u64>,
156    #[serde(default)]
157    pub prompt_eval_count: Option<u64>,
158}
159
160impl From<ApiErrorResponse> for EmbeddingError {
161    fn from(err: ApiErrorResponse) -> Self {
162        EmbeddingError::ProviderError(err.message)
163    }
164}
165
166impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
167    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
168        match value {
169            ApiResponse::Ok(response) => Ok(response),
170            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
171        }
172    }
173}
174
175// ---------- Embedding Model ----------
176
177#[derive(Clone)]
178pub struct EmbeddingModel {
179    client: Client,
180    pub model: String,
181    ndims: usize,
182}
183
184impl EmbeddingModel {
185    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
186        Self {
187            client,
188            model: model.to_owned(),
189            ndims,
190        }
191    }
192}
193
194impl embeddings::EmbeddingModel for EmbeddingModel {
195    const MAX_DOCUMENTS: usize = 1024;
196    fn ndims(&self) -> usize {
197        self.ndims
198    }
199    #[cfg_attr(feature = "worker", worker::send)]
200    async fn embed_texts(
201        &self,
202        documents: impl IntoIterator<Item = String>,
203    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
204        let docs: Vec<String> = documents.into_iter().collect();
205        let payload = json!({
206            "model": self.model,
207            "input": docs,
208        });
209        let response = self
210            .client
211            .post("api/embed")
212            .json(&payload)
213            .send()
214            .await
215            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
216        if response.status().is_success() {
217            let api_resp: EmbeddingResponse = response
218                .json()
219                .await
220                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
221            if api_resp.embeddings.len() != docs.len() {
222                return Err(EmbeddingError::ResponseError(
223                    "Number of returned embeddings does not match input".into(),
224                ));
225            }
226            Ok(api_resp
227                .embeddings
228                .into_iter()
229                .zip(docs.into_iter())
230                .map(|(vec, document)| embeddings::Embedding { document, vec })
231                .collect())
232        } else {
233            Err(EmbeddingError::ProviderError(response.text().await?))
234        }
235    }
236}
237
238// ---------- Completion API ----------
239
240pub const LLAMA3_2: &str = "llama3.2";
241pub const LLAVA: &str = "llava";
242pub const MISTRAL: &str = "mistral";
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct CompletionResponse {
246    pub model: String,
247    pub created_at: String,
248    pub message: Message,
249    pub done: bool,
250    #[serde(default)]
251    pub done_reason: Option<String>,
252    #[serde(default)]
253    pub total_duration: Option<u64>,
254    #[serde(default)]
255    pub load_duration: Option<u64>,
256    #[serde(default)]
257    pub prompt_eval_count: Option<u64>,
258    #[serde(default)]
259    pub prompt_eval_duration: Option<u64>,
260    #[serde(default)]
261    pub eval_count: Option<u64>,
262    #[serde(default)]
263    pub eval_duration: Option<u64>,
264}
265impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
266    type Error = CompletionError;
267    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
268        match resp.message {
269            // Process only if an assistant message is present.
270            Message::Assistant {
271                content,
272                tool_calls,
273                ..
274            } => {
275                let mut assistant_contents = Vec::new();
276                // Add the assistant's text content if any.
277                if !content.is_empty() {
278                    assistant_contents.push(completion::AssistantContent::text(&content));
279                }
280                // Process tool_calls following Ollama's chat response definition.
281                // Each ToolCall has an id, a type, and a function field.
282                for tc in tool_calls.iter() {
283                    assistant_contents.push(completion::AssistantContent::tool_call(
284                        tc.function.name.clone(),
285                        tc.function.name.clone(),
286                        tc.function.arguments.clone(),
287                    ));
288                }
289                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
290                    CompletionError::ResponseError("No content provided".to_owned())
291                })?;
292                let raw_response = CompletionResponse {
293                    model: resp.model,
294                    created_at: resp.created_at,
295                    done: resp.done,
296                    done_reason: resp.done_reason,
297                    total_duration: resp.total_duration,
298                    load_duration: resp.load_duration,
299                    prompt_eval_count: resp.prompt_eval_count,
300                    prompt_eval_duration: resp.prompt_eval_duration,
301                    eval_count: resp.eval_count,
302                    eval_duration: resp.eval_duration,
303                    message: Message::Assistant {
304                        content,
305                        images: None,
306                        name: None,
307                        tool_calls,
308                    },
309                };
310                Ok(completion::CompletionResponse {
311                    choice,
312                    raw_response,
313                })
314            }
315            _ => Err(CompletionError::ResponseError(
316                "Chat response does not include an assistant message".into(),
317            )),
318        }
319    }
320}
321
322// ---------- Completion Model ----------
323
324#[derive(Clone)]
325pub struct CompletionModel {
326    client: Client,
327    pub model: String,
328}
329
330impl CompletionModel {
331    pub fn new(client: Client, model: &str) -> Self {
332        Self {
333            client,
334            model: model.to_owned(),
335        }
336    }
337
338    fn create_completion_request(
339        &self,
340        completion_request: CompletionRequest,
341    ) -> Result<Value, CompletionError> {
342        // Build up the order of messages (context, chat_history)
343        let mut partial_history = vec![];
344        if let Some(docs) = completion_request.normalized_documents() {
345            partial_history.push(docs);
346        }
347        partial_history.extend(completion_request.chat_history);
348
349        // Initialize full history with preamble (or empty if non-existent)
350        let mut full_history: Vec<Message> = completion_request
351            .preamble
352            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
353
354        // Convert and extend the rest of the history
355        full_history.extend(
356            partial_history
357                .into_iter()
358                .map(|msg| msg.try_into())
359                .collect::<Result<Vec<Message>, _>>()?,
360        );
361
362        // Convert internal prompt into a provider Message
363        let options = if let Some(extra) = completion_request.additional_params {
364            json_utils::merge(
365                json!({ "temperature": completion_request.temperature }),
366                extra,
367            )
368        } else {
369            json!({ "temperature": completion_request.temperature })
370        };
371
372        let mut request_payload = json!({
373            "model": self.model,
374            "messages": full_history,
375            "options": options,
376            "stream": false,
377        });
378        if !completion_request.tools.is_empty() {
379            request_payload["tools"] = json!(completion_request
380                .tools
381                .into_iter()
382                .map(|tool| tool.into())
383                .collect::<Vec<ToolDefinition>>());
384        }
385
386        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
387
388        Ok(request_payload)
389    }
390}
391
392// ---------- CompletionModel Implementation ----------
393
394#[derive(Clone)]
395pub struct StreamingCompletionResponse {
396    pub done_reason: Option<String>,
397    pub total_duration: Option<u64>,
398    pub load_duration: Option<u64>,
399    pub prompt_eval_count: Option<u64>,
400    pub prompt_eval_duration: Option<u64>,
401    pub eval_count: Option<u64>,
402    pub eval_duration: Option<u64>,
403}
404impl completion::CompletionModel for CompletionModel {
405    type Response = CompletionResponse;
406    type StreamingResponse = StreamingCompletionResponse;
407
408    #[cfg_attr(feature = "worker", worker::send)]
409    async fn completion(
410        &self,
411        completion_request: CompletionRequest,
412    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
413        let request_payload = self.create_completion_request(completion_request)?;
414
415        let response = self
416            .client
417            .post("api/chat")
418            .json(&request_payload)
419            .send()
420            .await
421            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
422        if response.status().is_success() {
423            let text = response
424                .text()
425                .await
426                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
427            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
428            let chat_resp: CompletionResponse = serde_json::from_str(&text)
429                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
430            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
431            Ok(conv)
432        } else {
433            let err_text = response
434                .text()
435                .await
436                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
437            Err(CompletionError::ProviderError(err_text))
438        }
439    }
440
441    #[cfg_attr(feature = "worker", worker::send)]
442    async fn stream(
443        &self,
444        request: CompletionRequest,
445    ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
446    {
447        let mut request_payload = self.create_completion_request(request)?;
448        merge_inplace(&mut request_payload, json!({"stream": true}));
449
450        let response = self
451            .client
452            .post("api/chat")
453            .json(&request_payload)
454            .send()
455            .await
456            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
457
458        if !response.status().is_success() {
459            let err_text = response
460                .text()
461                .await
462                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
463            return Err(CompletionError::ProviderError(err_text));
464        }
465
466        let stream = Box::pin(stream! {
467            let mut stream = response.bytes_stream();
468            while let Some(chunk_result) = stream.next().await {
469                let chunk = match chunk_result {
470                    Ok(c) => c,
471                    Err(e) => {
472                        yield Err(CompletionError::from(e));
473                        break;
474                    }
475                };
476
477                let text = match String::from_utf8(chunk.to_vec()) {
478                    Ok(t) => t,
479                    Err(e) => {
480                        yield Err(CompletionError::ResponseError(e.to_string()));
481                        break;
482                    }
483                };
484
485
486                for line in text.lines() {
487                    let line = line.to_string();
488
489                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
490                        continue;
491                    };
492
493                    match response.message {
494                        Message::Assistant{ content, tool_calls, .. } => {
495                            if !content.is_empty() {
496                                yield Ok(RawStreamingChoice::Message(content))
497                            }
498
499                            for tool_call in tool_calls.iter() {
500                                let function = tool_call.function.clone();
501
502                                yield Ok(RawStreamingChoice::ToolCall {
503                                    id: "".to_string(),
504                                    name: function.name,
505                                    arguments: function.arguments
506                                });
507                            }
508                        }
509                        _ => {
510                            continue;
511                        }
512                    }
513
514                    if response.done {
515                        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
516                            total_duration: response.total_duration,
517                            load_duration: response.load_duration,
518                            prompt_eval_count: response.prompt_eval_count,
519                            prompt_eval_duration: response.prompt_eval_duration,
520                            eval_count: response.eval_count,
521                            eval_duration: response.eval_duration,
522                            done_reason: response.done_reason,
523                        }));
524                    }
525                }
526            }
527        });
528
529        Ok(streaming::StreamingCompletionResponse::stream(stream))
530    }
531}
532
533// ---------- Tool Definition Conversion ----------
534
535/// Ollama-required tool definition format.
536#[derive(Clone, Debug, Deserialize, Serialize)]
537pub struct ToolDefinition {
538    #[serde(rename = "type")]
539    pub type_field: String, // Fixed as "function"
540    pub function: completion::ToolDefinition,
541}
542
543/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
544impl From<crate::completion::ToolDefinition> for ToolDefinition {
545    fn from(tool: crate::completion::ToolDefinition) -> Self {
546        ToolDefinition {
547            type_field: "function".to_owned(),
548            function: completion::ToolDefinition {
549                name: tool.name,
550                description: tool.description,
551                parameters: tool.parameters,
552            },
553        }
554    }
555}
556
557#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
558pub struct ToolCall {
559    // pub id: String,
560    #[serde(default, rename = "type")]
561    pub r#type: ToolType,
562    pub function: Function,
563}
564#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
565#[serde(rename_all = "lowercase")]
566pub enum ToolType {
567    #[default]
568    Function,
569}
570#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
571pub struct Function {
572    pub name: String,
573    pub arguments: Value,
574}
575
576// ---------- Provider Message Definition ----------
577
578#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
579#[serde(tag = "role", rename_all = "lowercase")]
580pub enum Message {
581    User {
582        content: String,
583        #[serde(skip_serializing_if = "Option::is_none")]
584        images: Option<Vec<String>>,
585        #[serde(skip_serializing_if = "Option::is_none")]
586        name: Option<String>,
587    },
588    Assistant {
589        #[serde(default)]
590        content: String,
591        #[serde(skip_serializing_if = "Option::is_none")]
592        images: Option<Vec<String>>,
593        #[serde(skip_serializing_if = "Option::is_none")]
594        name: Option<String>,
595        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
596        tool_calls: Vec<ToolCall>,
597    },
598    System {
599        content: String,
600        #[serde(skip_serializing_if = "Option::is_none")]
601        images: Option<Vec<String>>,
602        #[serde(skip_serializing_if = "Option::is_none")]
603        name: Option<String>,
604    },
605    #[serde(rename = "Tool")]
606    ToolResult { name: String, content: String },
607}
608
609/// -----------------------------
610/// Provider Message Conversions
611/// -----------------------------
612/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
613/// (Only User and Assistant variants are supported.)
614impl TryFrom<crate::message::Message> for Message {
615    type Error = crate::message::MessageError;
616    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
617        use crate::message::Message as InternalMessage;
618        match internal_msg {
619            InternalMessage::User { content, .. } => {
620                let mut texts = Vec::new();
621                let mut images = Vec::new();
622                for uc in content.into_iter() {
623                    match uc {
624                        crate::message::UserContent::Text(t) => texts.push(t.text),
625                        crate::message::UserContent::Image(img) => images.push(img.data),
626                        crate::message::UserContent::ToolResult(result) => {
627                            let content = result
628                                .content
629                                .into_iter()
630                                .map(ToolResultContent::try_from)
631                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
632
633                            let content = OneOrMany::many(content).map_err(|x| {
634                                MessageError::ConversionError(format!(
635                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
636                                ))
637                            })?;
638
639                            return Ok(Message::ToolResult {
640                                name: result.id,
641                                content: content.first().text,
642                            });
643                        }
644                        _ => {} // Audio variant removed since Ollama API does not support it.
645                    }
646                }
647                let content_str = texts.join(" ");
648                let images_opt = if images.is_empty() {
649                    None
650                } else {
651                    Some(images)
652                };
653                Ok(Message::User {
654                    content: content_str,
655                    images: images_opt,
656                    name: None,
657                })
658            }
659            InternalMessage::Assistant { content, .. } => {
660                let mut texts = Vec::new();
661                let mut tool_calls = Vec::new();
662                for ac in content.into_iter() {
663                    match ac {
664                        crate::message::AssistantContent::Text(t) => texts.push(t.text),
665                        crate::message::AssistantContent::ToolCall(tc) => {
666                            tool_calls.push(ToolCall {
667                                r#type: ToolType::Function, // Assuming internal tool call provides these fields
668                                function: Function {
669                                    name: tc.function.name,
670                                    arguments: tc.function.arguments,
671                                },
672                            });
673                        }
674                    }
675                }
676                let content_str = texts.join(" ");
677                Ok(Message::Assistant {
678                    content: content_str,
679                    images: None,
680                    name: None,
681                    tool_calls,
682                })
683            }
684        }
685    }
686}
687
688/// Conversion from provider Message to a completion message.
689/// This is needed so that responses can be converted back into chat history.
690impl From<Message> for crate::completion::Message {
691    fn from(msg: Message) -> Self {
692        match msg {
693            Message::User { content, .. } => crate::completion::Message::User {
694                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
695                    text: content,
696                })),
697            },
698            Message::Assistant {
699                content,
700                tool_calls,
701                ..
702            } => {
703                let mut assistant_contents =
704                    vec![crate::completion::message::AssistantContent::Text(Text {
705                        text: content,
706                    })];
707                for tc in tool_calls {
708                    assistant_contents.push(
709                        crate::completion::message::AssistantContent::tool_call(
710                            tc.function.name.clone(),
711                            tc.function.name,
712                            tc.function.arguments,
713                        ),
714                    );
715                }
716                crate::completion::Message::Assistant {
717                    content: OneOrMany::many(assistant_contents).unwrap(),
718                }
719            }
720            // System and ToolResult are converted to User message as needed.
721            Message::System { content, .. } => crate::completion::Message::User {
722                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
723                    text: content,
724                })),
725            },
726            Message::ToolResult { name, content } => crate::completion::Message::User {
727                content: OneOrMany::one(message::UserContent::tool_result(
728                    name,
729                    OneOrMany::one(message::ToolResultContent::Text(Text { text: content })),
730                )),
731            },
732        }
733    }
734}
735
736impl Message {
737    /// Constructs a system message.
738    pub fn system(content: &str) -> Self {
739        Message::System {
740            content: content.to_owned(),
741            images: None,
742            name: None,
743        }
744    }
745}
746
747// ---------- Additional Message Types ----------
748
749#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
750pub struct ToolResultContent {
751    text: String,
752}
753
754impl TryFrom<crate::message::ToolResultContent> for ToolResultContent {
755    type Error = MessageError;
756    fn try_from(value: crate::message::ToolResultContent) -> Result<Self, Self::Error> {
757        let crate::message::ToolResultContent::Text(Text { text }) = value else {
758            return Err(MessageError::ConversionError(
759                "Non-text tool results not supported".into(),
760            ));
761        };
762
763        Ok(Self { text })
764    }
765}
766
767impl FromStr for ToolResultContent {
768    type Err = Infallible;
769
770    fn from_str(s: &str) -> Result<Self, Self::Err> {
771        Ok(s.to_owned().into())
772    }
773}
774
775impl From<String> for ToolResultContent {
776    fn from(s: String) -> Self {
777        ToolResultContent { text: s }
778    }
779}
780
781#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
782pub struct SystemContent {
783    #[serde(default)]
784    r#type: SystemContentType,
785    text: String,
786}
787
788#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
789#[serde(rename_all = "lowercase")]
790pub enum SystemContentType {
791    #[default]
792    Text,
793}
794
795impl From<String> for SystemContent {
796    fn from(s: String) -> Self {
797        SystemContent {
798            r#type: SystemContentType::default(),
799            text: s,
800        }
801    }
802}
803
804impl FromStr for SystemContent {
805    type Err = std::convert::Infallible;
806    fn from_str(s: &str) -> Result<Self, Self::Err> {
807        Ok(SystemContent {
808            r#type: SystemContentType::default(),
809            text: s.to_string(),
810        })
811    }
812}
813
814#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
815pub struct AssistantContent {
816    pub text: String,
817}
818
819impl FromStr for AssistantContent {
820    type Err = std::convert::Infallible;
821    fn from_str(s: &str) -> Result<Self, Self::Err> {
822        Ok(AssistantContent { text: s.to_owned() })
823    }
824}
825
826#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
827#[serde(tag = "type", rename_all = "lowercase")]
828pub enum UserContent {
829    Text { text: String },
830    Image { image_url: ImageUrl },
831    // Audio variant removed as Ollama API does not support audio input.
832}
833
834impl FromStr for UserContent {
835    type Err = std::convert::Infallible;
836    fn from_str(s: &str) -> Result<Self, Self::Err> {
837        Ok(UserContent::Text { text: s.to_owned() })
838    }
839}
840
841#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
842pub struct ImageUrl {
843    pub url: String,
844    #[serde(default)]
845    pub detail: ImageDetail,
846}
847
848// =================================================================
849// Tests
850// =================================================================
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855    use serde_json::json;
856
857    // Test deserialization and conversion for the /api/chat endpoint.
858    #[tokio::test]
859    async fn test_chat_completion() {
860        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
861        let sample_chat_response = json!({
862            "model": "llama3.2",
863            "created_at": "2023-08-04T19:22:45.499127Z",
864            "message": {
865                "role": "assistant",
866                "content": "The sky is blue because of Rayleigh scattering.",
867                "images": null,
868                "tool_calls": [
869                    {
870                        "type": "function",
871                        "function": {
872                            "name": "get_current_weather",
873                            "arguments": {
874                                "location": "San Francisco, CA",
875                                "format": "celsius"
876                            }
877                        }
878                    }
879                ]
880            },
881            "done": true,
882            "total_duration": 8000000000u64,
883            "load_duration": 6000000u64,
884            "prompt_eval_count": 61u64,
885            "prompt_eval_duration": 400000000u64,
886            "eval_count": 468u64,
887            "eval_duration": 7700000000u64
888        });
889        let sample_text = sample_chat_response.to_string();
890
891        let chat_resp: CompletionResponse =
892            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
893        let conv: completion::CompletionResponse<CompletionResponse> =
894            chat_resp.try_into().unwrap();
895        assert!(
896            !conv.choice.is_empty(),
897            "Expected non-empty choice in chat response"
898        );
899    }
900
901    // Test conversion from provider Message to completion Message.
902    #[test]
903    fn test_message_conversion() {
904        // Construct a provider Message (User variant with String content).
905        let provider_msg = Message::User {
906            content: "Test message".to_owned(),
907            images: None,
908            name: None,
909        };
910        // Convert it into a completion::Message.
911        let comp_msg: crate::completion::Message = provider_msg.into();
912        match comp_msg {
913            crate::completion::Message::User { content } => {
914                // Assume OneOrMany<T> has a method first() to access the first element.
915                let first_content = content.first();
916                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
917                match first_content {
918                    crate::completion::message::UserContent::Text(text_struct) => {
919                        assert_eq!(text_struct.text, "Test message");
920                    }
921                    _ => panic!("Expected text content in conversion"),
922                }
923            }
924            _ => panic!("Conversion from provider Message to completion Message failed"),
925        }
926    }
927
928    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
929    #[test]
930    fn test_tool_definition_conversion() {
931        // Internal tool definition from the completion module.
932        let internal_tool = crate::completion::ToolDefinition {
933            name: "get_current_weather".to_owned(),
934            description: "Get the current weather for a location".to_owned(),
935            parameters: json!({
936                "type": "object",
937                "properties": {
938                    "location": {
939                        "type": "string",
940                        "description": "The location to get the weather for, e.g. San Francisco, CA"
941                    },
942                    "format": {
943                        "type": "string",
944                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
945                        "enum": ["celsius", "fahrenheit"]
946                    }
947                },
948                "required": ["location", "format"]
949            }),
950        };
951        // Convert internal tool to Ollama's tool definition.
952        let ollama_tool: ToolDefinition = internal_tool.into();
953        assert_eq!(ollama_tool.type_field, "function");
954        assert_eq!(ollama_tool.function.name, "get_current_weather");
955        assert_eq!(
956            ollama_tool.function.description,
957            "Get the current weather for a location"
958        );
959        // Check JSON fields in parameters.
960        let params = &ollama_tool.function.parameters;
961        assert_eq!(params["properties"]["location"]["type"], "string");
962    }
963}