1use crate::client::{
42 CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::json_utils::merge_inplace;
47use crate::message::DocumentSourceKind;
48use crate::streaming::RawStreamingChoice;
49use crate::{
50 Embed, OneOrMany,
51 completion::{self, CompletionError, CompletionRequest},
52 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
53 impl_conversion_traits, json_utils, message,
54 message::{ImageDetail, Text},
55 streaming,
56};
57use async_stream::try_stream;
58use futures::StreamExt;
59use reqwest;
60use serde::{Deserialize, Serialize};
62use serde_json::{Value, json};
63use std::{convert::TryFrom, str::FromStr};
64use tracing::info_span;
65use tracing_futures::Instrument;
66const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
69
70pub struct ClientBuilder<'a, T = reqwest::Client> {
71 base_url: &'a str,
72 http_client: T,
73}
74
75impl<'a, T> ClientBuilder<'a, T>
76where
77 T: Default,
78{
79 #[allow(clippy::new_without_default)]
80 pub fn new() -> Self {
81 Self {
82 base_url: OLLAMA_API_BASE_URL,
83 http_client: Default::default(),
84 }
85 }
86}
87
88impl<'a, T> ClientBuilder<'a, T> {
89 pub fn base_url(mut self, base_url: &'a str) -> Self {
90 self.base_url = base_url;
91 self
92 }
93
94 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
95 ClientBuilder {
96 base_url: self.base_url,
97 http_client,
98 }
99 }
100
101 pub fn build(self) -> Client<T> {
102 Client {
103 base_url: self.base_url.into(),
104 http_client: self.http_client,
105 }
106 }
107}
108
109#[derive(Clone, Debug)]
110pub struct Client<T = reqwest::Client> {
111 base_url: String,
112 http_client: T,
113}
114
115impl<T> Default for Client<T>
116where
117 T: Default,
118{
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl<T> Client<T>
125where
126 T: Default,
127{
128 pub fn builder<'a>() -> ClientBuilder<'a, T> {
139 ClientBuilder::new()
140 }
141
142 pub fn new() -> Self {
147 Self::builder().build()
148 }
149}
150
151impl<T> Client<T> {
152 fn req(&self, method: http_client::Method, path: &str) -> http_client::Builder {
153 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
154 http_client::Builder::new().method(method).uri(url)
155 }
156
157 pub(crate) fn post(&self, path: &str) -> http_client::Builder {
158 self.req(http_client::Method::POST, path)
159 }
160
161 pub(crate) fn get(&self, path: &str) -> http_client::Builder {
162 self.req(http_client::Method::GET, path)
163 }
164}
165
166impl Client<reqwest::Client> {
167 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
168 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
169 self.http_client.post(url)
170 }
171}
172
173impl ProviderClient for Client<reqwest::Client> {
174 fn from_env() -> Self {
175 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
176 Self::builder().base_url(&api_base).build()
177 }
178
179 fn from_val(input: crate::client::ProviderValue) -> Self {
180 let crate::client::ProviderValue::Simple(_) = input else {
181 panic!("Incorrect provider value type")
182 };
183
184 Self::new()
185 }
186}
187
188impl CompletionClient for Client<reqwest::Client> {
189 type CompletionModel = CompletionModel<reqwest::Client>;
190
191 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
192 CompletionModel::new(self.clone(), model)
193 }
194}
195
196impl EmbeddingsClient for Client<reqwest::Client> {
197 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
198 fn embedding_model(&self, model: &str) -> EmbeddingModel<reqwest::Client> {
199 EmbeddingModel::new(self.clone(), model, 0)
200 }
201 fn embedding_model_with_ndims(
202 &self,
203 model: &str,
204 ndims: usize,
205 ) -> EmbeddingModel<reqwest::Client> {
206 EmbeddingModel::new(self.clone(), model, ndims)
207 }
208 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
209 EmbeddingsBuilder::new(self.embedding_model(model))
210 }
211}
212
213impl VerifyClient for Client<reqwest::Client> {
214 #[cfg_attr(feature = "worker", worker::send)]
215 async fn verify(&self) -> Result<(), VerifyError> {
216 let req = self
217 .get("api/tags")
218 .body(http_client::NoBody)
219 .map_err(http_client::Error::from)?;
220
221 let response = HttpClientExt::send(&self.http_client, req).await?;
222
223 match response.status() {
224 reqwest::StatusCode::OK => Ok(()),
225 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
226 reqwest::StatusCode::INTERNAL_SERVER_ERROR
227 | reqwest::StatusCode::SERVICE_UNAVAILABLE
228 | reqwest::StatusCode::BAD_GATEWAY => {
229 let text = http_client::text(response).await?;
230 Err(VerifyError::ProviderError(text))
231 }
232 _ => {
233 Ok(())
235 }
236 }
237 }
238}
239
240impl_conversion_traits!(
241 AsTranscription,
242 AsImageGeneration,
243 AsAudioGeneration for Client<T>
244);
245
246#[derive(Debug, Deserialize)]
249struct ApiErrorResponse {
250 message: String,
251}
252
253#[derive(Debug, Deserialize)]
254#[serde(untagged)]
255enum ApiResponse<T> {
256 Ok(T),
257 Err(ApiErrorResponse),
258}
259
260pub const ALL_MINILM: &str = "all-minilm";
263pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
264
265#[derive(Debug, Serialize, Deserialize)]
266pub struct EmbeddingResponse {
267 pub model: String,
268 pub embeddings: Vec<Vec<f64>>,
269 #[serde(default)]
270 pub total_duration: Option<u64>,
271 #[serde(default)]
272 pub load_duration: Option<u64>,
273 #[serde(default)]
274 pub prompt_eval_count: Option<u64>,
275}
276
277impl From<ApiErrorResponse> for EmbeddingError {
278 fn from(err: ApiErrorResponse) -> Self {
279 EmbeddingError::ProviderError(err.message)
280 }
281}
282
283impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
284 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
285 match value {
286 ApiResponse::Ok(response) => Ok(response),
287 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
288 }
289 }
290}
291
292#[derive(Clone)]
295pub struct EmbeddingModel<T> {
296 client: Client<T>,
297 pub model: String,
298 ndims: usize,
299}
300
301impl<T> EmbeddingModel<T> {
302 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
303 Self {
304 client,
305 model: model.to_owned(),
306 ndims,
307 }
308 }
309}
310
311impl embeddings::EmbeddingModel for EmbeddingModel<reqwest::Client> {
312 const MAX_DOCUMENTS: usize = 1024;
313 fn ndims(&self) -> usize {
314 self.ndims
315 }
316 #[cfg_attr(feature = "worker", worker::send)]
317 async fn embed_texts(
318 &self,
319 documents: impl IntoIterator<Item = String>,
320 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
321 let docs: Vec<String> = documents.into_iter().collect();
322
323 let body = serde_json::to_vec(&json!({
324 "model": self.model,
325 "input": docs
326 }))?;
327
328 let req = self
329 .client
330 .post("api/embed")
331 .header("Content-Type", "application/json")
332 .body(body)
333 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
334
335 let response = HttpClientExt::send(&self.client.http_client, req).await?;
336
337 if !response.status().is_success() {
338 let text = http_client::text(response).await?;
339 return Err(EmbeddingError::ProviderError(text));
340 }
341
342 let bytes: Vec<u8> = response.into_body().await?;
343
344 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
345
346 if api_resp.embeddings.len() != docs.len() {
347 return Err(EmbeddingError::ResponseError(
348 "Number of returned embeddings does not match input".into(),
349 ));
350 }
351 Ok(api_resp
352 .embeddings
353 .into_iter()
354 .zip(docs.into_iter())
355 .map(|(vec, document)| embeddings::Embedding { document, vec })
356 .collect())
357 }
358}
359
360pub const LLAMA3_2: &str = "llama3.2";
363pub const LLAVA: &str = "llava";
364pub const MISTRAL: &str = "mistral";
365
366#[derive(Debug, Serialize, Deserialize)]
367pub struct CompletionResponse {
368 pub model: String,
369 pub created_at: String,
370 pub message: Message,
371 pub done: bool,
372 #[serde(default)]
373 pub done_reason: Option<String>,
374 #[serde(default)]
375 pub total_duration: Option<u64>,
376 #[serde(default)]
377 pub load_duration: Option<u64>,
378 #[serde(default)]
379 pub prompt_eval_count: Option<u64>,
380 #[serde(default)]
381 pub prompt_eval_duration: Option<u64>,
382 #[serde(default)]
383 pub eval_count: Option<u64>,
384 #[serde(default)]
385 pub eval_duration: Option<u64>,
386}
387impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
388 type Error = CompletionError;
389 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
390 match resp.message {
391 Message::Assistant {
393 content,
394 thinking,
395 tool_calls,
396 ..
397 } => {
398 let mut assistant_contents = Vec::new();
399 if !content.is_empty() {
401 assistant_contents.push(completion::AssistantContent::text(&content));
402 }
403 for tc in tool_calls.iter() {
406 assistant_contents.push(completion::AssistantContent::tool_call(
407 tc.function.name.clone(),
408 tc.function.name.clone(),
409 tc.function.arguments.clone(),
410 ));
411 }
412 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
413 CompletionError::ResponseError("No content provided".to_owned())
414 })?;
415 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
416 let completion_tokens = resp.eval_count.unwrap_or(0);
417
418 let raw_response = CompletionResponse {
419 model: resp.model,
420 created_at: resp.created_at,
421 done: resp.done,
422 done_reason: resp.done_reason,
423 total_duration: resp.total_duration,
424 load_duration: resp.load_duration,
425 prompt_eval_count: resp.prompt_eval_count,
426 prompt_eval_duration: resp.prompt_eval_duration,
427 eval_count: resp.eval_count,
428 eval_duration: resp.eval_duration,
429 message: Message::Assistant {
430 content,
431 thinking,
432 images: None,
433 name: None,
434 tool_calls,
435 },
436 };
437
438 Ok(completion::CompletionResponse {
439 choice,
440 usage: Usage {
441 input_tokens: prompt_tokens,
442 output_tokens: completion_tokens,
443 total_tokens: prompt_tokens + completion_tokens,
444 },
445 raw_response,
446 })
447 }
448 _ => Err(CompletionError::ResponseError(
449 "Chat response does not include an assistant message".into(),
450 )),
451 }
452 }
453}
454
455#[derive(Clone)]
458pub struct CompletionModel<T> {
459 client: Client<T>,
460 pub model: String,
461}
462
463impl<T> CompletionModel<T> {
464 pub fn new(client: Client<T>, model: &str) -> Self {
465 Self {
466 client,
467 model: model.to_owned(),
468 }
469 }
470
471 fn create_completion_request(
472 &self,
473 completion_request: CompletionRequest,
474 ) -> Result<Value, CompletionError> {
475 if completion_request.tool_choice.is_some() {
476 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
477 }
478
479 let mut partial_history = vec![];
481 if let Some(docs) = completion_request.normalized_documents() {
482 partial_history.push(docs);
483 }
484 partial_history.extend(completion_request.chat_history);
485
486 let mut full_history: Vec<Message> = completion_request
488 .preamble
489 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
490
491 full_history.extend(
493 partial_history
494 .into_iter()
495 .map(|msg| msg.try_into())
496 .collect::<Result<Vec<Vec<Message>>, _>>()?
497 .into_iter()
498 .flatten()
499 .collect::<Vec<Message>>(),
500 );
501
502 let options = if let Some(extra) = completion_request.additional_params {
504 json_utils::merge(
505 json!({ "temperature": completion_request.temperature }),
506 extra,
507 )
508 } else {
509 json!({ "temperature": completion_request.temperature })
510 };
511
512 let mut request_payload = json!({
513 "model": self.model,
514 "messages": full_history,
515 "options": options,
516 "stream": false,
517 });
518 if !completion_request.tools.is_empty() {
519 request_payload["tools"] = json!(
520 completion_request
521 .tools
522 .into_iter()
523 .map(|tool| tool.into())
524 .collect::<Vec<ToolDefinition>>()
525 );
526 }
527
528 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
529
530 Ok(request_payload)
531 }
532}
533
534#[derive(Clone, Serialize, Deserialize, Debug)]
537pub struct StreamingCompletionResponse {
538 pub done_reason: Option<String>,
539 pub total_duration: Option<u64>,
540 pub load_duration: Option<u64>,
541 pub prompt_eval_count: Option<u64>,
542 pub prompt_eval_duration: Option<u64>,
543 pub eval_count: Option<u64>,
544 pub eval_duration: Option<u64>,
545}
546
547impl GetTokenUsage for StreamingCompletionResponse {
548 fn token_usage(&self) -> Option<crate::completion::Usage> {
549 let mut usage = crate::completion::Usage::new();
550 let input_tokens = self.prompt_eval_count.unwrap_or_default();
551 let output_tokens = self.eval_count.unwrap_or_default();
552 usage.input_tokens = input_tokens;
553 usage.output_tokens = output_tokens;
554 usage.total_tokens = input_tokens + output_tokens;
555
556 Some(usage)
557 }
558}
559
560impl completion::CompletionModel for CompletionModel<reqwest::Client> {
561 type Response = CompletionResponse;
562 type StreamingResponse = StreamingCompletionResponse;
563
564 #[cfg_attr(feature = "worker", worker::send)]
565 async fn completion(
566 &self,
567 completion_request: CompletionRequest,
568 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
569 let preamble = completion_request.preamble.clone();
570 let request = self.create_completion_request(completion_request)?;
571
572 let span = if tracing::Span::current().is_disabled() {
573 info_span!(
574 target: "rig::completions",
575 "chat",
576 gen_ai.operation.name = "chat",
577 gen_ai.provider.name = "ollama",
578 gen_ai.request.model = self.model,
579 gen_ai.system_instructions = preamble,
580 gen_ai.response.id = tracing::field::Empty,
581 gen_ai.response.model = tracing::field::Empty,
582 gen_ai.usage.output_tokens = tracing::field::Empty,
583 gen_ai.usage.input_tokens = tracing::field::Empty,
584 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
585 gen_ai.output.messages = tracing::field::Empty,
586 )
587 } else {
588 tracing::Span::current()
589 };
590
591 let async_block = async move {
592 let response = self
593 .client
594 .reqwest_post("api/chat")
595 .json(&request)
596 .send()
597 .await
598 .map_err(|e| http_client::Error::Instance(e.into()))?;
599
600 if !response.status().is_success() {
601 return Err(CompletionError::ProviderError(
602 response
603 .text()
604 .await
605 .map_err(|e| http_client::Error::Instance(e.into()))?,
606 ));
607 }
608
609 let bytes = response
610 .bytes()
611 .await
612 .map_err(|e| http_client::Error::Instance(e.into()))?;
613
614 tracing::debug!(target: "rig", "Received response from Ollama: {}", String::from_utf8_lossy(&bytes));
615
616 let response: CompletionResponse = serde_json::from_slice(&bytes)?;
617 let span = tracing::Span::current();
618 span.record("gen_ai.response.model_name", &response.model);
619 span.record(
620 "gen_ai.output.messages",
621 serde_json::to_string(&vec![&response.message]).unwrap(),
622 );
623 span.record(
624 "gen_ai.usage.input_tokens",
625 response.prompt_eval_count.unwrap_or_default(),
626 );
627 span.record(
628 "gen_ai.usage.output_tokens",
629 response.eval_count.unwrap_or_default(),
630 );
631
632 let response: completion::CompletionResponse<CompletionResponse> =
633 response.try_into()?;
634
635 Ok(response)
636 };
637
638 tracing::Instrument::instrument(async_block, span).await
639 }
640
641 #[cfg_attr(feature = "worker", worker::send)]
642 async fn stream(
643 &self,
644 request: CompletionRequest,
645 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
646 {
647 let preamble = request.preamble.clone();
648 let mut request = self.create_completion_request(request)?;
649 merge_inplace(&mut request, json!({"stream": true}));
650
651 let span = if tracing::Span::current().is_disabled() {
652 info_span!(
653 target: "rig::completions",
654 "chat_streaming",
655 gen_ai.operation.name = "chat_streaming",
656 gen_ai.provider.name = "ollama",
657 gen_ai.request.model = self.model,
658 gen_ai.system_instructions = preamble,
659 gen_ai.response.id = tracing::field::Empty,
660 gen_ai.response.model = self.model,
661 gen_ai.usage.output_tokens = tracing::field::Empty,
662 gen_ai.usage.input_tokens = tracing::field::Empty,
663 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
664 gen_ai.output.messages = tracing::field::Empty,
665 )
666 } else {
667 tracing::Span::current()
668 };
669
670 let response = self
671 .client
672 .reqwest_post("api/chat")
673 .json(&request)
674 .send()
675 .await
676 .map_err(|e| http_client::Error::Instance(e.into()))?;
677
678 if !response.status().is_success() {
679 return Err(CompletionError::ProviderError(
680 response
681 .text()
682 .await
683 .map_err(|e| http_client::Error::Instance(e.into()))?,
684 ));
685 }
686
687 let stream = try_stream! {
688 let span = tracing::Span::current();
689 let mut byte_stream = response.bytes_stream();
690 let mut tool_calls_final = Vec::new();
691 let mut text_response = String::new();
692
693 while let Some(chunk) = byte_stream.next().await {
694 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
695
696 for line in bytes.split(|&b| b == b'\n') {
697 if line.is_empty() {
698 continue;
699 }
700
701 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
702
703 let response: CompletionResponse = serde_json::from_slice(line)?;
704
705 if response.done {
706 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
707 span.record("gen_ai.usage.output_tokens", response.eval_count);
708 let message = Message::Assistant {
709 content: text_response.clone(),
710 thinking: None,
711 images: None,
712 name: None,
713 tool_calls: tool_calls_final.clone()
714 };
715 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
716 yield RawStreamingChoice::FinalResponse(
717 StreamingCompletionResponse {
718 total_duration: response.total_duration,
719 load_duration: response.load_duration,
720 prompt_eval_count: response.prompt_eval_count,
721 prompt_eval_duration: response.prompt_eval_duration,
722 eval_count: response.eval_count,
723 eval_duration: response.eval_duration,
724 done_reason: response.done_reason,
725 }
726 );
727 break;
728 }
729
730 if let Message::Assistant { content, tool_calls, .. } = response.message {
731 if !content.is_empty() {
732 text_response += &content;
733 yield RawStreamingChoice::Message(content);
734 }
735 for tool_call in tool_calls {
736 tool_calls_final.push(tool_call.clone());
737 yield RawStreamingChoice::ToolCall {
738 id: String::new(),
739 name: tool_call.function.name,
740 arguments: tool_call.function.arguments,
741 call_id: None,
742 };
743 }
744 }
745 }
746 }
747 }.instrument(span);
748
749 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
750 stream,
751 )))
752 }
753}
754
755#[derive(Clone, Debug, Deserialize, Serialize)]
759pub struct ToolDefinition {
760 #[serde(rename = "type")]
761 pub type_field: String, pub function: completion::ToolDefinition,
763}
764
765impl From<crate::completion::ToolDefinition> for ToolDefinition {
767 fn from(tool: crate::completion::ToolDefinition) -> Self {
768 ToolDefinition {
769 type_field: "function".to_owned(),
770 function: completion::ToolDefinition {
771 name: tool.name,
772 description: tool.description,
773 parameters: tool.parameters,
774 },
775 }
776 }
777}
778
779#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
780pub struct ToolCall {
781 #[serde(default, rename = "type")]
782 pub r#type: ToolType,
783 pub function: Function,
784}
785#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
786#[serde(rename_all = "lowercase")]
787pub enum ToolType {
788 #[default]
789 Function,
790}
791#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
792pub struct Function {
793 pub name: String,
794 pub arguments: Value,
795}
796
797#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
800#[serde(tag = "role", rename_all = "lowercase")]
801pub enum Message {
802 User {
803 content: String,
804 #[serde(skip_serializing_if = "Option::is_none")]
805 images: Option<Vec<String>>,
806 #[serde(skip_serializing_if = "Option::is_none")]
807 name: Option<String>,
808 },
809 Assistant {
810 #[serde(default)]
811 content: String,
812 #[serde(skip_serializing_if = "Option::is_none")]
813 thinking: Option<String>,
814 #[serde(skip_serializing_if = "Option::is_none")]
815 images: Option<Vec<String>>,
816 #[serde(skip_serializing_if = "Option::is_none")]
817 name: Option<String>,
818 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
819 tool_calls: Vec<ToolCall>,
820 },
821 System {
822 content: String,
823 #[serde(skip_serializing_if = "Option::is_none")]
824 images: Option<Vec<String>>,
825 #[serde(skip_serializing_if = "Option::is_none")]
826 name: Option<String>,
827 },
828 #[serde(rename = "tool")]
829 ToolResult {
830 #[serde(rename = "tool_name")]
831 name: String,
832 content: String,
833 },
834}
835
836impl TryFrom<crate::message::Message> for Vec<Message> {
842 type Error = crate::message::MessageError;
843 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
844 use crate::message::Message as InternalMessage;
845 match internal_msg {
846 InternalMessage::User { content, .. } => {
847 let (tool_results, other_content): (Vec<_>, Vec<_>) =
848 content.into_iter().partition(|content| {
849 matches!(content, crate::message::UserContent::ToolResult(_))
850 });
851
852 if !tool_results.is_empty() {
853 tool_results
854 .into_iter()
855 .map(|content| match content {
856 crate::message::UserContent::ToolResult(
857 crate::message::ToolResult { id, content, .. },
858 ) => {
859 let content_string = content
861 .into_iter()
862 .map(|content| match content {
863 crate::message::ToolResultContent::Text(text) => text.text,
864 _ => "[Non-text content]".to_string(),
865 })
866 .collect::<Vec<_>>()
867 .join("\n");
868
869 Ok::<_, crate::message::MessageError>(Message::ToolResult {
870 name: id,
871 content: content_string,
872 })
873 }
874 _ => unreachable!(),
875 })
876 .collect::<Result<Vec<_>, _>>()
877 } else {
878 let (texts, images) = other_content.into_iter().fold(
880 (Vec::new(), Vec::new()),
881 |(mut texts, mut images), content| {
882 match content {
883 crate::message::UserContent::Text(crate::message::Text {
884 text,
885 }) => texts.push(text),
886 crate::message::UserContent::Image(crate::message::Image {
887 data: DocumentSourceKind::Base64(data),
888 ..
889 }) => images.push(data),
890 crate::message::UserContent::Document(
891 crate::message::Document {
892 data:
893 DocumentSourceKind::Base64(data)
894 | DocumentSourceKind::String(data),
895 ..
896 },
897 ) => texts.push(data),
898 _ => {} }
900 (texts, images)
901 },
902 );
903
904 Ok(vec![Message::User {
905 content: texts.join(" "),
906 images: if images.is_empty() {
907 None
908 } else {
909 Some(
910 images
911 .into_iter()
912 .map(|x| x.to_string())
913 .collect::<Vec<String>>(),
914 )
915 },
916 name: None,
917 }])
918 }
919 }
920 InternalMessage::Assistant { content, .. } => {
921 let mut thinking: Option<String> = None;
922 let (text_content, tool_calls) = content.into_iter().fold(
923 (Vec::new(), Vec::new()),
924 |(mut texts, mut tools), content| {
925 match content {
926 crate::message::AssistantContent::Text(text) => texts.push(text.text),
927 crate::message::AssistantContent::ToolCall(tool_call) => {
928 tools.push(tool_call)
929 }
930 crate::message::AssistantContent::Reasoning(
931 crate::message::Reasoning { reasoning, .. },
932 ) => {
933 thinking =
934 Some(reasoning.first().cloned().unwrap_or(String::new()));
935 }
936 }
937 (texts, tools)
938 },
939 );
940
941 Ok(vec![Message::Assistant {
944 content: text_content.join(" "),
945 thinking,
946 images: None,
947 name: None,
948 tool_calls: tool_calls
949 .into_iter()
950 .map(|tool_call| tool_call.into())
951 .collect::<Vec<_>>(),
952 }])
953 }
954 }
955 }
956}
957
958impl From<Message> for crate::completion::Message {
961 fn from(msg: Message) -> Self {
962 match msg {
963 Message::User { content, .. } => crate::completion::Message::User {
964 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
965 text: content,
966 })),
967 },
968 Message::Assistant {
969 content,
970 tool_calls,
971 ..
972 } => {
973 let mut assistant_contents =
974 vec![crate::completion::message::AssistantContent::Text(Text {
975 text: content,
976 })];
977 for tc in tool_calls {
978 assistant_contents.push(
979 crate::completion::message::AssistantContent::tool_call(
980 tc.function.name.clone(),
981 tc.function.name,
982 tc.function.arguments,
983 ),
984 );
985 }
986 crate::completion::Message::Assistant {
987 id: None,
988 content: OneOrMany::many(assistant_contents).unwrap(),
989 }
990 }
991 Message::System { content, .. } => crate::completion::Message::User {
993 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
994 text: content,
995 })),
996 },
997 Message::ToolResult { name, content } => crate::completion::Message::User {
998 content: OneOrMany::one(message::UserContent::tool_result(
999 name,
1000 OneOrMany::one(message::ToolResultContent::text(content)),
1001 )),
1002 },
1003 }
1004 }
1005}
1006
1007impl Message {
1008 pub fn system(content: &str) -> Self {
1010 Message::System {
1011 content: content.to_owned(),
1012 images: None,
1013 name: None,
1014 }
1015 }
1016}
1017
1018impl From<crate::message::ToolCall> for ToolCall {
1021 fn from(tool_call: crate::message::ToolCall) -> Self {
1022 Self {
1023 r#type: ToolType::Function,
1024 function: Function {
1025 name: tool_call.function.name,
1026 arguments: tool_call.function.arguments,
1027 },
1028 }
1029 }
1030}
1031
1032#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1033pub struct SystemContent {
1034 #[serde(default)]
1035 r#type: SystemContentType,
1036 text: String,
1037}
1038
1039#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1040#[serde(rename_all = "lowercase")]
1041pub enum SystemContentType {
1042 #[default]
1043 Text,
1044}
1045
1046impl From<String> for SystemContent {
1047 fn from(s: String) -> Self {
1048 SystemContent {
1049 r#type: SystemContentType::default(),
1050 text: s,
1051 }
1052 }
1053}
1054
1055impl FromStr for SystemContent {
1056 type Err = std::convert::Infallible;
1057 fn from_str(s: &str) -> Result<Self, Self::Err> {
1058 Ok(SystemContent {
1059 r#type: SystemContentType::default(),
1060 text: s.to_string(),
1061 })
1062 }
1063}
1064
1065#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1066pub struct AssistantContent {
1067 pub text: String,
1068}
1069
1070impl FromStr for AssistantContent {
1071 type Err = std::convert::Infallible;
1072 fn from_str(s: &str) -> Result<Self, Self::Err> {
1073 Ok(AssistantContent { text: s.to_owned() })
1074 }
1075}
1076
1077#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1078#[serde(tag = "type", rename_all = "lowercase")]
1079pub enum UserContent {
1080 Text { text: String },
1081 Image { image_url: ImageUrl },
1082 }
1084
1085impl FromStr for UserContent {
1086 type Err = std::convert::Infallible;
1087 fn from_str(s: &str) -> Result<Self, Self::Err> {
1088 Ok(UserContent::Text { text: s.to_owned() })
1089 }
1090}
1091
1092#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1093pub struct ImageUrl {
1094 pub url: String,
1095 #[serde(default)]
1096 pub detail: ImageDetail,
1097}
1098
1099#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use serde_json::json;
1107
1108 #[tokio::test]
1110 async fn test_chat_completion() {
1111 let sample_chat_response = json!({
1113 "model": "llama3.2",
1114 "created_at": "2023-08-04T19:22:45.499127Z",
1115 "message": {
1116 "role": "assistant",
1117 "content": "The sky is blue because of Rayleigh scattering.",
1118 "images": null,
1119 "tool_calls": [
1120 {
1121 "type": "function",
1122 "function": {
1123 "name": "get_current_weather",
1124 "arguments": {
1125 "location": "San Francisco, CA",
1126 "format": "celsius"
1127 }
1128 }
1129 }
1130 ]
1131 },
1132 "done": true,
1133 "total_duration": 8000000000u64,
1134 "load_duration": 6000000u64,
1135 "prompt_eval_count": 61u64,
1136 "prompt_eval_duration": 400000000u64,
1137 "eval_count": 468u64,
1138 "eval_duration": 7700000000u64
1139 });
1140 let sample_text = sample_chat_response.to_string();
1141
1142 let chat_resp: CompletionResponse =
1143 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1144 let conv: completion::CompletionResponse<CompletionResponse> =
1145 chat_resp.try_into().unwrap();
1146 assert!(
1147 !conv.choice.is_empty(),
1148 "Expected non-empty choice in chat response"
1149 );
1150 }
1151
1152 #[test]
1154 fn test_message_conversion() {
1155 let provider_msg = Message::User {
1157 content: "Test message".to_owned(),
1158 images: None,
1159 name: None,
1160 };
1161 let comp_msg: crate::completion::Message = provider_msg.into();
1163 match comp_msg {
1164 crate::completion::Message::User { content } => {
1165 let first_content = content.first();
1167 match first_content {
1169 crate::completion::message::UserContent::Text(text_struct) => {
1170 assert_eq!(text_struct.text, "Test message");
1171 }
1172 _ => panic!("Expected text content in conversion"),
1173 }
1174 }
1175 _ => panic!("Conversion from provider Message to completion Message failed"),
1176 }
1177 }
1178
1179 #[test]
1181 fn test_tool_definition_conversion() {
1182 let internal_tool = crate::completion::ToolDefinition {
1184 name: "get_current_weather".to_owned(),
1185 description: "Get the current weather for a location".to_owned(),
1186 parameters: json!({
1187 "type": "object",
1188 "properties": {
1189 "location": {
1190 "type": "string",
1191 "description": "The location to get the weather for, e.g. San Francisco, CA"
1192 },
1193 "format": {
1194 "type": "string",
1195 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1196 "enum": ["celsius", "fahrenheit"]
1197 }
1198 },
1199 "required": ["location", "format"]
1200 }),
1201 };
1202 let ollama_tool: ToolDefinition = internal_tool.into();
1204 assert_eq!(ollama_tool.type_field, "function");
1205 assert_eq!(ollama_tool.function.name, "get_current_weather");
1206 assert_eq!(
1207 ollama_tool.function.description,
1208 "Get the current weather for a location"
1209 );
1210 let params = &ollama_tool.function.parameters;
1212 assert_eq!(params["properties"]["location"]["type"], "string");
1213 }
1214}