1use crate::client::{
42 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient,
43 VerifyError,
44};
45use crate::completion::{GetTokenUsage, Usage};
46use crate::json_utils::merge_inplace;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49 Embed, OneOrMany,
50 completion::{self, CompletionError, CompletionRequest},
51 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
52 impl_conversion_traits, json_utils, message,
53 message::{ImageDetail, Text},
54 streaming,
55};
56use async_stream::stream;
57use futures::StreamExt;
58use reqwest;
59use reqwest_eventsource::{Event, RequestBuilderExt};
60use serde::{Deserialize, Serialize};
61use serde_json::{Value, json};
62use std::{convert::TryFrom, str::FromStr};
63use url::Url;
64const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68pub struct ClientBuilder<'a> {
69 base_url: &'a str,
70 http_client: Option<reqwest::Client>,
71}
72
73impl<'a> ClientBuilder<'a> {
74 #[allow(clippy::new_without_default)]
75 pub fn new() -> Self {
76 Self {
77 base_url: OLLAMA_API_BASE_URL,
78 http_client: None,
79 }
80 }
81
82 pub fn base_url(mut self, base_url: &'a str) -> Self {
83 self.base_url = base_url;
84 self
85 }
86
87 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
88 self.http_client = Some(client);
89 self
90 }
91
92 pub fn build(self) -> Result<Client, ClientBuilderError> {
93 let http_client = if let Some(http_client) = self.http_client {
94 http_client
95 } else {
96 reqwest::Client::builder().build()?
97 };
98
99 Ok(Client {
100 base_url: Url::parse(self.base_url)
101 .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
102 http_client,
103 })
104 }
105}
106
107#[derive(Clone, Debug)]
108pub struct Client {
109 base_url: Url,
110 http_client: reqwest::Client,
111}
112
113impl Default for Client {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl Client {
120 pub fn builder() -> ClientBuilder<'static> {
131 ClientBuilder::new()
132 }
133
134 pub fn new() -> Self {
139 Self::builder().build().expect("Ollama client should build")
140 }
141
142 pub(crate) fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
143 let url = self.base_url.join(path)?;
144 Ok(self.http_client.post(url))
145 }
146
147 pub(crate) fn get(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
148 let url = self.base_url.join(path)?;
149 Ok(self.http_client.get(url))
150 }
151}
152
153impl ProviderClient for Client {
154 fn from_env() -> Self
155 where
156 Self: Sized,
157 {
158 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
159 Self::builder().base_url(&api_base).build().unwrap()
160 }
161
162 fn from_val(input: crate::client::ProviderValue) -> Self {
163 let crate::client::ProviderValue::Simple(_) = input else {
164 panic!("Incorrect provider value type")
165 };
166
167 Self::new()
168 }
169}
170
171impl CompletionClient for Client {
172 type CompletionModel = CompletionModel;
173
174 fn completion_model(&self, model: &str) -> CompletionModel {
175 CompletionModel::new(self.clone(), model)
176 }
177}
178
179impl EmbeddingsClient for Client {
180 type EmbeddingModel = EmbeddingModel;
181 fn embedding_model(&self, model: &str) -> EmbeddingModel {
182 EmbeddingModel::new(self.clone(), model, 0)
183 }
184 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
185 EmbeddingModel::new(self.clone(), model, ndims)
186 }
187 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
188 EmbeddingsBuilder::new(self.embedding_model(model))
189 }
190}
191
192impl VerifyClient for Client {
193 #[cfg_attr(feature = "worker", worker::send)]
194 async fn verify(&self) -> Result<(), VerifyError> {
195 let response = self
196 .get("api/tags")
197 .expect("Failed to build request")
198 .send()
199 .await?;
200 match response.status() {
201 reqwest::StatusCode::OK => Ok(()),
202 _ => {
203 response.error_for_status()?;
204 Ok(())
205 }
206 }
207 }
208}
209
210impl_conversion_traits!(
211 AsTranscription,
212 AsImageGeneration,
213 AsAudioGeneration for Client
214);
215
216#[derive(Debug, Deserialize)]
219struct ApiErrorResponse {
220 message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225enum ApiResponse<T> {
226 Ok(T),
227 Err(ApiErrorResponse),
228}
229
230pub const ALL_MINILM: &str = "all-minilm";
233pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
234
235#[derive(Debug, Serialize, Deserialize)]
236pub struct EmbeddingResponse {
237 pub model: String,
238 pub embeddings: Vec<Vec<f64>>,
239 #[serde(default)]
240 pub total_duration: Option<u64>,
241 #[serde(default)]
242 pub load_duration: Option<u64>,
243 #[serde(default)]
244 pub prompt_eval_count: Option<u64>,
245}
246
247impl From<ApiErrorResponse> for EmbeddingError {
248 fn from(err: ApiErrorResponse) -> Self {
249 EmbeddingError::ProviderError(err.message)
250 }
251}
252
253impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
254 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
255 match value {
256 ApiResponse::Ok(response) => Ok(response),
257 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
258 }
259 }
260}
261
262#[derive(Clone)]
265pub struct EmbeddingModel {
266 client: Client,
267 pub model: String,
268 ndims: usize,
269}
270
271impl EmbeddingModel {
272 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
273 Self {
274 client,
275 model: model.to_owned(),
276 ndims,
277 }
278 }
279}
280
281impl embeddings::EmbeddingModel for EmbeddingModel {
282 const MAX_DOCUMENTS: usize = 1024;
283 fn ndims(&self) -> usize {
284 self.ndims
285 }
286 #[cfg_attr(feature = "worker", worker::send)]
287 async fn embed_texts(
288 &self,
289 documents: impl IntoIterator<Item = String>,
290 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
291 let docs: Vec<String> = documents.into_iter().collect();
292 let payload = json!({
293 "model": self.model,
294 "input": docs,
295 });
296 let response = self
297 .client
298 .post("api/embed")?
299 .json(&payload)
300 .send()
301 .await
302 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
303 if response.status().is_success() {
304 let api_resp: EmbeddingResponse = response
305 .json()
306 .await
307 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
308 if api_resp.embeddings.len() != docs.len() {
309 return Err(EmbeddingError::ResponseError(
310 "Number of returned embeddings does not match input".into(),
311 ));
312 }
313 Ok(api_resp
314 .embeddings
315 .into_iter()
316 .zip(docs.into_iter())
317 .map(|(vec, document)| embeddings::Embedding { document, vec })
318 .collect())
319 } else {
320 Err(EmbeddingError::ProviderError(response.text().await?))
321 }
322 }
323}
324
325pub const LLAMA3_2: &str = "llama3.2";
328pub const LLAVA: &str = "llava";
329pub const MISTRAL: &str = "mistral";
330
331#[derive(Debug, Serialize, Deserialize)]
332pub struct CompletionResponse {
333 pub model: String,
334 pub created_at: String,
335 pub message: Message,
336 pub done: bool,
337 #[serde(default)]
338 pub done_reason: Option<String>,
339 #[serde(default)]
340 pub total_duration: Option<u64>,
341 #[serde(default)]
342 pub load_duration: Option<u64>,
343 #[serde(default)]
344 pub prompt_eval_count: Option<u64>,
345 #[serde(default)]
346 pub prompt_eval_duration: Option<u64>,
347 #[serde(default)]
348 pub eval_count: Option<u64>,
349 #[serde(default)]
350 pub eval_duration: Option<u64>,
351}
352impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
353 type Error = CompletionError;
354 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
355 match resp.message {
356 Message::Assistant {
358 content,
359 thinking,
360 tool_calls,
361 ..
362 } => {
363 let mut assistant_contents = Vec::new();
364 if !content.is_empty() {
366 assistant_contents.push(completion::AssistantContent::text(&content));
367 }
368 for tc in tool_calls.iter() {
371 assistant_contents.push(completion::AssistantContent::tool_call(
372 tc.function.name.clone(),
373 tc.function.name.clone(),
374 tc.function.arguments.clone(),
375 ));
376 }
377 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
378 CompletionError::ResponseError("No content provided".to_owned())
379 })?;
380 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
381 let completion_tokens = resp.eval_count.unwrap_or(0);
382
383 let raw_response = CompletionResponse {
384 model: resp.model,
385 created_at: resp.created_at,
386 done: resp.done,
387 done_reason: resp.done_reason,
388 total_duration: resp.total_duration,
389 load_duration: resp.load_duration,
390 prompt_eval_count: resp.prompt_eval_count,
391 prompt_eval_duration: resp.prompt_eval_duration,
392 eval_count: resp.eval_count,
393 eval_duration: resp.eval_duration,
394 message: Message::Assistant {
395 content,
396 thinking,
397 images: None,
398 name: None,
399 tool_calls,
400 },
401 };
402
403 Ok(completion::CompletionResponse {
404 choice,
405 usage: Usage {
406 input_tokens: prompt_tokens,
407 output_tokens: completion_tokens,
408 total_tokens: prompt_tokens + completion_tokens,
409 },
410 raw_response,
411 })
412 }
413 _ => Err(CompletionError::ResponseError(
414 "Chat response does not include an assistant message".into(),
415 )),
416 }
417 }
418}
419
420#[derive(Clone)]
423pub struct CompletionModel {
424 client: Client,
425 pub model: String,
426}
427
428impl CompletionModel {
429 pub fn new(client: Client, model: &str) -> Self {
430 Self {
431 client,
432 model: model.to_owned(),
433 }
434 }
435
436 fn create_completion_request(
437 &self,
438 completion_request: CompletionRequest,
439 ) -> Result<Value, CompletionError> {
440 let mut partial_history = vec![];
442 if let Some(docs) = completion_request.normalized_documents() {
443 partial_history.push(docs);
444 }
445 partial_history.extend(completion_request.chat_history);
446
447 let mut full_history: Vec<Message> = completion_request
449 .preamble
450 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
451
452 full_history.extend(
454 partial_history
455 .into_iter()
456 .map(|msg| msg.try_into())
457 .collect::<Result<Vec<Vec<Message>>, _>>()?
458 .into_iter()
459 .flatten()
460 .collect::<Vec<Message>>(),
461 );
462
463 let options = if let Some(extra) = completion_request.additional_params {
465 json_utils::merge(
466 json!({ "temperature": completion_request.temperature }),
467 extra,
468 )
469 } else {
470 json!({ "temperature": completion_request.temperature })
471 };
472
473 let mut request_payload = json!({
474 "model": self.model,
475 "messages": full_history,
476 "options": options,
477 "stream": false,
478 });
479 if !completion_request.tools.is_empty() {
480 request_payload["tools"] = json!(
481 completion_request
482 .tools
483 .into_iter()
484 .map(|tool| tool.into())
485 .collect::<Vec<ToolDefinition>>()
486 );
487 }
488
489 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
490
491 Ok(request_payload)
492 }
493}
494
495#[derive(Clone, Serialize, Deserialize, Debug)]
498pub struct StreamingCompletionResponse {
499 pub done_reason: Option<String>,
500 pub total_duration: Option<u64>,
501 pub load_duration: Option<u64>,
502 pub prompt_eval_count: Option<u64>,
503 pub prompt_eval_duration: Option<u64>,
504 pub eval_count: Option<u64>,
505 pub eval_duration: Option<u64>,
506}
507
508impl GetTokenUsage for StreamingCompletionResponse {
509 fn token_usage(&self) -> Option<crate::completion::Usage> {
510 let mut usage = crate::completion::Usage::new();
511 let input_tokens = self.prompt_eval_count.unwrap_or_default();
512 let output_tokens = self.eval_count.unwrap_or_default();
513 usage.input_tokens = input_tokens;
514 usage.output_tokens = output_tokens;
515 usage.total_tokens = input_tokens + output_tokens;
516
517 Some(usage)
518 }
519}
520
521impl completion::CompletionModel for CompletionModel {
522 type Response = CompletionResponse;
523 type StreamingResponse = StreamingCompletionResponse;
524
525 #[cfg_attr(feature = "worker", worker::send)]
526 async fn completion(
527 &self,
528 completion_request: CompletionRequest,
529 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
530 let request_payload = self.create_completion_request(completion_request)?;
531
532 let response = self
533 .client
534 .post("api/chat")?
535 .json(&request_payload)
536 .send()
537 .await
538 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
539 if response.status().is_success() {
540 let text = response
541 .text()
542 .await
543 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
544 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
545 let chat_resp: CompletionResponse = serde_json::from_str(&text)
546 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
547 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
548 Ok(conv)
549 } else {
550 let err_text = response
551 .text()
552 .await
553 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
554 Err(CompletionError::ProviderError(err_text))
555 }
556 }
557
558 #[cfg_attr(feature = "worker", worker::send)]
559 async fn stream(
560 &self,
561 request: CompletionRequest,
562 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
563 {
564 let mut request_payload = self.create_completion_request(request)?;
565 merge_inplace(&mut request_payload, json!({"stream": true}));
566
567 let mut event_source = self
568 .client
569 .post("api/chat")?
570 .json(&request_payload)
571 .eventsource()
572 .expect("Cloning request must succeed");
573
574 let stream = Box::pin(stream! {
575 while let Some(event_result) = event_source.next().await {
576 match event_result {
577 Ok(Event::Open) => {
578 tracing::trace!("SSE connection opened");
579 continue;
580 }
581
582 Ok(Event::Message(message)) => {
583 let data_str = message.data.trim();
584
585 let parsed = serde_json::from_str::<CompletionResponse>(data_str);
586 let Ok(response) = parsed else {
587 tracing::debug!("Couldn't parse SSE payload as CompletionResponse");
588 continue;
589 };
590
591 match response.message {
592 Message::Assistant { content, tool_calls, .. } => {
593 if !content.is_empty() {
594 yield Ok(RawStreamingChoice::Message(content));
595 }
596
597 for tool_call in tool_calls {
598 let function = tool_call.function.clone();
599 yield Ok(RawStreamingChoice::ToolCall {
600 id: "".to_string(),
601 name: function.name,
602 arguments: function.arguments,
603 call_id: None,
604 });
605 }
606 }
607 _ => continue,
608 }
609
610 if response.done {
611 yield Ok(RawStreamingChoice::FinalResponse(
612 StreamingCompletionResponse {
613 total_duration: response.total_duration,
614 load_duration: response.load_duration,
615 prompt_eval_count: response.prompt_eval_count,
616 prompt_eval_duration: response.prompt_eval_duration,
617 eval_count: response.eval_count,
618 eval_duration: response.eval_duration,
619 done_reason: response.done_reason,
620 }
621 ));
622 }
623 }
624
625 Err(reqwest_eventsource::Error::StreamEnded) => break,
626
627 Err(err) => {
628 tracing::error!(?err, "SSE error");
629 yield Err(CompletionError::ResponseError(err.to_string()));
630 break;
631 }
632 };
633 }});
634
635 Ok(streaming::StreamingCompletionResponse::stream(stream))
636 }
637}
638
639#[derive(Clone, Debug, Deserialize, Serialize)]
643pub struct ToolDefinition {
644 #[serde(rename = "type")]
645 pub type_field: String, pub function: completion::ToolDefinition,
647}
648
649impl From<crate::completion::ToolDefinition> for ToolDefinition {
651 fn from(tool: crate::completion::ToolDefinition) -> Self {
652 ToolDefinition {
653 type_field: "function".to_owned(),
654 function: completion::ToolDefinition {
655 name: tool.name,
656 description: tool.description,
657 parameters: tool.parameters,
658 },
659 }
660 }
661}
662
663#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
664pub struct ToolCall {
665 #[serde(default, rename = "type")]
667 pub r#type: ToolType,
668 pub function: Function,
669}
670#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
671#[serde(rename_all = "lowercase")]
672pub enum ToolType {
673 #[default]
674 Function,
675}
676#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
677pub struct Function {
678 pub name: String,
679 pub arguments: Value,
680}
681
682#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
685#[serde(tag = "role", rename_all = "lowercase")]
686pub enum Message {
687 User {
688 content: String,
689 #[serde(skip_serializing_if = "Option::is_none")]
690 images: Option<Vec<String>>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 name: Option<String>,
693 },
694 Assistant {
695 #[serde(default)]
696 content: String,
697 #[serde(skip_serializing_if = "Option::is_none")]
698 thinking: Option<String>,
699 #[serde(skip_serializing_if = "Option::is_none")]
700 images: Option<Vec<String>>,
701 #[serde(skip_serializing_if = "Option::is_none")]
702 name: Option<String>,
703 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
704 tool_calls: Vec<ToolCall>,
705 },
706 System {
707 content: String,
708 #[serde(skip_serializing_if = "Option::is_none")]
709 images: Option<Vec<String>>,
710 #[serde(skip_serializing_if = "Option::is_none")]
711 name: Option<String>,
712 },
713 #[serde(rename = "tool")]
714 ToolResult {
715 #[serde(rename = "tool_name")]
716 name: String,
717 content: String,
718 },
719}
720
721impl TryFrom<crate::message::Message> for Vec<Message> {
727 type Error = crate::message::MessageError;
728 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
729 use crate::message::Message as InternalMessage;
730 match internal_msg {
731 InternalMessage::User { content, .. } => {
732 let (tool_results, other_content): (Vec<_>, Vec<_>) =
733 content.into_iter().partition(|content| {
734 matches!(content, crate::message::UserContent::ToolResult(_))
735 });
736
737 if !tool_results.is_empty() {
738 tool_results
739 .into_iter()
740 .map(|content| match content {
741 crate::message::UserContent::ToolResult(
742 crate::message::ToolResult { id, content, .. },
743 ) => {
744 let content_string = content
746 .into_iter()
747 .map(|content| match content {
748 crate::message::ToolResultContent::Text(text) => text.text,
749 _ => "[Non-text content]".to_string(),
750 })
751 .collect::<Vec<_>>()
752 .join("\n");
753
754 Ok::<_, crate::message::MessageError>(Message::ToolResult {
755 name: id,
756 content: content_string,
757 })
758 }
759 _ => unreachable!(),
760 })
761 .collect::<Result<Vec<_>, _>>()
762 } else {
763 let (texts, images) = other_content.into_iter().fold(
765 (Vec::new(), Vec::new()),
766 |(mut texts, mut images), content| {
767 match content {
768 crate::message::UserContent::Text(crate::message::Text {
769 text,
770 }) => texts.push(text),
771 crate::message::UserContent::Image(crate::message::Image {
772 data,
773 ..
774 }) => images.push(data),
775 crate::message::UserContent::Document(
776 crate::message::Document { data, .. },
777 ) => texts.push(data),
778 _ => {} }
780 (texts, images)
781 },
782 );
783
784 Ok(vec![Message::User {
785 content: texts.join(" "),
786 images: if images.is_empty() {
787 None
788 } else {
789 Some(
790 images
791 .into_iter()
792 .map(|x| x.to_string())
793 .collect::<Vec<String>>(),
794 )
795 },
796 name: None,
797 }])
798 }
799 }
800 InternalMessage::Assistant { content, .. } => {
801 let mut thinking: Option<String> = None;
802 let (text_content, tool_calls) = content.into_iter().fold(
803 (Vec::new(), Vec::new()),
804 |(mut texts, mut tools), content| {
805 match content {
806 crate::message::AssistantContent::Text(text) => texts.push(text.text),
807 crate::message::AssistantContent::ToolCall(tool_call) => {
808 tools.push(tool_call)
809 }
810 crate::message::AssistantContent::Reasoning(
811 crate::message::Reasoning { reasoning, .. },
812 ) => {
813 thinking =
814 Some(reasoning.first().cloned().unwrap_or(String::new()));
815 }
816 }
817 (texts, tools)
818 },
819 );
820
821 Ok(vec![Message::Assistant {
824 content: text_content.join(" "),
825 thinking,
826 images: None,
827 name: None,
828 tool_calls: tool_calls
829 .into_iter()
830 .map(|tool_call| tool_call.into())
831 .collect::<Vec<_>>(),
832 }])
833 }
834 }
835 }
836}
837
838impl From<Message> for crate::completion::Message {
841 fn from(msg: Message) -> Self {
842 match msg {
843 Message::User { content, .. } => crate::completion::Message::User {
844 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
845 text: content,
846 })),
847 },
848 Message::Assistant {
849 content,
850 tool_calls,
851 ..
852 } => {
853 let mut assistant_contents =
854 vec![crate::completion::message::AssistantContent::Text(Text {
855 text: content,
856 })];
857 for tc in tool_calls {
858 assistant_contents.push(
859 crate::completion::message::AssistantContent::tool_call(
860 tc.function.name.clone(),
861 tc.function.name,
862 tc.function.arguments,
863 ),
864 );
865 }
866 crate::completion::Message::Assistant {
867 id: None,
868 content: OneOrMany::many(assistant_contents).unwrap(),
869 }
870 }
871 Message::System { content, .. } => crate::completion::Message::User {
873 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
874 text: content,
875 })),
876 },
877 Message::ToolResult { name, content } => crate::completion::Message::User {
878 content: OneOrMany::one(message::UserContent::tool_result(
879 name,
880 OneOrMany::one(message::ToolResultContent::text(content)),
881 )),
882 },
883 }
884 }
885}
886
887impl Message {
888 pub fn system(content: &str) -> Self {
890 Message::System {
891 content: content.to_owned(),
892 images: None,
893 name: None,
894 }
895 }
896}
897
898impl From<crate::message::ToolCall> for ToolCall {
901 fn from(tool_call: crate::message::ToolCall) -> Self {
902 Self {
903 r#type: ToolType::Function,
904 function: Function {
905 name: tool_call.function.name,
906 arguments: tool_call.function.arguments,
907 },
908 }
909 }
910}
911
912#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
913pub struct SystemContent {
914 #[serde(default)]
915 r#type: SystemContentType,
916 text: String,
917}
918
919#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
920#[serde(rename_all = "lowercase")]
921pub enum SystemContentType {
922 #[default]
923 Text,
924}
925
926impl From<String> for SystemContent {
927 fn from(s: String) -> Self {
928 SystemContent {
929 r#type: SystemContentType::default(),
930 text: s,
931 }
932 }
933}
934
935impl FromStr for SystemContent {
936 type Err = std::convert::Infallible;
937 fn from_str(s: &str) -> Result<Self, Self::Err> {
938 Ok(SystemContent {
939 r#type: SystemContentType::default(),
940 text: s.to_string(),
941 })
942 }
943}
944
945#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
946pub struct AssistantContent {
947 pub text: String,
948}
949
950impl FromStr for AssistantContent {
951 type Err = std::convert::Infallible;
952 fn from_str(s: &str) -> Result<Self, Self::Err> {
953 Ok(AssistantContent { text: s.to_owned() })
954 }
955}
956
957#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
958#[serde(tag = "type", rename_all = "lowercase")]
959pub enum UserContent {
960 Text { text: String },
961 Image { image_url: ImageUrl },
962 }
964
965impl FromStr for UserContent {
966 type Err = std::convert::Infallible;
967 fn from_str(s: &str) -> Result<Self, Self::Err> {
968 Ok(UserContent::Text { text: s.to_owned() })
969 }
970}
971
972#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
973pub struct ImageUrl {
974 pub url: String,
975 #[serde(default)]
976 pub detail: ImageDetail,
977}
978
979#[cfg(test)]
984mod tests {
985 use super::*;
986 use serde_json::json;
987
988 #[tokio::test]
990 async fn test_chat_completion() {
991 let sample_chat_response = json!({
993 "model": "llama3.2",
994 "created_at": "2023-08-04T19:22:45.499127Z",
995 "message": {
996 "role": "assistant",
997 "content": "The sky is blue because of Rayleigh scattering.",
998 "images": null,
999 "tool_calls": [
1000 {
1001 "type": "function",
1002 "function": {
1003 "name": "get_current_weather",
1004 "arguments": {
1005 "location": "San Francisco, CA",
1006 "format": "celsius"
1007 }
1008 }
1009 }
1010 ]
1011 },
1012 "done": true,
1013 "total_duration": 8000000000u64,
1014 "load_duration": 6000000u64,
1015 "prompt_eval_count": 61u64,
1016 "prompt_eval_duration": 400000000u64,
1017 "eval_count": 468u64,
1018 "eval_duration": 7700000000u64
1019 });
1020 let sample_text = sample_chat_response.to_string();
1021
1022 let chat_resp: CompletionResponse =
1023 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1024 let conv: completion::CompletionResponse<CompletionResponse> =
1025 chat_resp.try_into().unwrap();
1026 assert!(
1027 !conv.choice.is_empty(),
1028 "Expected non-empty choice in chat response"
1029 );
1030 }
1031
1032 #[test]
1034 fn test_message_conversion() {
1035 let provider_msg = Message::User {
1037 content: "Test message".to_owned(),
1038 images: None,
1039 name: None,
1040 };
1041 let comp_msg: crate::completion::Message = provider_msg.into();
1043 match comp_msg {
1044 crate::completion::Message::User { content } => {
1045 let first_content = content.first();
1047 match first_content {
1049 crate::completion::message::UserContent::Text(text_struct) => {
1050 assert_eq!(text_struct.text, "Test message");
1051 }
1052 _ => panic!("Expected text content in conversion"),
1053 }
1054 }
1055 _ => panic!("Conversion from provider Message to completion Message failed"),
1056 }
1057 }
1058
1059 #[test]
1061 fn test_tool_definition_conversion() {
1062 let internal_tool = crate::completion::ToolDefinition {
1064 name: "get_current_weather".to_owned(),
1065 description: "Get the current weather for a location".to_owned(),
1066 parameters: json!({
1067 "type": "object",
1068 "properties": {
1069 "location": {
1070 "type": "string",
1071 "description": "The location to get the weather for, e.g. San Francisco, CA"
1072 },
1073 "format": {
1074 "type": "string",
1075 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1076 "enum": ["celsius", "fahrenheit"]
1077 }
1078 },
1079 "required": ["location", "format"]
1080 }),
1081 };
1082 let ollama_tool: ToolDefinition = internal_tool.into();
1084 assert_eq!(ollama_tool.type_field, "function");
1085 assert_eq!(ollama_tool.function.name, "get_current_weather");
1086 assert_eq!(
1087 ollama_tool.function.description,
1088 "Get the current weather for a location"
1089 );
1090 let params = &ollama_tool.function.parameters;
1092 assert_eq!(params["properties"]["location"]["type"], "string");
1093 }
1094}