1use crate::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient};
42use crate::completion::{GetTokenUsage, Usage};
43use crate::json_utils::merge_inplace;
44use crate::streaming::RawStreamingChoice;
45use crate::{
46 Embed, OneOrMany,
47 completion::{self, CompletionError, CompletionRequest},
48 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
49 impl_conversion_traits, json_utils, message,
50 message::{ImageDetail, Text},
51 streaming,
52};
53use async_stream::stream;
54use futures::StreamExt;
55use reqwest;
56use serde::{Deserialize, Serialize};
57use serde_json::{Value, json};
58use std::{convert::TryFrom, str::FromStr};
59use url::Url;
60const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64pub struct ClientBuilder<'a> {
65 base_url: &'a str,
66 http_client: Option<reqwest::Client>,
67}
68
69impl<'a> ClientBuilder<'a> {
70 #[allow(clippy::new_without_default)]
71 pub fn new() -> Self {
72 Self {
73 base_url: OLLAMA_API_BASE_URL,
74 http_client: None,
75 }
76 }
77
78 pub fn base_url(mut self, base_url: &'a str) -> Self {
79 self.base_url = base_url;
80 self
81 }
82
83 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
84 self.http_client = Some(client);
85 self
86 }
87
88 pub fn build(self) -> Result<Client, ClientBuilderError> {
89 let http_client = if let Some(http_client) = self.http_client {
90 http_client
91 } else {
92 reqwest::Client::builder().build()?
93 };
94
95 Ok(Client {
96 base_url: Url::parse(self.base_url)
97 .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
98 http_client,
99 })
100 }
101}
102
103#[derive(Clone, Debug)]
104pub struct Client {
105 base_url: Url,
106 http_client: reqwest::Client,
107}
108
109impl Default for Client {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl Client {
116 pub fn builder() -> ClientBuilder<'static> {
127 ClientBuilder::new()
128 }
129
130 pub fn new() -> Self {
135 Self::builder().build().expect("Ollama client should build")
136 }
137
138 pub fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
139 let url = self.base_url.join(path)?;
140 Ok(self.http_client.post(url))
141 }
142}
143
144impl ProviderClient for Client {
145 fn from_env() -> Self
146 where
147 Self: Sized,
148 {
149 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
150 Self::builder().base_url(&api_base).build().unwrap()
151 }
152
153 fn from_val(input: crate::client::ProviderValue) -> Self {
154 let crate::client::ProviderValue::Simple(_) = input else {
155 panic!("Incorrect provider value type")
156 };
157
158 Self::new()
159 }
160}
161
162impl CompletionClient for Client {
163 type CompletionModel = CompletionModel;
164
165 fn completion_model(&self, model: &str) -> CompletionModel {
166 CompletionModel::new(self.clone(), model)
167 }
168}
169
170impl EmbeddingsClient for Client {
171 type EmbeddingModel = EmbeddingModel;
172 fn embedding_model(&self, model: &str) -> EmbeddingModel {
173 EmbeddingModel::new(self.clone(), model, 0)
174 }
175 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
176 EmbeddingModel::new(self.clone(), model, ndims)
177 }
178 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
179 EmbeddingsBuilder::new(self.embedding_model(model))
180 }
181}
182
183impl_conversion_traits!(
184 AsTranscription,
185 AsImageGeneration,
186 AsAudioGeneration for Client
187);
188
189#[derive(Debug, Deserialize)]
192struct ApiErrorResponse {
193 message: String,
194}
195
196#[derive(Debug, Deserialize)]
197#[serde(untagged)]
198enum ApiResponse<T> {
199 Ok(T),
200 Err(ApiErrorResponse),
201}
202
203pub const ALL_MINILM: &str = "all-minilm";
206pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
207
208#[derive(Debug, Serialize, Deserialize)]
209pub struct EmbeddingResponse {
210 pub model: String,
211 pub embeddings: Vec<Vec<f64>>,
212 #[serde(default)]
213 pub total_duration: Option<u64>,
214 #[serde(default)]
215 pub load_duration: Option<u64>,
216 #[serde(default)]
217 pub prompt_eval_count: Option<u64>,
218}
219
220impl From<ApiErrorResponse> for EmbeddingError {
221 fn from(err: ApiErrorResponse) -> Self {
222 EmbeddingError::ProviderError(err.message)
223 }
224}
225
226impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
227 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
228 match value {
229 ApiResponse::Ok(response) => Ok(response),
230 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
231 }
232 }
233}
234
235#[derive(Clone)]
238pub struct EmbeddingModel {
239 client: Client,
240 pub model: String,
241 ndims: usize,
242}
243
244impl EmbeddingModel {
245 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
246 Self {
247 client,
248 model: model.to_owned(),
249 ndims,
250 }
251 }
252}
253
254impl embeddings::EmbeddingModel for EmbeddingModel {
255 const MAX_DOCUMENTS: usize = 1024;
256 fn ndims(&self) -> usize {
257 self.ndims
258 }
259 #[cfg_attr(feature = "worker", worker::send)]
260 async fn embed_texts(
261 &self,
262 documents: impl IntoIterator<Item = String>,
263 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
264 let docs: Vec<String> = documents.into_iter().collect();
265 let payload = json!({
266 "model": self.model,
267 "input": docs,
268 });
269 let response = self
270 .client
271 .post("api/embed")?
272 .json(&payload)
273 .send()
274 .await
275 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
276 if response.status().is_success() {
277 let api_resp: EmbeddingResponse = response
278 .json()
279 .await
280 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
281 if api_resp.embeddings.len() != docs.len() {
282 return Err(EmbeddingError::ResponseError(
283 "Number of returned embeddings does not match input".into(),
284 ));
285 }
286 Ok(api_resp
287 .embeddings
288 .into_iter()
289 .zip(docs.into_iter())
290 .map(|(vec, document)| embeddings::Embedding { document, vec })
291 .collect())
292 } else {
293 Err(EmbeddingError::ProviderError(response.text().await?))
294 }
295 }
296}
297
298pub const LLAMA3_2: &str = "llama3.2";
301pub const LLAVA: &str = "llava";
302pub const MISTRAL: &str = "mistral";
303
304#[derive(Debug, Serialize, Deserialize)]
305pub struct CompletionResponse {
306 pub model: String,
307 pub created_at: String,
308 pub message: Message,
309 pub done: bool,
310 #[serde(default)]
311 pub done_reason: Option<String>,
312 #[serde(default)]
313 pub total_duration: Option<u64>,
314 #[serde(default)]
315 pub load_duration: Option<u64>,
316 #[serde(default)]
317 pub prompt_eval_count: Option<u64>,
318 #[serde(default)]
319 pub prompt_eval_duration: Option<u64>,
320 #[serde(default)]
321 pub eval_count: Option<u64>,
322 #[serde(default)]
323 pub eval_duration: Option<u64>,
324}
325impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
326 type Error = CompletionError;
327 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
328 match resp.message {
329 Message::Assistant {
331 content,
332 thinking,
333 tool_calls,
334 ..
335 } => {
336 let mut assistant_contents = Vec::new();
337 if !content.is_empty() {
339 assistant_contents.push(completion::AssistantContent::text(&content));
340 }
341 for tc in tool_calls.iter() {
344 assistant_contents.push(completion::AssistantContent::tool_call(
345 tc.function.name.clone(),
346 tc.function.name.clone(),
347 tc.function.arguments.clone(),
348 ));
349 }
350 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
351 CompletionError::ResponseError("No content provided".to_owned())
352 })?;
353 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
354 let completion_tokens = resp.eval_count.unwrap_or(0);
355
356 let raw_response = CompletionResponse {
357 model: resp.model,
358 created_at: resp.created_at,
359 done: resp.done,
360 done_reason: resp.done_reason,
361 total_duration: resp.total_duration,
362 load_duration: resp.load_duration,
363 prompt_eval_count: resp.prompt_eval_count,
364 prompt_eval_duration: resp.prompt_eval_duration,
365 eval_count: resp.eval_count,
366 eval_duration: resp.eval_duration,
367 message: Message::Assistant {
368 content,
369 thinking,
370 images: None,
371 name: None,
372 tool_calls,
373 },
374 };
375
376 Ok(completion::CompletionResponse {
377 choice,
378 usage: Usage {
379 input_tokens: prompt_tokens,
380 output_tokens: completion_tokens,
381 total_tokens: prompt_tokens + completion_tokens,
382 },
383 raw_response,
384 })
385 }
386 _ => Err(CompletionError::ResponseError(
387 "Chat response does not include an assistant message".into(),
388 )),
389 }
390 }
391}
392
393#[derive(Clone)]
396pub struct CompletionModel {
397 client: Client,
398 pub model: String,
399}
400
401impl CompletionModel {
402 pub fn new(client: Client, model: &str) -> Self {
403 Self {
404 client,
405 model: model.to_owned(),
406 }
407 }
408
409 fn create_completion_request(
410 &self,
411 completion_request: CompletionRequest,
412 ) -> Result<Value, CompletionError> {
413 let mut partial_history = vec![];
415 if let Some(docs) = completion_request.normalized_documents() {
416 partial_history.push(docs);
417 }
418 partial_history.extend(completion_request.chat_history);
419
420 let mut full_history: Vec<Message> = completion_request
422 .preamble
423 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
424
425 full_history.extend(
427 partial_history
428 .into_iter()
429 .map(|msg| msg.try_into())
430 .collect::<Result<Vec<Vec<Message>>, _>>()?
431 .into_iter()
432 .flatten()
433 .collect::<Vec<Message>>(),
434 );
435
436 let options = if let Some(extra) = completion_request.additional_params {
438 json_utils::merge(
439 json!({ "temperature": completion_request.temperature }),
440 extra,
441 )
442 } else {
443 json!({ "temperature": completion_request.temperature })
444 };
445
446 let mut request_payload = json!({
447 "model": self.model,
448 "messages": full_history,
449 "options": options,
450 "stream": false,
451 });
452 if !completion_request.tools.is_empty() {
453 request_payload["tools"] = json!(
454 completion_request
455 .tools
456 .into_iter()
457 .map(|tool| tool.into())
458 .collect::<Vec<ToolDefinition>>()
459 );
460 }
461
462 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
463
464 Ok(request_payload)
465 }
466}
467
468#[derive(Clone, Serialize, Deserialize, Debug)]
471pub struct StreamingCompletionResponse {
472 pub done_reason: Option<String>,
473 pub total_duration: Option<u64>,
474 pub load_duration: Option<u64>,
475 pub prompt_eval_count: Option<u64>,
476 pub prompt_eval_duration: Option<u64>,
477 pub eval_count: Option<u64>,
478 pub eval_duration: Option<u64>,
479}
480
481impl GetTokenUsage for StreamingCompletionResponse {
482 fn token_usage(&self) -> Option<crate::completion::Usage> {
483 let mut usage = crate::completion::Usage::new();
484 let input_tokens = self.prompt_eval_count.unwrap_or_default();
485 let output_tokens = self.eval_count.unwrap_or_default();
486 usage.input_tokens = input_tokens;
487 usage.output_tokens = output_tokens;
488 usage.total_tokens = input_tokens + output_tokens;
489
490 Some(usage)
491 }
492}
493
494impl completion::CompletionModel for CompletionModel {
495 type Response = CompletionResponse;
496 type StreamingResponse = StreamingCompletionResponse;
497
498 #[cfg_attr(feature = "worker", worker::send)]
499 async fn completion(
500 &self,
501 completion_request: CompletionRequest,
502 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
503 let request_payload = self.create_completion_request(completion_request)?;
504
505 let response = self
506 .client
507 .post("api/chat")?
508 .json(&request_payload)
509 .send()
510 .await
511 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
512 if response.status().is_success() {
513 let text = response
514 .text()
515 .await
516 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
517 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
518 let chat_resp: CompletionResponse = serde_json::from_str(&text)
519 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
520 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
521 Ok(conv)
522 } else {
523 let err_text = response
524 .text()
525 .await
526 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
527 Err(CompletionError::ProviderError(err_text))
528 }
529 }
530
531 #[cfg_attr(feature = "worker", worker::send)]
532 async fn stream(
533 &self,
534 request: CompletionRequest,
535 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
536 {
537 let mut request_payload = self.create_completion_request(request)?;
538 merge_inplace(&mut request_payload, json!({"stream": true}));
539
540 let response = self
541 .client
542 .post("api/chat")?
543 .json(&request_payload)
544 .send()
545 .await
546 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
547
548 if !response.status().is_success() {
549 let err_text = response
550 .text()
551 .await
552 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
553 return Err(CompletionError::ProviderError(err_text));
554 }
555
556 let stream = Box::pin(stream! {
557 let mut stream = response.bytes_stream();
558 while let Some(chunk_result) = stream.next().await {
559 let chunk = match chunk_result {
560 Ok(c) => c,
561 Err(e) => {
562 yield Err(CompletionError::from(e));
563 break;
564 }
565 };
566
567 let text = match String::from_utf8(chunk.to_vec()) {
568 Ok(t) => t,
569 Err(e) => {
570 yield Err(CompletionError::ResponseError(e.to_string()));
571 break;
572 }
573 };
574
575
576 for line in text.lines() {
577 let line = line.to_string();
578
579 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
580 continue;
581 };
582
583 match response.message {
584 Message::Assistant{ content, tool_calls, .. } => {
585 if !content.is_empty() {
586 yield Ok(RawStreamingChoice::Message(content))
587 }
588
589 for tool_call in tool_calls.iter() {
590 let function = tool_call.function.clone();
591
592 yield Ok(RawStreamingChoice::ToolCall {
593 id: "".to_string(),
594 name: function.name,
595 arguments: function.arguments,
596 call_id: None
597 });
598 }
599 }
600 _ => {
601 continue;
602 }
603 }
604
605 if response.done {
606 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
607 total_duration: response.total_duration,
608 load_duration: response.load_duration,
609 prompt_eval_count: response.prompt_eval_count,
610 prompt_eval_duration: response.prompt_eval_duration,
611 eval_count: response.eval_count,
612 eval_duration: response.eval_duration,
613 done_reason: response.done_reason,
614 }));
615 }
616 }
617 }
618 });
619
620 Ok(streaming::StreamingCompletionResponse::stream(stream))
621 }
622}
623
624#[derive(Clone, Debug, Deserialize, Serialize)]
628pub struct ToolDefinition {
629 #[serde(rename = "type")]
630 pub type_field: String, pub function: completion::ToolDefinition,
632}
633
634impl From<crate::completion::ToolDefinition> for ToolDefinition {
636 fn from(tool: crate::completion::ToolDefinition) -> Self {
637 ToolDefinition {
638 type_field: "function".to_owned(),
639 function: completion::ToolDefinition {
640 name: tool.name,
641 description: tool.description,
642 parameters: tool.parameters,
643 },
644 }
645 }
646}
647
648#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
649pub struct ToolCall {
650 #[serde(default, rename = "type")]
652 pub r#type: ToolType,
653 pub function: Function,
654}
655#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
656#[serde(rename_all = "lowercase")]
657pub enum ToolType {
658 #[default]
659 Function,
660}
661#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
662pub struct Function {
663 pub name: String,
664 pub arguments: Value,
665}
666
667#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
670#[serde(tag = "role", rename_all = "lowercase")]
671pub enum Message {
672 User {
673 content: String,
674 #[serde(skip_serializing_if = "Option::is_none")]
675 images: Option<Vec<String>>,
676 #[serde(skip_serializing_if = "Option::is_none")]
677 name: Option<String>,
678 },
679 Assistant {
680 #[serde(default)]
681 content: String,
682 #[serde(skip_serializing_if = "Option::is_none")]
683 thinking: Option<String>,
684 #[serde(skip_serializing_if = "Option::is_none")]
685 images: Option<Vec<String>>,
686 #[serde(skip_serializing_if = "Option::is_none")]
687 name: Option<String>,
688 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
689 tool_calls: Vec<ToolCall>,
690 },
691 System {
692 content: String,
693 #[serde(skip_serializing_if = "Option::is_none")]
694 images: Option<Vec<String>>,
695 #[serde(skip_serializing_if = "Option::is_none")]
696 name: Option<String>,
697 },
698 #[serde(rename = "tool")]
699 ToolResult {
700 #[serde(rename = "tool_name")]
701 name: String,
702 content: String,
703 },
704}
705
706impl TryFrom<crate::message::Message> for Vec<Message> {
712 type Error = crate::message::MessageError;
713 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
714 use crate::message::Message as InternalMessage;
715 match internal_msg {
716 InternalMessage::User { content, .. } => {
717 let (tool_results, other_content): (Vec<_>, Vec<_>) =
718 content.into_iter().partition(|content| {
719 matches!(content, crate::message::UserContent::ToolResult(_))
720 });
721
722 if !tool_results.is_empty() {
723 tool_results
724 .into_iter()
725 .map(|content| match content {
726 crate::message::UserContent::ToolResult(
727 crate::message::ToolResult { id, content, .. },
728 ) => {
729 let content_string = content
731 .into_iter()
732 .map(|content| match content {
733 crate::message::ToolResultContent::Text(text) => text.text,
734 _ => "[Non-text content]".to_string(),
735 })
736 .collect::<Vec<_>>()
737 .join("\n");
738
739 Ok::<_, crate::message::MessageError>(Message::ToolResult {
740 name: id,
741 content: content_string,
742 })
743 }
744 _ => unreachable!(),
745 })
746 .collect::<Result<Vec<_>, _>>()
747 } else {
748 let (texts, images) = other_content.into_iter().fold(
750 (Vec::new(), Vec::new()),
751 |(mut texts, mut images), content| {
752 match content {
753 crate::message::UserContent::Text(crate::message::Text {
754 text,
755 }) => texts.push(text),
756 crate::message::UserContent::Image(crate::message::Image {
757 data,
758 ..
759 }) => images.push(data),
760 crate::message::UserContent::Document(
761 crate::message::Document { data, .. },
762 ) => texts.push(data),
763 _ => {} }
765 (texts, images)
766 },
767 );
768
769 Ok(vec![Message::User {
770 content: texts.join(" "),
771 images: if images.is_empty() {
772 None
773 } else {
774 Some(images)
775 },
776 name: None,
777 }])
778 }
779 }
780 InternalMessage::Assistant { content, .. } => {
781 let mut thinking: Option<String> = None;
782 let (text_content, tool_calls) = content.into_iter().fold(
783 (Vec::new(), Vec::new()),
784 |(mut texts, mut tools), content| {
785 match content {
786 crate::message::AssistantContent::Text(text) => texts.push(text.text),
787 crate::message::AssistantContent::ToolCall(tool_call) => {
788 tools.push(tool_call)
789 }
790 crate::message::AssistantContent::Reasoning(
791 crate::message::Reasoning { reasoning, .. },
792 ) => {
793 thinking =
794 Some(reasoning.first().cloned().unwrap_or(String::new()));
795 }
796 }
797 (texts, tools)
798 },
799 );
800
801 Ok(vec![Message::Assistant {
804 content: text_content.join(" "),
805 thinking,
806 images: None,
807 name: None,
808 tool_calls: tool_calls
809 .into_iter()
810 .map(|tool_call| tool_call.into())
811 .collect::<Vec<_>>(),
812 }])
813 }
814 }
815 }
816}
817
818impl From<Message> for crate::completion::Message {
821 fn from(msg: Message) -> Self {
822 match msg {
823 Message::User { content, .. } => crate::completion::Message::User {
824 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
825 text: content,
826 })),
827 },
828 Message::Assistant {
829 content,
830 tool_calls,
831 ..
832 } => {
833 let mut assistant_contents =
834 vec![crate::completion::message::AssistantContent::Text(Text {
835 text: content,
836 })];
837 for tc in tool_calls {
838 assistant_contents.push(
839 crate::completion::message::AssistantContent::tool_call(
840 tc.function.name.clone(),
841 tc.function.name,
842 tc.function.arguments,
843 ),
844 );
845 }
846 crate::completion::Message::Assistant {
847 id: None,
848 content: OneOrMany::many(assistant_contents).unwrap(),
849 }
850 }
851 Message::System { content, .. } => crate::completion::Message::User {
853 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
854 text: content,
855 })),
856 },
857 Message::ToolResult { name, content } => crate::completion::Message::User {
858 content: OneOrMany::one(message::UserContent::tool_result(
859 name,
860 OneOrMany::one(message::ToolResultContent::text(content)),
861 )),
862 },
863 }
864 }
865}
866
867impl Message {
868 pub fn system(content: &str) -> Self {
870 Message::System {
871 content: content.to_owned(),
872 images: None,
873 name: None,
874 }
875 }
876}
877
878impl From<crate::message::ToolCall> for ToolCall {
881 fn from(tool_call: crate::message::ToolCall) -> Self {
882 Self {
883 r#type: ToolType::Function,
884 function: Function {
885 name: tool_call.function.name,
886 arguments: tool_call.function.arguments,
887 },
888 }
889 }
890}
891
892#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
893pub struct SystemContent {
894 #[serde(default)]
895 r#type: SystemContentType,
896 text: String,
897}
898
899#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
900#[serde(rename_all = "lowercase")]
901pub enum SystemContentType {
902 #[default]
903 Text,
904}
905
906impl From<String> for SystemContent {
907 fn from(s: String) -> Self {
908 SystemContent {
909 r#type: SystemContentType::default(),
910 text: s,
911 }
912 }
913}
914
915impl FromStr for SystemContent {
916 type Err = std::convert::Infallible;
917 fn from_str(s: &str) -> Result<Self, Self::Err> {
918 Ok(SystemContent {
919 r#type: SystemContentType::default(),
920 text: s.to_string(),
921 })
922 }
923}
924
925#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
926pub struct AssistantContent {
927 pub text: String,
928}
929
930impl FromStr for AssistantContent {
931 type Err = std::convert::Infallible;
932 fn from_str(s: &str) -> Result<Self, Self::Err> {
933 Ok(AssistantContent { text: s.to_owned() })
934 }
935}
936
937#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
938#[serde(tag = "type", rename_all = "lowercase")]
939pub enum UserContent {
940 Text { text: String },
941 Image { image_url: ImageUrl },
942 }
944
945impl FromStr for UserContent {
946 type Err = std::convert::Infallible;
947 fn from_str(s: &str) -> Result<Self, Self::Err> {
948 Ok(UserContent::Text { text: s.to_owned() })
949 }
950}
951
952#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
953pub struct ImageUrl {
954 pub url: String,
955 #[serde(default)]
956 pub detail: ImageDetail,
957}
958
959#[cfg(test)]
964mod tests {
965 use super::*;
966 use serde_json::json;
967
968 #[tokio::test]
970 async fn test_chat_completion() {
971 let sample_chat_response = json!({
973 "model": "llama3.2",
974 "created_at": "2023-08-04T19:22:45.499127Z",
975 "message": {
976 "role": "assistant",
977 "content": "The sky is blue because of Rayleigh scattering.",
978 "images": null,
979 "tool_calls": [
980 {
981 "type": "function",
982 "function": {
983 "name": "get_current_weather",
984 "arguments": {
985 "location": "San Francisco, CA",
986 "format": "celsius"
987 }
988 }
989 }
990 ]
991 },
992 "done": true,
993 "total_duration": 8000000000u64,
994 "load_duration": 6000000u64,
995 "prompt_eval_count": 61u64,
996 "prompt_eval_duration": 400000000u64,
997 "eval_count": 468u64,
998 "eval_duration": 7700000000u64
999 });
1000 let sample_text = sample_chat_response.to_string();
1001
1002 let chat_resp: CompletionResponse =
1003 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1004 let conv: completion::CompletionResponse<CompletionResponse> =
1005 chat_resp.try_into().unwrap();
1006 assert!(
1007 !conv.choice.is_empty(),
1008 "Expected non-empty choice in chat response"
1009 );
1010 }
1011
1012 #[test]
1014 fn test_message_conversion() {
1015 let provider_msg = Message::User {
1017 content: "Test message".to_owned(),
1018 images: None,
1019 name: None,
1020 };
1021 let comp_msg: crate::completion::Message = provider_msg.into();
1023 match comp_msg {
1024 crate::completion::Message::User { content } => {
1025 let first_content = content.first();
1027 match first_content {
1029 crate::completion::message::UserContent::Text(text_struct) => {
1030 assert_eq!(text_struct.text, "Test message");
1031 }
1032 _ => panic!("Expected text content in conversion"),
1033 }
1034 }
1035 _ => panic!("Conversion from provider Message to completion Message failed"),
1036 }
1037 }
1038
1039 #[test]
1041 fn test_tool_definition_conversion() {
1042 let internal_tool = crate::completion::ToolDefinition {
1044 name: "get_current_weather".to_owned(),
1045 description: "Get the current weather for a location".to_owned(),
1046 parameters: json!({
1047 "type": "object",
1048 "properties": {
1049 "location": {
1050 "type": "string",
1051 "description": "The location to get the weather for, e.g. San Francisco, CA"
1052 },
1053 "format": {
1054 "type": "string",
1055 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1056 "enum": ["celsius", "fahrenheit"]
1057 }
1058 },
1059 "required": ["location", "format"]
1060 }),
1061 };
1062 let ollama_tool: ToolDefinition = internal_tool.into();
1064 assert_eq!(ollama_tool.type_field, "function");
1065 assert_eq!(ollama_tool.function.name, "get_current_weather");
1066 assert_eq!(
1067 ollama_tool.function.description,
1068 "Get the current weather for a location"
1069 );
1070 let params = &ollama_tool.function.parameters;
1072 assert_eq!(params["properties"]["location"]["type"], "string");
1073 }
1074}