rig/providers/
ollama.rs

1//! Ollama API client and Rig integration
2//!
3//! # Example
4//! ```rust
5//! use rig::providers::ollama;
6//!
7//! // Create a new Ollama client (defaults to http://localhost:11434)
8//! let client = ollama::Client::new();
9//!
10//! // Create a completion model interface using, for example, the "llama3.2" model
11//! let comp_model = client.completion_model("llama3.2");
12//!
13//! let req = rig::completion::CompletionRequest {
14//!     preamble: Some("You are now a humorous AI assistant.".to_owned()),
15//!     chat_history: vec![],  // internal messages (if any)
16//!     prompt: rig::message::Message::User {
17//!         content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
18//!         name: None
19//!     },
20//!     temperature: 0.7,
21//!     additional_params: None,
22//!     tools: vec![],
23//! };
24//!
25//! let response = comp_model.completion(req).await.unwrap();
26//! println!("Ollama completion response: {:?}", response.choice);
27//!
28//! // Create an embedding interface using the "all-minilm" model
29//! let emb_model = ollama::Client::new().embedding_model("all-minilm");
30//! let docs = vec![
31//!     "Why is the sky blue?".to_owned(),
32//!     "Why is the grass green?".to_owned()
33//! ];
34//! let embeddings = emb_model.embed_texts(docs).await.unwrap();
35//! println!("Embedding response: {:?}", embeddings);
36//!
37//! // Also create an agent and extractor if needed
38//! let agent = client.agent("llama3.2");
39//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
40//! ```
41use crate::json_utils::merge_inplace;
42use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
43use crate::{
44    agent::AgentBuilder,
45    completion::{self, CompletionError, CompletionRequest},
46    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
47    extractor::ExtractorBuilder,
48    json_utils, message,
49    message::{ImageDetail, Text},
50    Embed, OneOrMany,
51};
52use async_stream::stream;
53use futures::StreamExt;
54use reqwest;
55use schemars::JsonSchema;
56use serde::{Deserialize, Serialize};
57use serde_json::{json, Value};
58use std::convert::Infallible;
59use std::{convert::TryFrom, str::FromStr};
60// ---------- Main Client ----------
61
62const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64#[derive(Clone)]
65pub struct Client {
66    base_url: String,
67    http_client: reqwest::Client,
68}
69
70impl Default for Client {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Client {
77    pub fn new() -> Self {
78        Self::from_url(OLLAMA_API_BASE_URL)
79    }
80    pub fn from_url(base_url: &str) -> Self {
81        Self {
82            base_url: base_url.to_owned(),
83            http_client: reqwest::Client::builder()
84                .build()
85                .expect("Ollama reqwest client should build"),
86        }
87    }
88    fn post(&self, path: &str) -> reqwest::RequestBuilder {
89        let url = format!("{}/{}", self.base_url, path);
90        self.http_client.post(url)
91    }
92    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
93        EmbeddingModel::new(self.clone(), model, 0)
94    }
95    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
96        EmbeddingModel::new(self.clone(), model, ndims)
97    }
98    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
99        EmbeddingsBuilder::new(self.embedding_model(model))
100    }
101    pub fn completion_model(&self, model: &str) -> CompletionModel {
102        CompletionModel::new(self.clone(), model)
103    }
104    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
105        AgentBuilder::new(self.completion_model(model))
106    }
107    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
108        &self,
109        model: &str,
110    ) -> ExtractorBuilder<T, CompletionModel> {
111        ExtractorBuilder::new(self.completion_model(model))
112    }
113}
114
115// ---------- API Error and Response Structures ----------
116
117#[derive(Debug, Deserialize)]
118struct ApiErrorResponse {
119    message: String,
120}
121
122#[derive(Debug, Deserialize)]
123#[serde(untagged)]
124enum ApiResponse<T> {
125    Ok(T),
126    Err(ApiErrorResponse),
127}
128
129// ---------- Embedding API ----------
130
131pub const ALL_MINILM: &str = "all-minilm";
132pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
133
134#[derive(Debug, Serialize, Deserialize)]
135pub struct EmbeddingResponse {
136    pub model: String,
137    pub embeddings: Vec<Vec<f64>>,
138    #[serde(default)]
139    pub total_duration: Option<u64>,
140    #[serde(default)]
141    pub load_duration: Option<u64>,
142    #[serde(default)]
143    pub prompt_eval_count: Option<u64>,
144}
145
146impl From<ApiErrorResponse> for EmbeddingError {
147    fn from(err: ApiErrorResponse) -> Self {
148        EmbeddingError::ProviderError(err.message)
149    }
150}
151
152impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
153    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
154        match value {
155            ApiResponse::Ok(response) => Ok(response),
156            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
157        }
158    }
159}
160
161// ---------- Embedding Model ----------
162
163#[derive(Clone)]
164pub struct EmbeddingModel {
165    client: Client,
166    pub model: String,
167    ndims: usize,
168}
169
170impl EmbeddingModel {
171    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
172        Self {
173            client,
174            model: model.to_owned(),
175            ndims,
176        }
177    }
178}
179
180impl embeddings::EmbeddingModel for EmbeddingModel {
181    const MAX_DOCUMENTS: usize = 1024;
182    fn ndims(&self) -> usize {
183        self.ndims
184    }
185    #[cfg_attr(feature = "worker", worker::send)]
186    async fn embed_texts(
187        &self,
188        documents: impl IntoIterator<Item = String>,
189    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
190        let docs: Vec<String> = documents.into_iter().collect();
191        let payload = json!({
192            "model": self.model,
193            "input": docs,
194        });
195        let response = self
196            .client
197            .post("api/embed")
198            .json(&payload)
199            .send()
200            .await
201            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
202        if response.status().is_success() {
203            let api_resp: EmbeddingResponse = response
204                .json()
205                .await
206                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
207            if api_resp.embeddings.len() != docs.len() {
208                return Err(EmbeddingError::ResponseError(
209                    "Number of returned embeddings does not match input".into(),
210                ));
211            }
212            Ok(api_resp
213                .embeddings
214                .into_iter()
215                .zip(docs.into_iter())
216                .map(|(vec, document)| embeddings::Embedding { document, vec })
217                .collect())
218        } else {
219            Err(EmbeddingError::ProviderError(response.text().await?))
220        }
221    }
222}
223
224// ---------- Completion API ----------
225
226pub const LLAMA3_2: &str = "llama3.2";
227pub const LLAVA: &str = "llava";
228pub const MISTRAL: &str = "mistral";
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionResponse {
232    pub model: String,
233    pub created_at: String,
234    pub message: Message,
235    pub done: bool,
236    #[serde(default)]
237    pub done_reason: Option<String>,
238    #[serde(default)]
239    pub total_duration: Option<u64>,
240    #[serde(default)]
241    pub load_duration: Option<u64>,
242    #[serde(default)]
243    pub prompt_eval_count: Option<u64>,
244    #[serde(default)]
245    pub prompt_eval_duration: Option<u64>,
246    #[serde(default)]
247    pub eval_count: Option<u64>,
248    #[serde(default)]
249    pub eval_duration: Option<u64>,
250}
251impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
252    type Error = CompletionError;
253    fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
254        match resp.message {
255            // Process only if an assistant message is present.
256            Message::Assistant {
257                content,
258                tool_calls,
259                ..
260            } => {
261                let mut assistant_contents = Vec::new();
262                // Add the assistant's text content if any.
263                if !content.is_empty() {
264                    assistant_contents.push(completion::AssistantContent::text(&content));
265                }
266                // Process tool_calls following Ollama's chat response definition.
267                // Each ToolCall has an id, a type, and a function field.
268                for tc in tool_calls.iter() {
269                    assistant_contents.push(completion::AssistantContent::tool_call(
270                        tc.function.name.clone(),
271                        tc.function.name.clone(),
272                        tc.function.arguments.clone(),
273                    ));
274                }
275                let choice = OneOrMany::many(assistant_contents).map_err(|_| {
276                    CompletionError::ResponseError("No content provided".to_owned())
277                })?;
278                let raw_response = CompletionResponse {
279                    model: resp.model,
280                    created_at: resp.created_at,
281                    done: resp.done,
282                    done_reason: resp.done_reason,
283                    total_duration: resp.total_duration,
284                    load_duration: resp.load_duration,
285                    prompt_eval_count: resp.prompt_eval_count,
286                    prompt_eval_duration: resp.prompt_eval_duration,
287                    eval_count: resp.eval_count,
288                    eval_duration: resp.eval_duration,
289                    message: Message::Assistant {
290                        content,
291                        images: None,
292                        name: None,
293                        tool_calls,
294                    },
295                };
296                Ok(completion::CompletionResponse {
297                    choice,
298                    raw_response,
299                })
300            }
301            _ => Err(CompletionError::ResponseError(
302                "Chat response does not include an assistant message".into(),
303            )),
304        }
305    }
306}
307
308// ---------- Completion Model ----------
309
310#[derive(Clone)]
311pub struct CompletionModel {
312    client: Client,
313    pub model: String,
314}
315
316impl CompletionModel {
317    pub fn new(client: Client, model: &str) -> Self {
318        Self {
319            client,
320            model: model.to_owned(),
321        }
322    }
323
324    fn create_completion_request(
325        &self,
326        completion_request: CompletionRequest,
327    ) -> Result<Value, CompletionError> {
328        // Build up the order of messages (context, chat_history)
329        let mut partial_history = vec![];
330        if let Some(docs) = completion_request.normalized_documents() {
331            partial_history.push(docs);
332        }
333        partial_history.extend(completion_request.chat_history);
334
335        // Initialize full history with preamble (or empty if non-existent)
336        let mut full_history: Vec<Message> = completion_request
337            .preamble
338            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
339
340        // Convert and extend the rest of the history
341        full_history.extend(
342            partial_history
343                .into_iter()
344                .map(|msg| msg.try_into())
345                .collect::<Result<Vec<Message>, _>>()?,
346        );
347
348        // Convert internal prompt into a provider Message
349        let options = if let Some(extra) = completion_request.additional_params {
350            json_utils::merge(
351                json!({ "temperature": completion_request.temperature }),
352                extra,
353            )
354        } else {
355            json!({ "temperature": completion_request.temperature })
356        };
357
358        let mut request_payload = json!({
359            "model": self.model,
360            "messages": full_history,
361            "options": options,
362            "stream": false,
363        });
364        if !completion_request.tools.is_empty() {
365            request_payload["tools"] = json!(completion_request
366                .tools
367                .into_iter()
368                .map(|tool| tool.into())
369                .collect::<Vec<ToolDefinition>>());
370        }
371
372        tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
373
374        Ok(request_payload)
375    }
376}
377
378// ---------- CompletionModel Implementation ----------
379
380impl completion::CompletionModel for CompletionModel {
381    type Response = CompletionResponse;
382
383    #[cfg_attr(feature = "worker", worker::send)]
384    async fn completion(
385        &self,
386        completion_request: CompletionRequest,
387    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
388        let request_payload = self.create_completion_request(completion_request)?;
389
390        let response = self
391            .client
392            .post("api/chat")
393            .json(&request_payload)
394            .send()
395            .await
396            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
397        if response.status().is_success() {
398            let text = response
399                .text()
400                .await
401                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
402            tracing::debug!(target: "rig", "Ollama chat response: {}", text);
403            let chat_resp: CompletionResponse = serde_json::from_str(&text)
404                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
405            let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
406            Ok(conv)
407        } else {
408            let err_text = response
409                .text()
410                .await
411                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
412            Err(CompletionError::ProviderError(err_text))
413        }
414    }
415}
416
417impl StreamingCompletionModel for CompletionModel {
418    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
419        let mut request_payload = self.create_completion_request(request)?;
420        merge_inplace(&mut request_payload, json!({"stream": true}));
421
422        let response = self
423            .client
424            .post("api/chat")
425            .json(&request_payload)
426            .send()
427            .await
428            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
429
430        if !response.status().is_success() {
431            let err_text = response
432                .text()
433                .await
434                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
435            return Err(CompletionError::ProviderError(err_text));
436        }
437
438        Ok(Box::pin(stream! {
439            let mut stream = response.bytes_stream();
440            while let Some(chunk_result) = stream.next().await {
441                let chunk = match chunk_result {
442                    Ok(c) => c,
443                    Err(e) => {
444                        yield Err(CompletionError::from(e));
445                        break;
446                    }
447                };
448
449                let text = match String::from_utf8(chunk.to_vec()) {
450                    Ok(t) => t,
451                    Err(e) => {
452                        yield Err(CompletionError::ResponseError(e.to_string()));
453                        break;
454                    }
455                };
456
457
458                for line in text.lines() {
459                    let line = line.to_string();
460
461                    let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
462                        continue;
463                    };
464
465                    match response.message {
466                        Message::Assistant{ content, tool_calls, .. } => {
467                            if !content.is_empty() {
468                                yield Ok(StreamingChoice::Message(content))
469                            }
470
471                            for tool_call in tool_calls.iter() {
472                                let function = tool_call.function.clone();
473
474                                yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
475                            }
476                        }
477                        _ => {
478                            continue;
479                        }
480                    }
481                }
482            }
483        }))
484    }
485}
486
487// ---------- Tool Definition Conversion ----------
488
489/// Ollama-required tool definition format.
490#[derive(Clone, Debug, Deserialize, Serialize)]
491pub struct ToolDefinition {
492    #[serde(rename = "type")]
493    pub type_field: String, // Fixed as "function"
494    pub function: completion::ToolDefinition,
495}
496
497/// Convert internal ToolDefinition (from the completion module) into Ollama's tool definition.
498impl From<crate::completion::ToolDefinition> for ToolDefinition {
499    fn from(tool: crate::completion::ToolDefinition) -> Self {
500        ToolDefinition {
501            type_field: "function".to_owned(),
502            function: completion::ToolDefinition {
503                name: tool.name,
504                description: tool.description,
505                parameters: tool.parameters,
506            },
507        }
508    }
509}
510
511#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
512pub struct ToolCall {
513    // pub id: String,
514    #[serde(default, rename = "type")]
515    pub r#type: ToolType,
516    pub function: Function,
517}
518#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
519#[serde(rename_all = "lowercase")]
520pub enum ToolType {
521    #[default]
522    Function,
523}
524#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
525pub struct Function {
526    pub name: String,
527    pub arguments: Value,
528}
529
530// ---------- Provider Message Definition ----------
531
532#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
533#[serde(tag = "role", rename_all = "lowercase")]
534pub enum Message {
535    User {
536        content: String,
537        #[serde(skip_serializing_if = "Option::is_none")]
538        images: Option<Vec<String>>,
539        #[serde(skip_serializing_if = "Option::is_none")]
540        name: Option<String>,
541    },
542    Assistant {
543        #[serde(default)]
544        content: String,
545        #[serde(skip_serializing_if = "Option::is_none")]
546        images: Option<Vec<String>>,
547        #[serde(skip_serializing_if = "Option::is_none")]
548        name: Option<String>,
549        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
550        tool_calls: Vec<ToolCall>,
551    },
552    System {
553        content: String,
554        #[serde(skip_serializing_if = "Option::is_none")]
555        images: Option<Vec<String>>,
556        #[serde(skip_serializing_if = "Option::is_none")]
557        name: Option<String>,
558    },
559    #[serde(rename = "Tool")]
560    ToolResult {
561        tool_call_id: String,
562        content: OneOrMany<ToolResultContent>,
563    },
564}
565
566/// -----------------------------
567/// Provider Message Conversions
568/// -----------------------------
569/// Conversion from an internal Rig message (crate::message::Message) to a provider Message.
570/// (Only User and Assistant variants are supported.)
571impl TryFrom<crate::message::Message> for Message {
572    type Error = crate::message::MessageError;
573    fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
574        use crate::message::Message as InternalMessage;
575        match internal_msg {
576            InternalMessage::User { content, .. } => {
577                let mut texts = Vec::new();
578                let mut images = Vec::new();
579                for uc in content.into_iter() {
580                    match uc {
581                        crate::message::UserContent::Text(t) => texts.push(t.text),
582                        crate::message::UserContent::Image(img) => images.push(img.data),
583                        _ => {} // Audio variant removed since Ollama API does not support it.
584                    }
585                }
586                let content_str = texts.join(" ");
587                let images_opt = if images.is_empty() {
588                    None
589                } else {
590                    Some(images)
591                };
592                Ok(Message::User {
593                    content: content_str,
594                    images: images_opt,
595                    name: None,
596                })
597            }
598            InternalMessage::Assistant { content, .. } => {
599                let mut texts = Vec::new();
600                let mut tool_calls = Vec::new();
601                for ac in content.into_iter() {
602                    match ac {
603                        crate::message::AssistantContent::Text(t) => texts.push(t.text),
604                        crate::message::AssistantContent::ToolCall(tc) => {
605                            tool_calls.push(ToolCall {
606                                r#type: ToolType::Function, // Assuming internal tool call provides these fields
607                                function: Function {
608                                    name: tc.function.name,
609                                    arguments: tc.function.arguments,
610                                },
611                            });
612                        }
613                    }
614                }
615                let content_str = texts.join(" ");
616                Ok(Message::Assistant {
617                    content: content_str,
618                    images: None,
619                    name: None,
620                    tool_calls,
621                })
622            }
623        }
624    }
625}
626
627/// Conversion from provider Message to a completion message.
628/// This is needed so that responses can be converted back into chat history.
629impl From<Message> for crate::completion::Message {
630    fn from(msg: Message) -> Self {
631        match msg {
632            Message::User { content, .. } => crate::completion::Message::User {
633                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
634                    text: content,
635                })),
636            },
637            Message::Assistant {
638                content,
639                tool_calls,
640                ..
641            } => {
642                let mut assistant_contents =
643                    vec![crate::completion::message::AssistantContent::Text(Text {
644                        text: content,
645                    })];
646                for tc in tool_calls {
647                    assistant_contents.push(
648                        crate::completion::message::AssistantContent::tool_call(
649                            tc.function.name.clone(),
650                            tc.function.name,
651                            tc.function.arguments,
652                        ),
653                    );
654                }
655                crate::completion::Message::Assistant {
656                    content: OneOrMany::many(assistant_contents).unwrap(),
657                }
658            }
659            // System and ToolResult are converted to User message as needed.
660            Message::System { content, .. } => crate::completion::Message::User {
661                content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
662                    text: content,
663                })),
664            },
665            Message::ToolResult {
666                tool_call_id,
667                content,
668            } => crate::completion::Message::User {
669                content: OneOrMany::one(message::UserContent::tool_result(
670                    tool_call_id,
671                    content.map(|content| message::ToolResultContent::text(content.text)),
672                )),
673            },
674        }
675    }
676}
677
678impl Message {
679    /// Constructs a system message.
680    pub fn system(content: &str) -> Self {
681        Message::System {
682            content: content.to_owned(),
683            images: None,
684            name: None,
685        }
686    }
687}
688
689// ---------- Additional Message Types ----------
690
691#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
692pub struct ToolResultContent {
693    text: String,
694}
695
696impl FromStr for ToolResultContent {
697    type Err = Infallible;
698
699    fn from_str(s: &str) -> Result<Self, Self::Err> {
700        Ok(s.to_owned().into())
701    }
702}
703
704impl From<String> for ToolResultContent {
705    fn from(s: String) -> Self {
706        ToolResultContent { text: s }
707    }
708}
709
710#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
711pub struct SystemContent {
712    #[serde(default)]
713    r#type: SystemContentType,
714    text: String,
715}
716
717#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
718#[serde(rename_all = "lowercase")]
719pub enum SystemContentType {
720    #[default]
721    Text,
722}
723
724impl From<String> for SystemContent {
725    fn from(s: String) -> Self {
726        SystemContent {
727            r#type: SystemContentType::default(),
728            text: s,
729        }
730    }
731}
732
733impl FromStr for SystemContent {
734    type Err = std::convert::Infallible;
735    fn from_str(s: &str) -> Result<Self, Self::Err> {
736        Ok(SystemContent {
737            r#type: SystemContentType::default(),
738            text: s.to_string(),
739        })
740    }
741}
742
743#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
744pub struct AssistantContent {
745    pub text: String,
746}
747
748impl FromStr for AssistantContent {
749    type Err = std::convert::Infallible;
750    fn from_str(s: &str) -> Result<Self, Self::Err> {
751        Ok(AssistantContent { text: s.to_owned() })
752    }
753}
754
755#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
756#[serde(tag = "type", rename_all = "lowercase")]
757pub enum UserContent {
758    Text { text: String },
759    Image { image_url: ImageUrl },
760    // Audio variant removed as Ollama API does not support audio input.
761}
762
763impl FromStr for UserContent {
764    type Err = std::convert::Infallible;
765    fn from_str(s: &str) -> Result<Self, Self::Err> {
766        Ok(UserContent::Text { text: s.to_owned() })
767    }
768}
769
770#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
771pub struct ImageUrl {
772    pub url: String,
773    #[serde(default)]
774    pub detail: ImageDetail,
775}
776
777// =================================================================
778// Tests
779// =================================================================
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784    use serde_json::json;
785
786    // Test deserialization and conversion for the /api/chat endpoint.
787    #[tokio::test]
788    async fn test_chat_completion() {
789        // Sample JSON response from /api/chat (non-streaming) based on Ollama docs.
790        let sample_chat_response = json!({
791            "model": "llama3.2",
792            "created_at": "2023-08-04T19:22:45.499127Z",
793            "message": {
794                "role": "assistant",
795                "content": "The sky is blue because of Rayleigh scattering.",
796                "images": null,
797                "tool_calls": [
798                    {
799                        "type": "function",
800                        "function": {
801                            "name": "get_current_weather",
802                            "arguments": {
803                                "location": "San Francisco, CA",
804                                "format": "celsius"
805                            }
806                        }
807                    }
808                ]
809            },
810            "done": true,
811            "total_duration": 8000000000u64,
812            "load_duration": 6000000u64,
813            "prompt_eval_count": 61u64,
814            "prompt_eval_duration": 400000000u64,
815            "eval_count": 468u64,
816            "eval_duration": 7700000000u64
817        });
818        let sample_text = sample_chat_response.to_string();
819
820        let chat_resp: CompletionResponse =
821            serde_json::from_str(&sample_text).expect("Invalid JSON structure");
822        let conv: completion::CompletionResponse<CompletionResponse> =
823            chat_resp.try_into().unwrap();
824        assert!(
825            !conv.choice.is_empty(),
826            "Expected non-empty choice in chat response"
827        );
828    }
829
830    // Test conversion from provider Message to completion Message.
831    #[test]
832    fn test_message_conversion() {
833        // Construct a provider Message (User variant with String content).
834        let provider_msg = Message::User {
835            content: "Test message".to_owned(),
836            images: None,
837            name: None,
838        };
839        // Convert it into a completion::Message.
840        let comp_msg: crate::completion::Message = provider_msg.into();
841        match comp_msg {
842            crate::completion::Message::User { content } => {
843                // Assume OneOrMany<T> has a method first() to access the first element.
844                let first_content = content.first();
845                // The expected type is crate::completion::message::UserContent::Text wrapping a Text struct.
846                match first_content {
847                    crate::completion::message::UserContent::Text(text_struct) => {
848                        assert_eq!(text_struct.text, "Test message");
849                    }
850                    _ => panic!("Expected text content in conversion"),
851                }
852            }
853            _ => panic!("Conversion from provider Message to completion Message failed"),
854        }
855    }
856
857    // Test conversion of internal tool definition to Ollama's ToolDefinition format.
858    #[test]
859    fn test_tool_definition_conversion() {
860        // Internal tool definition from the completion module.
861        let internal_tool = crate::completion::ToolDefinition {
862            name: "get_current_weather".to_owned(),
863            description: "Get the current weather for a location".to_owned(),
864            parameters: json!({
865                "type": "object",
866                "properties": {
867                    "location": {
868                        "type": "string",
869                        "description": "The location to get the weather for, e.g. San Francisco, CA"
870                    },
871                    "format": {
872                        "type": "string",
873                        "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
874                        "enum": ["celsius", "fahrenheit"]
875                    }
876                },
877                "required": ["location", "format"]
878            }),
879        };
880        // Convert internal tool to Ollama's tool definition.
881        let ollama_tool: ToolDefinition = internal_tool.into();
882        assert_eq!(ollama_tool.type_field, "function");
883        assert_eq!(ollama_tool.function.name, "get_current_weather");
884        assert_eq!(
885            ollama_tool.function.description,
886            "Get the current weather for a location"
887        );
888        // Check JSON fields in parameters.
889        let params = &ollama_tool.function.parameters;
890        assert_eq!(params["properties"]["location"]["type"], "string");
891    }
892}