1use crate::client::{
42 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49 OneOrMany,
50 completion::{self, CompletionError, CompletionRequest},
51 embeddings::{self, EmbeddingError},
52 json_utils, message,
53 message::{ImageDetail, Text},
54 streaming,
55};
56use async_stream::try_stream;
57use bytes::Bytes;
58use futures::StreamExt;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use tracing::info_span;
63use tracing_futures::Instrument;
64const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68#[derive(Debug, Default, Clone, Copy)]
69pub struct OllamaExt;
70
71#[derive(Debug, Default, Clone, Copy)]
72pub struct OllamaBuilder;
73
74impl Provider for OllamaExt {
75 type Builder = OllamaBuilder;
76
77 const VERIFY_PATH: &'static str = "api/tags";
78
79 fn build<H>(
80 _: &crate::client::ClientBuilder<
81 Self::Builder,
82 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
83 H,
84 >,
85 ) -> http_client::Result<Self> {
86 Ok(Self)
87 }
88}
89
90impl<H> Capabilities<H> for OllamaExt {
91 type Completion = Capable<CompletionModel<H>>;
92 type Transcription = Nothing;
93 type Embeddings = Capable<EmbeddingModel<H>>;
94 #[cfg(feature = "image")]
95 type ImageGeneration = Nothing;
96
97 #[cfg(feature = "audio")]
98 type AudioGeneration = Nothing;
99}
100
101impl DebugExt for OllamaExt {}
102
103impl ProviderBuilder for OllamaBuilder {
104 type Output = OllamaExt;
105 type ApiKey = Nothing;
106
107 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
108}
109
110pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
111pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
112
113impl ProviderClient for Client {
114 type Input = Nothing;
115
116 fn from_env() -> Self {
117 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
118
119 Self::builder()
120 .api_key(Nothing)
121 .base_url(&api_base)
122 .build()
123 .unwrap()
124 }
125
126 fn from_val(_: Self::Input) -> Self {
127 Self::builder().api_key(Nothing).build().unwrap()
128 }
129}
130
131#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135 message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141 Ok(T),
142 Err(ApiErrorResponse),
143}
144
145pub const ALL_MINILM: &str = "all-minilm";
148pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct EmbeddingResponse {
152 pub model: String,
153 pub embeddings: Vec<Vec<f64>>,
154 #[serde(default)]
155 pub total_duration: Option<u64>,
156 #[serde(default)]
157 pub load_duration: Option<u64>,
158 #[serde(default)]
159 pub prompt_eval_count: Option<u64>,
160}
161
162impl From<ApiErrorResponse> for EmbeddingError {
163 fn from(err: ApiErrorResponse) -> Self {
164 EmbeddingError::ProviderError(err.message)
165 }
166}
167
168impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
169 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
170 match value {
171 ApiResponse::Ok(response) => Ok(response),
172 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
173 }
174 }
175}
176
177#[derive(Clone)]
180pub struct EmbeddingModel<T> {
181 client: Client<T>,
182 pub model: String,
183 ndims: usize,
184}
185
186impl<T> EmbeddingModel<T> {
187 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
188 Self {
189 client,
190 model: model.into(),
191 ndims,
192 }
193 }
194
195 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
196 Self {
197 client,
198 model: model.into(),
199 ndims,
200 }
201 }
202}
203
204impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
205where
206 T: HttpClientExt + Clone + 'static,
207{
208 type Client = Client<T>;
209
210 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
211 Self::new(client.clone(), model, dims.unwrap())
212 }
213
214 const MAX_DOCUMENTS: usize = 1024;
215 fn ndims(&self) -> usize {
216 self.ndims
217 }
218
219 async fn embed_texts(
220 &self,
221 documents: impl IntoIterator<Item = String>,
222 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223 let docs: Vec<String> = documents.into_iter().collect();
224
225 let body = serde_json::to_vec(&json!({
226 "model": self.model,
227 "input": docs
228 }))?;
229
230 let req = self
231 .client
232 .post("api/embed")?
233 .body(body)
234 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
235
236 let response = self.client.send(req).await?;
237
238 if !response.status().is_success() {
239 let text = http_client::text(response).await?;
240 return Err(EmbeddingError::ProviderError(text));
241 }
242
243 let bytes: Vec<u8> = response.into_body().await?;
244
245 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
246
247 if api_resp.embeddings.len() != docs.len() {
248 return Err(EmbeddingError::ResponseError(
249 "Number of returned embeddings does not match input".into(),
250 ));
251 }
252 Ok(api_resp
253 .embeddings
254 .into_iter()
255 .zip(docs.into_iter())
256 .map(|(vec, document)| embeddings::Embedding { document, vec })
257 .collect())
258 }
259}
260
261pub const LLAMA3_2: &str = "llama3.2";
264pub const LLAVA: &str = "llava";
265pub const MISTRAL: &str = "mistral";
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct CompletionResponse {
269 pub model: String,
270 pub created_at: String,
271 pub message: Message,
272 pub done: bool,
273 #[serde(default)]
274 pub done_reason: Option<String>,
275 #[serde(default)]
276 pub total_duration: Option<u64>,
277 #[serde(default)]
278 pub load_duration: Option<u64>,
279 #[serde(default)]
280 pub prompt_eval_count: Option<u64>,
281 #[serde(default)]
282 pub prompt_eval_duration: Option<u64>,
283 #[serde(default)]
284 pub eval_count: Option<u64>,
285 #[serde(default)]
286 pub eval_duration: Option<u64>,
287}
288impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
289 type Error = CompletionError;
290 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
291 match resp.message {
292 Message::Assistant {
294 content,
295 thinking,
296 tool_calls,
297 ..
298 } => {
299 let mut assistant_contents = Vec::new();
300 if !content.is_empty() {
302 assistant_contents.push(completion::AssistantContent::text(&content));
303 }
304 for tc in tool_calls.iter() {
307 assistant_contents.push(completion::AssistantContent::tool_call(
308 tc.function.name.clone(),
309 tc.function.name.clone(),
310 tc.function.arguments.clone(),
311 ));
312 }
313 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
314 CompletionError::ResponseError("No content provided".to_owned())
315 })?;
316 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
317 let completion_tokens = resp.eval_count.unwrap_or(0);
318
319 let raw_response = CompletionResponse {
320 model: resp.model,
321 created_at: resp.created_at,
322 done: resp.done,
323 done_reason: resp.done_reason,
324 total_duration: resp.total_duration,
325 load_duration: resp.load_duration,
326 prompt_eval_count: resp.prompt_eval_count,
327 prompt_eval_duration: resp.prompt_eval_duration,
328 eval_count: resp.eval_count,
329 eval_duration: resp.eval_duration,
330 message: Message::Assistant {
331 content,
332 thinking,
333 images: None,
334 name: None,
335 tool_calls,
336 },
337 };
338
339 Ok(completion::CompletionResponse {
340 choice,
341 usage: Usage {
342 input_tokens: prompt_tokens,
343 output_tokens: completion_tokens,
344 total_tokens: prompt_tokens + completion_tokens,
345 },
346 raw_response,
347 })
348 }
349 _ => Err(CompletionError::ResponseError(
350 "Chat response does not include an assistant message".into(),
351 )),
352 }
353 }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357pub(super) struct OllamaCompletionRequest {
358 model: String,
359 pub messages: Vec<Message>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 temperature: Option<f64>,
362 #[serde(skip_serializing_if = "Vec::is_empty")]
363 tools: Vec<ToolDefinition>,
364 pub stream: bool,
365 think: bool,
366 #[serde(skip_serializing_if = "Option::is_none")]
367 max_tokens: Option<u64>,
368 options: serde_json::Value,
369}
370
371impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
372 type Error = CompletionError;
373
374 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
375 if req.tool_choice.is_some() {
376 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
377 }
378 let mut partial_history = vec![];
380 if let Some(docs) = req.normalized_documents() {
381 partial_history.push(docs);
382 }
383 partial_history.extend(req.chat_history);
384
385 let mut full_history: Vec<Message> = match &req.preamble {
387 Some(preamble) => vec![Message::system(preamble)],
388 None => vec![],
389 };
390
391 full_history.extend(
393 partial_history
394 .into_iter()
395 .map(message::Message::try_into)
396 .collect::<Result<Vec<Vec<Message>>, _>>()?
397 .into_iter()
398 .flatten()
399 .collect::<Vec<_>>(),
400 );
401
402 let mut think = false;
403
404 let options = if let Some(mut extra) = req.additional_params {
406 if extra.get("think").is_some() {
407 think = extra["think"].take().as_bool().ok_or_else(|| {
408 CompletionError::RequestError("`think` must be a bool".into())
409 })?;
410 }
411 json_utils::merge(json!({ "temperature": req.temperature }), extra)
412 } else {
413 json!({ "temperature": req.temperature })
414 };
415
416 Ok(Self {
417 model: model.to_string(),
418 messages: full_history,
419 temperature: req.temperature,
420 max_tokens: req.max_tokens,
421 stream: false,
422 think,
423 tools: req
424 .tools
425 .clone()
426 .into_iter()
427 .map(ToolDefinition::from)
428 .collect::<Vec<_>>(),
429 options,
430 })
431 }
432}
433
434#[derive(Clone)]
435pub struct CompletionModel<T = reqwest::Client> {
436 client: Client<T>,
437 pub model: String,
438}
439
440impl<T> CompletionModel<T> {
441 pub fn new(client: Client<T>, model: &str) -> Self {
442 Self {
443 client,
444 model: model.to_owned(),
445 }
446 }
447}
448
449#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct StreamingCompletionResponse {
453 pub done_reason: Option<String>,
454 pub total_duration: Option<u64>,
455 pub load_duration: Option<u64>,
456 pub prompt_eval_count: Option<u64>,
457 pub prompt_eval_duration: Option<u64>,
458 pub eval_count: Option<u64>,
459 pub eval_duration: Option<u64>,
460}
461
462impl GetTokenUsage for StreamingCompletionResponse {
463 fn token_usage(&self) -> Option<crate::completion::Usage> {
464 let mut usage = crate::completion::Usage::new();
465 let input_tokens = self.prompt_eval_count.unwrap_or_default();
466 let output_tokens = self.eval_count.unwrap_or_default();
467 usage.input_tokens = input_tokens;
468 usage.output_tokens = output_tokens;
469 usage.total_tokens = input_tokens + output_tokens;
470
471 Some(usage)
472 }
473}
474
475impl<T> completion::CompletionModel for CompletionModel<T>
476where
477 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
478{
479 type Response = CompletionResponse;
480 type StreamingResponse = StreamingCompletionResponse;
481
482 type Client = Client<T>;
483
484 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
485 Self::new(client.clone(), model.into().as_str())
486 }
487
488 async fn completion(
489 &self,
490 completion_request: CompletionRequest,
491 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
492 let span = if tracing::Span::current().is_disabled() {
493 info_span!(
494 target: "rig::completions",
495 "chat",
496 gen_ai.operation.name = "chat",
497 gen_ai.provider.name = "ollama",
498 gen_ai.request.model = self.model,
499 gen_ai.system_instructions = tracing::field::Empty,
500 gen_ai.response.id = tracing::field::Empty,
501 gen_ai.response.model = tracing::field::Empty,
502 gen_ai.usage.output_tokens = tracing::field::Empty,
503 gen_ai.usage.input_tokens = tracing::field::Empty,
504 )
505 } else {
506 tracing::Span::current()
507 };
508
509 span.record("gen_ai.system_instructions", &completion_request.preamble);
510 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
511
512 if tracing::enabled!(tracing::Level::TRACE) {
513 tracing::trace!(target: "rig::completions",
514 "Ollama completion request: {}",
515 serde_json::to_string_pretty(&request)?
516 );
517 }
518
519 let body = serde_json::to_vec(&request)?;
520
521 let req = self
522 .client
523 .post("api/chat")?
524 .body(body)
525 .map_err(http_client::Error::from)?;
526
527 let async_block = async move {
528 let response = self.client.send::<_, Bytes>(req).await?;
529 let status = response.status();
530 let response_body = response.into_body().into_future().await?.to_vec();
531
532 if !status.is_success() {
533 return Err(CompletionError::ProviderError(
534 String::from_utf8_lossy(&response_body).to_string(),
535 ));
536 }
537
538 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
539 let span = tracing::Span::current();
540 span.record("gen_ai.response.model_name", &response.model);
541 span.record(
542 "gen_ai.usage.input_tokens",
543 response.prompt_eval_count.unwrap_or_default(),
544 );
545 span.record(
546 "gen_ai.usage.output_tokens",
547 response.eval_count.unwrap_or_default(),
548 );
549
550 if tracing::enabled!(tracing::Level::TRACE) {
551 tracing::trace!(target: "rig::completions",
552 "Ollama completion response: {}",
553 serde_json::to_string_pretty(&response)?
554 );
555 }
556
557 let response: completion::CompletionResponse<CompletionResponse> =
558 response.try_into()?;
559
560 Ok(response)
561 };
562
563 tracing::Instrument::instrument(async_block, span).await
564 }
565
566 async fn stream(
567 &self,
568 request: CompletionRequest,
569 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
570 {
571 let span = if tracing::Span::current().is_disabled() {
572 info_span!(
573 target: "rig::completions",
574 "chat_streaming",
575 gen_ai.operation.name = "chat_streaming",
576 gen_ai.provider.name = "ollama",
577 gen_ai.request.model = self.model,
578 gen_ai.system_instructions = tracing::field::Empty,
579 gen_ai.response.id = tracing::field::Empty,
580 gen_ai.response.model = self.model,
581 gen_ai.usage.output_tokens = tracing::field::Empty,
582 gen_ai.usage.input_tokens = tracing::field::Empty,
583 )
584 } else {
585 tracing::Span::current()
586 };
587
588 span.record("gen_ai.system_instructions", &request.preamble);
589
590 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
591 request.stream = true;
592
593 if tracing::enabled!(tracing::Level::TRACE) {
594 tracing::trace!(target: "rig::completions",
595 "Ollama streaming completion request: {}",
596 serde_json::to_string_pretty(&request)?
597 );
598 }
599
600 let body = serde_json::to_vec(&request)?;
601
602 let req = self
603 .client
604 .post("api/chat")?
605 .body(body)
606 .map_err(http_client::Error::from)?;
607
608 let response = self.client.send_streaming(req).await?;
609 let status = response.status();
610 let mut byte_stream = response.into_body();
611
612 if !status.is_success() {
613 return Err(CompletionError::ProviderError(format!(
614 "Got error status code trying to send a request to Ollama: {status}"
615 )));
616 }
617
618 let stream = try_stream! {
619 let span = tracing::Span::current();
620 let mut tool_calls_final = Vec::new();
621 let mut text_response = String::new();
622 let mut thinking_response = String::new();
623
624 while let Some(chunk) = byte_stream.next().await {
625 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
626
627 for line in bytes.split(|&b| b == b'\n') {
628 if line.is_empty() {
629 continue;
630 }
631
632 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
633
634 let response: CompletionResponse = serde_json::from_slice(line)?;
635
636 if response.done {
637 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
638 span.record("gen_ai.usage.output_tokens", response.eval_count);
639 let message = Message::Assistant {
640 content: text_response.clone(),
641 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
642 images: None,
643 name: None,
644 tool_calls: tool_calls_final.clone()
645 };
646 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
647 yield RawStreamingChoice::FinalResponse(
648 StreamingCompletionResponse {
649 total_duration: response.total_duration,
650 load_duration: response.load_duration,
651 prompt_eval_count: response.prompt_eval_count,
652 prompt_eval_duration: response.prompt_eval_duration,
653 eval_count: response.eval_count,
654 eval_duration: response.eval_duration,
655 done_reason: response.done_reason,
656 }
657 );
658 break;
659 }
660
661 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
662 if let Some(thinking_content) = thinking
663 && !thinking_content.is_empty() {
664 thinking_response += &thinking_content;
665 yield RawStreamingChoice::ReasoningDelta {
666 id: None,
667 reasoning: thinking_content,
668 };
669 }
670
671 if !content.is_empty() {
672 text_response += &content;
673 yield RawStreamingChoice::Message(content);
674 }
675
676 for tool_call in tool_calls {
677 tool_calls_final.push(tool_call.clone());
678 yield RawStreamingChoice::ToolCall(
679 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
680 );
681 }
682 }
683 }
684 }
685 }.instrument(span);
686
687 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
688 stream,
689 )))
690 }
691}
692
693#[derive(Clone, Debug, Deserialize, Serialize)]
697pub struct ToolDefinition {
698 #[serde(rename = "type")]
699 pub type_field: String, pub function: completion::ToolDefinition,
701}
702
703impl From<crate::completion::ToolDefinition> for ToolDefinition {
705 fn from(tool: crate::completion::ToolDefinition) -> Self {
706 ToolDefinition {
707 type_field: "function".to_owned(),
708 function: completion::ToolDefinition {
709 name: tool.name,
710 description: tool.description,
711 parameters: tool.parameters,
712 },
713 }
714 }
715}
716
717#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
718pub struct ToolCall {
719 #[serde(default, rename = "type")]
720 pub r#type: ToolType,
721 pub function: Function,
722}
723#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
724#[serde(rename_all = "lowercase")]
725pub enum ToolType {
726 #[default]
727 Function,
728}
729#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
730pub struct Function {
731 pub name: String,
732 pub arguments: Value,
733}
734
735#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
738#[serde(tag = "role", rename_all = "lowercase")]
739pub enum Message {
740 User {
741 content: String,
742 #[serde(skip_serializing_if = "Option::is_none")]
743 images: Option<Vec<String>>,
744 #[serde(skip_serializing_if = "Option::is_none")]
745 name: Option<String>,
746 },
747 Assistant {
748 #[serde(default)]
749 content: String,
750 #[serde(skip_serializing_if = "Option::is_none")]
751 thinking: Option<String>,
752 #[serde(skip_serializing_if = "Option::is_none")]
753 images: Option<Vec<String>>,
754 #[serde(skip_serializing_if = "Option::is_none")]
755 name: Option<String>,
756 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
757 tool_calls: Vec<ToolCall>,
758 },
759 System {
760 content: String,
761 #[serde(skip_serializing_if = "Option::is_none")]
762 images: Option<Vec<String>>,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 name: Option<String>,
765 },
766 #[serde(rename = "tool")]
767 ToolResult {
768 #[serde(rename = "tool_name")]
769 name: String,
770 content: String,
771 },
772}
773
774impl TryFrom<crate::message::Message> for Vec<Message> {
780 type Error = crate::message::MessageError;
781 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
782 use crate::message::Message as InternalMessage;
783 match internal_msg {
784 InternalMessage::User { content, .. } => {
785 let (tool_results, other_content): (Vec<_>, Vec<_>) =
786 content.into_iter().partition(|content| {
787 matches!(content, crate::message::UserContent::ToolResult(_))
788 });
789
790 if !tool_results.is_empty() {
791 tool_results
792 .into_iter()
793 .map(|content| match content {
794 crate::message::UserContent::ToolResult(
795 crate::message::ToolResult { id, content, .. },
796 ) => {
797 let content_string = content
799 .into_iter()
800 .map(|content| match content {
801 crate::message::ToolResultContent::Text(text) => text.text,
802 _ => "[Non-text content]".to_string(),
803 })
804 .collect::<Vec<_>>()
805 .join("\n");
806
807 Ok::<_, crate::message::MessageError>(Message::ToolResult {
808 name: id,
809 content: content_string,
810 })
811 }
812 _ => unreachable!(),
813 })
814 .collect::<Result<Vec<_>, _>>()
815 } else {
816 let (texts, images) = other_content.into_iter().fold(
818 (Vec::new(), Vec::new()),
819 |(mut texts, mut images), content| {
820 match content {
821 crate::message::UserContent::Text(crate::message::Text {
822 text,
823 }) => texts.push(text),
824 crate::message::UserContent::Image(crate::message::Image {
825 data: DocumentSourceKind::Base64(data),
826 ..
827 }) => images.push(data),
828 crate::message::UserContent::Document(
829 crate::message::Document {
830 data:
831 DocumentSourceKind::Base64(data)
832 | DocumentSourceKind::String(data),
833 ..
834 },
835 ) => texts.push(data),
836 _ => {} }
838 (texts, images)
839 },
840 );
841
842 Ok(vec![Message::User {
843 content: texts.join(" "),
844 images: if images.is_empty() {
845 None
846 } else {
847 Some(
848 images
849 .into_iter()
850 .map(|x| x.to_string())
851 .collect::<Vec<String>>(),
852 )
853 },
854 name: None,
855 }])
856 }
857 }
858 InternalMessage::Assistant { content, .. } => {
859 let mut thinking: Option<String> = None;
860 let mut text_content = Vec::new();
861 let mut tool_calls = Vec::new();
862
863 for content in content.into_iter() {
864 match content {
865 crate::message::AssistantContent::Text(text) => {
866 text_content.push(text.text)
867 }
868 crate::message::AssistantContent::ToolCall(tool_call) => {
869 tool_calls.push(tool_call)
870 }
871 crate::message::AssistantContent::Reasoning(
872 crate::message::Reasoning { reasoning, .. },
873 ) => {
874 thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
875 }
876 crate::message::AssistantContent::Image(_) => {
877 return Err(crate::message::MessageError::ConversionError(
878 "Ollama currently doesn't support images.".into(),
879 ));
880 }
881 }
882 }
883
884 Ok(vec![Message::Assistant {
887 content: text_content.join(" "),
888 thinking,
889 images: None,
890 name: None,
891 tool_calls: tool_calls
892 .into_iter()
893 .map(|tool_call| tool_call.into())
894 .collect::<Vec<_>>(),
895 }])
896 }
897 }
898 }
899}
900
901impl From<Message> for crate::completion::Message {
904 fn from(msg: Message) -> Self {
905 match msg {
906 Message::User { content, .. } => crate::completion::Message::User {
907 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
908 text: content,
909 })),
910 },
911 Message::Assistant {
912 content,
913 tool_calls,
914 ..
915 } => {
916 let mut assistant_contents =
917 vec![crate::completion::message::AssistantContent::Text(Text {
918 text: content,
919 })];
920 for tc in tool_calls {
921 assistant_contents.push(
922 crate::completion::message::AssistantContent::tool_call(
923 tc.function.name.clone(),
924 tc.function.name,
925 tc.function.arguments,
926 ),
927 );
928 }
929 crate::completion::Message::Assistant {
930 id: None,
931 content: OneOrMany::many(assistant_contents).unwrap(),
932 }
933 }
934 Message::System { content, .. } => crate::completion::Message::User {
936 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
937 text: content,
938 })),
939 },
940 Message::ToolResult { name, content } => crate::completion::Message::User {
941 content: OneOrMany::one(message::UserContent::tool_result(
942 name,
943 OneOrMany::one(message::ToolResultContent::text(content)),
944 )),
945 },
946 }
947 }
948}
949
950impl Message {
951 pub fn system(content: &str) -> Self {
953 Message::System {
954 content: content.to_owned(),
955 images: None,
956 name: None,
957 }
958 }
959}
960
961impl From<crate::message::ToolCall> for ToolCall {
964 fn from(tool_call: crate::message::ToolCall) -> Self {
965 Self {
966 r#type: ToolType::Function,
967 function: Function {
968 name: tool_call.function.name,
969 arguments: tool_call.function.arguments,
970 },
971 }
972 }
973}
974
975#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
976pub struct SystemContent {
977 #[serde(default)]
978 r#type: SystemContentType,
979 text: String,
980}
981
982#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
983#[serde(rename_all = "lowercase")]
984pub enum SystemContentType {
985 #[default]
986 Text,
987}
988
989impl From<String> for SystemContent {
990 fn from(s: String) -> Self {
991 SystemContent {
992 r#type: SystemContentType::default(),
993 text: s,
994 }
995 }
996}
997
998impl FromStr for SystemContent {
999 type Err = std::convert::Infallible;
1000 fn from_str(s: &str) -> Result<Self, Self::Err> {
1001 Ok(SystemContent {
1002 r#type: SystemContentType::default(),
1003 text: s.to_string(),
1004 })
1005 }
1006}
1007
1008#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1009pub struct AssistantContent {
1010 pub text: String,
1011}
1012
1013impl FromStr for AssistantContent {
1014 type Err = std::convert::Infallible;
1015 fn from_str(s: &str) -> Result<Self, Self::Err> {
1016 Ok(AssistantContent { text: s.to_owned() })
1017 }
1018}
1019
1020#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1021#[serde(tag = "type", rename_all = "lowercase")]
1022pub enum UserContent {
1023 Text { text: String },
1024 Image { image_url: ImageUrl },
1025 }
1027
1028impl FromStr for UserContent {
1029 type Err = std::convert::Infallible;
1030 fn from_str(s: &str) -> Result<Self, Self::Err> {
1031 Ok(UserContent::Text { text: s.to_owned() })
1032 }
1033}
1034
1035#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1036pub struct ImageUrl {
1037 pub url: String,
1038 #[serde(default)]
1039 pub detail: ImageDetail,
1040}
1041
1042#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use serde_json::json;
1050
1051 #[tokio::test]
1053 async fn test_chat_completion() {
1054 let sample_chat_response = json!({
1056 "model": "llama3.2",
1057 "created_at": "2023-08-04T19:22:45.499127Z",
1058 "message": {
1059 "role": "assistant",
1060 "content": "The sky is blue because of Rayleigh scattering.",
1061 "images": null,
1062 "tool_calls": [
1063 {
1064 "type": "function",
1065 "function": {
1066 "name": "get_current_weather",
1067 "arguments": {
1068 "location": "San Francisco, CA",
1069 "format": "celsius"
1070 }
1071 }
1072 }
1073 ]
1074 },
1075 "done": true,
1076 "total_duration": 8000000000u64,
1077 "load_duration": 6000000u64,
1078 "prompt_eval_count": 61u64,
1079 "prompt_eval_duration": 400000000u64,
1080 "eval_count": 468u64,
1081 "eval_duration": 7700000000u64
1082 });
1083 let sample_text = sample_chat_response.to_string();
1084
1085 let chat_resp: CompletionResponse =
1086 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1087 let conv: completion::CompletionResponse<CompletionResponse> =
1088 chat_resp.try_into().unwrap();
1089 assert!(
1090 !conv.choice.is_empty(),
1091 "Expected non-empty choice in chat response"
1092 );
1093 }
1094
1095 #[test]
1097 fn test_message_conversion() {
1098 let provider_msg = Message::User {
1100 content: "Test message".to_owned(),
1101 images: None,
1102 name: None,
1103 };
1104 let comp_msg: crate::completion::Message = provider_msg.into();
1106 match comp_msg {
1107 crate::completion::Message::User { content } => {
1108 let first_content = content.first();
1110 match first_content {
1112 crate::completion::message::UserContent::Text(text_struct) => {
1113 assert_eq!(text_struct.text, "Test message");
1114 }
1115 _ => panic!("Expected text content in conversion"),
1116 }
1117 }
1118 _ => panic!("Conversion from provider Message to completion Message failed"),
1119 }
1120 }
1121
1122 #[test]
1124 fn test_tool_definition_conversion() {
1125 let internal_tool = crate::completion::ToolDefinition {
1127 name: "get_current_weather".to_owned(),
1128 description: "Get the current weather for a location".to_owned(),
1129 parameters: json!({
1130 "type": "object",
1131 "properties": {
1132 "location": {
1133 "type": "string",
1134 "description": "The location to get the weather for, e.g. San Francisco, CA"
1135 },
1136 "format": {
1137 "type": "string",
1138 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1139 "enum": ["celsius", "fahrenheit"]
1140 }
1141 },
1142 "required": ["location", "format"]
1143 }),
1144 };
1145 let ollama_tool: ToolDefinition = internal_tool.into();
1147 assert_eq!(ollama_tool.type_field, "function");
1148 assert_eq!(ollama_tool.function.name, "get_current_weather");
1149 assert_eq!(
1150 ollama_tool.function.description,
1151 "Get the current weather for a location"
1152 );
1153 let params = &ollama_tool.function.parameters;
1155 assert_eq!(params["properties"]["location"]["type"], "string");
1156 }
1157
1158 #[tokio::test]
1160 async fn test_chat_completion_with_thinking() {
1161 let sample_response = json!({
1162 "model": "qwen-thinking",
1163 "created_at": "2023-08-04T19:22:45.499127Z",
1164 "message": {
1165 "role": "assistant",
1166 "content": "The answer is 42.",
1167 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1168 "images": null,
1169 "tool_calls": []
1170 },
1171 "done": true,
1172 "total_duration": 8000000000u64,
1173 "load_duration": 6000000u64,
1174 "prompt_eval_count": 61u64,
1175 "prompt_eval_duration": 400000000u64,
1176 "eval_count": 468u64,
1177 "eval_duration": 7700000000u64
1178 });
1179
1180 let chat_resp: CompletionResponse =
1181 serde_json::from_value(sample_response).expect("Failed to deserialize");
1182
1183 if let Message::Assistant {
1185 thinking, content, ..
1186 } = &chat_resp.message
1187 {
1188 assert_eq!(
1189 thinking.as_ref().unwrap(),
1190 "Let me think about this carefully. The question asks for the meaning of life..."
1191 );
1192 assert_eq!(content, "The answer is 42.");
1193 } else {
1194 panic!("Expected Assistant message");
1195 }
1196 }
1197
1198 #[tokio::test]
1200 async fn test_chat_completion_without_thinking() {
1201 let sample_response = json!({
1202 "model": "llama3.2",
1203 "created_at": "2023-08-04T19:22:45.499127Z",
1204 "message": {
1205 "role": "assistant",
1206 "content": "Hello!",
1207 "images": null,
1208 "tool_calls": []
1209 },
1210 "done": true,
1211 "total_duration": 8000000000u64,
1212 "load_duration": 6000000u64,
1213 "prompt_eval_count": 10u64,
1214 "prompt_eval_duration": 400000000u64,
1215 "eval_count": 5u64,
1216 "eval_duration": 7700000000u64
1217 });
1218
1219 let chat_resp: CompletionResponse =
1220 serde_json::from_value(sample_response).expect("Failed to deserialize");
1221
1222 if let Message::Assistant {
1224 thinking, content, ..
1225 } = &chat_resp.message
1226 {
1227 assert!(thinking.is_none());
1228 assert_eq!(content, "Hello!");
1229 } else {
1230 panic!("Expected Assistant message");
1231 }
1232 }
1233
1234 #[test]
1236 fn test_streaming_response_with_thinking() {
1237 let sample_chunk = json!({
1238 "model": "qwen-thinking",
1239 "created_at": "2023-08-04T19:22:45.499127Z",
1240 "message": {
1241 "role": "assistant",
1242 "content": "",
1243 "thinking": "Analyzing the problem...",
1244 "images": null,
1245 "tool_calls": []
1246 },
1247 "done": false
1248 });
1249
1250 let chunk: CompletionResponse =
1251 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1252
1253 if let Message::Assistant {
1254 thinking, content, ..
1255 } = &chunk.message
1256 {
1257 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1258 assert_eq!(content, "");
1259 } else {
1260 panic!("Expected Assistant message");
1261 }
1262 }
1263
1264 #[test]
1266 fn test_message_conversion_with_thinking() {
1267 let reasoning_content = crate::message::Reasoning {
1269 id: None,
1270 reasoning: vec!["Step 1: Consider the problem".to_string()],
1271 signature: None,
1272 };
1273
1274 let internal_msg = crate::message::Message::Assistant {
1275 id: None,
1276 content: crate::OneOrMany::many(vec![
1277 crate::message::AssistantContent::Reasoning(reasoning_content),
1278 crate::message::AssistantContent::Text(crate::message::Text {
1279 text: "The answer is X".to_string(),
1280 }),
1281 ])
1282 .unwrap(),
1283 };
1284
1285 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1287 assert_eq!(provider_msgs.len(), 1);
1288
1289 if let Message::Assistant {
1290 thinking, content, ..
1291 } = &provider_msgs[0]
1292 {
1293 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1294 assert_eq!(content, "The answer is X");
1295 } else {
1296 panic!("Expected Assistant message with thinking");
1297 }
1298 }
1299
1300 #[test]
1302 fn test_empty_thinking_content() {
1303 let sample_response = json!({
1304 "model": "llama3.2",
1305 "created_at": "2023-08-04T19:22:45.499127Z",
1306 "message": {
1307 "role": "assistant",
1308 "content": "Response",
1309 "thinking": "",
1310 "images": null,
1311 "tool_calls": []
1312 },
1313 "done": true,
1314 "total_duration": 8000000000u64,
1315 "load_duration": 6000000u64,
1316 "prompt_eval_count": 10u64,
1317 "prompt_eval_duration": 400000000u64,
1318 "eval_count": 5u64,
1319 "eval_duration": 7700000000u64
1320 });
1321
1322 let chat_resp: CompletionResponse =
1323 serde_json::from_value(sample_response).expect("Failed to deserialize");
1324
1325 if let Message::Assistant {
1326 thinking, content, ..
1327 } = &chat_resp.message
1328 {
1329 assert_eq!(thinking.as_ref().unwrap(), "");
1331 assert_eq!(content, "Response");
1332 } else {
1333 panic!("Expected Assistant message");
1334 }
1335 }
1336
1337 #[test]
1339 fn test_thinking_with_tool_calls() {
1340 let sample_response = json!({
1341 "model": "qwen-thinking",
1342 "created_at": "2023-08-04T19:22:45.499127Z",
1343 "message": {
1344 "role": "assistant",
1345 "content": "Let me check the weather.",
1346 "thinking": "User wants weather info, I should use the weather tool",
1347 "images": null,
1348 "tool_calls": [
1349 {
1350 "type": "function",
1351 "function": {
1352 "name": "get_weather",
1353 "arguments": {
1354 "location": "San Francisco"
1355 }
1356 }
1357 }
1358 ]
1359 },
1360 "done": true,
1361 "total_duration": 8000000000u64,
1362 "load_duration": 6000000u64,
1363 "prompt_eval_count": 30u64,
1364 "prompt_eval_duration": 400000000u64,
1365 "eval_count": 50u64,
1366 "eval_duration": 7700000000u64
1367 });
1368
1369 let chat_resp: CompletionResponse =
1370 serde_json::from_value(sample_response).expect("Failed to deserialize");
1371
1372 if let Message::Assistant {
1373 thinking,
1374 content,
1375 tool_calls,
1376 ..
1377 } = &chat_resp.message
1378 {
1379 assert_eq!(
1380 thinking.as_ref().unwrap(),
1381 "User wants weather info, I should use the weather tool"
1382 );
1383 assert_eq!(content, "Let me check the weather.");
1384 assert_eq!(tool_calls.len(), 1);
1385 assert_eq!(tool_calls[0].function.name, "get_weather");
1386 } else {
1387 panic!("Expected Assistant message with thinking and tool calls");
1388 }
1389 }
1390}