1use crate::client::{
42 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient,
43 VerifyError,
44};
45use crate::completion::{GetTokenUsage, Usage};
46use crate::json_utils::merge_inplace;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49 Embed, OneOrMany,
50 completion::{self, CompletionError, CompletionRequest},
51 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
52 impl_conversion_traits, json_utils, message,
53 message::{ImageDetail, Text},
54 streaming,
55};
56use async_stream::stream;
57use futures::StreamExt;
58use reqwest;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use url::Url;
63const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
66
67pub struct ClientBuilder<'a> {
68 base_url: &'a str,
69 http_client: Option<reqwest::Client>,
70}
71
72impl<'a> ClientBuilder<'a> {
73 #[allow(clippy::new_without_default)]
74 pub fn new() -> Self {
75 Self {
76 base_url: OLLAMA_API_BASE_URL,
77 http_client: None,
78 }
79 }
80
81 pub fn base_url(mut self, base_url: &'a str) -> Self {
82 self.base_url = base_url;
83 self
84 }
85
86 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
87 self.http_client = Some(client);
88 self
89 }
90
91 pub fn build(self) -> Result<Client, ClientBuilderError> {
92 let http_client = if let Some(http_client) = self.http_client {
93 http_client
94 } else {
95 reqwest::Client::builder().build()?
96 };
97
98 Ok(Client {
99 base_url: Url::parse(self.base_url)
100 .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?,
101 http_client,
102 })
103 }
104}
105
106#[derive(Clone, Debug)]
107pub struct Client {
108 base_url: Url,
109 http_client: reqwest::Client,
110}
111
112impl Default for Client {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl Client {
119 pub fn builder() -> ClientBuilder<'static> {
130 ClientBuilder::new()
131 }
132
133 pub fn new() -> Self {
138 Self::builder().build().expect("Ollama client should build")
139 }
140
141 pub(crate) fn post(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
142 let url = self.base_url.join(path)?;
143 Ok(self.http_client.post(url))
144 }
145
146 pub(crate) fn get(&self, path: &str) -> Result<reqwest::RequestBuilder, url::ParseError> {
147 let url = self.base_url.join(path)?;
148 Ok(self.http_client.get(url))
149 }
150}
151
152impl ProviderClient for Client {
153 fn from_env() -> Self
154 where
155 Self: Sized,
156 {
157 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
158 Self::builder().base_url(&api_base).build().unwrap()
159 }
160
161 fn from_val(input: crate::client::ProviderValue) -> Self {
162 let crate::client::ProviderValue::Simple(_) = input else {
163 panic!("Incorrect provider value type")
164 };
165
166 Self::new()
167 }
168}
169
170impl CompletionClient for Client {
171 type CompletionModel = CompletionModel;
172
173 fn completion_model(&self, model: &str) -> CompletionModel {
174 CompletionModel::new(self.clone(), model)
175 }
176}
177
178impl EmbeddingsClient for Client {
179 type EmbeddingModel = EmbeddingModel;
180 fn embedding_model(&self, model: &str) -> EmbeddingModel {
181 EmbeddingModel::new(self.clone(), model, 0)
182 }
183 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
184 EmbeddingModel::new(self.clone(), model, ndims)
185 }
186 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
187 EmbeddingsBuilder::new(self.embedding_model(model))
188 }
189}
190
191impl VerifyClient for Client {
192 #[cfg_attr(feature = "worker", worker::send)]
193 async fn verify(&self) -> Result<(), VerifyError> {
194 let response = self
195 .get("api/tags")
196 .expect("Failed to build request")
197 .send()
198 .await?;
199 match response.status() {
200 reqwest::StatusCode::OK => Ok(()),
201 _ => {
202 response.error_for_status()?;
203 Ok(())
204 }
205 }
206 }
207}
208
209impl_conversion_traits!(
210 AsTranscription,
211 AsImageGeneration,
212 AsAudioGeneration for Client
213);
214
215#[derive(Debug, Deserialize)]
218struct ApiErrorResponse {
219 message: String,
220}
221
222#[derive(Debug, Deserialize)]
223#[serde(untagged)]
224enum ApiResponse<T> {
225 Ok(T),
226 Err(ApiErrorResponse),
227}
228
229pub const ALL_MINILM: &str = "all-minilm";
232pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
233
234#[derive(Debug, Serialize, Deserialize)]
235pub struct EmbeddingResponse {
236 pub model: String,
237 pub embeddings: Vec<Vec<f64>>,
238 #[serde(default)]
239 pub total_duration: Option<u64>,
240 #[serde(default)]
241 pub load_duration: Option<u64>,
242 #[serde(default)]
243 pub prompt_eval_count: Option<u64>,
244}
245
246impl From<ApiErrorResponse> for EmbeddingError {
247 fn from(err: ApiErrorResponse) -> Self {
248 EmbeddingError::ProviderError(err.message)
249 }
250}
251
252impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
253 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
254 match value {
255 ApiResponse::Ok(response) => Ok(response),
256 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
257 }
258 }
259}
260
261#[derive(Clone)]
264pub struct EmbeddingModel {
265 client: Client,
266 pub model: String,
267 ndims: usize,
268}
269
270impl EmbeddingModel {
271 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
272 Self {
273 client,
274 model: model.to_owned(),
275 ndims,
276 }
277 }
278}
279
280impl embeddings::EmbeddingModel for EmbeddingModel {
281 const MAX_DOCUMENTS: usize = 1024;
282 fn ndims(&self) -> usize {
283 self.ndims
284 }
285 #[cfg_attr(feature = "worker", worker::send)]
286 async fn embed_texts(
287 &self,
288 documents: impl IntoIterator<Item = String>,
289 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
290 let docs: Vec<String> = documents.into_iter().collect();
291 let payload = json!({
292 "model": self.model,
293 "input": docs,
294 });
295 let response = self
296 .client
297 .post("api/embed")?
298 .json(&payload)
299 .send()
300 .await
301 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
302 if response.status().is_success() {
303 let api_resp: EmbeddingResponse = response
304 .json()
305 .await
306 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
307 if api_resp.embeddings.len() != docs.len() {
308 return Err(EmbeddingError::ResponseError(
309 "Number of returned embeddings does not match input".into(),
310 ));
311 }
312 Ok(api_resp
313 .embeddings
314 .into_iter()
315 .zip(docs.into_iter())
316 .map(|(vec, document)| embeddings::Embedding { document, vec })
317 .collect())
318 } else {
319 Err(EmbeddingError::ProviderError(response.text().await?))
320 }
321 }
322}
323
324pub const LLAMA3_2: &str = "llama3.2";
327pub const LLAVA: &str = "llava";
328pub const MISTRAL: &str = "mistral";
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct CompletionResponse {
332 pub model: String,
333 pub created_at: String,
334 pub message: Message,
335 pub done: bool,
336 #[serde(default)]
337 pub done_reason: Option<String>,
338 #[serde(default)]
339 pub total_duration: Option<u64>,
340 #[serde(default)]
341 pub load_duration: Option<u64>,
342 #[serde(default)]
343 pub prompt_eval_count: Option<u64>,
344 #[serde(default)]
345 pub prompt_eval_duration: Option<u64>,
346 #[serde(default)]
347 pub eval_count: Option<u64>,
348 #[serde(default)]
349 pub eval_duration: Option<u64>,
350}
351impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
352 type Error = CompletionError;
353 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
354 match resp.message {
355 Message::Assistant {
357 content,
358 thinking,
359 tool_calls,
360 ..
361 } => {
362 let mut assistant_contents = Vec::new();
363 if !content.is_empty() {
365 assistant_contents.push(completion::AssistantContent::text(&content));
366 }
367 for tc in tool_calls.iter() {
370 assistant_contents.push(completion::AssistantContent::tool_call(
371 tc.function.name.clone(),
372 tc.function.name.clone(),
373 tc.function.arguments.clone(),
374 ));
375 }
376 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
377 CompletionError::ResponseError("No content provided".to_owned())
378 })?;
379 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
380 let completion_tokens = resp.eval_count.unwrap_or(0);
381
382 let raw_response = CompletionResponse {
383 model: resp.model,
384 created_at: resp.created_at,
385 done: resp.done,
386 done_reason: resp.done_reason,
387 total_duration: resp.total_duration,
388 load_duration: resp.load_duration,
389 prompt_eval_count: resp.prompt_eval_count,
390 prompt_eval_duration: resp.prompt_eval_duration,
391 eval_count: resp.eval_count,
392 eval_duration: resp.eval_duration,
393 message: Message::Assistant {
394 content,
395 thinking,
396 images: None,
397 name: None,
398 tool_calls,
399 },
400 };
401
402 Ok(completion::CompletionResponse {
403 choice,
404 usage: Usage {
405 input_tokens: prompt_tokens,
406 output_tokens: completion_tokens,
407 total_tokens: prompt_tokens + completion_tokens,
408 },
409 raw_response,
410 })
411 }
412 _ => Err(CompletionError::ResponseError(
413 "Chat response does not include an assistant message".into(),
414 )),
415 }
416 }
417}
418
419#[derive(Clone)]
422pub struct CompletionModel {
423 client: Client,
424 pub model: String,
425}
426
427impl CompletionModel {
428 pub fn new(client: Client, model: &str) -> Self {
429 Self {
430 client,
431 model: model.to_owned(),
432 }
433 }
434
435 fn create_completion_request(
436 &self,
437 completion_request: CompletionRequest,
438 ) -> Result<Value, CompletionError> {
439 let mut partial_history = vec![];
441 if let Some(docs) = completion_request.normalized_documents() {
442 partial_history.push(docs);
443 }
444 partial_history.extend(completion_request.chat_history);
445
446 let mut full_history: Vec<Message> = completion_request
448 .preamble
449 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
450
451 full_history.extend(
453 partial_history
454 .into_iter()
455 .map(|msg| msg.try_into())
456 .collect::<Result<Vec<Vec<Message>>, _>>()?
457 .into_iter()
458 .flatten()
459 .collect::<Vec<Message>>(),
460 );
461
462 let options = if let Some(extra) = completion_request.additional_params {
464 json_utils::merge(
465 json!({ "temperature": completion_request.temperature }),
466 extra,
467 )
468 } else {
469 json!({ "temperature": completion_request.temperature })
470 };
471
472 let mut request_payload = json!({
473 "model": self.model,
474 "messages": full_history,
475 "options": options,
476 "stream": false,
477 });
478 if !completion_request.tools.is_empty() {
479 request_payload["tools"] = json!(
480 completion_request
481 .tools
482 .into_iter()
483 .map(|tool| tool.into())
484 .collect::<Vec<ToolDefinition>>()
485 );
486 }
487
488 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
489
490 Ok(request_payload)
491 }
492}
493
494#[derive(Clone, Serialize, Deserialize, Debug)]
497pub struct StreamingCompletionResponse {
498 pub done_reason: Option<String>,
499 pub total_duration: Option<u64>,
500 pub load_duration: Option<u64>,
501 pub prompt_eval_count: Option<u64>,
502 pub prompt_eval_duration: Option<u64>,
503 pub eval_count: Option<u64>,
504 pub eval_duration: Option<u64>,
505}
506
507impl GetTokenUsage for StreamingCompletionResponse {
508 fn token_usage(&self) -> Option<crate::completion::Usage> {
509 let mut usage = crate::completion::Usage::new();
510 let input_tokens = self.prompt_eval_count.unwrap_or_default();
511 let output_tokens = self.eval_count.unwrap_or_default();
512 usage.input_tokens = input_tokens;
513 usage.output_tokens = output_tokens;
514 usage.total_tokens = input_tokens + output_tokens;
515
516 Some(usage)
517 }
518}
519
520impl completion::CompletionModel for CompletionModel {
521 type Response = CompletionResponse;
522 type StreamingResponse = StreamingCompletionResponse;
523
524 #[cfg_attr(feature = "worker", worker::send)]
525 async fn completion(
526 &self,
527 completion_request: CompletionRequest,
528 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
529 let request_payload = self.create_completion_request(completion_request)?;
530
531 let response = self
532 .client
533 .post("api/chat")?
534 .json(&request_payload)
535 .send()
536 .await
537 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
538 if response.status().is_success() {
539 let text = response
540 .text()
541 .await
542 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
543 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
544 let chat_resp: CompletionResponse = serde_json::from_str(&text)
545 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
546 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
547 Ok(conv)
548 } else {
549 let err_text = response
550 .text()
551 .await
552 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
553 Err(CompletionError::ProviderError(err_text))
554 }
555 }
556
557 #[cfg_attr(feature = "worker", worker::send)]
558 async fn stream(
559 &self,
560 request: CompletionRequest,
561 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
562 {
563 let mut request_payload = self.create_completion_request(request)?;
564 merge_inplace(&mut request_payload, json!({"stream": true}));
565
566 let response = self
567 .client
568 .post("api/chat")?
569 .json(&request_payload)
570 .send()
571 .await
572 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
573
574 if !response.status().is_success() {
575 let err_text = response
576 .text()
577 .await
578 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
579 return Err(CompletionError::ProviderError(err_text));
580 }
581
582 let stream = Box::pin(stream! {
583 let mut stream = response.bytes_stream();
584 while let Some(chunk_result) = stream.next().await {
585 let chunk = match chunk_result {
586 Ok(c) => c,
587 Err(e) => {
588 yield Err(CompletionError::from(e));
589 break;
590 }
591 };
592
593 let text = match String::from_utf8(chunk.to_vec()) {
594 Ok(t) => t,
595 Err(e) => {
596 yield Err(CompletionError::ResponseError(e.to_string()));
597 break;
598 }
599 };
600
601
602 for line in text.lines() {
603 let line = line.to_string();
604
605 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
606 continue;
607 };
608
609 match response.message {
610 Message::Assistant{ content, tool_calls, .. } => {
611 if !content.is_empty() {
612 yield Ok(RawStreamingChoice::Message(content))
613 }
614
615 for tool_call in tool_calls.iter() {
616 let function = tool_call.function.clone();
617
618 yield Ok(RawStreamingChoice::ToolCall {
619 id: "".to_string(),
620 name: function.name,
621 arguments: function.arguments,
622 call_id: None
623 });
624 }
625 }
626 _ => {
627 continue;
628 }
629 }
630
631 if response.done {
632 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
633 total_duration: response.total_duration,
634 load_duration: response.load_duration,
635 prompt_eval_count: response.prompt_eval_count,
636 prompt_eval_duration: response.prompt_eval_duration,
637 eval_count: response.eval_count,
638 eval_duration: response.eval_duration,
639 done_reason: response.done_reason,
640 }));
641 }
642 }
643 }
644 });
645
646 Ok(streaming::StreamingCompletionResponse::stream(stream))
647 }
648}
649
650#[derive(Clone, Debug, Deserialize, Serialize)]
654pub struct ToolDefinition {
655 #[serde(rename = "type")]
656 pub type_field: String, pub function: completion::ToolDefinition,
658}
659
660impl From<crate::completion::ToolDefinition> for ToolDefinition {
662 fn from(tool: crate::completion::ToolDefinition) -> Self {
663 ToolDefinition {
664 type_field: "function".to_owned(),
665 function: completion::ToolDefinition {
666 name: tool.name,
667 description: tool.description,
668 parameters: tool.parameters,
669 },
670 }
671 }
672}
673
674#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
675pub struct ToolCall {
676 #[serde(default, rename = "type")]
678 pub r#type: ToolType,
679 pub function: Function,
680}
681#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
682#[serde(rename_all = "lowercase")]
683pub enum ToolType {
684 #[default]
685 Function,
686}
687#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
688pub struct Function {
689 pub name: String,
690 pub arguments: Value,
691}
692
693#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
696#[serde(tag = "role", rename_all = "lowercase")]
697pub enum Message {
698 User {
699 content: String,
700 #[serde(skip_serializing_if = "Option::is_none")]
701 images: Option<Vec<String>>,
702 #[serde(skip_serializing_if = "Option::is_none")]
703 name: Option<String>,
704 },
705 Assistant {
706 #[serde(default)]
707 content: String,
708 #[serde(skip_serializing_if = "Option::is_none")]
709 thinking: Option<String>,
710 #[serde(skip_serializing_if = "Option::is_none")]
711 images: Option<Vec<String>>,
712 #[serde(skip_serializing_if = "Option::is_none")]
713 name: Option<String>,
714 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
715 tool_calls: Vec<ToolCall>,
716 },
717 System {
718 content: String,
719 #[serde(skip_serializing_if = "Option::is_none")]
720 images: Option<Vec<String>>,
721 #[serde(skip_serializing_if = "Option::is_none")]
722 name: Option<String>,
723 },
724 #[serde(rename = "tool")]
725 ToolResult {
726 #[serde(rename = "tool_name")]
727 name: String,
728 content: String,
729 },
730}
731
732impl TryFrom<crate::message::Message> for Vec<Message> {
738 type Error = crate::message::MessageError;
739 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
740 use crate::message::Message as InternalMessage;
741 match internal_msg {
742 InternalMessage::User { content, .. } => {
743 let (tool_results, other_content): (Vec<_>, Vec<_>) =
744 content.into_iter().partition(|content| {
745 matches!(content, crate::message::UserContent::ToolResult(_))
746 });
747
748 if !tool_results.is_empty() {
749 tool_results
750 .into_iter()
751 .map(|content| match content {
752 crate::message::UserContent::ToolResult(
753 crate::message::ToolResult { id, content, .. },
754 ) => {
755 let content_string = content
757 .into_iter()
758 .map(|content| match content {
759 crate::message::ToolResultContent::Text(text) => text.text,
760 _ => "[Non-text content]".to_string(),
761 })
762 .collect::<Vec<_>>()
763 .join("\n");
764
765 Ok::<_, crate::message::MessageError>(Message::ToolResult {
766 name: id,
767 content: content_string,
768 })
769 }
770 _ => unreachable!(),
771 })
772 .collect::<Result<Vec<_>, _>>()
773 } else {
774 let (texts, images) = other_content.into_iter().fold(
776 (Vec::new(), Vec::new()),
777 |(mut texts, mut images), content| {
778 match content {
779 crate::message::UserContent::Text(crate::message::Text {
780 text,
781 }) => texts.push(text),
782 crate::message::UserContent::Image(crate::message::Image {
783 data,
784 ..
785 }) => images.push(data),
786 crate::message::UserContent::Document(
787 crate::message::Document { data, .. },
788 ) => texts.push(data),
789 _ => {} }
791 (texts, images)
792 },
793 );
794
795 Ok(vec![Message::User {
796 content: texts.join(" "),
797 images: if images.is_empty() {
798 None
799 } else {
800 Some(images)
801 },
802 name: None,
803 }])
804 }
805 }
806 InternalMessage::Assistant { content, .. } => {
807 let mut thinking: Option<String> = None;
808 let (text_content, tool_calls) = content.into_iter().fold(
809 (Vec::new(), Vec::new()),
810 |(mut texts, mut tools), content| {
811 match content {
812 crate::message::AssistantContent::Text(text) => texts.push(text.text),
813 crate::message::AssistantContent::ToolCall(tool_call) => {
814 tools.push(tool_call)
815 }
816 crate::message::AssistantContent::Reasoning(
817 crate::message::Reasoning { reasoning, .. },
818 ) => {
819 thinking =
820 Some(reasoning.first().cloned().unwrap_or(String::new()));
821 }
822 }
823 (texts, tools)
824 },
825 );
826
827 Ok(vec![Message::Assistant {
830 content: text_content.join(" "),
831 thinking,
832 images: None,
833 name: None,
834 tool_calls: tool_calls
835 .into_iter()
836 .map(|tool_call| tool_call.into())
837 .collect::<Vec<_>>(),
838 }])
839 }
840 }
841 }
842}
843
844impl From<Message> for crate::completion::Message {
847 fn from(msg: Message) -> Self {
848 match msg {
849 Message::User { content, .. } => crate::completion::Message::User {
850 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
851 text: content,
852 })),
853 },
854 Message::Assistant {
855 content,
856 tool_calls,
857 ..
858 } => {
859 let mut assistant_contents =
860 vec![crate::completion::message::AssistantContent::Text(Text {
861 text: content,
862 })];
863 for tc in tool_calls {
864 assistant_contents.push(
865 crate::completion::message::AssistantContent::tool_call(
866 tc.function.name.clone(),
867 tc.function.name,
868 tc.function.arguments,
869 ),
870 );
871 }
872 crate::completion::Message::Assistant {
873 id: None,
874 content: OneOrMany::many(assistant_contents).unwrap(),
875 }
876 }
877 Message::System { content, .. } => crate::completion::Message::User {
879 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
880 text: content,
881 })),
882 },
883 Message::ToolResult { name, content } => crate::completion::Message::User {
884 content: OneOrMany::one(message::UserContent::tool_result(
885 name,
886 OneOrMany::one(message::ToolResultContent::text(content)),
887 )),
888 },
889 }
890 }
891}
892
893impl Message {
894 pub fn system(content: &str) -> Self {
896 Message::System {
897 content: content.to_owned(),
898 images: None,
899 name: None,
900 }
901 }
902}
903
904impl From<crate::message::ToolCall> for ToolCall {
907 fn from(tool_call: crate::message::ToolCall) -> Self {
908 Self {
909 r#type: ToolType::Function,
910 function: Function {
911 name: tool_call.function.name,
912 arguments: tool_call.function.arguments,
913 },
914 }
915 }
916}
917
918#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
919pub struct SystemContent {
920 #[serde(default)]
921 r#type: SystemContentType,
922 text: String,
923}
924
925#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
926#[serde(rename_all = "lowercase")]
927pub enum SystemContentType {
928 #[default]
929 Text,
930}
931
932impl From<String> for SystemContent {
933 fn from(s: String) -> Self {
934 SystemContent {
935 r#type: SystemContentType::default(),
936 text: s,
937 }
938 }
939}
940
941impl FromStr for SystemContent {
942 type Err = std::convert::Infallible;
943 fn from_str(s: &str) -> Result<Self, Self::Err> {
944 Ok(SystemContent {
945 r#type: SystemContentType::default(),
946 text: s.to_string(),
947 })
948 }
949}
950
951#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
952pub struct AssistantContent {
953 pub text: String,
954}
955
956impl FromStr for AssistantContent {
957 type Err = std::convert::Infallible;
958 fn from_str(s: &str) -> Result<Self, Self::Err> {
959 Ok(AssistantContent { text: s.to_owned() })
960 }
961}
962
963#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
964#[serde(tag = "type", rename_all = "lowercase")]
965pub enum UserContent {
966 Text { text: String },
967 Image { image_url: ImageUrl },
968 }
970
971impl FromStr for UserContent {
972 type Err = std::convert::Infallible;
973 fn from_str(s: &str) -> Result<Self, Self::Err> {
974 Ok(UserContent::Text { text: s.to_owned() })
975 }
976}
977
978#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
979pub struct ImageUrl {
980 pub url: String,
981 #[serde(default)]
982 pub detail: ImageDetail,
983}
984
985#[cfg(test)]
990mod tests {
991 use super::*;
992 use serde_json::json;
993
994 #[tokio::test]
996 async fn test_chat_completion() {
997 let sample_chat_response = json!({
999 "model": "llama3.2",
1000 "created_at": "2023-08-04T19:22:45.499127Z",
1001 "message": {
1002 "role": "assistant",
1003 "content": "The sky is blue because of Rayleigh scattering.",
1004 "images": null,
1005 "tool_calls": [
1006 {
1007 "type": "function",
1008 "function": {
1009 "name": "get_current_weather",
1010 "arguments": {
1011 "location": "San Francisco, CA",
1012 "format": "celsius"
1013 }
1014 }
1015 }
1016 ]
1017 },
1018 "done": true,
1019 "total_duration": 8000000000u64,
1020 "load_duration": 6000000u64,
1021 "prompt_eval_count": 61u64,
1022 "prompt_eval_duration": 400000000u64,
1023 "eval_count": 468u64,
1024 "eval_duration": 7700000000u64
1025 });
1026 let sample_text = sample_chat_response.to_string();
1027
1028 let chat_resp: CompletionResponse =
1029 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1030 let conv: completion::CompletionResponse<CompletionResponse> =
1031 chat_resp.try_into().unwrap();
1032 assert!(
1033 !conv.choice.is_empty(),
1034 "Expected non-empty choice in chat response"
1035 );
1036 }
1037
1038 #[test]
1040 fn test_message_conversion() {
1041 let provider_msg = Message::User {
1043 content: "Test message".to_owned(),
1044 images: None,
1045 name: None,
1046 };
1047 let comp_msg: crate::completion::Message = provider_msg.into();
1049 match comp_msg {
1050 crate::completion::Message::User { content } => {
1051 let first_content = content.first();
1053 match first_content {
1055 crate::completion::message::UserContent::Text(text_struct) => {
1056 assert_eq!(text_struct.text, "Test message");
1057 }
1058 _ => panic!("Expected text content in conversion"),
1059 }
1060 }
1061 _ => panic!("Conversion from provider Message to completion Message failed"),
1062 }
1063 }
1064
1065 #[test]
1067 fn test_tool_definition_conversion() {
1068 let internal_tool = crate::completion::ToolDefinition {
1070 name: "get_current_weather".to_owned(),
1071 description: "Get the current weather for a location".to_owned(),
1072 parameters: json!({
1073 "type": "object",
1074 "properties": {
1075 "location": {
1076 "type": "string",
1077 "description": "The location to get the weather for, e.g. San Francisco, CA"
1078 },
1079 "format": {
1080 "type": "string",
1081 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1082 "enum": ["celsius", "fahrenheit"]
1083 }
1084 },
1085 "required": ["location", "format"]
1086 }),
1087 };
1088 let ollama_tool: ToolDefinition = internal_tool.into();
1090 assert_eq!(ollama_tool.type_field, "function");
1091 assert_eq!(ollama_tool.function.name, "get_current_weather");
1092 assert_eq!(
1093 ollama_tool.function.description,
1094 "Get the current weather for a location"
1095 );
1096 let params = &ollama_tool.function.parameters;
1098 assert_eq!(params["properties"]["location"]["type"], "string");
1099 }
1100}