1use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::json_utils::merge_inplace;
43use crate::message::MessageError;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46 completion::{self, CompletionError, CompletionRequest},
47 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
48 impl_conversion_traits, json_utils, message,
49 message::{ImageDetail, Text},
50 streaming, Embed, OneOrMany,
51};
52use async_stream::stream;
53use futures::StreamExt;
54use reqwest;
55use serde::{Deserialize, Serialize};
56use serde_json::{json, Value};
57use std::convert::Infallible;
58use std::{convert::TryFrom, str::FromStr};
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
62
63#[derive(Clone, Debug)]
64pub struct Client {
65 base_url: String,
66 http_client: reqwest::Client,
67}
68
69impl Default for Client {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Client {
76 pub fn new() -> Self {
77 Self::from_url(OLLAMA_API_BASE_URL)
78 }
79 pub fn from_url(base_url: &str) -> Self {
80 Self {
81 base_url: base_url.to_owned(),
82 http_client: reqwest::Client::builder()
83 .build()
84 .expect("Ollama reqwest client should build"),
85 }
86 }
87 fn post(&self, path: &str) -> reqwest::RequestBuilder {
88 let url = format!("{}/{}", self.base_url, path);
89 self.http_client.post(url)
90 }
91}
92
93impl ProviderClient for Client {
94 fn from_env() -> Self
95 where
96 Self: Sized,
97 {
98 Client::default()
99 }
100}
101
102impl CompletionClient for Client {
103 type CompletionModel = CompletionModel;
104
105 fn completion_model(&self, model: &str) -> CompletionModel {
106 CompletionModel::new(self.clone(), model)
107 }
108}
109
110impl EmbeddingsClient for Client {
111 type EmbeddingModel = EmbeddingModel;
112 fn embedding_model(&self, model: &str) -> EmbeddingModel {
113 EmbeddingModel::new(self.clone(), model, 0)
114 }
115 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
116 EmbeddingModel::new(self.clone(), model, ndims)
117 }
118 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
119 EmbeddingsBuilder::new(self.embedding_model(model))
120 }
121}
122
123impl_conversion_traits!(
124 AsTranscription,
125 AsImageGeneration,
126 AsAudioGeneration for Client
127);
128
129#[derive(Debug, Deserialize)]
132struct ApiErrorResponse {
133 message: String,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde(untagged)]
138enum ApiResponse<T> {
139 Ok(T),
140 Err(ApiErrorResponse),
141}
142
143pub const ALL_MINILM: &str = "all-minilm";
146pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct EmbeddingResponse {
150 pub model: String,
151 pub embeddings: Vec<Vec<f64>>,
152 #[serde(default)]
153 pub total_duration: Option<u64>,
154 #[serde(default)]
155 pub load_duration: Option<u64>,
156 #[serde(default)]
157 pub prompt_eval_count: Option<u64>,
158}
159
160impl From<ApiErrorResponse> for EmbeddingError {
161 fn from(err: ApiErrorResponse) -> Self {
162 EmbeddingError::ProviderError(err.message)
163 }
164}
165
166impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
167 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
168 match value {
169 ApiResponse::Ok(response) => Ok(response),
170 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
171 }
172 }
173}
174
175#[derive(Clone)]
178pub struct EmbeddingModel {
179 client: Client,
180 pub model: String,
181 ndims: usize,
182}
183
184impl EmbeddingModel {
185 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
186 Self {
187 client,
188 model: model.to_owned(),
189 ndims,
190 }
191 }
192}
193
194impl embeddings::EmbeddingModel for EmbeddingModel {
195 const MAX_DOCUMENTS: usize = 1024;
196 fn ndims(&self) -> usize {
197 self.ndims
198 }
199 #[cfg_attr(feature = "worker", worker::send)]
200 async fn embed_texts(
201 &self,
202 documents: impl IntoIterator<Item = String>,
203 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
204 let docs: Vec<String> = documents.into_iter().collect();
205 let payload = json!({
206 "model": self.model,
207 "input": docs,
208 });
209 let response = self
210 .client
211 .post("api/embed")
212 .json(&payload)
213 .send()
214 .await
215 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
216 if response.status().is_success() {
217 let api_resp: EmbeddingResponse = response
218 .json()
219 .await
220 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
221 if api_resp.embeddings.len() != docs.len() {
222 return Err(EmbeddingError::ResponseError(
223 "Number of returned embeddings does not match input".into(),
224 ));
225 }
226 Ok(api_resp
227 .embeddings
228 .into_iter()
229 .zip(docs.into_iter())
230 .map(|(vec, document)| embeddings::Embedding { document, vec })
231 .collect())
232 } else {
233 Err(EmbeddingError::ProviderError(response.text().await?))
234 }
235 }
236}
237
238pub const LLAMA3_2: &str = "llama3.2";
241pub const LLAVA: &str = "llava";
242pub const MISTRAL: &str = "mistral";
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct CompletionResponse {
246 pub model: String,
247 pub created_at: String,
248 pub message: Message,
249 pub done: bool,
250 #[serde(default)]
251 pub done_reason: Option<String>,
252 #[serde(default)]
253 pub total_duration: Option<u64>,
254 #[serde(default)]
255 pub load_duration: Option<u64>,
256 #[serde(default)]
257 pub prompt_eval_count: Option<u64>,
258 #[serde(default)]
259 pub prompt_eval_duration: Option<u64>,
260 #[serde(default)]
261 pub eval_count: Option<u64>,
262 #[serde(default)]
263 pub eval_duration: Option<u64>,
264}
265impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
266 type Error = CompletionError;
267 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
268 match resp.message {
269 Message::Assistant {
271 content,
272 tool_calls,
273 ..
274 } => {
275 let mut assistant_contents = Vec::new();
276 if !content.is_empty() {
278 assistant_contents.push(completion::AssistantContent::text(&content));
279 }
280 for tc in tool_calls.iter() {
283 assistant_contents.push(completion::AssistantContent::tool_call(
284 tc.function.name.clone(),
285 tc.function.name.clone(),
286 tc.function.arguments.clone(),
287 ));
288 }
289 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
290 CompletionError::ResponseError("No content provided".to_owned())
291 })?;
292 let raw_response = CompletionResponse {
293 model: resp.model,
294 created_at: resp.created_at,
295 done: resp.done,
296 done_reason: resp.done_reason,
297 total_duration: resp.total_duration,
298 load_duration: resp.load_duration,
299 prompt_eval_count: resp.prompt_eval_count,
300 prompt_eval_duration: resp.prompt_eval_duration,
301 eval_count: resp.eval_count,
302 eval_duration: resp.eval_duration,
303 message: Message::Assistant {
304 content,
305 images: None,
306 name: None,
307 tool_calls,
308 },
309 };
310 Ok(completion::CompletionResponse {
311 choice,
312 raw_response,
313 })
314 }
315 _ => Err(CompletionError::ResponseError(
316 "Chat response does not include an assistant message".into(),
317 )),
318 }
319 }
320}
321
322#[derive(Clone)]
325pub struct CompletionModel {
326 client: Client,
327 pub model: String,
328}
329
330impl CompletionModel {
331 pub fn new(client: Client, model: &str) -> Self {
332 Self {
333 client,
334 model: model.to_owned(),
335 }
336 }
337
338 fn create_completion_request(
339 &self,
340 completion_request: CompletionRequest,
341 ) -> Result<Value, CompletionError> {
342 let mut partial_history = vec![];
344 if let Some(docs) = completion_request.normalized_documents() {
345 partial_history.push(docs);
346 }
347 partial_history.extend(completion_request.chat_history);
348
349 let mut full_history: Vec<Message> = completion_request
351 .preamble
352 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
353
354 full_history.extend(
356 partial_history
357 .into_iter()
358 .map(|msg| msg.try_into())
359 .collect::<Result<Vec<Message>, _>>()?,
360 );
361
362 let options = if let Some(extra) = completion_request.additional_params {
364 json_utils::merge(
365 json!({ "temperature": completion_request.temperature }),
366 extra,
367 )
368 } else {
369 json!({ "temperature": completion_request.temperature })
370 };
371
372 let mut request_payload = json!({
373 "model": self.model,
374 "messages": full_history,
375 "options": options,
376 "stream": false,
377 });
378 if !completion_request.tools.is_empty() {
379 request_payload["tools"] = json!(completion_request
380 .tools
381 .into_iter()
382 .map(|tool| tool.into())
383 .collect::<Vec<ToolDefinition>>());
384 }
385
386 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
387
388 Ok(request_payload)
389 }
390}
391
392#[derive(Clone)]
395pub struct StreamingCompletionResponse {
396 pub done_reason: Option<String>,
397 pub total_duration: Option<u64>,
398 pub load_duration: Option<u64>,
399 pub prompt_eval_count: Option<u64>,
400 pub prompt_eval_duration: Option<u64>,
401 pub eval_count: Option<u64>,
402 pub eval_duration: Option<u64>,
403}
404impl completion::CompletionModel for CompletionModel {
405 type Response = CompletionResponse;
406 type StreamingResponse = StreamingCompletionResponse;
407
408 #[cfg_attr(feature = "worker", worker::send)]
409 async fn completion(
410 &self,
411 completion_request: CompletionRequest,
412 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
413 let request_payload = self.create_completion_request(completion_request)?;
414
415 let response = self
416 .client
417 .post("api/chat")
418 .json(&request_payload)
419 .send()
420 .await
421 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
422 if response.status().is_success() {
423 let text = response
424 .text()
425 .await
426 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
427 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
428 let chat_resp: CompletionResponse = serde_json::from_str(&text)
429 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
430 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
431 Ok(conv)
432 } else {
433 let err_text = response
434 .text()
435 .await
436 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
437 Err(CompletionError::ProviderError(err_text))
438 }
439 }
440
441 #[cfg_attr(feature = "worker", worker::send)]
442 async fn stream(
443 &self,
444 request: CompletionRequest,
445 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
446 {
447 let mut request_payload = self.create_completion_request(request)?;
448 merge_inplace(&mut request_payload, json!({"stream": true}));
449
450 let response = self
451 .client
452 .post("api/chat")
453 .json(&request_payload)
454 .send()
455 .await
456 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
457
458 if !response.status().is_success() {
459 let err_text = response
460 .text()
461 .await
462 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
463 return Err(CompletionError::ProviderError(err_text));
464 }
465
466 let stream = Box::pin(stream! {
467 let mut stream = response.bytes_stream();
468 while let Some(chunk_result) = stream.next().await {
469 let chunk = match chunk_result {
470 Ok(c) => c,
471 Err(e) => {
472 yield Err(CompletionError::from(e));
473 break;
474 }
475 };
476
477 let text = match String::from_utf8(chunk.to_vec()) {
478 Ok(t) => t,
479 Err(e) => {
480 yield Err(CompletionError::ResponseError(e.to_string()));
481 break;
482 }
483 };
484
485
486 for line in text.lines() {
487 let line = line.to_string();
488
489 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
490 continue;
491 };
492
493 match response.message {
494 Message::Assistant{ content, tool_calls, .. } => {
495 if !content.is_empty() {
496 yield Ok(RawStreamingChoice::Message(content))
497 }
498
499 for tool_call in tool_calls.iter() {
500 let function = tool_call.function.clone();
501
502 yield Ok(RawStreamingChoice::ToolCall {
503 id: "".to_string(),
504 name: function.name,
505 arguments: function.arguments
506 });
507 }
508 }
509 _ => {
510 continue;
511 }
512 }
513
514 if response.done {
515 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
516 total_duration: response.total_duration,
517 load_duration: response.load_duration,
518 prompt_eval_count: response.prompt_eval_count,
519 prompt_eval_duration: response.prompt_eval_duration,
520 eval_count: response.eval_count,
521 eval_duration: response.eval_duration,
522 done_reason: response.done_reason,
523 }));
524 }
525 }
526 }
527 });
528
529 Ok(streaming::StreamingCompletionResponse::stream(stream))
530 }
531}
532
533#[derive(Clone, Debug, Deserialize, Serialize)]
537pub struct ToolDefinition {
538 #[serde(rename = "type")]
539 pub type_field: String, pub function: completion::ToolDefinition,
541}
542
543impl From<crate::completion::ToolDefinition> for ToolDefinition {
545 fn from(tool: crate::completion::ToolDefinition) -> Self {
546 ToolDefinition {
547 type_field: "function".to_owned(),
548 function: completion::ToolDefinition {
549 name: tool.name,
550 description: tool.description,
551 parameters: tool.parameters,
552 },
553 }
554 }
555}
556
557#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
558pub struct ToolCall {
559 #[serde(default, rename = "type")]
561 pub r#type: ToolType,
562 pub function: Function,
563}
564#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
565#[serde(rename_all = "lowercase")]
566pub enum ToolType {
567 #[default]
568 Function,
569}
570#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
571pub struct Function {
572 pub name: String,
573 pub arguments: Value,
574}
575
576#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
579#[serde(tag = "role", rename_all = "lowercase")]
580pub enum Message {
581 User {
582 content: String,
583 #[serde(skip_serializing_if = "Option::is_none")]
584 images: Option<Vec<String>>,
585 #[serde(skip_serializing_if = "Option::is_none")]
586 name: Option<String>,
587 },
588 Assistant {
589 #[serde(default)]
590 content: String,
591 #[serde(skip_serializing_if = "Option::is_none")]
592 images: Option<Vec<String>>,
593 #[serde(skip_serializing_if = "Option::is_none")]
594 name: Option<String>,
595 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
596 tool_calls: Vec<ToolCall>,
597 },
598 System {
599 content: String,
600 #[serde(skip_serializing_if = "Option::is_none")]
601 images: Option<Vec<String>>,
602 #[serde(skip_serializing_if = "Option::is_none")]
603 name: Option<String>,
604 },
605 #[serde(rename = "Tool")]
606 ToolResult { name: String, content: String },
607}
608
609impl TryFrom<crate::message::Message> for Message {
615 type Error = crate::message::MessageError;
616 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
617 use crate::message::Message as InternalMessage;
618 match internal_msg {
619 InternalMessage::User { content, .. } => {
620 let mut texts = Vec::new();
621 let mut images = Vec::new();
622 for uc in content.into_iter() {
623 match uc {
624 crate::message::UserContent::Text(t) => texts.push(t.text),
625 crate::message::UserContent::Image(img) => images.push(img.data),
626 crate::message::UserContent::ToolResult(result) => {
627 let content = result
628 .content
629 .into_iter()
630 .map(ToolResultContent::try_from)
631 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
632
633 let content = OneOrMany::many(content).map_err(|x| {
634 MessageError::ConversionError(format!(
635 "Couldn't make a OneOrMany from a list of tool results: {x}"
636 ))
637 })?;
638
639 return Ok(Message::ToolResult {
640 name: result.id,
641 content: content.first().text,
642 });
643 }
644 _ => {} }
646 }
647 let content_str = texts.join(" ");
648 let images_opt = if images.is_empty() {
649 None
650 } else {
651 Some(images)
652 };
653 Ok(Message::User {
654 content: content_str,
655 images: images_opt,
656 name: None,
657 })
658 }
659 InternalMessage::Assistant { content, .. } => {
660 let mut texts = Vec::new();
661 let mut tool_calls = Vec::new();
662 for ac in content.into_iter() {
663 match ac {
664 crate::message::AssistantContent::Text(t) => texts.push(t.text),
665 crate::message::AssistantContent::ToolCall(tc) => {
666 tool_calls.push(ToolCall {
667 r#type: ToolType::Function, function: Function {
669 name: tc.function.name,
670 arguments: tc.function.arguments,
671 },
672 });
673 }
674 }
675 }
676 let content_str = texts.join(" ");
677 Ok(Message::Assistant {
678 content: content_str,
679 images: None,
680 name: None,
681 tool_calls,
682 })
683 }
684 }
685 }
686}
687
688impl From<Message> for crate::completion::Message {
691 fn from(msg: Message) -> Self {
692 match msg {
693 Message::User { content, .. } => crate::completion::Message::User {
694 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
695 text: content,
696 })),
697 },
698 Message::Assistant {
699 content,
700 tool_calls,
701 ..
702 } => {
703 let mut assistant_contents =
704 vec![crate::completion::message::AssistantContent::Text(Text {
705 text: content,
706 })];
707 for tc in tool_calls {
708 assistant_contents.push(
709 crate::completion::message::AssistantContent::tool_call(
710 tc.function.name.clone(),
711 tc.function.name,
712 tc.function.arguments,
713 ),
714 );
715 }
716 crate::completion::Message::Assistant {
717 content: OneOrMany::many(assistant_contents).unwrap(),
718 }
719 }
720 Message::System { content, .. } => crate::completion::Message::User {
722 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
723 text: content,
724 })),
725 },
726 Message::ToolResult { name, content } => crate::completion::Message::User {
727 content: OneOrMany::one(message::UserContent::tool_result(
728 name,
729 OneOrMany::one(message::ToolResultContent::Text(Text { text: content })),
730 )),
731 },
732 }
733 }
734}
735
736impl Message {
737 pub fn system(content: &str) -> Self {
739 Message::System {
740 content: content.to_owned(),
741 images: None,
742 name: None,
743 }
744 }
745}
746
747#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
750pub struct ToolResultContent {
751 text: String,
752}
753
754impl TryFrom<crate::message::ToolResultContent> for ToolResultContent {
755 type Error = MessageError;
756 fn try_from(value: crate::message::ToolResultContent) -> Result<Self, Self::Error> {
757 let crate::message::ToolResultContent::Text(Text { text }) = value else {
758 return Err(MessageError::ConversionError(
759 "Non-text tool results not supported".into(),
760 ));
761 };
762
763 Ok(Self { text })
764 }
765}
766
767impl FromStr for ToolResultContent {
768 type Err = Infallible;
769
770 fn from_str(s: &str) -> Result<Self, Self::Err> {
771 Ok(s.to_owned().into())
772 }
773}
774
775impl From<String> for ToolResultContent {
776 fn from(s: String) -> Self {
777 ToolResultContent { text: s }
778 }
779}
780
781#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
782pub struct SystemContent {
783 #[serde(default)]
784 r#type: SystemContentType,
785 text: String,
786}
787
788#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
789#[serde(rename_all = "lowercase")]
790pub enum SystemContentType {
791 #[default]
792 Text,
793}
794
795impl From<String> for SystemContent {
796 fn from(s: String) -> Self {
797 SystemContent {
798 r#type: SystemContentType::default(),
799 text: s,
800 }
801 }
802}
803
804impl FromStr for SystemContent {
805 type Err = std::convert::Infallible;
806 fn from_str(s: &str) -> Result<Self, Self::Err> {
807 Ok(SystemContent {
808 r#type: SystemContentType::default(),
809 text: s.to_string(),
810 })
811 }
812}
813
814#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
815pub struct AssistantContent {
816 pub text: String,
817}
818
819impl FromStr for AssistantContent {
820 type Err = std::convert::Infallible;
821 fn from_str(s: &str) -> Result<Self, Self::Err> {
822 Ok(AssistantContent { text: s.to_owned() })
823 }
824}
825
826#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
827#[serde(tag = "type", rename_all = "lowercase")]
828pub enum UserContent {
829 Text { text: String },
830 Image { image_url: ImageUrl },
831 }
833
834impl FromStr for UserContent {
835 type Err = std::convert::Infallible;
836 fn from_str(s: &str) -> Result<Self, Self::Err> {
837 Ok(UserContent::Text { text: s.to_owned() })
838 }
839}
840
841#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
842pub struct ImageUrl {
843 pub url: String,
844 #[serde(default)]
845 pub detail: ImageDetail,
846}
847
848#[cfg(test)]
853mod tests {
854 use super::*;
855 use serde_json::json;
856
857 #[tokio::test]
859 async fn test_chat_completion() {
860 let sample_chat_response = json!({
862 "model": "llama3.2",
863 "created_at": "2023-08-04T19:22:45.499127Z",
864 "message": {
865 "role": "assistant",
866 "content": "The sky is blue because of Rayleigh scattering.",
867 "images": null,
868 "tool_calls": [
869 {
870 "type": "function",
871 "function": {
872 "name": "get_current_weather",
873 "arguments": {
874 "location": "San Francisco, CA",
875 "format": "celsius"
876 }
877 }
878 }
879 ]
880 },
881 "done": true,
882 "total_duration": 8000000000u64,
883 "load_duration": 6000000u64,
884 "prompt_eval_count": 61u64,
885 "prompt_eval_duration": 400000000u64,
886 "eval_count": 468u64,
887 "eval_duration": 7700000000u64
888 });
889 let sample_text = sample_chat_response.to_string();
890
891 let chat_resp: CompletionResponse =
892 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
893 let conv: completion::CompletionResponse<CompletionResponse> =
894 chat_resp.try_into().unwrap();
895 assert!(
896 !conv.choice.is_empty(),
897 "Expected non-empty choice in chat response"
898 );
899 }
900
901 #[test]
903 fn test_message_conversion() {
904 let provider_msg = Message::User {
906 content: "Test message".to_owned(),
907 images: None,
908 name: None,
909 };
910 let comp_msg: crate::completion::Message = provider_msg.into();
912 match comp_msg {
913 crate::completion::Message::User { content } => {
914 let first_content = content.first();
916 match first_content {
918 crate::completion::message::UserContent::Text(text_struct) => {
919 assert_eq!(text_struct.text, "Test message");
920 }
921 _ => panic!("Expected text content in conversion"),
922 }
923 }
924 _ => panic!("Conversion from provider Message to completion Message failed"),
925 }
926 }
927
928 #[test]
930 fn test_tool_definition_conversion() {
931 let internal_tool = crate::completion::ToolDefinition {
933 name: "get_current_weather".to_owned(),
934 description: "Get the current weather for a location".to_owned(),
935 parameters: json!({
936 "type": "object",
937 "properties": {
938 "location": {
939 "type": "string",
940 "description": "The location to get the weather for, e.g. San Francisco, CA"
941 },
942 "format": {
943 "type": "string",
944 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
945 "enum": ["celsius", "fahrenheit"]
946 }
947 },
948 "required": ["location", "format"]
949 }),
950 };
951 let ollama_tool: ToolDefinition = internal_tool.into();
953 assert_eq!(ollama_tool.type_field, "function");
954 assert_eq!(ollama_tool.function.name, "get_current_weather");
955 assert_eq!(
956 ollama_tool.function.description,
957 "Get the current weather for a location"
958 );
959 let params = &ollama_tool.function.parameters;
961 assert_eq!(params["properties"]["location"]["type"], "string");
962 }
963}