1use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::json_utils::merge_inplace;
43use crate::message::MessageError;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46 Embed, OneOrMany,
47 completion::{self, CompletionError, CompletionRequest},
48 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49 impl_conversion_traits, json_utils, message,
50 message::{ImageDetail, Text},
51 streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::convert::Infallible;
59use std::{convert::TryFrom, str::FromStr};
60const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64#[derive(Clone, Debug)]
65pub struct Client {
66 base_url: String,
67 http_client: reqwest::Client,
68}
69
70impl Default for Client {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Client {
77 pub fn new() -> Self {
78 Self::from_url(OLLAMA_API_BASE_URL)
79 }
80 pub fn from_url(base_url: &str) -> Self {
81 Self {
82 base_url: base_url.to_owned(),
83 http_client: reqwest::Client::builder()
84 .build()
85 .expect("Ollama reqwest client should build"),
86 }
87 }
88
89 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
92 self.http_client = client;
93
94 self
95 }
96
97 fn post(&self, path: &str) -> reqwest::RequestBuilder {
98 let url = format!("{}/{}", self.base_url, path);
99 self.http_client.post(url)
100 }
101}
102
103impl ProviderClient for Client {
104 fn from_env() -> Self
105 where
106 Self: Sized,
107 {
108 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
109 Self::from_url(&api_base)
110 }
111}
112
113impl CompletionClient for Client {
114 type CompletionModel = CompletionModel;
115
116 fn completion_model(&self, model: &str) -> CompletionModel {
117 CompletionModel::new(self.clone(), model)
118 }
119}
120
121impl EmbeddingsClient for Client {
122 type EmbeddingModel = EmbeddingModel;
123 fn embedding_model(&self, model: &str) -> EmbeddingModel {
124 EmbeddingModel::new(self.clone(), model, 0)
125 }
126 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
127 EmbeddingModel::new(self.clone(), model, ndims)
128 }
129 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
130 EmbeddingsBuilder::new(self.embedding_model(model))
131 }
132}
133
134impl_conversion_traits!(
135 AsTranscription,
136 AsImageGeneration,
137 AsAudioGeneration for Client
138);
139
140#[derive(Debug, Deserialize)]
143struct ApiErrorResponse {
144 message: String,
145}
146
147#[derive(Debug, Deserialize)]
148#[serde(untagged)]
149enum ApiResponse<T> {
150 Ok(T),
151 Err(ApiErrorResponse),
152}
153
154pub const ALL_MINILM: &str = "all-minilm";
157pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
158
159#[derive(Debug, Serialize, Deserialize)]
160pub struct EmbeddingResponse {
161 pub model: String,
162 pub embeddings: Vec<Vec<f64>>,
163 #[serde(default)]
164 pub total_duration: Option<u64>,
165 #[serde(default)]
166 pub load_duration: Option<u64>,
167 #[serde(default)]
168 pub prompt_eval_count: Option<u64>,
169}
170
171impl From<ApiErrorResponse> for EmbeddingError {
172 fn from(err: ApiErrorResponse) -> Self {
173 EmbeddingError::ProviderError(err.message)
174 }
175}
176
177impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
178 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
179 match value {
180 ApiResponse::Ok(response) => Ok(response),
181 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
182 }
183 }
184}
185
186#[derive(Clone)]
189pub struct EmbeddingModel {
190 client: Client,
191 pub model: String,
192 ndims: usize,
193}
194
195impl EmbeddingModel {
196 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
197 Self {
198 client,
199 model: model.to_owned(),
200 ndims,
201 }
202 }
203}
204
205impl embeddings::EmbeddingModel for EmbeddingModel {
206 const MAX_DOCUMENTS: usize = 1024;
207 fn ndims(&self) -> usize {
208 self.ndims
209 }
210 #[cfg_attr(feature = "worker", worker::send)]
211 async fn embed_texts(
212 &self,
213 documents: impl IntoIterator<Item = String>,
214 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
215 let docs: Vec<String> = documents.into_iter().collect();
216 let payload = json!({
217 "model": self.model,
218 "input": docs,
219 });
220 let response = self
221 .client
222 .post("api/embed")
223 .json(&payload)
224 .send()
225 .await
226 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
227 if response.status().is_success() {
228 let api_resp: EmbeddingResponse = response
229 .json()
230 .await
231 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
232 if api_resp.embeddings.len() != docs.len() {
233 return Err(EmbeddingError::ResponseError(
234 "Number of returned embeddings does not match input".into(),
235 ));
236 }
237 Ok(api_resp
238 .embeddings
239 .into_iter()
240 .zip(docs.into_iter())
241 .map(|(vec, document)| embeddings::Embedding { document, vec })
242 .collect())
243 } else {
244 Err(EmbeddingError::ProviderError(response.text().await?))
245 }
246 }
247}
248
249pub const LLAMA3_2: &str = "llama3.2";
252pub const LLAVA: &str = "llava";
253pub const MISTRAL: &str = "mistral";
254
255#[derive(Debug, Serialize, Deserialize)]
256pub struct CompletionResponse {
257 pub model: String,
258 pub created_at: String,
259 pub message: Message,
260 pub done: bool,
261 #[serde(default)]
262 pub done_reason: Option<String>,
263 #[serde(default)]
264 pub total_duration: Option<u64>,
265 #[serde(default)]
266 pub load_duration: Option<u64>,
267 #[serde(default)]
268 pub prompt_eval_count: Option<u64>,
269 #[serde(default)]
270 pub prompt_eval_duration: Option<u64>,
271 #[serde(default)]
272 pub eval_count: Option<u64>,
273 #[serde(default)]
274 pub eval_duration: Option<u64>,
275}
276impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
277 type Error = CompletionError;
278 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
279 match resp.message {
280 Message::Assistant {
282 content,
283 tool_calls,
284 ..
285 } => {
286 let mut assistant_contents = Vec::new();
287 if !content.is_empty() {
289 assistant_contents.push(completion::AssistantContent::text(&content));
290 }
291 for tc in tool_calls.iter() {
294 assistant_contents.push(completion::AssistantContent::tool_call(
295 tc.function.name.clone(),
296 tc.function.name.clone(),
297 tc.function.arguments.clone(),
298 ));
299 }
300 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
301 CompletionError::ResponseError("No content provided".to_owned())
302 })?;
303 let raw_response = CompletionResponse {
304 model: resp.model,
305 created_at: resp.created_at,
306 done: resp.done,
307 done_reason: resp.done_reason,
308 total_duration: resp.total_duration,
309 load_duration: resp.load_duration,
310 prompt_eval_count: resp.prompt_eval_count,
311 prompt_eval_duration: resp.prompt_eval_duration,
312 eval_count: resp.eval_count,
313 eval_duration: resp.eval_duration,
314 message: Message::Assistant {
315 content,
316 images: None,
317 name: None,
318 tool_calls,
319 },
320 };
321 Ok(completion::CompletionResponse {
322 choice,
323 raw_response,
324 })
325 }
326 _ => Err(CompletionError::ResponseError(
327 "Chat response does not include an assistant message".into(),
328 )),
329 }
330 }
331}
332
333#[derive(Clone)]
336pub struct CompletionModel {
337 client: Client,
338 pub model: String,
339}
340
341impl CompletionModel {
342 pub fn new(client: Client, model: &str) -> Self {
343 Self {
344 client,
345 model: model.to_owned(),
346 }
347 }
348
349 fn create_completion_request(
350 &self,
351 completion_request: CompletionRequest,
352 ) -> Result<Value, CompletionError> {
353 let mut partial_history = vec![];
355 if let Some(docs) = completion_request.normalized_documents() {
356 partial_history.push(docs);
357 }
358 partial_history.extend(completion_request.chat_history);
359
360 let mut full_history: Vec<Message> = completion_request
362 .preamble
363 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
364
365 full_history.extend(
367 partial_history
368 .into_iter()
369 .map(|msg| msg.try_into())
370 .collect::<Result<Vec<Message>, _>>()?,
371 );
372
373 let options = if let Some(extra) = completion_request.additional_params {
375 json_utils::merge(
376 json!({ "temperature": completion_request.temperature }),
377 extra,
378 )
379 } else {
380 json!({ "temperature": completion_request.temperature })
381 };
382
383 let mut request_payload = json!({
384 "model": self.model,
385 "messages": full_history,
386 "options": options,
387 "stream": false,
388 });
389 if !completion_request.tools.is_empty() {
390 request_payload["tools"] = json!(
391 completion_request
392 .tools
393 .into_iter()
394 .map(|tool| tool.into())
395 .collect::<Vec<ToolDefinition>>()
396 );
397 }
398
399 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
400
401 Ok(request_payload)
402 }
403}
404
405#[derive(Clone)]
408pub struct StreamingCompletionResponse {
409 pub done_reason: Option<String>,
410 pub total_duration: Option<u64>,
411 pub load_duration: Option<u64>,
412 pub prompt_eval_count: Option<u64>,
413 pub prompt_eval_duration: Option<u64>,
414 pub eval_count: Option<u64>,
415 pub eval_duration: Option<u64>,
416}
417impl completion::CompletionModel for CompletionModel {
418 type Response = CompletionResponse;
419 type StreamingResponse = StreamingCompletionResponse;
420
421 #[cfg_attr(feature = "worker", worker::send)]
422 async fn completion(
423 &self,
424 completion_request: CompletionRequest,
425 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
426 let request_payload = self.create_completion_request(completion_request)?;
427
428 let response = self
429 .client
430 .post("api/chat")
431 .json(&request_payload)
432 .send()
433 .await
434 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
435 if response.status().is_success() {
436 let text = response
437 .text()
438 .await
439 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
440 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
441 let chat_resp: CompletionResponse = serde_json::from_str(&text)
442 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
443 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
444 Ok(conv)
445 } else {
446 let err_text = response
447 .text()
448 .await
449 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
450 Err(CompletionError::ProviderError(err_text))
451 }
452 }
453
454 #[cfg_attr(feature = "worker", worker::send)]
455 async fn stream(
456 &self,
457 request: CompletionRequest,
458 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
459 {
460 let mut request_payload = self.create_completion_request(request)?;
461 merge_inplace(&mut request_payload, json!({"stream": true}));
462
463 let response = self
464 .client
465 .post("api/chat")
466 .json(&request_payload)
467 .send()
468 .await
469 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
470
471 if !response.status().is_success() {
472 let err_text = response
473 .text()
474 .await
475 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
476 return Err(CompletionError::ProviderError(err_text));
477 }
478
479 let stream = Box::pin(stream! {
480 let mut stream = response.bytes_stream();
481 while let Some(chunk_result) = stream.next().await {
482 let chunk = match chunk_result {
483 Ok(c) => c,
484 Err(e) => {
485 yield Err(CompletionError::from(e));
486 break;
487 }
488 };
489
490 let text = match String::from_utf8(chunk.to_vec()) {
491 Ok(t) => t,
492 Err(e) => {
493 yield Err(CompletionError::ResponseError(e.to_string()));
494 break;
495 }
496 };
497
498
499 for line in text.lines() {
500 let line = line.to_string();
501
502 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
503 continue;
504 };
505
506 match response.message {
507 Message::Assistant{ content, tool_calls, .. } => {
508 if !content.is_empty() {
509 yield Ok(RawStreamingChoice::Message(content))
510 }
511
512 for tool_call in tool_calls.iter() {
513 let function = tool_call.function.clone();
514
515 yield Ok(RawStreamingChoice::ToolCall {
516 id: "".to_string(),
517 name: function.name,
518 arguments: function.arguments,
519 call_id: None
520 });
521 }
522 }
523 _ => {
524 continue;
525 }
526 }
527
528 if response.done {
529 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
530 total_duration: response.total_duration,
531 load_duration: response.load_duration,
532 prompt_eval_count: response.prompt_eval_count,
533 prompt_eval_duration: response.prompt_eval_duration,
534 eval_count: response.eval_count,
535 eval_duration: response.eval_duration,
536 done_reason: response.done_reason,
537 }));
538 }
539 }
540 }
541 });
542
543 Ok(streaming::StreamingCompletionResponse::stream(stream))
544 }
545}
546
547#[derive(Clone, Debug, Deserialize, Serialize)]
551pub struct ToolDefinition {
552 #[serde(rename = "type")]
553 pub type_field: String, pub function: completion::ToolDefinition,
555}
556
557impl From<crate::completion::ToolDefinition> for ToolDefinition {
559 fn from(tool: crate::completion::ToolDefinition) -> Self {
560 ToolDefinition {
561 type_field: "function".to_owned(),
562 function: completion::ToolDefinition {
563 name: tool.name,
564 description: tool.description,
565 parameters: tool.parameters,
566 },
567 }
568 }
569}
570
571#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
572pub struct ToolCall {
573 #[serde(default, rename = "type")]
575 pub r#type: ToolType,
576 pub function: Function,
577}
578#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
579#[serde(rename_all = "lowercase")]
580pub enum ToolType {
581 #[default]
582 Function,
583}
584#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
585pub struct Function {
586 pub name: String,
587 pub arguments: Value,
588}
589
590#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
593#[serde(tag = "role", rename_all = "lowercase")]
594pub enum Message {
595 User {
596 content: String,
597 #[serde(skip_serializing_if = "Option::is_none")]
598 images: Option<Vec<String>>,
599 #[serde(skip_serializing_if = "Option::is_none")]
600 name: Option<String>,
601 },
602 Assistant {
603 #[serde(default)]
604 content: String,
605 #[serde(skip_serializing_if = "Option::is_none")]
606 images: Option<Vec<String>>,
607 #[serde(skip_serializing_if = "Option::is_none")]
608 name: Option<String>,
609 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
610 tool_calls: Vec<ToolCall>,
611 },
612 System {
613 content: String,
614 #[serde(skip_serializing_if = "Option::is_none")]
615 images: Option<Vec<String>>,
616 #[serde(skip_serializing_if = "Option::is_none")]
617 name: Option<String>,
618 },
619 #[serde(rename = "tool")]
620 ToolResult { name: String, content: String },
621}
622
623impl TryFrom<crate::message::Message> for Message {
629 type Error = crate::message::MessageError;
630 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
631 use crate::message::Message as InternalMessage;
632 match internal_msg {
633 InternalMessage::User { content, .. } => {
634 let mut texts = Vec::new();
635 let mut images = Vec::new();
636 for uc in content.into_iter() {
637 match uc {
638 crate::message::UserContent::Text(t) => texts.push(t.text),
639 crate::message::UserContent::Image(img) => images.push(img.data),
640 crate::message::UserContent::ToolResult(result) => {
641 let content = result
642 .content
643 .into_iter()
644 .map(ToolResultContent::try_from)
645 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
646
647 let content = OneOrMany::many(content).map_err(|x| {
648 MessageError::ConversionError(format!(
649 "Couldn't make a OneOrMany from a list of tool results: {x}"
650 ))
651 })?;
652
653 return Ok(Message::ToolResult {
654 name: result.id,
655 content: content.first().text,
656 });
657 }
658 _ => {} }
660 }
661 let content_str = texts.join(" ");
662 let images_opt = if images.is_empty() {
663 None
664 } else {
665 Some(images)
666 };
667 Ok(Message::User {
668 content: content_str,
669 images: images_opt,
670 name: None,
671 })
672 }
673 InternalMessage::Assistant { content, .. } => {
674 let mut texts = Vec::new();
675 let mut tool_calls = Vec::new();
676 for ac in content.into_iter() {
677 match ac {
678 crate::message::AssistantContent::Text(t) => texts.push(t.text),
679 crate::message::AssistantContent::ToolCall(tc) => {
680 tool_calls.push(ToolCall {
681 r#type: ToolType::Function, function: Function {
683 name: tc.function.name,
684 arguments: tc.function.arguments,
685 },
686 });
687 }
688 }
689 }
690 let content_str = texts.join(" ");
691 Ok(Message::Assistant {
692 content: content_str,
693 images: None,
694 name: None,
695 tool_calls,
696 })
697 }
698 }
699 }
700}
701
702impl From<Message> for crate::completion::Message {
705 fn from(msg: Message) -> Self {
706 match msg {
707 Message::User { content, .. } => crate::completion::Message::User {
708 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
709 text: content,
710 })),
711 },
712 Message::Assistant {
713 content,
714 tool_calls,
715 ..
716 } => {
717 let mut assistant_contents =
718 vec![crate::completion::message::AssistantContent::Text(Text {
719 text: content,
720 })];
721 for tc in tool_calls {
722 assistant_contents.push(
723 crate::completion::message::AssistantContent::tool_call(
724 tc.function.name.clone(),
725 tc.function.name,
726 tc.function.arguments,
727 ),
728 );
729 }
730 crate::completion::Message::Assistant {
731 id: None,
732 content: OneOrMany::many(assistant_contents).unwrap(),
733 }
734 }
735 Message::System { content, .. } => crate::completion::Message::User {
737 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
738 text: content,
739 })),
740 },
741 Message::ToolResult { name, content } => crate::completion::Message::User {
742 content: OneOrMany::one(message::UserContent::tool_result(
743 name,
744 OneOrMany::one(message::ToolResultContent::Text(Text { text: content })),
745 )),
746 },
747 }
748 }
749}
750
751impl Message {
752 pub fn system(content: &str) -> Self {
754 Message::System {
755 content: content.to_owned(),
756 images: None,
757 name: None,
758 }
759 }
760}
761
762#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
765pub struct ToolResultContent {
766 text: String,
767}
768
769impl TryFrom<crate::message::ToolResultContent> for ToolResultContent {
770 type Error = MessageError;
771 fn try_from(value: crate::message::ToolResultContent) -> Result<Self, Self::Error> {
772 let crate::message::ToolResultContent::Text(Text { text }) = value else {
773 return Err(MessageError::ConversionError(
774 "Non-text tool results not supported".into(),
775 ));
776 };
777
778 Ok(Self { text })
779 }
780}
781
782impl FromStr for ToolResultContent {
783 type Err = Infallible;
784
785 fn from_str(s: &str) -> Result<Self, Self::Err> {
786 Ok(s.to_owned().into())
787 }
788}
789
790impl From<String> for ToolResultContent {
791 fn from(s: String) -> Self {
792 ToolResultContent { text: s }
793 }
794}
795
796#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
797pub struct SystemContent {
798 #[serde(default)]
799 r#type: SystemContentType,
800 text: String,
801}
802
803#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
804#[serde(rename_all = "lowercase")]
805pub enum SystemContentType {
806 #[default]
807 Text,
808}
809
810impl From<String> for SystemContent {
811 fn from(s: String) -> Self {
812 SystemContent {
813 r#type: SystemContentType::default(),
814 text: s,
815 }
816 }
817}
818
819impl FromStr for SystemContent {
820 type Err = std::convert::Infallible;
821 fn from_str(s: &str) -> Result<Self, Self::Err> {
822 Ok(SystemContent {
823 r#type: SystemContentType::default(),
824 text: s.to_string(),
825 })
826 }
827}
828
829#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
830pub struct AssistantContent {
831 pub text: String,
832}
833
834impl FromStr for AssistantContent {
835 type Err = std::convert::Infallible;
836 fn from_str(s: &str) -> Result<Self, Self::Err> {
837 Ok(AssistantContent { text: s.to_owned() })
838 }
839}
840
841#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
842#[serde(tag = "type", rename_all = "lowercase")]
843pub enum UserContent {
844 Text { text: String },
845 Image { image_url: ImageUrl },
846 }
848
849impl FromStr for UserContent {
850 type Err = std::convert::Infallible;
851 fn from_str(s: &str) -> Result<Self, Self::Err> {
852 Ok(UserContent::Text { text: s.to_owned() })
853 }
854}
855
856#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
857pub struct ImageUrl {
858 pub url: String,
859 #[serde(default)]
860 pub detail: ImageDetail,
861}
862
863#[cfg(test)]
868mod tests {
869 use super::*;
870 use serde_json::json;
871
872 #[tokio::test]
874 async fn test_chat_completion() {
875 let sample_chat_response = json!({
877 "model": "llama3.2",
878 "created_at": "2023-08-04T19:22:45.499127Z",
879 "message": {
880 "role": "assistant",
881 "content": "The sky is blue because of Rayleigh scattering.",
882 "images": null,
883 "tool_calls": [
884 {
885 "type": "function",
886 "function": {
887 "name": "get_current_weather",
888 "arguments": {
889 "location": "San Francisco, CA",
890 "format": "celsius"
891 }
892 }
893 }
894 ]
895 },
896 "done": true,
897 "total_duration": 8000000000u64,
898 "load_duration": 6000000u64,
899 "prompt_eval_count": 61u64,
900 "prompt_eval_duration": 400000000u64,
901 "eval_count": 468u64,
902 "eval_duration": 7700000000u64
903 });
904 let sample_text = sample_chat_response.to_string();
905
906 let chat_resp: CompletionResponse =
907 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
908 let conv: completion::CompletionResponse<CompletionResponse> =
909 chat_resp.try_into().unwrap();
910 assert!(
911 !conv.choice.is_empty(),
912 "Expected non-empty choice in chat response"
913 );
914 }
915
916 #[test]
918 fn test_message_conversion() {
919 let provider_msg = Message::User {
921 content: "Test message".to_owned(),
922 images: None,
923 name: None,
924 };
925 let comp_msg: crate::completion::Message = provider_msg.into();
927 match comp_msg {
928 crate::completion::Message::User { content } => {
929 let first_content = content.first();
931 match first_content {
933 crate::completion::message::UserContent::Text(text_struct) => {
934 assert_eq!(text_struct.text, "Test message");
935 }
936 _ => panic!("Expected text content in conversion"),
937 }
938 }
939 _ => panic!("Conversion from provider Message to completion Message failed"),
940 }
941 }
942
943 #[test]
945 fn test_tool_definition_conversion() {
946 let internal_tool = crate::completion::ToolDefinition {
948 name: "get_current_weather".to_owned(),
949 description: "Get the current weather for a location".to_owned(),
950 parameters: json!({
951 "type": "object",
952 "properties": {
953 "location": {
954 "type": "string",
955 "description": "The location to get the weather for, e.g. San Francisco, CA"
956 },
957 "format": {
958 "type": "string",
959 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
960 "enum": ["celsius", "fahrenheit"]
961 }
962 },
963 "required": ["location", "format"]
964 }),
965 };
966 let ollama_tool: ToolDefinition = internal_tool.into();
968 assert_eq!(ollama_tool.type_field, "function");
969 assert_eq!(ollama_tool.function.name, "get_current_weather");
970 assert_eq!(
971 ollama_tool.function.description,
972 "Get the current weather for a location"
973 );
974 let params = &ollama_tool.function.parameters;
976 assert_eq!(params["properties"]["location"]["type"], "string");
977 }
978}