rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::json_utils::merge_inplace;
42use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
43use crate::{
44    agent::AgentBuilder,
45    completion::{self, CompletionError, CompletionRequest},
46    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
47    extractor::ExtractorBuilder,
48    json_utils, message,
49    message::{ImageDetail, Text},
50    Embed, OneOrMany,
51};
52use async_stream::stream;
53use futures::StreamExt;
54use reqwest;
55use schemars::JsonSchema;
56use serde::{Deserialize, Serialize};
57use serde_json::{json, Value};
58use std::convert::Infallible;
59use std::{convert::TryFrom, str::FromStr};
60// ---------- Main Client ----------
61
62const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64#[derive(Clone)]
65pub struct Client {
66    base_url: String,
67    http_client: reqwest::Client,
68}
69
70impl Default for Client {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Client {
77    pub fn new() -> Self {
78        Self::from_url(OLLAMA_API_BASE_URL)
79    }
80    pub fn from_url(base_url: &str) -> Self {
81        Self {
82            base_url: base_url.to_owned(),
83            http_client: reqwest::Client::builder()
84                .build()
85                .expect("Ollama reqwest client should build"),
86        }
87    }
88    fn post(&self, path: &str) -> reqwest::RequestBuilder {
89        let url = format!("{}/{}", self.base_url, path);
90        self.http_client.post(url)
91    }
92    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
93        EmbeddingModel::new(self.clone(), model, 0)
94    }
95    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
96        EmbeddingModel::new(self.clone(), model, ndims)
97    }
98    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
99        EmbeddingsBuilder::new(self.embedding_model(model))
100    }
101    pub fn completion_model(&self, model: &str) -> CompletionModel {
102        CompletionModel::new(self.clone(), model)
103    }
104    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
105        AgentBuilder::new(self.completion_model(model))
106    }
107    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
108        &self,
109        model: &str,
110    ) -> ExtractorBuilder<T, CompletionModel> {
111        ExtractorBuilder::new(self.completion_model(model))
112    }
113}
114
115// ---------- API Error and Response Structures ----------
116
117#[derive(Debug, Deserialize)]
118struct ApiErrorResponse {
119    message: String,
120}
121
122#[derive(Debug, Deserialize)]
123#[serde(untagged)]
124enum ApiResponse<T> {
125    Ok(T),
126    Err(ApiErrorResponse),
127}
128
129// ---------- Embedding API ----------
130
131pub const ALL_MINILM: &str = "all-minilm";
132pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
133
134#[derive(Debug, Serialize, Deserialize)]
135pub struct EmbeddingResponse {
136    pub model: String,
137    pub embeddings: Vec<Vec<f64>>,
138    #[serde(default)]
139    pub total_duration: Option<u64>,
140    #[serde(default)]
141    pub load_duration: Option<u64>,
142    #[serde(default)]
143    pub prompt_eval_count: Option<u64>,
144}
145
146impl From<ApiErrorResponse> for EmbeddingError {
147    fn from(err: ApiErrorResponse) -> Self {
148        EmbeddingError::ProviderError(err.message)
149    }
150}
151
152impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
153    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
154        match value {
155            ApiResponse::Ok(response) => Ok(response),
156            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
157        }
158    }
159}
160
161// ---------- Embedding Model ----------
162
163#[derive(Clone)]
164pub struct EmbeddingModel {
165    client: Client,
166    pub model: String,
167    ndims: usize,
168}
169
170impl EmbeddingModel {
171    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
172        Self {
173            client,
174            model: model.to_owned(),
175            ndims,
176        }
177    }
178}
179
180impl embeddings::EmbeddingModel for EmbeddingModel {
181    const MAX_DOCUMENTS: usize = 1024;
182    fn ndims(&self) -> usize {
183        self.ndims
184    }
185    #[cfg_attr(feature = "worker", worker::send)]
186    async fn embed_texts(
187        &self,
188        documents: impl IntoIterator<Item = String>,
189    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
190        let docs: Vec<String> = documents.into_iter().collect();
191        let payload = json!({
192            "model": self.model,
193            "input": docs,
194        });
195        let response = self
196            .client
197            .post("api/embed")
198            .json(&payload)
199            .send()
200            .await
201            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
202        if response.status().is_success() {
203            let api_resp: EmbeddingResponse = response
204                .json()
205                .await
206                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
207            if api_resp.embeddings.len() != docs.len() {
208                return Err(EmbeddingError::ResponseError(
209                    "Number of returned embeddings does not match input".into(),
210                ));
211            }
212            Ok(api_resp
213                .embeddings
214                .into_iter()
215                .zip(docs.into_iter())
216                .map(|(vec, document)| embeddings::Embedding { document, vec })
217                .collect())
218        } else {
219            Err(EmbeddingError::ProviderError(response.text().await?))
220        }
221    }
222}
223
224// ---------- Completion API ----------
225
226pub const LLAMA3_2: &str = "llama3.2";
227pub const LLAVA: &str = "llava";
228pub const MISTRAL: &str = "mistral";
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionResponse {
232    pub model: String,
233    pub created_at: String,
234    pub message: Message,
235    pub done: bool,
236    #[serde(default)]
237    pub done_reason: Option<String>,
238    #[serde(default)]
239    pub total_duration: Option<u64>,
240    #[serde(default)]
241    pub load_duration: Option<u64>,
242    #[serde(default)]
243    pub prompt_eval_count: Option<u64>,
244    #[serde(default)]
245    pub prompt_eval_duration: Option<u64>,
246    #[serde(default)]
247    pub eval_count: Option<u64>,
248    #[serde(default)]
249    pub eval_duration: Option<u64>,
250}
251impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
252    type Error = CompletionError;
253    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
254        match resp.message {
255            // Process only if an assistant message is present.
256            Message::Assistant {
257                content,
258                tool_calls,
259                ..
260            } => {
261                let mut assistant_contents = Vec::new();
262                // Add the assistant's text content if any.
263                if !content.is_empty() {
264                    assistant_contents.push(completion::AssistantContent::text(&content));
265                }
266                // Process tool_calls following Ollama's chat response definition.
267                // Each ToolCall has an id, a type, and a function field.
268                for tc in tool_calls.iter() {
269                    assistant_contents.push(completion::AssistantContent::tool_call(
270                        tc.function.name.clone(),
271                        tc.function.name.clone(),
272                        tc.function.arguments.clone(),
273                    ));
274                }
275                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
276                    CompletionError::ResponseError("No content provided".to_owned())
277                })?;
278                let raw_response = CompletionResponse {
279                    model: resp.model,
280                    created_at: resp.created_at,
281                    done: resp.done,
282                    done_reason: resp.done_reason,
283                    total_duration: resp.total_duration,
284                    load_duration: resp.load_duration,
285                    prompt_eval_count: resp.prompt_eval_count,
286                    prompt_eval_duration: resp.prompt_eval_duration,
287                    eval_count: resp.eval_count,
288                    eval_duration: resp.eval_duration,
289                    message: Message::Assistant {
290                        content,
291                        images: None,
292                        name: None,
293                        tool_calls,
294                    },
295                };
296                Ok(completion::CompletionResponse {
297                    choice,
298                    raw_response,
299                })
300            }
301            _ => Err(CompletionError::ResponseError(
302                "Chat response does not include an assistant message".into(),
303            )),
304        }
305    }
306}
307
308// ---------- Completion Model ----------
309
310#[derive(Clone)]
311pub struct CompletionModel {
312    client: Client,
313    pub model: String,
314}
315
316impl CompletionModel {
317    pub fn new(client: Client, model: &str) -> Self {
318        Self {
319            client,
320            model: model.to_owned(),
321        }
322    }
323
324    fn create_completion_request(
325        &self,
326        completion_request: CompletionRequest,
327    ) -> Result<Value, CompletionError> {
328        // Convert internal prompt into a provider Message
329        let prompt: Message = completion_request.prompt_with_context().try_into()?;
330        let options = if let Some(extra) = completion_request.additional_params {
331            json_utils::merge(
332                json!({ "temperature": completion_request.temperature }),
333                extra,
334            )
335        } else {
336            json!({ "temperature": completion_request.temperature })
337        };
338
339        // Chat mode: assemble full conversation history including preamble and chat history
340        let mut full_history = Vec::new();
341        if let Some(preamble) = completion_request.preamble {
342            full_history.push(Message::system(&preamble));
343        }
344        for msg in completion_request.chat_history.into_iter() {
345            full_history.push(Message::try_from(msg)?);
346        }
347        full_history.push(prompt);
348
349        let mut request_payload = json!({
350            "model": self.model,
351            "messages": full_history,
352            "options": options,
353            "stream": false,
354        });
355        if !completion_request.tools.is_empty() {
356            request_payload["tools"] = json!(completion_request
357                .tools
358                .into_iter()
359                .map(|tool| tool.into())
360                .collect::<Vec<ToolDefinition>>());
361        }
362
363        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
364
365        Ok(request_payload)
366    }
367}
368
369// ---------- CompletionModel Implementation ----------
370
371impl completion::CompletionModel for CompletionModel {
372    type Response = CompletionResponse;
373
374    #[cfg_attr(feature = "worker", worker::send)]
375    async fn completion(
376        &self,
377        completion_request: CompletionRequest,
378    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
379        let request_payload = self.create_completion_request(completion_request)?;
380
381        let response = self
382            .client
383            .post("api/chat")
384            .json(&request_payload)
385            .send()
386            .await
387            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
388        if response.status().is_success() {
389            let text = response
390                .text()
391                .await
392                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
393            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
394            let chat_resp: CompletionResponse = serde_json::from_str(&text)
395                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
396            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
397            Ok(conv)
398        } else {
399            let err_text = response
400                .text()
401                .await
402                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
403            Err(CompletionError::ProviderError(err_text))
404        }
405    }
406}
407
408impl StreamingCompletionModel for CompletionModel {
409    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
410        let mut request_payload = self.create_completion_request(request)?;
411        merge_inplace(&mut request_payload, json!({"stream": true}));
412
413        let response = self
414            .client
415            .post("api/chat")
416            .json(&request_payload)
417            .send()
418            .await
419            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
420
421        if !response.status().is_success() {
422            let err_text = response
423                .text()
424                .await
425                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
426            return Err(CompletionError::ProviderError(err_text));
427        }
428
429        Ok(Box::pin(stream! {
430            let mut stream = response.bytes_stream();
431            while let Some(chunk_result) = stream.next().await {
432                let chunk = match chunk_result {
433                    Ok(c) => c,
434                    Err(e) => {
435                        yield Err(CompletionError::from(e));
436                        break;
437                    }
438                };
439
440                let text = match String::from_utf8(chunk.to_vec()) {
441                    Ok(t) => t,
442                    Err(e) => {
443                        yield Err(CompletionError::ResponseError(e.to_string()));
444                        break;
445                    }
446                };
447
448
449                for line in text.lines() {
450                    let line = line.to_string();
451
452                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
453                        continue;
454                    };
455
456                    match response.message {
457                        Message::Assistant{ content, tool_calls, .. } => {
458                            if !content.is_empty() {
459                                yield Ok(StreamingChoice::Message(content))
460                            }
461
462                            for tool_call in tool_calls.iter() {
463                                let function = tool_call.function.clone();
464
465                                yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
466                            }
467                        }
468                        _ => {
469                            continue;
470                        }
471                    }
472                }
473            }
474        }))
475    }
476}
477
478// ---------- Tool Definition Conversion ----------
479
480/// Ollama-required tool definition format.
481#[derive(Clone, Debug, Deserialize, Serialize)]
482pub struct ToolDefinition {
483    #[serde(rename = "type")]
484    pub type_field: String, // Fixed as "function"
485    pub function: completion::ToolDefinition,
486}
487
488/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
489impl From<crate::completion::ToolDefinition> for ToolDefinition {
490    fn from(tool: crate::completion::ToolDefinition) -> Self {
491        ToolDefinition {
492            type_field: "function".to_owned(),
493            function: completion::ToolDefinition {
494                name: tool.name,
495                description: tool.description,
496                parameters: tool.parameters,
497            },
498        }
499    }
500}
501
502#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
503pub struct ToolCall {
504    // pub id: String,
505    #[serde(default, rename = "type")]
506    pub r#type: ToolType,
507    pub function: Function,
508}
509#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
510#[serde(rename_all = "lowercase")]
511pub enum ToolType {
512    #[default]
513    Function,
514}
515#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
516pub struct Function {
517    pub name: String,
518    pub arguments: Value,
519}
520
521// ---------- Provider Message Definition ----------
522
523#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
524#[serde(tag = "role", rename_all = "lowercase")]
525pub enum Message {
526    User {
527        content: String,
528        #[serde(skip_serializing_if = "Option::is_none")]
529        images: Option<Vec<String>>,
530        #[serde(skip_serializing_if = "Option::is_none")]
531        name: Option<String>,
532    },
533    Assistant {
534        #[serde(default)]
535        content: String,
536        #[serde(skip_serializing_if = "Option::is_none")]
537        images: Option<Vec<String>>,
538        #[serde(skip_serializing_if = "Option::is_none")]
539        name: Option<String>,
540        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
541        tool_calls: Vec<ToolCall>,
542    },
543    System {
544        content: String,
545        #[serde(skip_serializing_if = "Option::is_none")]
546        images: Option<Vec<String>>,
547        #[serde(skip_serializing_if = "Option::is_none")]
548        name: Option<String>,
549    },
550    #[serde(rename = "Tool")]
551    ToolResult {
552        tool_call_id: String,
553        content: OneOrMany<ToolResultContent>,
554    },
555}
556
557/// -----------------------------
558/// Provider Message Conversions
559/// -----------------------------
560/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
561/// (Only User and Assistant variants are supported.)
562impl TryFrom<crate::message::Message> for Message {
563    type Error = crate::message::MessageError;
564    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
565        use crate::message::Message as InternalMessage;
566        match internal_msg {
567            InternalMessage::User { content, .. } => {
568                let mut texts = Vec::new();
569                let mut images = Vec::new();
570                for uc in content.into_iter() {
571                    match uc {
572                        crate::message::UserContent::Text(t) => texts.push(t.text),
573                        crate::message::UserContent::Image(img) => images.push(img.data),
574                        _ => {} // Audio variant removed since Ollama API does not support it.
575                    }
576                }
577                let content_str = texts.join(" ");
578                let images_opt = if images.is_empty() {
579                    None
580                } else {
581                    Some(images)
582                };
583                Ok(Message::User {
584                    content: content_str,
585                    images: images_opt,
586                    name: None,
587                })
588            }
589            InternalMessage::Assistant { content, .. } => {
590                let mut texts = Vec::new();
591                let mut tool_calls = Vec::new();
592                for ac in content.into_iter() {
593                    match ac {
594                        crate::message::AssistantContent::Text(t) => texts.push(t.text),
595                        crate::message::AssistantContent::ToolCall(tc) => {
596                            tool_calls.push(ToolCall {
597                                r#type: ToolType::Function, // Assuming internal tool call provides these fields
598                                function: Function {
599                                    name: tc.function.name,
600                                    arguments: tc.function.arguments,
601                                },
602                            });
603                        }
604                    }
605                }
606                let content_str = texts.join(" ");
607                Ok(Message::Assistant {
608                    content: content_str,
609                    images: None,
610                    name: None,
611                    tool_calls,
612                })
613            }
614        }
615    }
616}
617
618/// Conversion from provider Message to a completion message.
619/// This is needed so that responses can be converted back into chat history.
620impl From<Message> for crate::completion::Message {
621    fn from(msg: Message) -> Self {
622        match msg {
623            Message::User { content, .. } => crate::completion::Message::User {
624                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
625                    text: content,
626                })),
627            },
628            Message::Assistant {
629                content,
630                tool_calls,
631                ..
632            } => {
633                let mut assistant_contents =
634                    vec![crate::completion::message::AssistantContent::Text(Text {
635                        text: content,
636                    })];
637                for tc in tool_calls {
638                    assistant_contents.push(
639                        crate::completion::message::AssistantContent::tool_call(
640                            tc.function.name.clone(),
641                            tc.function.name,
642                            tc.function.arguments,
643                        ),
644                    );
645                }
646                crate::completion::Message::Assistant {
647                    content: OneOrMany::many(assistant_contents).unwrap(),
648                }
649            }
650            // System and ToolResult are converted to User message as needed.
651            Message::System { content, .. } => crate::completion::Message::User {
652                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
653                    text: content,
654                })),
655            },
656            Message::ToolResult {
657                tool_call_id,
658                content,
659            } => crate::completion::Message::User {
660                content: OneOrMany::one(message::UserContent::tool_result(
661                    tool_call_id,
662                    content.map(|content| message::ToolResultContent::text(content.text)),
663                )),
664            },
665        }
666    }
667}
668
669impl Message {
670    /// Constructs a system message.
671    pub fn system(content: &str) -> Self {
672        Message::System {
673            content: content.to_owned(),
674            images: None,
675            name: None,
676        }
677    }
678}
679
680// ---------- Additional Message Types ----------
681
682#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
683pub struct ToolResultContent {
684    text: String,
685}
686
687impl FromStr for ToolResultContent {
688    type Err = Infallible;
689
690    fn from_str(s: &str) -> Result<Self, Self::Err> {
691        Ok(s.to_owned().into())
692    }
693}
694
695impl From<String> for ToolResultContent {
696    fn from(s: String) -> Self {
697        ToolResultContent { text: s }
698    }
699}
700
701#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
702pub struct SystemContent {
703    #[serde(default)]
704    r#type: SystemContentType,
705    text: String,
706}
707
708#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
709#[serde(rename_all = "lowercase")]
710pub enum SystemContentType {
711    #[default]
712    Text,
713}
714
715impl From<String> for SystemContent {
716    fn from(s: String) -> Self {
717        SystemContent {
718            r#type: SystemContentType::default(),
719            text: s,
720        }
721    }
722}
723
724impl FromStr for SystemContent {
725    type Err = std::convert::Infallible;
726    fn from_str(s: &str) -> Result<Self, Self::Err> {
727        Ok(SystemContent {
728            r#type: SystemContentType::default(),
729            text: s.to_string(),
730        })
731    }
732}
733
734#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
735pub struct AssistantContent {
736    pub text: String,
737}
738
739impl FromStr for AssistantContent {
740    type Err = std::convert::Infallible;
741    fn from_str(s: &str) -> Result<Self, Self::Err> {
742        Ok(AssistantContent { text: s.to_owned() })
743    }
744}
745
746#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
747#[serde(tag = "type", rename_all = "lowercase")]
748pub enum UserContent {
749    Text { text: String },
750    Image { image_url: ImageUrl },
751    // Audio variant removed as Ollama API does not support audio input.
752}
753
754impl FromStr for UserContent {
755    type Err = std::convert::Infallible;
756    fn from_str(s: &str) -> Result<Self, Self::Err> {
757        Ok(UserContent::Text { text: s.to_owned() })
758    }
759}
760
761#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
762pub struct ImageUrl {
763    pub url: String,
764    #[serde(default)]
765    pub detail: ImageDetail,
766}
767
768// =================================================================
769// Tests
770// =================================================================
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use serde_json::json;
776
777    // Test deserialization and conversion for the /api/chat endpoint.
778    #[tokio::test]
779    async fn test_chat_completion() {
780        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
781        let sample_chat_response = json!({
782            "model": "llama3.2",
783            "created_at": "2023-08-04T19:22:45.499127Z",
784            "message": {
785                "role": "assistant",
786                "content": "The sky is blue because of Rayleigh scattering.",
787                "images": null,
788                "tool_calls": [
789                    {
790                        "type": "function",
791                        "function": {
792                            "name": "get_current_weather",
793                            "arguments": {
794                                "location": "San Francisco, CA",
795                                "format": "celsius"
796                            }
797                        }
798                    }
799                ]
800            },
801            "done": true,
802            "total_duration": 8000000000u64,
803            "load_duration": 6000000u64,
804            "prompt_eval_count": 61u64,
805            "prompt_eval_duration": 400000000u64,
806            "eval_count": 468u64,
807            "eval_duration": 7700000000u64
808        });
809        let sample_text = sample_chat_response.to_string();
810
811        let chat_resp: CompletionResponse =
812            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
813        let conv: completion::CompletionResponse<CompletionResponse> =
814            chat_resp.try_into().unwrap();
815        assert!(
816            !conv.choice.is_empty(),
817            "Expected non-empty choice in chat response"
818        );
819    }
820
821    // Test conversion from provider Message to completion Message.
822    #[test]
823    fn test_message_conversion() {
824        // Construct a provider Message (User variant with String content).
825        let provider_msg = Message::User {
826            content: "Test message".to_owned(),
827            images: None,
828            name: None,
829        };
830        // Convert it into a completion::Message.
831        let comp_msg: crate::completion::Message = provider_msg.into();
832        match comp_msg {
833            crate::completion::Message::User { content } => {
834                // Assume OneOrMany<T> has a method first() to access the first element.
835                let first_content = content.first();
836                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
837                match first_content {
838                    crate::completion::message::UserContent::Text(text_struct) => {
839                        assert_eq!(text_struct.text, "Test message");
840                    }
841                    _ => panic!("Expected text content in conversion"),
842                }
843            }
844            _ => panic!("Conversion from provider Message to completion Message failed"),
845        }
846    }
847
848    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
849    #[test]
850    fn test_tool_definition_conversion() {
851        // Internal tool definition from the completion module.
852        let internal_tool = crate::completion::ToolDefinition {
853            name: "get_current_weather".to_owned(),
854            description: "Get the current weather for a location".to_owned(),
855            parameters: json!({
856                "type": "object",
857                "properties": {
858                    "location": {
859                        "type": "string",
860                        "description": "The location to get the weather for, e.g. San Francisco, CA"
861                    },
862                    "format": {
863                        "type": "string",
864                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
865                        "enum": ["celsius", "fahrenheit"]
866                    }
867                },
868                "required": ["location", "format"]
869            }),
870        };
871        // Convert internal tool to Ollama's tool definition.
872        let ollama_tool: ToolDefinition = internal_tool.into();
873        assert_eq!(ollama_tool.type_field, "function");
874        assert_eq!(ollama_tool.function.name, "get_current_weather");
875        assert_eq!(
876            ollama_tool.function.description,
877            "Get the current weather for a location"
878        );
879        // Check JSON fields in parameters.
880        let params = &ollama_tool.function.parameters;
881        assert_eq!(params["properties"]["location"]["type"], "string");
882    }
883}