1use crate::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::completion::Usage;
43use crate::json_utils::merge_inplace;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46 Embed, OneOrMany,
47 completion::{self, CompletionError, CompletionRequest},
48 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49 impl_conversion_traits, json_utils, message,
50 message::{ImageDetail, Text},
51 streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::{convert::TryFrom, str::FromStr};
59const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
62
63pub struct ClientBuilder<'a> {
64 base_url: &'a str,
65 http_client: Option<reqwest::Client>,
66}
67
68impl<'a> ClientBuilder<'a> {
69 #[allow(clippy::new_without_default)]
70 pub fn new() -> Self {
71 Self {
72 base_url: OLLAMA_API_BASE_URL,
73 http_client: None,
74 }
75 }
76
77 pub fn base_url(mut self, base_url: &'a str) -> Self {
78 self.base_url = base_url;
79 self
80 }
81
82 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
83 self.http_client = Some(client);
84 self
85 }
86
87 pub fn build(self) -> Result<Client, ClientBuilderError> {
88 let http_client = if let Some(http_client) = self.http_client {
89 http_client
90 } else {
91 reqwest::Client::builder().build()?
92 };
93
94 Ok(Client {
95 base_url: self.base_url.to_string(),
96 http_client,
97 })
98 }
99}
100
101#[derive(Clone, Debug)]
102pub struct Client {
103 base_url: String,
104 http_client: reqwest::Client,
105}
106
107impl Default for Client {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl Client {
114 pub fn builder() -> ClientBuilder<'static> {
125 ClientBuilder::new()
126 }
127
128 pub fn new() -> Self {
133 Self::builder().build().expect("Ollama client should build")
134 }
135
136 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
137 let url = format!("{}/{}", self.base_url, path);
138 self.http_client.post(url)
139 }
140}
141
142impl ProviderClient for Client {
143 fn from_env() -> Self
144 where
145 Self: Sized,
146 {
147 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
148 Self::builder().base_url(&api_base).build().unwrap()
149 }
150
151 fn from_val(input: crate::client::ProviderValue) -> Self {
152 let crate::client::ProviderValue::Simple(_) = input else {
153 panic!("Incorrect provider value type")
154 };
155
156 Self::new()
157 }
158}
159
160impl CompletionClient for Client {
161 type CompletionModel = CompletionModel;
162
163 fn completion_model(&self, model: &str) -> CompletionModel {
164 CompletionModel::new(self.clone(), model)
165 }
166}
167
168impl EmbeddingsClient for Client {
169 type EmbeddingModel = EmbeddingModel;
170 fn embedding_model(&self, model: &str) -> EmbeddingModel {
171 EmbeddingModel::new(self.clone(), model, 0)
172 }
173 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
174 EmbeddingModel::new(self.clone(), model, ndims)
175 }
176 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
177 EmbeddingsBuilder::new(self.embedding_model(model))
178 }
179}
180
181impl_conversion_traits!(
182 AsTranscription,
183 AsImageGeneration,
184 AsAudioGeneration for Client
185);
186
187#[derive(Debug, Deserialize)]
190struct ApiErrorResponse {
191 message: String,
192}
193
194#[derive(Debug, Deserialize)]
195#[serde(untagged)]
196enum ApiResponse<T> {
197 Ok(T),
198 Err(ApiErrorResponse),
199}
200
201pub const ALL_MINILM: &str = "all-minilm";
204pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
205
206#[derive(Debug, Serialize, Deserialize)]
207pub struct EmbeddingResponse {
208 pub model: String,
209 pub embeddings: Vec<Vec<f64>>,
210 #[serde(default)]
211 pub total_duration: Option<u64>,
212 #[serde(default)]
213 pub load_duration: Option<u64>,
214 #[serde(default)]
215 pub prompt_eval_count: Option<u64>,
216}
217
218impl From<ApiErrorResponse> for EmbeddingError {
219 fn from(err: ApiErrorResponse) -> Self {
220 EmbeddingError::ProviderError(err.message)
221 }
222}
223
224impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
225 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
226 match value {
227 ApiResponse::Ok(response) => Ok(response),
228 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
229 }
230 }
231}
232
233#[derive(Clone)]
236pub struct EmbeddingModel {
237 client: Client,
238 pub model: String,
239 ndims: usize,
240}
241
242impl EmbeddingModel {
243 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
244 Self {
245 client,
246 model: model.to_owned(),
247 ndims,
248 }
249 }
250}
251
252impl embeddings::EmbeddingModel for EmbeddingModel {
253 const MAX_DOCUMENTS: usize = 1024;
254 fn ndims(&self) -> usize {
255 self.ndims
256 }
257 #[cfg_attr(feature = "worker", worker::send)]
258 async fn embed_texts(
259 &self,
260 documents: impl IntoIterator<Item = String>,
261 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
262 let docs: Vec<String> = documents.into_iter().collect();
263 let payload = json!({
264 "model": self.model,
265 "input": docs,
266 });
267 let response = self
268 .client
269 .post("api/embed")
270 .json(&payload)
271 .send()
272 .await
273 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
274 if response.status().is_success() {
275 let api_resp: EmbeddingResponse = response
276 .json()
277 .await
278 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
279 if api_resp.embeddings.len() != docs.len() {
280 return Err(EmbeddingError::ResponseError(
281 "Number of returned embeddings does not match input".into(),
282 ));
283 }
284 Ok(api_resp
285 .embeddings
286 .into_iter()
287 .zip(docs.into_iter())
288 .map(|(vec, document)| embeddings::Embedding { document, vec })
289 .collect())
290 } else {
291 Err(EmbeddingError::ProviderError(response.text().await?))
292 }
293 }
294}
295
296pub const LLAMA3_2: &str = "llama3.2";
299pub const LLAVA: &str = "llava";
300pub const MISTRAL: &str = "mistral";
301
302#[derive(Debug, Serialize, Deserialize)]
303pub struct CompletionResponse {
304 pub model: String,
305 pub created_at: String,
306 pub message: Message,
307 pub done: bool,
308 #[serde(default)]
309 pub done_reason: Option<String>,
310 #[serde(default)]
311 pub total_duration: Option<u64>,
312 #[serde(default)]
313 pub load_duration: Option<u64>,
314 #[serde(default)]
315 pub prompt_eval_count: Option<u64>,
316 #[serde(default)]
317 pub prompt_eval_duration: Option<u64>,
318 #[serde(default)]
319 pub eval_count: Option<u64>,
320 #[serde(default)]
321 pub eval_duration: Option<u64>,
322}
323impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
324 type Error = CompletionError;
325 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
326 match resp.message {
327 Message::Assistant {
329 content,
330 thinking,
331 tool_calls,
332 ..
333 } => {
334 let mut assistant_contents = Vec::new();
335 if !content.is_empty() {
337 assistant_contents.push(completion::AssistantContent::text(&content));
338 }
339 for tc in tool_calls.iter() {
342 assistant_contents.push(completion::AssistantContent::tool_call(
343 tc.function.name.clone(),
344 tc.function.name.clone(),
345 tc.function.arguments.clone(),
346 ));
347 }
348 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
349 CompletionError::ResponseError("No content provided".to_owned())
350 })?;
351 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
352 let completion_tokens = resp.eval_count.unwrap_or(0);
353
354 let raw_response = CompletionResponse {
355 model: resp.model,
356 created_at: resp.created_at,
357 done: resp.done,
358 done_reason: resp.done_reason,
359 total_duration: resp.total_duration,
360 load_duration: resp.load_duration,
361 prompt_eval_count: resp.prompt_eval_count,
362 prompt_eval_duration: resp.prompt_eval_duration,
363 eval_count: resp.eval_count,
364 eval_duration: resp.eval_duration,
365 message: Message::Assistant {
366 content,
367 thinking,
368 images: None,
369 name: None,
370 tool_calls,
371 },
372 };
373
374 Ok(completion::CompletionResponse {
375 choice,
376 usage: Usage {
377 input_tokens: prompt_tokens,
378 output_tokens: completion_tokens,
379 total_tokens: prompt_tokens + completion_tokens,
380 },
381 raw_response,
382 })
383 }
384 _ => Err(CompletionError::ResponseError(
385 "Chat response does not include an assistant message".into(),
386 )),
387 }
388 }
389}
390
391#[derive(Clone)]
394pub struct CompletionModel {
395 client: Client,
396 pub model: String,
397}
398
399impl CompletionModel {
400 pub fn new(client: Client, model: &str) -> Self {
401 Self {
402 client,
403 model: model.to_owned(),
404 }
405 }
406
407 fn create_completion_request(
408 &self,
409 completion_request: CompletionRequest,
410 ) -> Result<Value, CompletionError> {
411 let mut partial_history = vec![];
413 if let Some(docs) = completion_request.normalized_documents() {
414 partial_history.push(docs);
415 }
416 partial_history.extend(completion_request.chat_history);
417
418 let mut full_history: Vec<Message> = completion_request
420 .preamble
421 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
422
423 full_history.extend(
425 partial_history
426 .into_iter()
427 .map(|msg| msg.try_into())
428 .collect::<Result<Vec<Vec<Message>>, _>>()?
429 .into_iter()
430 .flatten()
431 .collect::<Vec<Message>>(),
432 );
433
434 let options = if let Some(extra) = completion_request.additional_params {
436 json_utils::merge(
437 json!({ "temperature": completion_request.temperature }),
438 extra,
439 )
440 } else {
441 json!({ "temperature": completion_request.temperature })
442 };
443
444 let mut request_payload = json!({
445 "model": self.model,
446 "messages": full_history,
447 "options": options,
448 "stream": false,
449 });
450 if !completion_request.tools.is_empty() {
451 request_payload["tools"] = json!(
452 completion_request
453 .tools
454 .into_iter()
455 .map(|tool| tool.into())
456 .collect::<Vec<ToolDefinition>>()
457 );
458 }
459
460 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
461
462 Ok(request_payload)
463 }
464}
465
466#[derive(Clone, Serialize, Deserialize, Debug)]
469pub struct StreamingCompletionResponse {
470 pub done_reason: Option<String>,
471 pub total_duration: Option<u64>,
472 pub load_duration: Option<u64>,
473 pub prompt_eval_count: Option<u64>,
474 pub prompt_eval_duration: Option<u64>,
475 pub eval_count: Option<u64>,
476 pub eval_duration: Option<u64>,
477}
478impl completion::CompletionModel for CompletionModel {
479 type Response = CompletionResponse;
480 type StreamingResponse = StreamingCompletionResponse;
481
482 #[cfg_attr(feature = "worker", worker::send)]
483 async fn completion(
484 &self,
485 completion_request: CompletionRequest,
486 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
487 let request_payload = self.create_completion_request(completion_request)?;
488
489 let response = self
490 .client
491 .post("api/chat")
492 .json(&request_payload)
493 .send()
494 .await
495 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
496 if response.status().is_success() {
497 let text = response
498 .text()
499 .await
500 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
501 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
502 let chat_resp: CompletionResponse = serde_json::from_str(&text)
503 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
504 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
505 Ok(conv)
506 } else {
507 let err_text = response
508 .text()
509 .await
510 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
511 Err(CompletionError::ProviderError(err_text))
512 }
513 }
514
515 #[cfg_attr(feature = "worker", worker::send)]
516 async fn stream(
517 &self,
518 request: CompletionRequest,
519 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
520 {
521 let mut request_payload = self.create_completion_request(request)?;
522 merge_inplace(&mut request_payload, json!({"stream": true}));
523
524 let response = self
525 .client
526 .post("api/chat")
527 .json(&request_payload)
528 .send()
529 .await
530 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
531
532 if !response.status().is_success() {
533 let err_text = response
534 .text()
535 .await
536 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
537 return Err(CompletionError::ProviderError(err_text));
538 }
539
540 let stream = Box::pin(stream! {
541 let mut stream = response.bytes_stream();
542 while let Some(chunk_result) = stream.next().await {
543 let chunk = match chunk_result {
544 Ok(c) => c,
545 Err(e) => {
546 yield Err(CompletionError::from(e));
547 break;
548 }
549 };
550
551 let text = match String::from_utf8(chunk.to_vec()) {
552 Ok(t) => t,
553 Err(e) => {
554 yield Err(CompletionError::ResponseError(e.to_string()));
555 break;
556 }
557 };
558
559
560 for line in text.lines() {
561 let line = line.to_string();
562
563 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
564 continue;
565 };
566
567 match response.message {
568 Message::Assistant{ content, tool_calls, .. } => {
569 if !content.is_empty() {
570 yield Ok(RawStreamingChoice::Message(content))
571 }
572
573 for tool_call in tool_calls.iter() {
574 let function = tool_call.function.clone();
575
576 yield Ok(RawStreamingChoice::ToolCall {
577 id: "".to_string(),
578 name: function.name,
579 arguments: function.arguments,
580 call_id: None
581 });
582 }
583 }
584 _ => {
585 continue;
586 }
587 }
588
589 if response.done {
590 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
591 total_duration: response.total_duration,
592 load_duration: response.load_duration,
593 prompt_eval_count: response.prompt_eval_count,
594 prompt_eval_duration: response.prompt_eval_duration,
595 eval_count: response.eval_count,
596 eval_duration: response.eval_duration,
597 done_reason: response.done_reason,
598 }));
599 }
600 }
601 }
602 });
603
604 Ok(streaming::StreamingCompletionResponse::stream(stream))
605 }
606}
607
608#[derive(Clone, Debug, Deserialize, Serialize)]
612pub struct ToolDefinition {
613 #[serde(rename = "type")]
614 pub type_field: String, pub function: completion::ToolDefinition,
616}
617
618impl From<crate::completion::ToolDefinition> for ToolDefinition {
620 fn from(tool: crate::completion::ToolDefinition) -> Self {
621 ToolDefinition {
622 type_field: "function".to_owned(),
623 function: completion::ToolDefinition {
624 name: tool.name,
625 description: tool.description,
626 parameters: tool.parameters,
627 },
628 }
629 }
630}
631
632#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
633pub struct ToolCall {
634 #[serde(default, rename = "type")]
636 pub r#type: ToolType,
637 pub function: Function,
638}
639#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
640#[serde(rename_all = "lowercase")]
641pub enum ToolType {
642 #[default]
643 Function,
644}
645#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
646pub struct Function {
647 pub name: String,
648 pub arguments: Value,
649}
650
651#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
654#[serde(tag = "role", rename_all = "lowercase")]
655pub enum Message {
656 User {
657 content: String,
658 #[serde(skip_serializing_if = "Option::is_none")]
659 images: Option<Vec<String>>,
660 #[serde(skip_serializing_if = "Option::is_none")]
661 name: Option<String>,
662 },
663 Assistant {
664 #[serde(default)]
665 content: String,
666 #[serde(skip_serializing_if = "Option::is_none")]
667 thinking: Option<String>,
668 #[serde(skip_serializing_if = "Option::is_none")]
669 images: Option<Vec<String>>,
670 #[serde(skip_serializing_if = "Option::is_none")]
671 name: Option<String>,
672 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
673 tool_calls: Vec<ToolCall>,
674 },
675 System {
676 content: String,
677 #[serde(skip_serializing_if = "Option::is_none")]
678 images: Option<Vec<String>>,
679 #[serde(skip_serializing_if = "Option::is_none")]
680 name: Option<String>,
681 },
682 #[serde(rename = "tool")]
683 ToolResult {
684 #[serde(rename = "tool_name")]
685 name: String,
686 content: String,
687 },
688}
689
690impl TryFrom<crate::message::Message> for Vec<Message> {
696 type Error = crate::message::MessageError;
697 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
698 use crate::message::Message as InternalMessage;
699 match internal_msg {
700 InternalMessage::User { content, .. } => {
701 let (tool_results, other_content): (Vec<_>, Vec<_>) =
702 content.into_iter().partition(|content| {
703 matches!(content, crate::message::UserContent::ToolResult(_))
704 });
705
706 if !tool_results.is_empty() {
707 tool_results
708 .into_iter()
709 .map(|content| match content {
710 crate::message::UserContent::ToolResult(
711 crate::message::ToolResult { id, content, .. },
712 ) => {
713 let content_string = content
715 .into_iter()
716 .map(|content| match content {
717 crate::message::ToolResultContent::Text(text) => text.text,
718 _ => "[Non-text content]".to_string(),
719 })
720 .collect::<Vec<_>>()
721 .join("\n");
722
723 Ok::<_, crate::message::MessageError>(Message::ToolResult {
724 name: id,
725 content: content_string,
726 })
727 }
728 _ => unreachable!(),
729 })
730 .collect::<Result<Vec<_>, _>>()
731 } else {
732 let (texts, images) = other_content.into_iter().fold(
734 (Vec::new(), Vec::new()),
735 |(mut texts, mut images), content| {
736 match content {
737 crate::message::UserContent::Text(crate::message::Text {
738 text,
739 }) => texts.push(text),
740 crate::message::UserContent::Image(crate::message::Image {
741 data,
742 ..
743 }) => images.push(data),
744 _ => {} }
746 (texts, images)
747 },
748 );
749
750 Ok(vec![Message::User {
751 content: texts.join(" "),
752 images: if images.is_empty() {
753 None
754 } else {
755 Some(images)
756 },
757 name: None,
758 }])
759 }
760 }
761 InternalMessage::Assistant { content, .. } => {
762 let mut thinking: Option<String> = None;
763 let (text_content, tool_calls) = content.into_iter().fold(
764 (Vec::new(), Vec::new()),
765 |(mut texts, mut tools), content| {
766 match content {
767 crate::message::AssistantContent::Text(text) => texts.push(text.text),
768 crate::message::AssistantContent::ToolCall(tool_call) => {
769 tools.push(tool_call)
770 }
771 crate::message::AssistantContent::Reasoning(
772 crate::message::Reasoning { reasoning },
773 ) => {
774 thinking = Some(reasoning);
775 }
776 }
777 (texts, tools)
778 },
779 );
780
781 Ok(vec![Message::Assistant {
784 content: text_content.join(" "),
785 thinking,
786 images: None,
787 name: None,
788 tool_calls: tool_calls
789 .into_iter()
790 .map(|tool_call| tool_call.into())
791 .collect::<Vec<_>>(),
792 }])
793 }
794 }
795 }
796}
797
798impl From<Message> for crate::completion::Message {
801 fn from(msg: Message) -> Self {
802 match msg {
803 Message::User { content, .. } => crate::completion::Message::User {
804 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
805 text: content,
806 })),
807 },
808 Message::Assistant {
809 content,
810 tool_calls,
811 ..
812 } => {
813 let mut assistant_contents =
814 vec![crate::completion::message::AssistantContent::Text(Text {
815 text: content,
816 })];
817 for tc in tool_calls {
818 assistant_contents.push(
819 crate::completion::message::AssistantContent::tool_call(
820 tc.function.name.clone(),
821 tc.function.name,
822 tc.function.arguments,
823 ),
824 );
825 }
826 crate::completion::Message::Assistant {
827 id: None,
828 content: OneOrMany::many(assistant_contents).unwrap(),
829 }
830 }
831 Message::System { content, .. } => crate::completion::Message::User {
833 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
834 text: content,
835 })),
836 },
837 Message::ToolResult { name, content } => crate::completion::Message::User {
838 content: OneOrMany::one(message::UserContent::tool_result(
839 name,
840 OneOrMany::one(message::ToolResultContent::text(content)),
841 )),
842 },
843 }
844 }
845}
846
847impl Message {
848 pub fn system(content: &str) -> Self {
850 Message::System {
851 content: content.to_owned(),
852 images: None,
853 name: None,
854 }
855 }
856}
857
858impl From<crate::message::ToolCall> for ToolCall {
861 fn from(tool_call: crate::message::ToolCall) -> Self {
862 Self {
863 r#type: ToolType::Function,
864 function: Function {
865 name: tool_call.function.name,
866 arguments: tool_call.function.arguments,
867 },
868 }
869 }
870}
871
872#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
873pub struct SystemContent {
874 #[serde(default)]
875 r#type: SystemContentType,
876 text: String,
877}
878
879#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
880#[serde(rename_all = "lowercase")]
881pub enum SystemContentType {
882 #[default]
883 Text,
884}
885
886impl From<String> for SystemContent {
887 fn from(s: String) -> Self {
888 SystemContent {
889 r#type: SystemContentType::default(),
890 text: s,
891 }
892 }
893}
894
895impl FromStr for SystemContent {
896 type Err = std::convert::Infallible;
897 fn from_str(s: &str) -> Result<Self, Self::Err> {
898 Ok(SystemContent {
899 r#type: SystemContentType::default(),
900 text: s.to_string(),
901 })
902 }
903}
904
905#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
906pub struct AssistantContent {
907 pub text: String,
908}
909
910impl FromStr for AssistantContent {
911 type Err = std::convert::Infallible;
912 fn from_str(s: &str) -> Result<Self, Self::Err> {
913 Ok(AssistantContent { text: s.to_owned() })
914 }
915}
916
917#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
918#[serde(tag = "type", rename_all = "lowercase")]
919pub enum UserContent {
920 Text { text: String },
921 Image { image_url: ImageUrl },
922 }
924
925impl FromStr for UserContent {
926 type Err = std::convert::Infallible;
927 fn from_str(s: &str) -> Result<Self, Self::Err> {
928 Ok(UserContent::Text { text: s.to_owned() })
929 }
930}
931
932#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
933pub struct ImageUrl {
934 pub url: String,
935 #[serde(default)]
936 pub detail: ImageDetail,
937}
938
939#[cfg(test)]
944mod tests {
945 use super::*;
946 use serde_json::json;
947
948 #[tokio::test]
950 async fn test_chat_completion() {
951 let sample_chat_response = json!({
953 "model": "llama3.2",
954 "created_at": "2023-08-04T19:22:45.499127Z",
955 "message": {
956 "role": "assistant",
957 "content": "The sky is blue because of Rayleigh scattering.",
958 "images": null,
959 "tool_calls": [
960 {
961 "type": "function",
962 "function": {
963 "name": "get_current_weather",
964 "arguments": {
965 "location": "San Francisco, CA",
966 "format": "celsius"
967 }
968 }
969 }
970 ]
971 },
972 "done": true,
973 "total_duration": 8000000000u64,
974 "load_duration": 6000000u64,
975 "prompt_eval_count": 61u64,
976 "prompt_eval_duration": 400000000u64,
977 "eval_count": 468u64,
978 "eval_duration": 7700000000u64
979 });
980 let sample_text = sample_chat_response.to_string();
981
982 let chat_resp: CompletionResponse =
983 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
984 let conv: completion::CompletionResponse<CompletionResponse> =
985 chat_resp.try_into().unwrap();
986 assert!(
987 !conv.choice.is_empty(),
988 "Expected non-empty choice in chat response"
989 );
990 }
991
992 #[test]
994 fn test_message_conversion() {
995 let provider_msg = Message::User {
997 content: "Test message".to_owned(),
998 images: None,
999 name: None,
1000 };
1001 let comp_msg: crate::completion::Message = provider_msg.into();
1003 match comp_msg {
1004 crate::completion::Message::User { content } => {
1005 let first_content = content.first();
1007 match first_content {
1009 crate::completion::message::UserContent::Text(text_struct) => {
1010 assert_eq!(text_struct.text, "Test message");
1011 }
1012 _ => panic!("Expected text content in conversion"),
1013 }
1014 }
1015 _ => panic!("Conversion from provider Message to completion Message failed"),
1016 }
1017 }
1018
1019 #[test]
1021 fn test_tool_definition_conversion() {
1022 let internal_tool = crate::completion::ToolDefinition {
1024 name: "get_current_weather".to_owned(),
1025 description: "Get the current weather for a location".to_owned(),
1026 parameters: json!({
1027 "type": "object",
1028 "properties": {
1029 "location": {
1030 "type": "string",
1031 "description": "The location to get the weather for, e.g. San Francisco, CA"
1032 },
1033 "format": {
1034 "type": "string",
1035 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1036 "enum": ["celsius", "fahrenheit"]
1037 }
1038 },
1039 "required": ["location", "format"]
1040 }),
1041 };
1042 let ollama_tool: ToolDefinition = internal_tool.into();
1044 assert_eq!(ollama_tool.type_field, "function");
1045 assert_eq!(ollama_tool.function.name, "get_current_weather");
1046 assert_eq!(
1047 ollama_tool.function.description,
1048 "Get the current weather for a location"
1049 );
1050 let params = &ollama_tool.function.parameters;
1052 assert_eq!(params["properties"]["location"]["type"], "string");
1053 }
1054}