1use crate::client::{
42 CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::json_utils::merge_inplace;
47use crate::message::DocumentSourceKind;
48use crate::streaming::RawStreamingChoice;
49use crate::{
50 Embed, OneOrMany,
51 completion::{self, CompletionError, CompletionRequest},
52 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
53 impl_conversion_traits, json_utils, message,
54 message::{ImageDetail, Text},
55 streaming,
56};
57use async_stream::try_stream;
58use bytes::Bytes;
59use futures::StreamExt;
60use reqwest;
61use serde::{Deserialize, Serialize};
63use serde_json::{Value, json};
64use std::{convert::TryFrom, str::FromStr};
65use tracing::info_span;
66use tracing_futures::Instrument;
67const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
70
71pub struct ClientBuilder<'a, T = reqwest::Client> {
72 base_url: &'a str,
73 http_client: T,
74}
75
76impl<'a, T> ClientBuilder<'a, T>
77where
78 T: Default,
79{
80 #[allow(clippy::new_without_default)]
81 pub fn new() -> Self {
82 Self {
83 base_url: OLLAMA_API_BASE_URL,
84 http_client: Default::default(),
85 }
86 }
87}
88
89impl<'a, T> ClientBuilder<'a, T> {
90 pub fn new_with_client(http_client: T) -> Self {
91 Self {
92 base_url: OLLAMA_API_BASE_URL,
93 http_client,
94 }
95 }
96
97 pub fn base_url(mut self, base_url: &'a str) -> Self {
98 self.base_url = base_url;
99 self
100 }
101
102 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
103 ClientBuilder {
104 base_url: self.base_url,
105 http_client,
106 }
107 }
108
109 pub fn build(self) -> Client<T> {
110 Client {
111 base_url: self.base_url.into(),
112 http_client: self.http_client,
113 }
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct Client<T = reqwest::Client> {
119 base_url: String,
120 http_client: T,
121}
122
123impl<T> Client<T> {
124 fn req(&self, method: http_client::Method, path: &str) -> http_client::Builder {
125 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
126 http_client::Builder::new().method(method).uri(url)
127 }
128
129 pub(crate) fn post(&self, path: &str) -> http_client::Builder {
130 self.req(http_client::Method::POST, path)
131 }
132
133 pub(crate) fn get(&self, path: &str) -> http_client::Builder {
134 self.req(http_client::Method::GET, path)
135 }
136}
137
138impl Client<reqwest::Client> {
139 pub fn builder<'a>() -> ClientBuilder<'a, reqwest::Client> {
140 ClientBuilder::new()
141 }
142
143 pub fn new() -> Self {
144 Self::builder().build()
145 }
146
147 pub fn from_env() -> Self {
148 <Self as ProviderClient>::from_env()
149 }
150}
151
152impl Default for Client<reqwest::Client> {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl<T> ProviderClient for Client<T>
159where
160 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
161{
162 fn from_env() -> Self {
163 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
164 ClientBuilder::new().base_url(&api_base).build()
165 }
166
167 fn from_val(input: crate::client::ProviderValue) -> Self {
168 let crate::client::ProviderValue::Simple(_) = input else {
169 panic!("Incorrect provider value type")
170 };
171
172 ClientBuilder::new().build()
173 }
174}
175
176impl<T> CompletionClient for Client<T>
177where
178 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
179{
180 type CompletionModel = CompletionModel<T>;
181
182 fn completion_model(&self, model: &str) -> Self::CompletionModel {
183 CompletionModel::new(self.clone(), model)
184 }
185}
186
187impl<T> EmbeddingsClient for Client<T>
188where
189 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
190{
191 type EmbeddingModel = EmbeddingModel<T>;
192 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
193 EmbeddingModel::new(self.clone(), model, 0)
194 }
195 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
196 EmbeddingModel::new(self.clone(), model, ndims)
197 }
198 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
199 EmbeddingsBuilder::new(self.embedding_model(model))
200 }
201}
202
203impl<T> VerifyClient for Client<T>
204where
205 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
206{
207 #[cfg_attr(feature = "worker", worker::send)]
208 async fn verify(&self) -> Result<(), VerifyError> {
209 let req = self
210 .get("api/tags")
211 .body(http_client::NoBody)
212 .map_err(http_client::Error::from)?;
213
214 let response = HttpClientExt::send(&self.http_client, req).await?;
215
216 match response.status() {
217 reqwest::StatusCode::OK => Ok(()),
218 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
219 reqwest::StatusCode::INTERNAL_SERVER_ERROR
220 | reqwest::StatusCode::SERVICE_UNAVAILABLE
221 | reqwest::StatusCode::BAD_GATEWAY => {
222 let text = http_client::text(response).await?;
223 Err(VerifyError::ProviderError(text))
224 }
225 _ => {
226 Ok(())
228 }
229 }
230 }
231}
232
233impl_conversion_traits!(
234 AsTranscription,
235 AsImageGeneration,
236 AsAudioGeneration for Client<T>
237);
238
239#[derive(Debug, Deserialize)]
242struct ApiErrorResponse {
243 message: String,
244}
245
246#[derive(Debug, Deserialize)]
247#[serde(untagged)]
248enum ApiResponse<T> {
249 Ok(T),
250 Err(ApiErrorResponse),
251}
252
253pub const ALL_MINILM: &str = "all-minilm";
256pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
257
258#[derive(Debug, Serialize, Deserialize)]
259pub struct EmbeddingResponse {
260 pub model: String,
261 pub embeddings: Vec<Vec<f64>>,
262 #[serde(default)]
263 pub total_duration: Option<u64>,
264 #[serde(default)]
265 pub load_duration: Option<u64>,
266 #[serde(default)]
267 pub prompt_eval_count: Option<u64>,
268}
269
270impl From<ApiErrorResponse> for EmbeddingError {
271 fn from(err: ApiErrorResponse) -> Self {
272 EmbeddingError::ProviderError(err.message)
273 }
274}
275
276impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
277 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
278 match value {
279 ApiResponse::Ok(response) => Ok(response),
280 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
281 }
282 }
283}
284
285#[derive(Clone)]
288pub struct EmbeddingModel<T> {
289 client: Client<T>,
290 pub model: String,
291 ndims: usize,
292}
293
294impl<T> EmbeddingModel<T> {
295 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
296 Self {
297 client,
298 model: model.to_owned(),
299 ndims,
300 }
301 }
302}
303
304impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
305where
306 T: HttpClientExt + Clone + 'static,
307{
308 const MAX_DOCUMENTS: usize = 1024;
309 fn ndims(&self) -> usize {
310 self.ndims
311 }
312 #[cfg_attr(feature = "worker", worker::send)]
313 async fn embed_texts(
314 &self,
315 documents: impl IntoIterator<Item = String>,
316 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
317 let docs: Vec<String> = documents.into_iter().collect();
318
319 let body = serde_json::to_vec(&json!({
320 "model": self.model,
321 "input": docs
322 }))?;
323
324 let req = self
325 .client
326 .post("api/embed")
327 .header("Content-Type", "application/json")
328 .body(body)
329 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
330
331 let response = HttpClientExt::send(&self.client.http_client, req).await?;
332
333 if !response.status().is_success() {
334 let text = http_client::text(response).await?;
335 return Err(EmbeddingError::ProviderError(text));
336 }
337
338 let bytes: Vec<u8> = response.into_body().await?;
339
340 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
341
342 if api_resp.embeddings.len() != docs.len() {
343 return Err(EmbeddingError::ResponseError(
344 "Number of returned embeddings does not match input".into(),
345 ));
346 }
347 Ok(api_resp
348 .embeddings
349 .into_iter()
350 .zip(docs.into_iter())
351 .map(|(vec, document)| embeddings::Embedding { document, vec })
352 .collect())
353 }
354}
355
356pub const LLAMA3_2: &str = "llama3.2";
359pub const LLAVA: &str = "llava";
360pub const MISTRAL: &str = "mistral";
361
362#[derive(Debug, Serialize, Deserialize)]
363pub struct CompletionResponse {
364 pub model: String,
365 pub created_at: String,
366 pub message: Message,
367 pub done: bool,
368 #[serde(default)]
369 pub done_reason: Option<String>,
370 #[serde(default)]
371 pub total_duration: Option<u64>,
372 #[serde(default)]
373 pub load_duration: Option<u64>,
374 #[serde(default)]
375 pub prompt_eval_count: Option<u64>,
376 #[serde(default)]
377 pub prompt_eval_duration: Option<u64>,
378 #[serde(default)]
379 pub eval_count: Option<u64>,
380 #[serde(default)]
381 pub eval_duration: Option<u64>,
382}
383impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
384 type Error = CompletionError;
385 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
386 match resp.message {
387 Message::Assistant {
389 content,
390 thinking,
391 tool_calls,
392 ..
393 } => {
394 let mut assistant_contents = Vec::new();
395 if !content.is_empty() {
397 assistant_contents.push(completion::AssistantContent::text(&content));
398 }
399 for tc in tool_calls.iter() {
402 assistant_contents.push(completion::AssistantContent::tool_call(
403 tc.function.name.clone(),
404 tc.function.name.clone(),
405 tc.function.arguments.clone(),
406 ));
407 }
408 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
409 CompletionError::ResponseError("No content provided".to_owned())
410 })?;
411 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
412 let completion_tokens = resp.eval_count.unwrap_or(0);
413
414 let raw_response = CompletionResponse {
415 model: resp.model,
416 created_at: resp.created_at,
417 done: resp.done,
418 done_reason: resp.done_reason,
419 total_duration: resp.total_duration,
420 load_duration: resp.load_duration,
421 prompt_eval_count: resp.prompt_eval_count,
422 prompt_eval_duration: resp.prompt_eval_duration,
423 eval_count: resp.eval_count,
424 eval_duration: resp.eval_duration,
425 message: Message::Assistant {
426 content,
427 thinking,
428 images: None,
429 name: None,
430 tool_calls,
431 },
432 };
433
434 Ok(completion::CompletionResponse {
435 choice,
436 usage: Usage {
437 input_tokens: prompt_tokens,
438 output_tokens: completion_tokens,
439 total_tokens: prompt_tokens + completion_tokens,
440 },
441 raw_response,
442 })
443 }
444 _ => Err(CompletionError::ResponseError(
445 "Chat response does not include an assistant message".into(),
446 )),
447 }
448 }
449}
450
451#[derive(Clone)]
454pub struct CompletionModel<T> {
455 client: Client<T>,
456 pub model: String,
457}
458
459impl<T> CompletionModel<T> {
460 pub fn new(client: Client<T>, model: &str) -> Self {
461 Self {
462 client,
463 model: model.to_owned(),
464 }
465 }
466
467 fn create_completion_request(
468 &self,
469 completion_request: CompletionRequest,
470 ) -> Result<Value, CompletionError> {
471 if completion_request.tool_choice.is_some() {
472 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
473 }
474
475 let mut partial_history = vec![];
477 if let Some(docs) = completion_request.normalized_documents() {
478 partial_history.push(docs);
479 }
480 partial_history.extend(completion_request.chat_history);
481
482 let mut full_history: Vec<Message> = completion_request
484 .preamble
485 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
486
487 full_history.extend(
489 partial_history
490 .into_iter()
491 .map(|msg| msg.try_into())
492 .collect::<Result<Vec<Vec<Message>>, _>>()?
493 .into_iter()
494 .flatten()
495 .collect::<Vec<Message>>(),
496 );
497
498 let mut request_payload = json!({
499 "model": self.model,
500 "messages": full_history,
501 "stream": false,
502 });
503
504 let options = if let Some(mut extra) = completion_request.additional_params {
506 if extra.get("think").is_some() {
507 request_payload["think"] = extra["think"].take();
508 }
509 json_utils::merge(
510 json!({ "temperature": completion_request.temperature }),
511 extra,
512 )
513 } else {
514 json!({ "temperature": completion_request.temperature })
515 };
516
517 request_payload["options"] = options;
518
519 if !completion_request.tools.is_empty() {
520 request_payload["tools"] = json!(
521 completion_request
522 .tools
523 .into_iter()
524 .map(|tool| tool.into())
525 .collect::<Vec<ToolDefinition>>()
526 );
527 }
528
529 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
530
531 Ok(request_payload)
532 }
533}
534
535#[derive(Clone, Serialize, Deserialize, Debug)]
538pub struct StreamingCompletionResponse {
539 pub done_reason: Option<String>,
540 pub total_duration: Option<u64>,
541 pub load_duration: Option<u64>,
542 pub prompt_eval_count: Option<u64>,
543 pub prompt_eval_duration: Option<u64>,
544 pub eval_count: Option<u64>,
545 pub eval_duration: Option<u64>,
546}
547
548impl GetTokenUsage for StreamingCompletionResponse {
549 fn token_usage(&self) -> Option<crate::completion::Usage> {
550 let mut usage = crate::completion::Usage::new();
551 let input_tokens = self.prompt_eval_count.unwrap_or_default();
552 let output_tokens = self.eval_count.unwrap_or_default();
553 usage.input_tokens = input_tokens;
554 usage.output_tokens = output_tokens;
555 usage.total_tokens = input_tokens + output_tokens;
556
557 Some(usage)
558 }
559}
560
561impl<T> completion::CompletionModel for CompletionModel<T>
562where
563 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
564{
565 type Response = CompletionResponse;
566 type StreamingResponse = StreamingCompletionResponse;
567
568 #[cfg_attr(feature = "worker", worker::send)]
569 async fn completion(
570 &self,
571 completion_request: CompletionRequest,
572 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
573 let preamble = completion_request.preamble.clone();
574 let request = self.create_completion_request(completion_request)?;
575
576 let span = if tracing::Span::current().is_disabled() {
577 info_span!(
578 target: "rig::completions",
579 "chat",
580 gen_ai.operation.name = "chat",
581 gen_ai.provider.name = "ollama",
582 gen_ai.request.model = self.model,
583 gen_ai.system_instructions = preamble,
584 gen_ai.response.id = tracing::field::Empty,
585 gen_ai.response.model = tracing::field::Empty,
586 gen_ai.usage.output_tokens = tracing::field::Empty,
587 gen_ai.usage.input_tokens = tracing::field::Empty,
588 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
589 gen_ai.output.messages = tracing::field::Empty,
590 )
591 } else {
592 tracing::Span::current()
593 };
594
595 let body = serde_json::to_vec(&request)?;
596
597 let req = self
598 .client
599 .post("api/chat")
600 .header("Content-Type", "application/json")
601 .body(body)
602 .map_err(http_client::Error::from)?;
603
604 let async_block = async move {
605 let response = self.client.http_client.send::<_, Bytes>(req).await?;
606 let status = response.status();
607 let response_body = response.into_body().into_future().await?.to_vec();
608
609 if !status.is_success() {
610 return Err(CompletionError::ProviderError(
611 String::from_utf8_lossy(&response_body).to_string(),
612 ));
613 }
614
615 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
616 let span = tracing::Span::current();
617 span.record("gen_ai.response.model_name", &response.model);
618 span.record(
619 "gen_ai.output.messages",
620 serde_json::to_string(&vec![&response.message]).unwrap(),
621 );
622 span.record(
623 "gen_ai.usage.input_tokens",
624 response.prompt_eval_count.unwrap_or_default(),
625 );
626 span.record(
627 "gen_ai.usage.output_tokens",
628 response.eval_count.unwrap_or_default(),
629 );
630
631 tracing::debug!(target: "rig", "Received response from Ollama: {}", serde_json::to_string_pretty(&response)?);
632
633 let response: completion::CompletionResponse<CompletionResponse> =
634 response.try_into()?;
635
636 Ok(response)
637 };
638
639 tracing::Instrument::instrument(async_block, span).await
640 }
641
642 #[cfg_attr(feature = "worker", worker::send)]
643 async fn stream(
644 &self,
645 request: CompletionRequest,
646 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
647 {
648 let preamble = request.preamble.clone();
649 let mut request = self.create_completion_request(request)?;
650 merge_inplace(&mut request, json!({"stream": true}));
651
652 let span = if tracing::Span::current().is_disabled() {
653 info_span!(
654 target: "rig::completions",
655 "chat_streaming",
656 gen_ai.operation.name = "chat_streaming",
657 gen_ai.provider.name = "ollama",
658 gen_ai.request.model = self.model,
659 gen_ai.system_instructions = preamble,
660 gen_ai.response.id = tracing::field::Empty,
661 gen_ai.response.model = self.model,
662 gen_ai.usage.output_tokens = tracing::field::Empty,
663 gen_ai.usage.input_tokens = tracing::field::Empty,
664 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
665 gen_ai.output.messages = tracing::field::Empty,
666 )
667 } else {
668 tracing::Span::current()
669 };
670
671 let body = serde_json::to_vec(&request)?;
672
673 let req = self
674 .client
675 .post("api/chat")
676 .header("Content-Type", "application/json")
677 .body(body)
678 .map_err(http_client::Error::from)?;
679
680 let response = self.client.http_client.send_streaming(req).await?;
681 let status = response.status();
682 let mut byte_stream = response.into_body();
683
684 if !status.is_success() {
685 return Err(CompletionError::ProviderError(format!(
686 "Got error status code trying to send a request to Ollama: {status}"
687 )));
688 }
689
690 let stream = try_stream! {
691 let span = tracing::Span::current();
692 let mut tool_calls_final = Vec::new();
693 let mut text_response = String::new();
694 let mut thinking_response = String::new();
695
696 while let Some(chunk) = byte_stream.next().await {
697 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
698
699 for line in bytes.split(|&b| b == b'\n') {
700 if line.is_empty() {
701 continue;
702 }
703
704 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
705
706 let response: CompletionResponse = serde_json::from_slice(line)?;
707
708 if response.done {
709 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
710 span.record("gen_ai.usage.output_tokens", response.eval_count);
711 let message = Message::Assistant {
712 content: text_response.clone(),
713 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
714 images: None,
715 name: None,
716 tool_calls: tool_calls_final.clone()
717 };
718 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
719 yield RawStreamingChoice::FinalResponse(
720 StreamingCompletionResponse {
721 total_duration: response.total_duration,
722 load_duration: response.load_duration,
723 prompt_eval_count: response.prompt_eval_count,
724 prompt_eval_duration: response.prompt_eval_duration,
725 eval_count: response.eval_count,
726 eval_duration: response.eval_duration,
727 done_reason: response.done_reason,
728 }
729 );
730 break;
731 }
732
733 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
734 if let Some(thinking_content) = thinking
735 && !thinking_content.is_empty() {
736 thinking_response += &thinking_content;
737 yield RawStreamingChoice::Reasoning {
738 reasoning: thinking_content,
739 id: None,
740 signature: None,
741 };
742 }
743
744 if !content.is_empty() {
745 text_response += &content;
746 yield RawStreamingChoice::Message(content);
747 }
748
749 for tool_call in tool_calls {
750 tool_calls_final.push(tool_call.clone());
751 yield RawStreamingChoice::ToolCall {
752 id: String::new(),
753 name: tool_call.function.name,
754 arguments: tool_call.function.arguments,
755 call_id: None,
756 };
757 }
758 }
759 }
760 }
761 }.instrument(span);
762
763 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
764 stream,
765 )))
766 }
767}
768
769#[derive(Clone, Debug, Deserialize, Serialize)]
773pub struct ToolDefinition {
774 #[serde(rename = "type")]
775 pub type_field: String, pub function: completion::ToolDefinition,
777}
778
779impl From<crate::completion::ToolDefinition> for ToolDefinition {
781 fn from(tool: crate::completion::ToolDefinition) -> Self {
782 ToolDefinition {
783 type_field: "function".to_owned(),
784 function: completion::ToolDefinition {
785 name: tool.name,
786 description: tool.description,
787 parameters: tool.parameters,
788 },
789 }
790 }
791}
792
793#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
794pub struct ToolCall {
795 #[serde(default, rename = "type")]
796 pub r#type: ToolType,
797 pub function: Function,
798}
799#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
800#[serde(rename_all = "lowercase")]
801pub enum ToolType {
802 #[default]
803 Function,
804}
805#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
806pub struct Function {
807 pub name: String,
808 pub arguments: Value,
809}
810
811#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
814#[serde(tag = "role", rename_all = "lowercase")]
815pub enum Message {
816 User {
817 content: String,
818 #[serde(skip_serializing_if = "Option::is_none")]
819 images: Option<Vec<String>>,
820 #[serde(skip_serializing_if = "Option::is_none")]
821 name: Option<String>,
822 },
823 Assistant {
824 #[serde(default)]
825 content: String,
826 #[serde(skip_serializing_if = "Option::is_none")]
827 thinking: Option<String>,
828 #[serde(skip_serializing_if = "Option::is_none")]
829 images: Option<Vec<String>>,
830 #[serde(skip_serializing_if = "Option::is_none")]
831 name: Option<String>,
832 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
833 tool_calls: Vec<ToolCall>,
834 },
835 System {
836 content: String,
837 #[serde(skip_serializing_if = "Option::is_none")]
838 images: Option<Vec<String>>,
839 #[serde(skip_serializing_if = "Option::is_none")]
840 name: Option<String>,
841 },
842 #[serde(rename = "tool")]
843 ToolResult {
844 #[serde(rename = "tool_name")]
845 name: String,
846 content: String,
847 },
848}
849
850impl TryFrom<crate::message::Message> for Vec<Message> {
856 type Error = crate::message::MessageError;
857 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
858 use crate::message::Message as InternalMessage;
859 match internal_msg {
860 InternalMessage::User { content, .. } => {
861 let (tool_results, other_content): (Vec<_>, Vec<_>) =
862 content.into_iter().partition(|content| {
863 matches!(content, crate::message::UserContent::ToolResult(_))
864 });
865
866 if !tool_results.is_empty() {
867 tool_results
868 .into_iter()
869 .map(|content| match content {
870 crate::message::UserContent::ToolResult(
871 crate::message::ToolResult { id, content, .. },
872 ) => {
873 let content_string = content
875 .into_iter()
876 .map(|content| match content {
877 crate::message::ToolResultContent::Text(text) => text.text,
878 _ => "[Non-text content]".to_string(),
879 })
880 .collect::<Vec<_>>()
881 .join("\n");
882
883 Ok::<_, crate::message::MessageError>(Message::ToolResult {
884 name: id,
885 content: content_string,
886 })
887 }
888 _ => unreachable!(),
889 })
890 .collect::<Result<Vec<_>, _>>()
891 } else {
892 let (texts, images) = other_content.into_iter().fold(
894 (Vec::new(), Vec::new()),
895 |(mut texts, mut images), content| {
896 match content {
897 crate::message::UserContent::Text(crate::message::Text {
898 text,
899 }) => texts.push(text),
900 crate::message::UserContent::Image(crate::message::Image {
901 data: DocumentSourceKind::Base64(data),
902 ..
903 }) => images.push(data),
904 crate::message::UserContent::Document(
905 crate::message::Document {
906 data:
907 DocumentSourceKind::Base64(data)
908 | DocumentSourceKind::String(data),
909 ..
910 },
911 ) => texts.push(data),
912 _ => {} }
914 (texts, images)
915 },
916 );
917
918 Ok(vec![Message::User {
919 content: texts.join(" "),
920 images: if images.is_empty() {
921 None
922 } else {
923 Some(
924 images
925 .into_iter()
926 .map(|x| x.to_string())
927 .collect::<Vec<String>>(),
928 )
929 },
930 name: None,
931 }])
932 }
933 }
934 InternalMessage::Assistant { content, .. } => {
935 let mut thinking: Option<String> = None;
936 let (text_content, tool_calls) = content.into_iter().fold(
937 (Vec::new(), Vec::new()),
938 |(mut texts, mut tools), content| {
939 match content {
940 crate::message::AssistantContent::Text(text) => texts.push(text.text),
941 crate::message::AssistantContent::ToolCall(tool_call) => {
942 tools.push(tool_call)
943 }
944 crate::message::AssistantContent::Reasoning(
945 crate::message::Reasoning { reasoning, .. },
946 ) => {
947 thinking =
948 Some(reasoning.first().cloned().unwrap_or(String::new()));
949 }
950 }
951 (texts, tools)
952 },
953 );
954
955 Ok(vec![Message::Assistant {
958 content: text_content.join(" "),
959 thinking,
960 images: None,
961 name: None,
962 tool_calls: tool_calls
963 .into_iter()
964 .map(|tool_call| tool_call.into())
965 .collect::<Vec<_>>(),
966 }])
967 }
968 }
969 }
970}
971
972impl From<Message> for crate::completion::Message {
975 fn from(msg: Message) -> Self {
976 match msg {
977 Message::User { content, .. } => crate::completion::Message::User {
978 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
979 text: content,
980 })),
981 },
982 Message::Assistant {
983 content,
984 tool_calls,
985 ..
986 } => {
987 let mut assistant_contents =
988 vec![crate::completion::message::AssistantContent::Text(Text {
989 text: content,
990 })];
991 for tc in tool_calls {
992 assistant_contents.push(
993 crate::completion::message::AssistantContent::tool_call(
994 tc.function.name.clone(),
995 tc.function.name,
996 tc.function.arguments,
997 ),
998 );
999 }
1000 crate::completion::Message::Assistant {
1001 id: None,
1002 content: OneOrMany::many(assistant_contents).unwrap(),
1003 }
1004 }
1005 Message::System { content, .. } => crate::completion::Message::User {
1007 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
1008 text: content,
1009 })),
1010 },
1011 Message::ToolResult { name, content } => crate::completion::Message::User {
1012 content: OneOrMany::one(message::UserContent::tool_result(
1013 name,
1014 OneOrMany::one(message::ToolResultContent::text(content)),
1015 )),
1016 },
1017 }
1018 }
1019}
1020
1021impl Message {
1022 pub fn system(content: &str) -> Self {
1024 Message::System {
1025 content: content.to_owned(),
1026 images: None,
1027 name: None,
1028 }
1029 }
1030}
1031
1032impl From<crate::message::ToolCall> for ToolCall {
1035 fn from(tool_call: crate::message::ToolCall) -> Self {
1036 Self {
1037 r#type: ToolType::Function,
1038 function: Function {
1039 name: tool_call.function.name,
1040 arguments: tool_call.function.arguments,
1041 },
1042 }
1043 }
1044}
1045
1046#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1047pub struct SystemContent {
1048 #[serde(default)]
1049 r#type: SystemContentType,
1050 text: String,
1051}
1052
1053#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1054#[serde(rename_all = "lowercase")]
1055pub enum SystemContentType {
1056 #[default]
1057 Text,
1058}
1059
1060impl From<String> for SystemContent {
1061 fn from(s: String) -> Self {
1062 SystemContent {
1063 r#type: SystemContentType::default(),
1064 text: s,
1065 }
1066 }
1067}
1068
1069impl FromStr for SystemContent {
1070 type Err = std::convert::Infallible;
1071 fn from_str(s: &str) -> Result<Self, Self::Err> {
1072 Ok(SystemContent {
1073 r#type: SystemContentType::default(),
1074 text: s.to_string(),
1075 })
1076 }
1077}
1078
1079#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1080pub struct AssistantContent {
1081 pub text: String,
1082}
1083
1084impl FromStr for AssistantContent {
1085 type Err = std::convert::Infallible;
1086 fn from_str(s: &str) -> Result<Self, Self::Err> {
1087 Ok(AssistantContent { text: s.to_owned() })
1088 }
1089}
1090
1091#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1092#[serde(tag = "type", rename_all = "lowercase")]
1093pub enum UserContent {
1094 Text { text: String },
1095 Image { image_url: ImageUrl },
1096 }
1098
1099impl FromStr for UserContent {
1100 type Err = std::convert::Infallible;
1101 fn from_str(s: &str) -> Result<Self, Self::Err> {
1102 Ok(UserContent::Text { text: s.to_owned() })
1103 }
1104}
1105
1106#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1107pub struct ImageUrl {
1108 pub url: String,
1109 #[serde(default)]
1110 pub detail: ImageDetail,
1111}
1112
1113#[cfg(test)]
1118mod tests {
1119 use super::*;
1120 use serde_json::json;
1121
1122 #[tokio::test]
1124 async fn test_chat_completion() {
1125 let sample_chat_response = json!({
1127 "model": "llama3.2",
1128 "created_at": "2023-08-04T19:22:45.499127Z",
1129 "message": {
1130 "role": "assistant",
1131 "content": "The sky is blue because of Rayleigh scattering.",
1132 "images": null,
1133 "tool_calls": [
1134 {
1135 "type": "function",
1136 "function": {
1137 "name": "get_current_weather",
1138 "arguments": {
1139 "location": "San Francisco, CA",
1140 "format": "celsius"
1141 }
1142 }
1143 }
1144 ]
1145 },
1146 "done": true,
1147 "total_duration": 8000000000u64,
1148 "load_duration": 6000000u64,
1149 "prompt_eval_count": 61u64,
1150 "prompt_eval_duration": 400000000u64,
1151 "eval_count": 468u64,
1152 "eval_duration": 7700000000u64
1153 });
1154 let sample_text = sample_chat_response.to_string();
1155
1156 let chat_resp: CompletionResponse =
1157 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1158 let conv: completion::CompletionResponse<CompletionResponse> =
1159 chat_resp.try_into().unwrap();
1160 assert!(
1161 !conv.choice.is_empty(),
1162 "Expected non-empty choice in chat response"
1163 );
1164 }
1165
1166 #[test]
1168 fn test_message_conversion() {
1169 let provider_msg = Message::User {
1171 content: "Test message".to_owned(),
1172 images: None,
1173 name: None,
1174 };
1175 let comp_msg: crate::completion::Message = provider_msg.into();
1177 match comp_msg {
1178 crate::completion::Message::User { content } => {
1179 let first_content = content.first();
1181 match first_content {
1183 crate::completion::message::UserContent::Text(text_struct) => {
1184 assert_eq!(text_struct.text, "Test message");
1185 }
1186 _ => panic!("Expected text content in conversion"),
1187 }
1188 }
1189 _ => panic!("Conversion from provider Message to completion Message failed"),
1190 }
1191 }
1192
1193 #[test]
1195 fn test_tool_definition_conversion() {
1196 let internal_tool = crate::completion::ToolDefinition {
1198 name: "get_current_weather".to_owned(),
1199 description: "Get the current weather for a location".to_owned(),
1200 parameters: json!({
1201 "type": "object",
1202 "properties": {
1203 "location": {
1204 "type": "string",
1205 "description": "The location to get the weather for, e.g. San Francisco, CA"
1206 },
1207 "format": {
1208 "type": "string",
1209 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1210 "enum": ["celsius", "fahrenheit"]
1211 }
1212 },
1213 "required": ["location", "format"]
1214 }),
1215 };
1216 let ollama_tool: ToolDefinition = internal_tool.into();
1218 assert_eq!(ollama_tool.type_field, "function");
1219 assert_eq!(ollama_tool.function.name, "get_current_weather");
1220 assert_eq!(
1221 ollama_tool.function.description,
1222 "Get the current weather for a location"
1223 );
1224 let params = &ollama_tool.function.parameters;
1226 assert_eq!(params["properties"]["location"]["type"], "string");
1227 }
1228
1229 #[tokio::test]
1231 async fn test_chat_completion_with_thinking() {
1232 let sample_response = json!({
1233 "model": "qwen-thinking",
1234 "created_at": "2023-08-04T19:22:45.499127Z",
1235 "message": {
1236 "role": "assistant",
1237 "content": "The answer is 42.",
1238 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1239 "images": null,
1240 "tool_calls": []
1241 },
1242 "done": true,
1243 "total_duration": 8000000000u64,
1244 "load_duration": 6000000u64,
1245 "prompt_eval_count": 61u64,
1246 "prompt_eval_duration": 400000000u64,
1247 "eval_count": 468u64,
1248 "eval_duration": 7700000000u64
1249 });
1250
1251 let chat_resp: CompletionResponse =
1252 serde_json::from_value(sample_response).expect("Failed to deserialize");
1253
1254 if let Message::Assistant {
1256 thinking, content, ..
1257 } = &chat_resp.message
1258 {
1259 assert_eq!(
1260 thinking.as_ref().unwrap(),
1261 "Let me think about this carefully. The question asks for the meaning of life..."
1262 );
1263 assert_eq!(content, "The answer is 42.");
1264 } else {
1265 panic!("Expected Assistant message");
1266 }
1267 }
1268
1269 #[tokio::test]
1271 async fn test_chat_completion_without_thinking() {
1272 let sample_response = json!({
1273 "model": "llama3.2",
1274 "created_at": "2023-08-04T19:22:45.499127Z",
1275 "message": {
1276 "role": "assistant",
1277 "content": "Hello!",
1278 "images": null,
1279 "tool_calls": []
1280 },
1281 "done": true,
1282 "total_duration": 8000000000u64,
1283 "load_duration": 6000000u64,
1284 "prompt_eval_count": 10u64,
1285 "prompt_eval_duration": 400000000u64,
1286 "eval_count": 5u64,
1287 "eval_duration": 7700000000u64
1288 });
1289
1290 let chat_resp: CompletionResponse =
1291 serde_json::from_value(sample_response).expect("Failed to deserialize");
1292
1293 if let Message::Assistant {
1295 thinking, content, ..
1296 } = &chat_resp.message
1297 {
1298 assert!(thinking.is_none());
1299 assert_eq!(content, "Hello!");
1300 } else {
1301 panic!("Expected Assistant message");
1302 }
1303 }
1304
1305 #[test]
1307 fn test_streaming_response_with_thinking() {
1308 let sample_chunk = json!({
1309 "model": "qwen-thinking",
1310 "created_at": "2023-08-04T19:22:45.499127Z",
1311 "message": {
1312 "role": "assistant",
1313 "content": "",
1314 "thinking": "Analyzing the problem...",
1315 "images": null,
1316 "tool_calls": []
1317 },
1318 "done": false
1319 });
1320
1321 let chunk: CompletionResponse =
1322 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1323
1324 if let Message::Assistant {
1325 thinking, content, ..
1326 } = &chunk.message
1327 {
1328 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1329 assert_eq!(content, "");
1330 } else {
1331 panic!("Expected Assistant message");
1332 }
1333 }
1334
1335 #[test]
1337 fn test_message_conversion_with_thinking() {
1338 let reasoning_content = crate::message::Reasoning {
1340 id: None,
1341 reasoning: vec!["Step 1: Consider the problem".to_string()],
1342 signature: None,
1343 };
1344
1345 let internal_msg = crate::message::Message::Assistant {
1346 id: None,
1347 content: crate::OneOrMany::many(vec![
1348 crate::message::AssistantContent::Reasoning(reasoning_content),
1349 crate::message::AssistantContent::Text(crate::message::Text {
1350 text: "The answer is X".to_string(),
1351 }),
1352 ])
1353 .unwrap(),
1354 };
1355
1356 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1358 assert_eq!(provider_msgs.len(), 1);
1359
1360 if let Message::Assistant {
1361 thinking, content, ..
1362 } = &provider_msgs[0]
1363 {
1364 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1365 assert_eq!(content, "The answer is X");
1366 } else {
1367 panic!("Expected Assistant message with thinking");
1368 }
1369 }
1370
1371 #[test]
1373 fn test_empty_thinking_content() {
1374 let sample_response = json!({
1375 "model": "llama3.2",
1376 "created_at": "2023-08-04T19:22:45.499127Z",
1377 "message": {
1378 "role": "assistant",
1379 "content": "Response",
1380 "thinking": "",
1381 "images": null,
1382 "tool_calls": []
1383 },
1384 "done": true,
1385 "total_duration": 8000000000u64,
1386 "load_duration": 6000000u64,
1387 "prompt_eval_count": 10u64,
1388 "prompt_eval_duration": 400000000u64,
1389 "eval_count": 5u64,
1390 "eval_duration": 7700000000u64
1391 });
1392
1393 let chat_resp: CompletionResponse =
1394 serde_json::from_value(sample_response).expect("Failed to deserialize");
1395
1396 if let Message::Assistant {
1397 thinking, content, ..
1398 } = &chat_resp.message
1399 {
1400 assert_eq!(thinking.as_ref().unwrap(), "");
1402 assert_eq!(content, "Response");
1403 } else {
1404 panic!("Expected Assistant message");
1405 }
1406 }
1407
1408 #[test]
1410 fn test_thinking_with_tool_calls() {
1411 let sample_response = json!({
1412 "model": "qwen-thinking",
1413 "created_at": "2023-08-04T19:22:45.499127Z",
1414 "message": {
1415 "role": "assistant",
1416 "content": "Let me check the weather.",
1417 "thinking": "User wants weather info, I should use the weather tool",
1418 "images": null,
1419 "tool_calls": [
1420 {
1421 "type": "function",
1422 "function": {
1423 "name": "get_weather",
1424 "arguments": {
1425 "location": "San Francisco"
1426 }
1427 }
1428 }
1429 ]
1430 },
1431 "done": true,
1432 "total_duration": 8000000000u64,
1433 "load_duration": 6000000u64,
1434 "prompt_eval_count": 30u64,
1435 "prompt_eval_duration": 400000000u64,
1436 "eval_count": 50u64,
1437 "eval_duration": 7700000000u64
1438 });
1439
1440 let chat_resp: CompletionResponse =
1441 serde_json::from_value(sample_response).expect("Failed to deserialize");
1442
1443 if let Message::Assistant {
1444 thinking,
1445 content,
1446 tool_calls,
1447 ..
1448 } = &chat_resp.message
1449 {
1450 assert_eq!(
1451 thinking.as_ref().unwrap(),
1452 "User wants weather info, I should use the weather tool"
1453 );
1454 assert_eq!(content, "Let me check the weather.");
1455 assert_eq!(tool_calls.len(), 1);
1456 assert_eq!(tool_calls[0].function.name, "get_weather");
1457 } else {
1458 panic!("Expected Assistant message with thinking and tool calls");
1459 }
1460 }
1461}