1use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::completion::Usage;
43use crate::json_utils::merge_inplace;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46 Embed, OneOrMany,
47 completion::{self, CompletionError, CompletionRequest},
48 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49 impl_conversion_traits, json_utils, message,
50 message::{ImageDetail, Text},
51 streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::{convert::TryFrom, str::FromStr};
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
62
63#[derive(Clone, Debug)]
64pub struct Client {
65 base_url: String,
66 http_client: reqwest::Client,
67}
68
69impl Default for Client {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Client {
76 pub fn new() -> Self {
77 Self::from_url(OLLAMA_API_BASE_URL)
78 }
79 pub fn from_url(base_url: &str) -> Self {
80 Self {
81 base_url: base_url.to_owned(),
82 http_client: reqwest::Client::builder()
83 .build()
84 .expect("Ollama reqwest client should build"),
85 }
86 }
87
88 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
91 self.http_client = client;
92
93 self
94 }
95
96 fn post(&self, path: &str) -> reqwest::RequestBuilder {
97 let url = format!("{}/{}", self.base_url, path);
98 self.http_client.post(url)
99 }
100}
101
102impl ProviderClient for Client {
103 fn from_env() -> Self
104 where
105 Self: Sized,
106 {
107 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
108 Self::from_url(&api_base)
109 }
110
111 fn from_val(input: crate::client::ProviderValue) -> Self {
112 let crate::client::ProviderValue::Simple(_) = input else {
113 panic!("Incorrect provider value type")
114 };
115
116 Self::new()
117 }
118}
119
120impl CompletionClient for Client {
121 type CompletionModel = CompletionModel;
122
123 fn completion_model(&self, model: &str) -> CompletionModel {
124 CompletionModel::new(self.clone(), model)
125 }
126}
127
128impl EmbeddingsClient for Client {
129 type EmbeddingModel = EmbeddingModel;
130 fn embedding_model(&self, model: &str) -> EmbeddingModel {
131 EmbeddingModel::new(self.clone(), model, 0)
132 }
133 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
134 EmbeddingModel::new(self.clone(), model, ndims)
135 }
136 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
137 EmbeddingsBuilder::new(self.embedding_model(model))
138 }
139}
140
141impl_conversion_traits!(
142 AsTranscription,
143 AsImageGeneration,
144 AsAudioGeneration for Client
145);
146
147#[derive(Debug, Deserialize)]
150struct ApiErrorResponse {
151 message: String,
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156enum ApiResponse<T> {
157 Ok(T),
158 Err(ApiErrorResponse),
159}
160
161pub const ALL_MINILM: &str = "all-minilm";
164pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
165
166#[derive(Debug, Serialize, Deserialize)]
167pub struct EmbeddingResponse {
168 pub model: String,
169 pub embeddings: Vec<Vec<f64>>,
170 #[serde(default)]
171 pub total_duration: Option<u64>,
172 #[serde(default)]
173 pub load_duration: Option<u64>,
174 #[serde(default)]
175 pub prompt_eval_count: Option<u64>,
176}
177
178impl From<ApiErrorResponse> for EmbeddingError {
179 fn from(err: ApiErrorResponse) -> Self {
180 EmbeddingError::ProviderError(err.message)
181 }
182}
183
184impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
185 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
186 match value {
187 ApiResponse::Ok(response) => Ok(response),
188 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
189 }
190 }
191}
192
193#[derive(Clone)]
196pub struct EmbeddingModel {
197 client: Client,
198 pub model: String,
199 ndims: usize,
200}
201
202impl EmbeddingModel {
203 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
204 Self {
205 client,
206 model: model.to_owned(),
207 ndims,
208 }
209 }
210}
211
212impl embeddings::EmbeddingModel for EmbeddingModel {
213 const MAX_DOCUMENTS: usize = 1024;
214 fn ndims(&self) -> usize {
215 self.ndims
216 }
217 #[cfg_attr(feature = "worker", worker::send)]
218 async fn embed_texts(
219 &self,
220 documents: impl IntoIterator<Item = String>,
221 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
222 let docs: Vec<String> = documents.into_iter().collect();
223 let payload = json!({
224 "model": self.model,
225 "input": docs,
226 });
227 let response = self
228 .client
229 .post("api/embed")
230 .json(&payload)
231 .send()
232 .await
233 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
234 if response.status().is_success() {
235 let api_resp: EmbeddingResponse = response
236 .json()
237 .await
238 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
239 if api_resp.embeddings.len() != docs.len() {
240 return Err(EmbeddingError::ResponseError(
241 "Number of returned embeddings does not match input".into(),
242 ));
243 }
244 Ok(api_resp
245 .embeddings
246 .into_iter()
247 .zip(docs.into_iter())
248 .map(|(vec, document)| embeddings::Embedding { document, vec })
249 .collect())
250 } else {
251 Err(EmbeddingError::ProviderError(response.text().await?))
252 }
253 }
254}
255
256pub const LLAMA3_2: &str = "llama3.2";
259pub const LLAVA: &str = "llava";
260pub const MISTRAL: &str = "mistral";
261
262#[derive(Debug, Serialize, Deserialize)]
263pub struct CompletionResponse {
264 pub model: String,
265 pub created_at: String,
266 pub message: Message,
267 pub done: bool,
268 #[serde(default)]
269 pub done_reason: Option<String>,
270 #[serde(default)]
271 pub total_duration: Option<u64>,
272 #[serde(default)]
273 pub load_duration: Option<u64>,
274 #[serde(default)]
275 pub prompt_eval_count: Option<u64>,
276 #[serde(default)]
277 pub prompt_eval_duration: Option<u64>,
278 #[serde(default)]
279 pub eval_count: Option<u64>,
280 #[serde(default)]
281 pub eval_duration: Option<u64>,
282}
283impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
284 type Error = CompletionError;
285 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
286 match resp.message {
287 Message::Assistant {
289 content,
290 tool_calls,
291 ..
292 } => {
293 let mut assistant_contents = Vec::new();
294 if !content.is_empty() {
296 assistant_contents.push(completion::AssistantContent::text(&content));
297 }
298 for tc in tool_calls.iter() {
301 assistant_contents.push(completion::AssistantContent::tool_call(
302 tc.function.name.clone(),
303 tc.function.name.clone(),
304 tc.function.arguments.clone(),
305 ));
306 }
307 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
308 CompletionError::ResponseError("No content provided".to_owned())
309 })?;
310 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
311 let completion_tokens = resp.eval_count.unwrap_or(0);
312
313 let raw_response = CompletionResponse {
314 model: resp.model,
315 created_at: resp.created_at,
316 done: resp.done,
317 done_reason: resp.done_reason,
318 total_duration: resp.total_duration,
319 load_duration: resp.load_duration,
320 prompt_eval_count: resp.prompt_eval_count,
321 prompt_eval_duration: resp.prompt_eval_duration,
322 eval_count: resp.eval_count,
323 eval_duration: resp.eval_duration,
324 message: Message::Assistant {
325 content,
326 images: None,
327 name: None,
328 tool_calls,
329 },
330 };
331
332 Ok(completion::CompletionResponse {
333 choice,
334 usage: Usage {
335 input_tokens: prompt_tokens,
336 output_tokens: completion_tokens,
337 total_tokens: prompt_tokens + completion_tokens,
338 },
339 raw_response,
340 })
341 }
342 _ => Err(CompletionError::ResponseError(
343 "Chat response does not include an assistant message".into(),
344 )),
345 }
346 }
347}
348
349#[derive(Clone)]
352pub struct CompletionModel {
353 client: Client,
354 pub model: String,
355}
356
357impl CompletionModel {
358 pub fn new(client: Client, model: &str) -> Self {
359 Self {
360 client,
361 model: model.to_owned(),
362 }
363 }
364
365 fn create_completion_request(
366 &self,
367 completion_request: CompletionRequest,
368 ) -> Result<Value, CompletionError> {
369 let mut partial_history = vec![];
371 if let Some(docs) = completion_request.normalized_documents() {
372 partial_history.push(docs);
373 }
374 partial_history.extend(completion_request.chat_history);
375
376 let mut full_history: Vec<Message> = completion_request
378 .preamble
379 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
380
381 full_history.extend(
383 partial_history
384 .into_iter()
385 .map(|msg| msg.try_into())
386 .collect::<Result<Vec<Vec<Message>>, _>>()?
387 .into_iter()
388 .flatten()
389 .collect::<Vec<Message>>(),
390 );
391
392 let options = if let Some(extra) = completion_request.additional_params {
394 json_utils::merge(
395 json!({ "temperature": completion_request.temperature }),
396 extra,
397 )
398 } else {
399 json!({ "temperature": completion_request.temperature })
400 };
401
402 let mut request_payload = json!({
403 "model": self.model,
404 "messages": full_history,
405 "options": options,
406 "stream": false,
407 });
408 if !completion_request.tools.is_empty() {
409 request_payload["tools"] = json!(
410 completion_request
411 .tools
412 .into_iter()
413 .map(|tool| tool.into())
414 .collect::<Vec<ToolDefinition>>()
415 );
416 }
417
418 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
419
420 Ok(request_payload)
421 }
422}
423
424#[derive(Clone)]
427pub struct StreamingCompletionResponse {
428 pub done_reason: Option<String>,
429 pub total_duration: Option<u64>,
430 pub load_duration: Option<u64>,
431 pub prompt_eval_count: Option<u64>,
432 pub prompt_eval_duration: Option<u64>,
433 pub eval_count: Option<u64>,
434 pub eval_duration: Option<u64>,
435}
436impl completion::CompletionModel for CompletionModel {
437 type Response = CompletionResponse;
438 type StreamingResponse = StreamingCompletionResponse;
439
440 #[cfg_attr(feature = "worker", worker::send)]
441 async fn completion(
442 &self,
443 completion_request: CompletionRequest,
444 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
445 let request_payload = self.create_completion_request(completion_request)?;
446
447 let response = self
448 .client
449 .post("api/chat")
450 .json(&request_payload)
451 .send()
452 .await
453 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
454 if response.status().is_success() {
455 let text = response
456 .text()
457 .await
458 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
459 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
460 let chat_resp: CompletionResponse = serde_json::from_str(&text)
461 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
462 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
463 Ok(conv)
464 } else {
465 let err_text = response
466 .text()
467 .await
468 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
469 Err(CompletionError::ProviderError(err_text))
470 }
471 }
472
473 #[cfg_attr(feature = "worker", worker::send)]
474 async fn stream(
475 &self,
476 request: CompletionRequest,
477 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
478 {
479 let mut request_payload = self.create_completion_request(request)?;
480 merge_inplace(&mut request_payload, json!({"stream": true}));
481
482 let response = self
483 .client
484 .post("api/chat")
485 .json(&request_payload)
486 .send()
487 .await
488 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
489
490 if !response.status().is_success() {
491 let err_text = response
492 .text()
493 .await
494 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
495 return Err(CompletionError::ProviderError(err_text));
496 }
497
498 let stream = Box::pin(stream! {
499 let mut stream = response.bytes_stream();
500 while let Some(chunk_result) = stream.next().await {
501 let chunk = match chunk_result {
502 Ok(c) => c,
503 Err(e) => {
504 yield Err(CompletionError::from(e));
505 break;
506 }
507 };
508
509 let text = match String::from_utf8(chunk.to_vec()) {
510 Ok(t) => t,
511 Err(e) => {
512 yield Err(CompletionError::ResponseError(e.to_string()));
513 break;
514 }
515 };
516
517
518 for line in text.lines() {
519 let line = line.to_string();
520
521 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
522 continue;
523 };
524
525 match response.message {
526 Message::Assistant{ content, tool_calls, .. } => {
527 if !content.is_empty() {
528 yield Ok(RawStreamingChoice::Message(content))
529 }
530
531 for tool_call in tool_calls.iter() {
532 let function = tool_call.function.clone();
533
534 yield Ok(RawStreamingChoice::ToolCall {
535 id: "".to_string(),
536 name: function.name,
537 arguments: function.arguments,
538 call_id: None
539 });
540 }
541 }
542 _ => {
543 continue;
544 }
545 }
546
547 if response.done {
548 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
549 total_duration: response.total_duration,
550 load_duration: response.load_duration,
551 prompt_eval_count: response.prompt_eval_count,
552 prompt_eval_duration: response.prompt_eval_duration,
553 eval_count: response.eval_count,
554 eval_duration: response.eval_duration,
555 done_reason: response.done_reason,
556 }));
557 }
558 }
559 }
560 });
561
562 Ok(streaming::StreamingCompletionResponse::stream(stream))
563 }
564}
565
566#[derive(Clone, Debug, Deserialize, Serialize)]
570pub struct ToolDefinition {
571 #[serde(rename = "type")]
572 pub type_field: String, pub function: completion::ToolDefinition,
574}
575
576impl From<crate::completion::ToolDefinition> for ToolDefinition {
578 fn from(tool: crate::completion::ToolDefinition) -> Self {
579 ToolDefinition {
580 type_field: "function".to_owned(),
581 function: completion::ToolDefinition {
582 name: tool.name,
583 description: tool.description,
584 parameters: tool.parameters,
585 },
586 }
587 }
588}
589
590#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
591pub struct ToolCall {
592 #[serde(default, rename = "type")]
594 pub r#type: ToolType,
595 pub function: Function,
596}
597#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
598#[serde(rename_all = "lowercase")]
599pub enum ToolType {
600 #[default]
601 Function,
602}
603#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
604pub struct Function {
605 pub name: String,
606 pub arguments: Value,
607}
608
609#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
612#[serde(tag = "role", rename_all = "lowercase")]
613pub enum Message {
614 User {
615 content: String,
616 #[serde(skip_serializing_if = "Option::is_none")]
617 images: Option<Vec<String>>,
618 #[serde(skip_serializing_if = "Option::is_none")]
619 name: Option<String>,
620 },
621 Assistant {
622 #[serde(default)]
623 content: String,
624 #[serde(skip_serializing_if = "Option::is_none")]
625 images: Option<Vec<String>>,
626 #[serde(skip_serializing_if = "Option::is_none")]
627 name: Option<String>,
628 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
629 tool_calls: Vec<ToolCall>,
630 },
631 System {
632 content: String,
633 #[serde(skip_serializing_if = "Option::is_none")]
634 images: Option<Vec<String>>,
635 #[serde(skip_serializing_if = "Option::is_none")]
636 name: Option<String>,
637 },
638 #[serde(rename = "tool")]
639 ToolResult {
640 #[serde(rename = "tool_name")]
641 name: String,
642 content: String,
643 },
644}
645
646impl TryFrom<crate::message::Message> for Vec<Message> {
652 type Error = crate::message::MessageError;
653 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
654 use crate::message::Message as InternalMessage;
655 match internal_msg {
656 InternalMessage::User { content, .. } => {
657 let (tool_results, other_content): (Vec<_>, Vec<_>) =
658 content.into_iter().partition(|content| {
659 matches!(content, crate::message::UserContent::ToolResult(_))
660 });
661
662 if !tool_results.is_empty() {
663 tool_results
664 .into_iter()
665 .map(|content| match content {
666 crate::message::UserContent::ToolResult(
667 crate::message::ToolResult { id, content, .. },
668 ) => {
669 let content_string = content
671 .into_iter()
672 .map(|content| match content {
673 crate::message::ToolResultContent::Text(text) => text.text,
674 _ => "[Non-text content]".to_string(),
675 })
676 .collect::<Vec<_>>()
677 .join("\n");
678
679 Ok::<_, crate::message::MessageError>(Message::ToolResult {
680 name: id,
681 content: content_string,
682 })
683 }
684 _ => unreachable!(),
685 })
686 .collect::<Result<Vec<_>, _>>()
687 } else {
688 let (texts, images) = other_content.into_iter().fold(
690 (Vec::new(), Vec::new()),
691 |(mut texts, mut images), content| {
692 match content {
693 crate::message::UserContent::Text(crate::message::Text {
694 text,
695 }) => texts.push(text),
696 crate::message::UserContent::Image(crate::message::Image {
697 data,
698 ..
699 }) => images.push(data),
700 _ => {} }
702 (texts, images)
703 },
704 );
705
706 Ok(vec![Message::User {
707 content: texts.join(" "),
708 images: if images.is_empty() {
709 None
710 } else {
711 Some(images)
712 },
713 name: None,
714 }])
715 }
716 }
717 InternalMessage::Assistant { content, .. } => {
718 let (text_content, tool_calls) = content.into_iter().fold(
719 (Vec::new(), Vec::new()),
720 |(mut texts, mut tools), content| {
721 match content {
722 crate::message::AssistantContent::Text(text) => texts.push(text.text),
723 crate::message::AssistantContent::ToolCall(tool_call) => {
724 tools.push(tool_call)
725 }
726 }
727 (texts, tools)
728 },
729 );
730
731 Ok(vec![Message::Assistant {
734 content: text_content.join(" "),
735 images: None,
736 name: None,
737 tool_calls: tool_calls
738 .into_iter()
739 .map(|tool_call| tool_call.into())
740 .collect::<Vec<_>>(),
741 }])
742 }
743 }
744 }
745}
746
747impl From<Message> for crate::completion::Message {
750 fn from(msg: Message) -> Self {
751 match msg {
752 Message::User { content, .. } => crate::completion::Message::User {
753 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
754 text: content,
755 })),
756 },
757 Message::Assistant {
758 content,
759 tool_calls,
760 ..
761 } => {
762 let mut assistant_contents =
763 vec![crate::completion::message::AssistantContent::Text(Text {
764 text: content,
765 })];
766 for tc in tool_calls {
767 assistant_contents.push(
768 crate::completion::message::AssistantContent::tool_call(
769 tc.function.name.clone(),
770 tc.function.name,
771 tc.function.arguments,
772 ),
773 );
774 }
775 crate::completion::Message::Assistant {
776 id: None,
777 content: OneOrMany::many(assistant_contents).unwrap(),
778 }
779 }
780 Message::System { content, .. } => crate::completion::Message::User {
782 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
783 text: content,
784 })),
785 },
786 Message::ToolResult { name, content } => crate::completion::Message::User {
787 content: OneOrMany::one(message::UserContent::tool_result(
788 name,
789 OneOrMany::one(message::ToolResultContent::text(content)),
790 )),
791 },
792 }
793 }
794}
795
796impl Message {
797 pub fn system(content: &str) -> Self {
799 Message::System {
800 content: content.to_owned(),
801 images: None,
802 name: None,
803 }
804 }
805}
806
807impl From<crate::message::ToolCall> for ToolCall {
810 fn from(tool_call: crate::message::ToolCall) -> Self {
811 Self {
812 r#type: ToolType::Function,
813 function: Function {
814 name: tool_call.function.name,
815 arguments: tool_call.function.arguments,
816 },
817 }
818 }
819}
820
821#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
822pub struct SystemContent {
823 #[serde(default)]
824 r#type: SystemContentType,
825 text: String,
826}
827
828#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
829#[serde(rename_all = "lowercase")]
830pub enum SystemContentType {
831 #[default]
832 Text,
833}
834
835impl From<String> for SystemContent {
836 fn from(s: String) -> Self {
837 SystemContent {
838 r#type: SystemContentType::default(),
839 text: s,
840 }
841 }
842}
843
844impl FromStr for SystemContent {
845 type Err = std::convert::Infallible;
846 fn from_str(s: &str) -> Result<Self, Self::Err> {
847 Ok(SystemContent {
848 r#type: SystemContentType::default(),
849 text: s.to_string(),
850 })
851 }
852}
853
854#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
855pub struct AssistantContent {
856 pub text: String,
857}
858
859impl FromStr for AssistantContent {
860 type Err = std::convert::Infallible;
861 fn from_str(s: &str) -> Result<Self, Self::Err> {
862 Ok(AssistantContent { text: s.to_owned() })
863 }
864}
865
866#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
867#[serde(tag = "type", rename_all = "lowercase")]
868pub enum UserContent {
869 Text { text: String },
870 Image { image_url: ImageUrl },
871 }
873
874impl FromStr for UserContent {
875 type Err = std::convert::Infallible;
876 fn from_str(s: &str) -> Result<Self, Self::Err> {
877 Ok(UserContent::Text { text: s.to_owned() })
878 }
879}
880
881#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
882pub struct ImageUrl {
883 pub url: String,
884 #[serde(default)]
885 pub detail: ImageDetail,
886}
887
888#[cfg(test)]
893mod tests {
894 use super::*;
895 use serde_json::json;
896
897 #[tokio::test]
899 async fn test_chat_completion() {
900 let sample_chat_response = json!({
902 "model": "llama3.2",
903 "created_at": "2023-08-04T19:22:45.499127Z",
904 "message": {
905 "role": "assistant",
906 "content": "The sky is blue because of Rayleigh scattering.",
907 "images": null,
908 "tool_calls": [
909 {
910 "type": "function",
911 "function": {
912 "name": "get_current_weather",
913 "arguments": {
914 "location": "San Francisco, CA",
915 "format": "celsius"
916 }
917 }
918 }
919 ]
920 },
921 "done": true,
922 "total_duration": 8000000000u64,
923 "load_duration": 6000000u64,
924 "prompt_eval_count": 61u64,
925 "prompt_eval_duration": 400000000u64,
926 "eval_count": 468u64,
927 "eval_duration": 7700000000u64
928 });
929 let sample_text = sample_chat_response.to_string();
930
931 let chat_resp: CompletionResponse =
932 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
933 let conv: completion::CompletionResponse<CompletionResponse> =
934 chat_resp.try_into().unwrap();
935 assert!(
936 !conv.choice.is_empty(),
937 "Expected non-empty choice in chat response"
938 );
939 }
940
941 #[test]
943 fn test_message_conversion() {
944 let provider_msg = Message::User {
946 content: "Test message".to_owned(),
947 images: None,
948 name: None,
949 };
950 let comp_msg: crate::completion::Message = provider_msg.into();
952 match comp_msg {
953 crate::completion::Message::User { content } => {
954 let first_content = content.first();
956 match first_content {
958 crate::completion::message::UserContent::Text(text_struct) => {
959 assert_eq!(text_struct.text, "Test message");
960 }
961 _ => panic!("Expected text content in conversion"),
962 }
963 }
964 _ => panic!("Conversion from provider Message to completion Message failed"),
965 }
966 }
967
968 #[test]
970 fn test_tool_definition_conversion() {
971 let internal_tool = crate::completion::ToolDefinition {
973 name: "get_current_weather".to_owned(),
974 description: "Get the current weather for a location".to_owned(),
975 parameters: json!({
976 "type": "object",
977 "properties": {
978 "location": {
979 "type": "string",
980 "description": "The location to get the weather for, e.g. San Francisco, CA"
981 },
982 "format": {
983 "type": "string",
984 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
985 "enum": ["celsius", "fahrenheit"]
986 }
987 },
988 "required": ["location", "format"]
989 }),
990 };
991 let ollama_tool: ToolDefinition = internal_tool.into();
993 assert_eq!(ollama_tool.type_field, "function");
994 assert_eq!(ollama_tool.function.name, "get_current_weather");
995 assert_eq!(
996 ollama_tool.function.description,
997 "Get the current weather for a location"
998 );
999 let params = &ollama_tool.function.parameters;
1001 assert_eq!(params["properties"]["location"]["type"], "string");
1002 }
1003}