1use crate::client::{
42 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49 OneOrMany,
50 completion::{self, CompletionError, CompletionRequest},
51 embeddings::{self, EmbeddingError},
52 json_utils, message,
53 message::{ImageDetail, Text},
54 streaming,
55};
56use async_stream::try_stream;
57use bytes::Bytes;
58use futures::StreamExt;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use tracing::info_span;
63use tracing_futures::Instrument;
64const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68#[derive(Debug, Default, Clone, Copy)]
69pub struct OllamaExt;
70
71#[derive(Debug, Default, Clone, Copy)]
72pub struct OllamaBuilder;
73
74impl Provider for OllamaExt {
75 type Builder = OllamaBuilder;
76
77 const VERIFY_PATH: &'static str = "api/tags";
78
79 fn build<H>(
80 _: &crate::client::ClientBuilder<
81 Self::Builder,
82 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
83 H,
84 >,
85 ) -> http_client::Result<Self> {
86 Ok(Self)
87 }
88}
89
90impl<H> Capabilities<H> for OllamaExt {
91 type Completion = Capable<CompletionModel<H>>;
92 type Transcription = Nothing;
93 type Embeddings = Capable<EmbeddingModel<H>>;
94 #[cfg(feature = "image")]
95 type ImageGeneration = Nothing;
96
97 #[cfg(feature = "audio")]
98 type AudioGeneration = Nothing;
99}
100
101impl DebugExt for OllamaExt {}
102
103impl ProviderBuilder for OllamaBuilder {
104 type Output = OllamaExt;
105 type ApiKey = Nothing;
106
107 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
108}
109
110pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
111pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
112
113impl ProviderClient for Client {
114 type Input = Nothing;
115
116 fn from_env() -> Self {
117 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
118
119 Self::builder()
120 .api_key(Nothing)
121 .base_url(&api_base)
122 .build()
123 .unwrap()
124 }
125
126 fn from_val(_: Self::Input) -> Self {
127 Self::builder().api_key(Nothing).build().unwrap()
128 }
129}
130
131#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135 message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141 Ok(T),
142 Err(ApiErrorResponse),
143}
144
145pub const ALL_MINILM: &str = "all-minilm";
148pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct EmbeddingResponse {
152 pub model: String,
153 pub embeddings: Vec<Vec<f64>>,
154 #[serde(default)]
155 pub total_duration: Option<u64>,
156 #[serde(default)]
157 pub load_duration: Option<u64>,
158 #[serde(default)]
159 pub prompt_eval_count: Option<u64>,
160}
161
162impl From<ApiErrorResponse> for EmbeddingError {
163 fn from(err: ApiErrorResponse) -> Self {
164 EmbeddingError::ProviderError(err.message)
165 }
166}
167
168impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
169 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
170 match value {
171 ApiResponse::Ok(response) => Ok(response),
172 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
173 }
174 }
175}
176
177#[derive(Clone)]
180pub struct EmbeddingModel<T> {
181 client: Client<T>,
182 pub model: String,
183 ndims: usize,
184}
185
186impl<T> EmbeddingModel<T> {
187 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
188 Self {
189 client,
190 model: model.into(),
191 ndims,
192 }
193 }
194
195 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
196 Self {
197 client,
198 model: model.into(),
199 ndims,
200 }
201 }
202}
203
204impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
205where
206 T: HttpClientExt + Clone + 'static,
207{
208 type Client = Client<T>;
209
210 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
211 Self::new(client.clone(), model, dims.unwrap())
212 }
213
214 const MAX_DOCUMENTS: usize = 1024;
215 fn ndims(&self) -> usize {
216 self.ndims
217 }
218 #[cfg_attr(feature = "worker", worker::send)]
219 async fn embed_texts(
220 &self,
221 documents: impl IntoIterator<Item = String>,
222 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223 let docs: Vec<String> = documents.into_iter().collect();
224
225 let body = serde_json::to_vec(&json!({
226 "model": self.model,
227 "input": docs
228 }))?;
229
230 let req = self
231 .client
232 .post("api/embed")?
233 .body(body)
234 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
235
236 let response = self.client.send(req).await?;
237
238 if !response.status().is_success() {
239 let text = http_client::text(response).await?;
240 return Err(EmbeddingError::ProviderError(text));
241 }
242
243 let bytes: Vec<u8> = response.into_body().await?;
244
245 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
246
247 if api_resp.embeddings.len() != docs.len() {
248 return Err(EmbeddingError::ResponseError(
249 "Number of returned embeddings does not match input".into(),
250 ));
251 }
252 Ok(api_resp
253 .embeddings
254 .into_iter()
255 .zip(docs.into_iter())
256 .map(|(vec, document)| embeddings::Embedding { document, vec })
257 .collect())
258 }
259}
260
261pub const LLAMA3_2: &str = "llama3.2";
264pub const LLAVA: &str = "llava";
265pub const MISTRAL: &str = "mistral";
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct CompletionResponse {
269 pub model: String,
270 pub created_at: String,
271 pub message: Message,
272 pub done: bool,
273 #[serde(default)]
274 pub done_reason: Option<String>,
275 #[serde(default)]
276 pub total_duration: Option<u64>,
277 #[serde(default)]
278 pub load_duration: Option<u64>,
279 #[serde(default)]
280 pub prompt_eval_count: Option<u64>,
281 #[serde(default)]
282 pub prompt_eval_duration: Option<u64>,
283 #[serde(default)]
284 pub eval_count: Option<u64>,
285 #[serde(default)]
286 pub eval_duration: Option<u64>,
287}
288impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
289 type Error = CompletionError;
290 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
291 match resp.message {
292 Message::Assistant {
294 content,
295 thinking,
296 tool_calls,
297 ..
298 } => {
299 let mut assistant_contents = Vec::new();
300 if !content.is_empty() {
302 assistant_contents.push(completion::AssistantContent::text(&content));
303 }
304 for tc in tool_calls.iter() {
307 assistant_contents.push(completion::AssistantContent::tool_call(
308 tc.function.name.clone(),
309 tc.function.name.clone(),
310 tc.function.arguments.clone(),
311 ));
312 }
313 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
314 CompletionError::ResponseError("No content provided".to_owned())
315 })?;
316 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
317 let completion_tokens = resp.eval_count.unwrap_or(0);
318
319 let raw_response = CompletionResponse {
320 model: resp.model,
321 created_at: resp.created_at,
322 done: resp.done,
323 done_reason: resp.done_reason,
324 total_duration: resp.total_duration,
325 load_duration: resp.load_duration,
326 prompt_eval_count: resp.prompt_eval_count,
327 prompt_eval_duration: resp.prompt_eval_duration,
328 eval_count: resp.eval_count,
329 eval_duration: resp.eval_duration,
330 message: Message::Assistant {
331 content,
332 thinking,
333 images: None,
334 name: None,
335 tool_calls,
336 },
337 };
338
339 Ok(completion::CompletionResponse {
340 choice,
341 usage: Usage {
342 input_tokens: prompt_tokens,
343 output_tokens: completion_tokens,
344 total_tokens: prompt_tokens + completion_tokens,
345 },
346 raw_response,
347 })
348 }
349 _ => Err(CompletionError::ResponseError(
350 "Chat response does not include an assistant message".into(),
351 )),
352 }
353 }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357pub(super) struct OllamaCompletionRequest {
358 model: String,
359 pub messages: Vec<Message>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 temperature: Option<f64>,
362 #[serde(skip_serializing_if = "Vec::is_empty")]
363 tools: Vec<ToolDefinition>,
364 pub stream: bool,
365 think: bool,
366 #[serde(skip_serializing_if = "Option::is_none")]
367 max_tokens: Option<u64>,
368 options: serde_json::Value,
369}
370
371impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
372 type Error = CompletionError;
373
374 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
375 if req.tool_choice.is_some() {
376 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
377 }
378 let mut partial_history = vec![];
380 if let Some(docs) = req.normalized_documents() {
381 partial_history.push(docs);
382 }
383 partial_history.extend(req.chat_history);
384
385 let mut full_history: Vec<Message> = match &req.preamble {
387 Some(preamble) => vec![Message::system(preamble)],
388 None => vec![],
389 };
390
391 full_history.extend(
393 partial_history
394 .into_iter()
395 .map(message::Message::try_into)
396 .collect::<Result<Vec<Vec<Message>>, _>>()?
397 .into_iter()
398 .flatten()
399 .collect::<Vec<_>>(),
400 );
401
402 let mut think = false;
403
404 let options = if let Some(mut extra) = req.additional_params {
406 if extra.get("think").is_some() {
407 think = extra["think"].take().as_bool().ok_or_else(|| {
408 CompletionError::RequestError("`think` must be a bool".into())
409 })?;
410 }
411 json_utils::merge(json!({ "temperature": req.temperature }), extra)
412 } else {
413 json!({ "temperature": req.temperature })
414 };
415
416 Ok(Self {
417 model: model.to_string(),
418 messages: full_history,
419 temperature: req.temperature,
420 max_tokens: req.max_tokens,
421 stream: false,
422 think,
423 tools: req
424 .tools
425 .clone()
426 .into_iter()
427 .map(ToolDefinition::from)
428 .collect::<Vec<_>>(),
429 options,
430 })
431 }
432}
433
434#[derive(Clone)]
435pub struct CompletionModel<T = reqwest::Client> {
436 client: Client<T>,
437 pub model: String,
438}
439
440impl<T> CompletionModel<T> {
441 pub fn new(client: Client<T>, model: &str) -> Self {
442 Self {
443 client,
444 model: model.to_owned(),
445 }
446 }
447}
448
449#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct StreamingCompletionResponse {
453 pub done_reason: Option<String>,
454 pub total_duration: Option<u64>,
455 pub load_duration: Option<u64>,
456 pub prompt_eval_count: Option<u64>,
457 pub prompt_eval_duration: Option<u64>,
458 pub eval_count: Option<u64>,
459 pub eval_duration: Option<u64>,
460}
461
462impl GetTokenUsage for StreamingCompletionResponse {
463 fn token_usage(&self) -> Option<crate::completion::Usage> {
464 let mut usage = crate::completion::Usage::new();
465 let input_tokens = self.prompt_eval_count.unwrap_or_default();
466 let output_tokens = self.eval_count.unwrap_or_default();
467 usage.input_tokens = input_tokens;
468 usage.output_tokens = output_tokens;
469 usage.total_tokens = input_tokens + output_tokens;
470
471 Some(usage)
472 }
473}
474
475impl<T> completion::CompletionModel for CompletionModel<T>
476where
477 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
478{
479 type Response = CompletionResponse;
480 type StreamingResponse = StreamingCompletionResponse;
481
482 type Client = Client<T>;
483
484 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
485 Self::new(client.clone(), model.into().as_str())
486 }
487
488 #[cfg_attr(feature = "worker", worker::send)]
489 async fn completion(
490 &self,
491 completion_request: CompletionRequest,
492 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
493 let span = if tracing::Span::current().is_disabled() {
494 info_span!(
495 target: "rig::completions",
496 "chat",
497 gen_ai.operation.name = "chat",
498 gen_ai.provider.name = "ollama",
499 gen_ai.request.model = self.model,
500 gen_ai.system_instructions = tracing::field::Empty,
501 gen_ai.response.id = tracing::field::Empty,
502 gen_ai.response.model = tracing::field::Empty,
503 gen_ai.usage.output_tokens = tracing::field::Empty,
504 gen_ai.usage.input_tokens = tracing::field::Empty,
505 )
506 } else {
507 tracing::Span::current()
508 };
509
510 span.record("gen_ai.system_instructions", &completion_request.preamble);
511 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
512
513 if tracing::enabled!(tracing::Level::TRACE) {
514 tracing::trace!(target: "rig::completions",
515 "Ollama completion request: {}",
516 serde_json::to_string_pretty(&request)?
517 );
518 }
519
520 let body = serde_json::to_vec(&request)?;
521
522 let req = self
523 .client
524 .post("api/chat")?
525 .body(body)
526 .map_err(http_client::Error::from)?;
527
528 let async_block = async move {
529 let response = self.client.send::<_, Bytes>(req).await?;
530 let status = response.status();
531 let response_body = response.into_body().into_future().await?.to_vec();
532
533 if !status.is_success() {
534 return Err(CompletionError::ProviderError(
535 String::from_utf8_lossy(&response_body).to_string(),
536 ));
537 }
538
539 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
540 let span = tracing::Span::current();
541 span.record("gen_ai.response.model_name", &response.model);
542 span.record(
543 "gen_ai.usage.input_tokens",
544 response.prompt_eval_count.unwrap_or_default(),
545 );
546 span.record(
547 "gen_ai.usage.output_tokens",
548 response.eval_count.unwrap_or_default(),
549 );
550
551 if tracing::enabled!(tracing::Level::TRACE) {
552 tracing::trace!(target: "rig::completions",
553 "Ollama completion response: {}",
554 serde_json::to_string_pretty(&response)?
555 );
556 }
557
558 let response: completion::CompletionResponse<CompletionResponse> =
559 response.try_into()?;
560
561 Ok(response)
562 };
563
564 tracing::Instrument::instrument(async_block, span).await
565 }
566
567 #[cfg_attr(feature = "worker", worker::send)]
568 async fn stream(
569 &self,
570 request: CompletionRequest,
571 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
572 {
573 let span = if tracing::Span::current().is_disabled() {
574 info_span!(
575 target: "rig::completions",
576 "chat_streaming",
577 gen_ai.operation.name = "chat_streaming",
578 gen_ai.provider.name = "ollama",
579 gen_ai.request.model = self.model,
580 gen_ai.system_instructions = tracing::field::Empty,
581 gen_ai.response.id = tracing::field::Empty,
582 gen_ai.response.model = self.model,
583 gen_ai.usage.output_tokens = tracing::field::Empty,
584 gen_ai.usage.input_tokens = tracing::field::Empty,
585 )
586 } else {
587 tracing::Span::current()
588 };
589
590 span.record("gen_ai.system_instructions", &request.preamble);
591
592 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
593 request.stream = true;
594
595 if tracing::enabled!(tracing::Level::TRACE) {
596 tracing::trace!(target: "rig::completions",
597 "Ollama streaming completion request: {}",
598 serde_json::to_string_pretty(&request)?
599 );
600 }
601
602 let body = serde_json::to_vec(&request)?;
603
604 let req = self
605 .client
606 .post("api/chat")?
607 .body(body)
608 .map_err(http_client::Error::from)?;
609
610 let response = self.client.send_streaming(req).await?;
611 let status = response.status();
612 let mut byte_stream = response.into_body();
613
614 if !status.is_success() {
615 return Err(CompletionError::ProviderError(format!(
616 "Got error status code trying to send a request to Ollama: {status}"
617 )));
618 }
619
620 let stream = try_stream! {
621 let span = tracing::Span::current();
622 let mut tool_calls_final = Vec::new();
623 let mut text_response = String::new();
624 let mut thinking_response = String::new();
625
626 while let Some(chunk) = byte_stream.next().await {
627 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
628
629 for line in bytes.split(|&b| b == b'\n') {
630 if line.is_empty() {
631 continue;
632 }
633
634 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
635
636 let response: CompletionResponse = serde_json::from_slice(line)?;
637
638 if response.done {
639 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
640 span.record("gen_ai.usage.output_tokens", response.eval_count);
641 let message = Message::Assistant {
642 content: text_response.clone(),
643 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
644 images: None,
645 name: None,
646 tool_calls: tool_calls_final.clone()
647 };
648 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
649 yield RawStreamingChoice::FinalResponse(
650 StreamingCompletionResponse {
651 total_duration: response.total_duration,
652 load_duration: response.load_duration,
653 prompt_eval_count: response.prompt_eval_count,
654 prompt_eval_duration: response.prompt_eval_duration,
655 eval_count: response.eval_count,
656 eval_duration: response.eval_duration,
657 done_reason: response.done_reason,
658 }
659 );
660 break;
661 }
662
663 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
664 if let Some(thinking_content) = thinking
665 && !thinking_content.is_empty() {
666 thinking_response += &thinking_content;
667 yield RawStreamingChoice::Reasoning {
668 reasoning: thinking_content,
669 id: None,
670 signature: None,
671 };
672 }
673
674 if !content.is_empty() {
675 text_response += &content;
676 yield RawStreamingChoice::Message(content);
677 }
678
679 for tool_call in tool_calls {
680 tool_calls_final.push(tool_call.clone());
681 yield RawStreamingChoice::ToolCall {
682 id: String::new(),
683 name: tool_call.function.name,
684 arguments: tool_call.function.arguments,
685 call_id: None,
686 };
687 }
688 }
689 }
690 }
691 }.instrument(span);
692
693 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
694 stream,
695 )))
696 }
697}
698
699#[derive(Clone, Debug, Deserialize, Serialize)]
703pub struct ToolDefinition {
704 #[serde(rename = "type")]
705 pub type_field: String, pub function: completion::ToolDefinition,
707}
708
709impl From<crate::completion::ToolDefinition> for ToolDefinition {
711 fn from(tool: crate::completion::ToolDefinition) -> Self {
712 ToolDefinition {
713 type_field: "function".to_owned(),
714 function: completion::ToolDefinition {
715 name: tool.name,
716 description: tool.description,
717 parameters: tool.parameters,
718 },
719 }
720 }
721}
722
723#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
724pub struct ToolCall {
725 #[serde(default, rename = "type")]
726 pub r#type: ToolType,
727 pub function: Function,
728}
729#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
730#[serde(rename_all = "lowercase")]
731pub enum ToolType {
732 #[default]
733 Function,
734}
735#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
736pub struct Function {
737 pub name: String,
738 pub arguments: Value,
739}
740
741#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
744#[serde(tag = "role", rename_all = "lowercase")]
745pub enum Message {
746 User {
747 content: String,
748 #[serde(skip_serializing_if = "Option::is_none")]
749 images: Option<Vec<String>>,
750 #[serde(skip_serializing_if = "Option::is_none")]
751 name: Option<String>,
752 },
753 Assistant {
754 #[serde(default)]
755 content: String,
756 #[serde(skip_serializing_if = "Option::is_none")]
757 thinking: Option<String>,
758 #[serde(skip_serializing_if = "Option::is_none")]
759 images: Option<Vec<String>>,
760 #[serde(skip_serializing_if = "Option::is_none")]
761 name: Option<String>,
762 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
763 tool_calls: Vec<ToolCall>,
764 },
765 System {
766 content: String,
767 #[serde(skip_serializing_if = "Option::is_none")]
768 images: Option<Vec<String>>,
769 #[serde(skip_serializing_if = "Option::is_none")]
770 name: Option<String>,
771 },
772 #[serde(rename = "tool")]
773 ToolResult {
774 #[serde(rename = "tool_name")]
775 name: String,
776 content: String,
777 },
778}
779
780impl TryFrom<crate::message::Message> for Vec<Message> {
786 type Error = crate::message::MessageError;
787 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
788 use crate::message::Message as InternalMessage;
789 match internal_msg {
790 InternalMessage::User { content, .. } => {
791 let (tool_results, other_content): (Vec<_>, Vec<_>) =
792 content.into_iter().partition(|content| {
793 matches!(content, crate::message::UserContent::ToolResult(_))
794 });
795
796 if !tool_results.is_empty() {
797 tool_results
798 .into_iter()
799 .map(|content| match content {
800 crate::message::UserContent::ToolResult(
801 crate::message::ToolResult { id, content, .. },
802 ) => {
803 let content_string = content
805 .into_iter()
806 .map(|content| match content {
807 crate::message::ToolResultContent::Text(text) => text.text,
808 _ => "[Non-text content]".to_string(),
809 })
810 .collect::<Vec<_>>()
811 .join("\n");
812
813 Ok::<_, crate::message::MessageError>(Message::ToolResult {
814 name: id,
815 content: content_string,
816 })
817 }
818 _ => unreachable!(),
819 })
820 .collect::<Result<Vec<_>, _>>()
821 } else {
822 let (texts, images) = other_content.into_iter().fold(
824 (Vec::new(), Vec::new()),
825 |(mut texts, mut images), content| {
826 match content {
827 crate::message::UserContent::Text(crate::message::Text {
828 text,
829 }) => texts.push(text),
830 crate::message::UserContent::Image(crate::message::Image {
831 data: DocumentSourceKind::Base64(data),
832 ..
833 }) => images.push(data),
834 crate::message::UserContent::Document(
835 crate::message::Document {
836 data:
837 DocumentSourceKind::Base64(data)
838 | DocumentSourceKind::String(data),
839 ..
840 },
841 ) => texts.push(data),
842 _ => {} }
844 (texts, images)
845 },
846 );
847
848 Ok(vec![Message::User {
849 content: texts.join(" "),
850 images: if images.is_empty() {
851 None
852 } else {
853 Some(
854 images
855 .into_iter()
856 .map(|x| x.to_string())
857 .collect::<Vec<String>>(),
858 )
859 },
860 name: None,
861 }])
862 }
863 }
864 InternalMessage::Assistant { content, .. } => {
865 let mut thinking: Option<String> = None;
866 let mut text_content = Vec::new();
867 let mut tool_calls = Vec::new();
868
869 for content in content.into_iter() {
870 match content {
871 crate::message::AssistantContent::Text(text) => {
872 text_content.push(text.text)
873 }
874 crate::message::AssistantContent::ToolCall(tool_call) => {
875 tool_calls.push(tool_call)
876 }
877 crate::message::AssistantContent::Reasoning(
878 crate::message::Reasoning { reasoning, .. },
879 ) => {
880 thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
881 }
882 crate::message::AssistantContent::Image(_) => {
883 return Err(crate::message::MessageError::ConversionError(
884 "Ollama currently doesn't support images.".into(),
885 ));
886 }
887 }
888 }
889
890 Ok(vec![Message::Assistant {
893 content: text_content.join(" "),
894 thinking,
895 images: None,
896 name: None,
897 tool_calls: tool_calls
898 .into_iter()
899 .map(|tool_call| tool_call.into())
900 .collect::<Vec<_>>(),
901 }])
902 }
903 }
904 }
905}
906
907impl From<Message> for crate::completion::Message {
910 fn from(msg: Message) -> Self {
911 match msg {
912 Message::User { content, .. } => crate::completion::Message::User {
913 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
914 text: content,
915 })),
916 },
917 Message::Assistant {
918 content,
919 tool_calls,
920 ..
921 } => {
922 let mut assistant_contents =
923 vec![crate::completion::message::AssistantContent::Text(Text {
924 text: content,
925 })];
926 for tc in tool_calls {
927 assistant_contents.push(
928 crate::completion::message::AssistantContent::tool_call(
929 tc.function.name.clone(),
930 tc.function.name,
931 tc.function.arguments,
932 ),
933 );
934 }
935 crate::completion::Message::Assistant {
936 id: None,
937 content: OneOrMany::many(assistant_contents).unwrap(),
938 }
939 }
940 Message::System { content, .. } => crate::completion::Message::User {
942 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
943 text: content,
944 })),
945 },
946 Message::ToolResult { name, content } => crate::completion::Message::User {
947 content: OneOrMany::one(message::UserContent::tool_result(
948 name,
949 OneOrMany::one(message::ToolResultContent::text(content)),
950 )),
951 },
952 }
953 }
954}
955
956impl Message {
957 pub fn system(content: &str) -> Self {
959 Message::System {
960 content: content.to_owned(),
961 images: None,
962 name: None,
963 }
964 }
965}
966
967impl From<crate::message::ToolCall> for ToolCall {
970 fn from(tool_call: crate::message::ToolCall) -> Self {
971 Self {
972 r#type: ToolType::Function,
973 function: Function {
974 name: tool_call.function.name,
975 arguments: tool_call.function.arguments,
976 },
977 }
978 }
979}
980
981#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
982pub struct SystemContent {
983 #[serde(default)]
984 r#type: SystemContentType,
985 text: String,
986}
987
988#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
989#[serde(rename_all = "lowercase")]
990pub enum SystemContentType {
991 #[default]
992 Text,
993}
994
995impl From<String> for SystemContent {
996 fn from(s: String) -> Self {
997 SystemContent {
998 r#type: SystemContentType::default(),
999 text: s,
1000 }
1001 }
1002}
1003
1004impl FromStr for SystemContent {
1005 type Err = std::convert::Infallible;
1006 fn from_str(s: &str) -> Result<Self, Self::Err> {
1007 Ok(SystemContent {
1008 r#type: SystemContentType::default(),
1009 text: s.to_string(),
1010 })
1011 }
1012}
1013
1014#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1015pub struct AssistantContent {
1016 pub text: String,
1017}
1018
1019impl FromStr for AssistantContent {
1020 type Err = std::convert::Infallible;
1021 fn from_str(s: &str) -> Result<Self, Self::Err> {
1022 Ok(AssistantContent { text: s.to_owned() })
1023 }
1024}
1025
1026#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1027#[serde(tag = "type", rename_all = "lowercase")]
1028pub enum UserContent {
1029 Text { text: String },
1030 Image { image_url: ImageUrl },
1031 }
1033
1034impl FromStr for UserContent {
1035 type Err = std::convert::Infallible;
1036 fn from_str(s: &str) -> Result<Self, Self::Err> {
1037 Ok(UserContent::Text { text: s.to_owned() })
1038 }
1039}
1040
1041#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1042pub struct ImageUrl {
1043 pub url: String,
1044 #[serde(default)]
1045 pub detail: ImageDetail,
1046}
1047
1048#[cfg(test)]
1053mod tests {
1054 use super::*;
1055 use serde_json::json;
1056
1057 #[tokio::test]
1059 async fn test_chat_completion() {
1060 let sample_chat_response = json!({
1062 "model": "llama3.2",
1063 "created_at": "2023-08-04T19:22:45.499127Z",
1064 "message": {
1065 "role": "assistant",
1066 "content": "The sky is blue because of Rayleigh scattering.",
1067 "images": null,
1068 "tool_calls": [
1069 {
1070 "type": "function",
1071 "function": {
1072 "name": "get_current_weather",
1073 "arguments": {
1074 "location": "San Francisco, CA",
1075 "format": "celsius"
1076 }
1077 }
1078 }
1079 ]
1080 },
1081 "done": true,
1082 "total_duration": 8000000000u64,
1083 "load_duration": 6000000u64,
1084 "prompt_eval_count": 61u64,
1085 "prompt_eval_duration": 400000000u64,
1086 "eval_count": 468u64,
1087 "eval_duration": 7700000000u64
1088 });
1089 let sample_text = sample_chat_response.to_string();
1090
1091 let chat_resp: CompletionResponse =
1092 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1093 let conv: completion::CompletionResponse<CompletionResponse> =
1094 chat_resp.try_into().unwrap();
1095 assert!(
1096 !conv.choice.is_empty(),
1097 "Expected non-empty choice in chat response"
1098 );
1099 }
1100
1101 #[test]
1103 fn test_message_conversion() {
1104 let provider_msg = Message::User {
1106 content: "Test message".to_owned(),
1107 images: None,
1108 name: None,
1109 };
1110 let comp_msg: crate::completion::Message = provider_msg.into();
1112 match comp_msg {
1113 crate::completion::Message::User { content } => {
1114 let first_content = content.first();
1116 match first_content {
1118 crate::completion::message::UserContent::Text(text_struct) => {
1119 assert_eq!(text_struct.text, "Test message");
1120 }
1121 _ => panic!("Expected text content in conversion"),
1122 }
1123 }
1124 _ => panic!("Conversion from provider Message to completion Message failed"),
1125 }
1126 }
1127
1128 #[test]
1130 fn test_tool_definition_conversion() {
1131 let internal_tool = crate::completion::ToolDefinition {
1133 name: "get_current_weather".to_owned(),
1134 description: "Get the current weather for a location".to_owned(),
1135 parameters: json!({
1136 "type": "object",
1137 "properties": {
1138 "location": {
1139 "type": "string",
1140 "description": "The location to get the weather for, e.g. San Francisco, CA"
1141 },
1142 "format": {
1143 "type": "string",
1144 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1145 "enum": ["celsius", "fahrenheit"]
1146 }
1147 },
1148 "required": ["location", "format"]
1149 }),
1150 };
1151 let ollama_tool: ToolDefinition = internal_tool.into();
1153 assert_eq!(ollama_tool.type_field, "function");
1154 assert_eq!(ollama_tool.function.name, "get_current_weather");
1155 assert_eq!(
1156 ollama_tool.function.description,
1157 "Get the current weather for a location"
1158 );
1159 let params = &ollama_tool.function.parameters;
1161 assert_eq!(params["properties"]["location"]["type"], "string");
1162 }
1163
1164 #[tokio::test]
1166 async fn test_chat_completion_with_thinking() {
1167 let sample_response = json!({
1168 "model": "qwen-thinking",
1169 "created_at": "2023-08-04T19:22:45.499127Z",
1170 "message": {
1171 "role": "assistant",
1172 "content": "The answer is 42.",
1173 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1174 "images": null,
1175 "tool_calls": []
1176 },
1177 "done": true,
1178 "total_duration": 8000000000u64,
1179 "load_duration": 6000000u64,
1180 "prompt_eval_count": 61u64,
1181 "prompt_eval_duration": 400000000u64,
1182 "eval_count": 468u64,
1183 "eval_duration": 7700000000u64
1184 });
1185
1186 let chat_resp: CompletionResponse =
1187 serde_json::from_value(sample_response).expect("Failed to deserialize");
1188
1189 if let Message::Assistant {
1191 thinking, content, ..
1192 } = &chat_resp.message
1193 {
1194 assert_eq!(
1195 thinking.as_ref().unwrap(),
1196 "Let me think about this carefully. The question asks for the meaning of life..."
1197 );
1198 assert_eq!(content, "The answer is 42.");
1199 } else {
1200 panic!("Expected Assistant message");
1201 }
1202 }
1203
1204 #[tokio::test]
1206 async fn test_chat_completion_without_thinking() {
1207 let sample_response = json!({
1208 "model": "llama3.2",
1209 "created_at": "2023-08-04T19:22:45.499127Z",
1210 "message": {
1211 "role": "assistant",
1212 "content": "Hello!",
1213 "images": null,
1214 "tool_calls": []
1215 },
1216 "done": true,
1217 "total_duration": 8000000000u64,
1218 "load_duration": 6000000u64,
1219 "prompt_eval_count": 10u64,
1220 "prompt_eval_duration": 400000000u64,
1221 "eval_count": 5u64,
1222 "eval_duration": 7700000000u64
1223 });
1224
1225 let chat_resp: CompletionResponse =
1226 serde_json::from_value(sample_response).expect("Failed to deserialize");
1227
1228 if let Message::Assistant {
1230 thinking, content, ..
1231 } = &chat_resp.message
1232 {
1233 assert!(thinking.is_none());
1234 assert_eq!(content, "Hello!");
1235 } else {
1236 panic!("Expected Assistant message");
1237 }
1238 }
1239
1240 #[test]
1242 fn test_streaming_response_with_thinking() {
1243 let sample_chunk = json!({
1244 "model": "qwen-thinking",
1245 "created_at": "2023-08-04T19:22:45.499127Z",
1246 "message": {
1247 "role": "assistant",
1248 "content": "",
1249 "thinking": "Analyzing the problem...",
1250 "images": null,
1251 "tool_calls": []
1252 },
1253 "done": false
1254 });
1255
1256 let chunk: CompletionResponse =
1257 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1258
1259 if let Message::Assistant {
1260 thinking, content, ..
1261 } = &chunk.message
1262 {
1263 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1264 assert_eq!(content, "");
1265 } else {
1266 panic!("Expected Assistant message");
1267 }
1268 }
1269
1270 #[test]
1272 fn test_message_conversion_with_thinking() {
1273 let reasoning_content = crate::message::Reasoning {
1275 id: None,
1276 reasoning: vec!["Step 1: Consider the problem".to_string()],
1277 signature: None,
1278 };
1279
1280 let internal_msg = crate::message::Message::Assistant {
1281 id: None,
1282 content: crate::OneOrMany::many(vec![
1283 crate::message::AssistantContent::Reasoning(reasoning_content),
1284 crate::message::AssistantContent::Text(crate::message::Text {
1285 text: "The answer is X".to_string(),
1286 }),
1287 ])
1288 .unwrap(),
1289 };
1290
1291 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1293 assert_eq!(provider_msgs.len(), 1);
1294
1295 if let Message::Assistant {
1296 thinking, content, ..
1297 } = &provider_msgs[0]
1298 {
1299 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1300 assert_eq!(content, "The answer is X");
1301 } else {
1302 panic!("Expected Assistant message with thinking");
1303 }
1304 }
1305
1306 #[test]
1308 fn test_empty_thinking_content() {
1309 let sample_response = json!({
1310 "model": "llama3.2",
1311 "created_at": "2023-08-04T19:22:45.499127Z",
1312 "message": {
1313 "role": "assistant",
1314 "content": "Response",
1315 "thinking": "",
1316 "images": null,
1317 "tool_calls": []
1318 },
1319 "done": true,
1320 "total_duration": 8000000000u64,
1321 "load_duration": 6000000u64,
1322 "prompt_eval_count": 10u64,
1323 "prompt_eval_duration": 400000000u64,
1324 "eval_count": 5u64,
1325 "eval_duration": 7700000000u64
1326 });
1327
1328 let chat_resp: CompletionResponse =
1329 serde_json::from_value(sample_response).expect("Failed to deserialize");
1330
1331 if let Message::Assistant {
1332 thinking, content, ..
1333 } = &chat_resp.message
1334 {
1335 assert_eq!(thinking.as_ref().unwrap(), "");
1337 assert_eq!(content, "Response");
1338 } else {
1339 panic!("Expected Assistant message");
1340 }
1341 }
1342
1343 #[test]
1345 fn test_thinking_with_tool_calls() {
1346 let sample_response = json!({
1347 "model": "qwen-thinking",
1348 "created_at": "2023-08-04T19:22:45.499127Z",
1349 "message": {
1350 "role": "assistant",
1351 "content": "Let me check the weather.",
1352 "thinking": "User wants weather info, I should use the weather tool",
1353 "images": null,
1354 "tool_calls": [
1355 {
1356 "type": "function",
1357 "function": {
1358 "name": "get_weather",
1359 "arguments": {
1360 "location": "San Francisco"
1361 }
1362 }
1363 }
1364 ]
1365 },
1366 "done": true,
1367 "total_duration": 8000000000u64,
1368 "load_duration": 6000000u64,
1369 "prompt_eval_count": 30u64,
1370 "prompt_eval_duration": 400000000u64,
1371 "eval_count": 50u64,
1372 "eval_duration": 7700000000u64
1373 });
1374
1375 let chat_resp: CompletionResponse =
1376 serde_json::from_value(sample_response).expect("Failed to deserialize");
1377
1378 if let Message::Assistant {
1379 thinking,
1380 content,
1381 tool_calls,
1382 ..
1383 } = &chat_resp.message
1384 {
1385 assert_eq!(
1386 thinking.as_ref().unwrap(),
1387 "User wants weather info, I should use the weather tool"
1388 );
1389 assert_eq!(content, "Let me check the weather.");
1390 assert_eq!(tool_calls.len(), 1);
1391 assert_eq!(tool_calls[0].function.name, "get_weather");
1392 } else {
1393 panic!("Expected Assistant message with thinking and tool calls");
1394 }
1395 }
1396}