1use crate::client::{
42 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::streaming::RawStreamingChoice;
48use crate::{
49 OneOrMany,
50 completion::{self, CompletionError, CompletionRequest},
51 embeddings::{self, EmbeddingError},
52 json_utils, message,
53 message::{ImageDetail, Text},
54 streaming,
55};
56use async_stream::try_stream;
57use bytes::Bytes;
58use futures::StreamExt;
59use serde::{Deserialize, Serialize};
60use serde_json::{Value, json};
61use std::{convert::TryFrom, str::FromStr};
62use tracing::info_span;
63use tracing_futures::Instrument;
64const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
67
68#[derive(Debug, Default, Clone, Copy)]
69pub struct OllamaExt;
70
71#[derive(Debug, Default, Clone, Copy)]
72pub struct OllamaBuilder;
73
74impl Provider for OllamaExt {
75 type Builder = OllamaBuilder;
76
77 const VERIFY_PATH: &'static str = "api/tags";
78
79 fn build<H>(
80 _: &crate::client::ClientBuilder<
81 Self::Builder,
82 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
83 H,
84 >,
85 ) -> http_client::Result<Self> {
86 Ok(Self)
87 }
88}
89
90impl<H> Capabilities<H> for OllamaExt {
91 type Completion = Capable<CompletionModel<H>>;
92 type Transcription = Nothing;
93 type Embeddings = Capable<EmbeddingModel<H>>;
94 #[cfg(feature = "image")]
95 type ImageGeneration = Nothing;
96
97 #[cfg(feature = "audio")]
98 type AudioGeneration = Nothing;
99}
100
101impl DebugExt for OllamaExt {}
102
103impl ProviderBuilder for OllamaBuilder {
104 type Output = OllamaExt;
105 type ApiKey = Nothing;
106
107 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
108}
109
110pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
111pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
112
113impl ProviderClient for Client {
114 type Input = Nothing;
115
116 fn from_env() -> Self {
117 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
118
119 Self::builder()
120 .api_key(Nothing)
121 .base_url(&api_base)
122 .build()
123 .unwrap()
124 }
125
126 fn from_val(_: Self::Input) -> Self {
127 Self::builder().api_key(Nothing).build().unwrap()
128 }
129}
130
131#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135 message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141 Ok(T),
142 Err(ApiErrorResponse),
143}
144
145pub const ALL_MINILM: &str = "all-minilm";
148pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct EmbeddingResponse {
152 pub model: String,
153 pub embeddings: Vec<Vec<f64>>,
154 #[serde(default)]
155 pub total_duration: Option<u64>,
156 #[serde(default)]
157 pub load_duration: Option<u64>,
158 #[serde(default)]
159 pub prompt_eval_count: Option<u64>,
160}
161
162impl From<ApiErrorResponse> for EmbeddingError {
163 fn from(err: ApiErrorResponse) -> Self {
164 EmbeddingError::ProviderError(err.message)
165 }
166}
167
168impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
169 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
170 match value {
171 ApiResponse::Ok(response) => Ok(response),
172 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
173 }
174 }
175}
176
177#[derive(Clone)]
180pub struct EmbeddingModel<T> {
181 client: Client<T>,
182 pub model: String,
183 ndims: usize,
184}
185
186impl<T> EmbeddingModel<T> {
187 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
188 Self {
189 client,
190 model: model.into(),
191 ndims,
192 }
193 }
194
195 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
196 Self {
197 client,
198 model: model.into(),
199 ndims,
200 }
201 }
202}
203
204impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
205where
206 T: HttpClientExt + Clone + 'static,
207{
208 type Client = Client<T>;
209
210 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
211 Self::new(client.clone(), model, dims.unwrap())
212 }
213
214 const MAX_DOCUMENTS: usize = 1024;
215 fn ndims(&self) -> usize {
216 self.ndims
217 }
218 #[cfg_attr(feature = "worker", worker::send)]
219 async fn embed_texts(
220 &self,
221 documents: impl IntoIterator<Item = String>,
222 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
223 let docs: Vec<String> = documents.into_iter().collect();
224
225 let body = serde_json::to_vec(&json!({
226 "model": self.model,
227 "input": docs
228 }))?;
229
230 let req = self
231 .client
232 .post("api/embed")?
233 .body(body)
234 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
235
236 let response = HttpClientExt::send(self.client.http_client(), req).await?;
237
238 if !response.status().is_success() {
239 let text = http_client::text(response).await?;
240 return Err(EmbeddingError::ProviderError(text));
241 }
242
243 let bytes: Vec<u8> = response.into_body().await?;
244
245 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
246
247 if api_resp.embeddings.len() != docs.len() {
248 return Err(EmbeddingError::ResponseError(
249 "Number of returned embeddings does not match input".into(),
250 ));
251 }
252 Ok(api_resp
253 .embeddings
254 .into_iter()
255 .zip(docs.into_iter())
256 .map(|(vec, document)| embeddings::Embedding { document, vec })
257 .collect())
258 }
259}
260
261pub const LLAMA3_2: &str = "llama3.2";
264pub const LLAVA: &str = "llava";
265pub const MISTRAL: &str = "mistral";
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct CompletionResponse {
269 pub model: String,
270 pub created_at: String,
271 pub message: Message,
272 pub done: bool,
273 #[serde(default)]
274 pub done_reason: Option<String>,
275 #[serde(default)]
276 pub total_duration: Option<u64>,
277 #[serde(default)]
278 pub load_duration: Option<u64>,
279 #[serde(default)]
280 pub prompt_eval_count: Option<u64>,
281 #[serde(default)]
282 pub prompt_eval_duration: Option<u64>,
283 #[serde(default)]
284 pub eval_count: Option<u64>,
285 #[serde(default)]
286 pub eval_duration: Option<u64>,
287}
288impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
289 type Error = CompletionError;
290 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
291 match resp.message {
292 Message::Assistant {
294 content,
295 thinking,
296 tool_calls,
297 ..
298 } => {
299 let mut assistant_contents = Vec::new();
300 if !content.is_empty() {
302 assistant_contents.push(completion::AssistantContent::text(&content));
303 }
304 for tc in tool_calls.iter() {
307 assistant_contents.push(completion::AssistantContent::tool_call(
308 tc.function.name.clone(),
309 tc.function.name.clone(),
310 tc.function.arguments.clone(),
311 ));
312 }
313 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
314 CompletionError::ResponseError("No content provided".to_owned())
315 })?;
316 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
317 let completion_tokens = resp.eval_count.unwrap_or(0);
318
319 let raw_response = CompletionResponse {
320 model: resp.model,
321 created_at: resp.created_at,
322 done: resp.done,
323 done_reason: resp.done_reason,
324 total_duration: resp.total_duration,
325 load_duration: resp.load_duration,
326 prompt_eval_count: resp.prompt_eval_count,
327 prompt_eval_duration: resp.prompt_eval_duration,
328 eval_count: resp.eval_count,
329 eval_duration: resp.eval_duration,
330 message: Message::Assistant {
331 content,
332 thinking,
333 images: None,
334 name: None,
335 tool_calls,
336 },
337 };
338
339 Ok(completion::CompletionResponse {
340 choice,
341 usage: Usage {
342 input_tokens: prompt_tokens,
343 output_tokens: completion_tokens,
344 total_tokens: prompt_tokens + completion_tokens,
345 },
346 raw_response,
347 })
348 }
349 _ => Err(CompletionError::ResponseError(
350 "Chat response does not include an assistant message".into(),
351 )),
352 }
353 }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357pub(super) struct OllamaCompletionRequest {
358 model: String,
359 pub messages: Vec<Message>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 temperature: Option<f64>,
362 #[serde(skip_serializing_if = "Vec::is_empty")]
363 tools: Vec<ToolDefinition>,
364 pub stream: bool,
365 think: bool,
366 #[serde(skip_serializing_if = "Option::is_none")]
367 max_tokens: Option<u64>,
368 options: serde_json::Value,
369}
370
371impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
372 type Error = CompletionError;
373
374 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
375 if req.tool_choice.is_some() {
376 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
377 }
378 let mut partial_history = vec![];
380 if let Some(docs) = req.normalized_documents() {
381 partial_history.push(docs);
382 }
383 partial_history.extend(req.chat_history);
384
385 let mut full_history: Vec<Message> = match &req.preamble {
387 Some(preamble) => vec![Message::system(preamble)],
388 None => vec![],
389 };
390
391 full_history.extend(
393 partial_history
394 .into_iter()
395 .map(message::Message::try_into)
396 .collect::<Result<Vec<Vec<Message>>, _>>()?
397 .into_iter()
398 .flatten()
399 .collect::<Vec<_>>(),
400 );
401
402 let mut think = false;
403
404 let options = if let Some(mut extra) = req.additional_params {
406 if extra.get("think").is_some() {
407 think = extra["think"].take().as_bool().ok_or_else(|| {
408 CompletionError::RequestError("`think` must be a bool".into())
409 })?;
410 }
411 json_utils::merge(json!({ "temperature": req.temperature }), extra)
412 } else {
413 json!({ "temperature": req.temperature })
414 };
415
416 Ok(Self {
417 model: model.to_string(),
418 messages: full_history,
419 temperature: req.temperature,
420 max_tokens: req.max_tokens,
421 stream: false,
422 think,
423 tools: req
424 .tools
425 .clone()
426 .into_iter()
427 .map(ToolDefinition::from)
428 .collect::<Vec<_>>(),
429 options,
430 })
431 }
432}
433
434#[derive(Clone)]
435pub struct CompletionModel<T = reqwest::Client> {
436 client: Client<T>,
437 pub model: String,
438}
439
440impl<T> CompletionModel<T> {
441 pub fn new(client: Client<T>, model: &str) -> Self {
442 Self {
443 client,
444 model: model.to_owned(),
445 }
446 }
447}
448
449#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct StreamingCompletionResponse {
453 pub done_reason: Option<String>,
454 pub total_duration: Option<u64>,
455 pub load_duration: Option<u64>,
456 pub prompt_eval_count: Option<u64>,
457 pub prompt_eval_duration: Option<u64>,
458 pub eval_count: Option<u64>,
459 pub eval_duration: Option<u64>,
460}
461
462impl GetTokenUsage for StreamingCompletionResponse {
463 fn token_usage(&self) -> Option<crate::completion::Usage> {
464 let mut usage = crate::completion::Usage::new();
465 let input_tokens = self.prompt_eval_count.unwrap_or_default();
466 let output_tokens = self.eval_count.unwrap_or_default();
467 usage.input_tokens = input_tokens;
468 usage.output_tokens = output_tokens;
469 usage.total_tokens = input_tokens + output_tokens;
470
471 Some(usage)
472 }
473}
474
475impl<T> completion::CompletionModel for CompletionModel<T>
476where
477 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
478{
479 type Response = CompletionResponse;
480 type StreamingResponse = StreamingCompletionResponse;
481
482 type Client = Client<T>;
483
484 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
485 Self::new(client.clone(), model.into().as_str())
486 }
487
488 #[cfg_attr(feature = "worker", worker::send)]
489 async fn completion(
490 &self,
491 completion_request: CompletionRequest,
492 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
493 let preamble = completion_request.preamble.clone();
494 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
495
496 let span = if tracing::Span::current().is_disabled() {
497 info_span!(
498 target: "rig::completions",
499 "chat",
500 gen_ai.operation.name = "chat",
501 gen_ai.provider.name = "ollama",
502 gen_ai.request.model = self.model,
503 gen_ai.system_instructions = preamble,
504 gen_ai.response.id = tracing::field::Empty,
505 gen_ai.response.model = tracing::field::Empty,
506 gen_ai.usage.output_tokens = tracing::field::Empty,
507 gen_ai.usage.input_tokens = tracing::field::Empty,
508 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
509 gen_ai.output.messages = tracing::field::Empty,
510 )
511 } else {
512 tracing::Span::current()
513 };
514
515 let body = serde_json::to_vec(&request)?;
516
517 let req = self
518 .client
519 .post("api/chat")?
520 .body(body)
521 .map_err(http_client::Error::from)?;
522
523 let async_block = async move {
524 let response = self.client.send::<_, Bytes>(req).await?;
525 let status = response.status();
526 let response_body = response.into_body().into_future().await?.to_vec();
527
528 if !status.is_success() {
529 return Err(CompletionError::ProviderError(
530 String::from_utf8_lossy(&response_body).to_string(),
531 ));
532 }
533
534 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
535 let span = tracing::Span::current();
536 span.record("gen_ai.response.model_name", &response.model);
537 span.record(
538 "gen_ai.output.messages",
539 serde_json::to_string(&vec![&response.message]).unwrap(),
540 );
541 span.record(
542 "gen_ai.usage.input_tokens",
543 response.prompt_eval_count.unwrap_or_default(),
544 );
545 span.record(
546 "gen_ai.usage.output_tokens",
547 response.eval_count.unwrap_or_default(),
548 );
549
550 tracing::trace!(
551 target: "rig::completions",
552 "Ollama completion response: {}",
553 serde_json::to_string_pretty(&response)?
554 );
555
556 let response: completion::CompletionResponse<CompletionResponse> =
557 response.try_into()?;
558
559 Ok(response)
560 };
561
562 tracing::Instrument::instrument(async_block, span).await
563 }
564
565 #[cfg_attr(feature = "worker", worker::send)]
566 async fn stream(
567 &self,
568 request: CompletionRequest,
569 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
570 {
571 let preamble = request.preamble.clone();
572 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
573 request.stream = true;
574
575 let span = if tracing::Span::current().is_disabled() {
576 info_span!(
577 target: "rig::completions",
578 "chat_streaming",
579 gen_ai.operation.name = "chat_streaming",
580 gen_ai.provider.name = "ollama",
581 gen_ai.request.model = self.model,
582 gen_ai.system_instructions = preamble,
583 gen_ai.response.id = tracing::field::Empty,
584 gen_ai.response.model = self.model,
585 gen_ai.usage.output_tokens = tracing::field::Empty,
586 gen_ai.usage.input_tokens = tracing::field::Empty,
587 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
588 gen_ai.output.messages = tracing::field::Empty,
589 )
590 } else {
591 tracing::Span::current()
592 };
593
594 let body = serde_json::to_vec(&request)?;
595
596 let req = self
597 .client
598 .post("api/chat")?
599 .body(body)
600 .map_err(http_client::Error::from)?;
601
602 let response = self.client.http_client().send_streaming(req).await?;
603 let status = response.status();
604 let mut byte_stream = response.into_body();
605
606 if !status.is_success() {
607 return Err(CompletionError::ProviderError(format!(
608 "Got error status code trying to send a request to Ollama: {status}"
609 )));
610 }
611
612 let stream = try_stream! {
613 let span = tracing::Span::current();
614 let mut tool_calls_final = Vec::new();
615 let mut text_response = String::new();
616 let mut thinking_response = String::new();
617
618 while let Some(chunk) = byte_stream.next().await {
619 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
620
621 for line in bytes.split(|&b| b == b'\n') {
622 if line.is_empty() {
623 continue;
624 }
625
626 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
627
628 let response: CompletionResponse = serde_json::from_slice(line)?;
629
630 if response.done {
631 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
632 span.record("gen_ai.usage.output_tokens", response.eval_count);
633 let message = Message::Assistant {
634 content: text_response.clone(),
635 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
636 images: None,
637 name: None,
638 tool_calls: tool_calls_final.clone()
639 };
640 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
641 yield RawStreamingChoice::FinalResponse(
642 StreamingCompletionResponse {
643 total_duration: response.total_duration,
644 load_duration: response.load_duration,
645 prompt_eval_count: response.prompt_eval_count,
646 prompt_eval_duration: response.prompt_eval_duration,
647 eval_count: response.eval_count,
648 eval_duration: response.eval_duration,
649 done_reason: response.done_reason,
650 }
651 );
652 break;
653 }
654
655 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
656 if let Some(thinking_content) = thinking
657 && !thinking_content.is_empty() {
658 thinking_response += &thinking_content;
659 yield RawStreamingChoice::Reasoning {
660 reasoning: thinking_content,
661 id: None,
662 signature: None,
663 };
664 }
665
666 if !content.is_empty() {
667 text_response += &content;
668 yield RawStreamingChoice::Message(content);
669 }
670
671 for tool_call in tool_calls {
672 tool_calls_final.push(tool_call.clone());
673 yield RawStreamingChoice::ToolCall {
674 id: String::new(),
675 name: tool_call.function.name,
676 arguments: tool_call.function.arguments,
677 call_id: None,
678 };
679 }
680 }
681 }
682 }
683 }.instrument(span);
684
685 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
686 stream,
687 )))
688 }
689}
690
691#[derive(Clone, Debug, Deserialize, Serialize)]
695pub struct ToolDefinition {
696 #[serde(rename = "type")]
697 pub type_field: String, pub function: completion::ToolDefinition,
699}
700
701impl From<crate::completion::ToolDefinition> for ToolDefinition {
703 fn from(tool: crate::completion::ToolDefinition) -> Self {
704 ToolDefinition {
705 type_field: "function".to_owned(),
706 function: completion::ToolDefinition {
707 name: tool.name,
708 description: tool.description,
709 parameters: tool.parameters,
710 },
711 }
712 }
713}
714
715#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
716pub struct ToolCall {
717 #[serde(default, rename = "type")]
718 pub r#type: ToolType,
719 pub function: Function,
720}
721#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
722#[serde(rename_all = "lowercase")]
723pub enum ToolType {
724 #[default]
725 Function,
726}
727#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
728pub struct Function {
729 pub name: String,
730 pub arguments: Value,
731}
732
733#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
736#[serde(tag = "role", rename_all = "lowercase")]
737pub enum Message {
738 User {
739 content: String,
740 #[serde(skip_serializing_if = "Option::is_none")]
741 images: Option<Vec<String>>,
742 #[serde(skip_serializing_if = "Option::is_none")]
743 name: Option<String>,
744 },
745 Assistant {
746 #[serde(default)]
747 content: String,
748 #[serde(skip_serializing_if = "Option::is_none")]
749 thinking: Option<String>,
750 #[serde(skip_serializing_if = "Option::is_none")]
751 images: Option<Vec<String>>,
752 #[serde(skip_serializing_if = "Option::is_none")]
753 name: Option<String>,
754 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
755 tool_calls: Vec<ToolCall>,
756 },
757 System {
758 content: String,
759 #[serde(skip_serializing_if = "Option::is_none")]
760 images: Option<Vec<String>>,
761 #[serde(skip_serializing_if = "Option::is_none")]
762 name: Option<String>,
763 },
764 #[serde(rename = "tool")]
765 ToolResult {
766 #[serde(rename = "tool_name")]
767 name: String,
768 content: String,
769 },
770}
771
772impl TryFrom<crate::message::Message> for Vec<Message> {
778 type Error = crate::message::MessageError;
779 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
780 use crate::message::Message as InternalMessage;
781 match internal_msg {
782 InternalMessage::User { content, .. } => {
783 let (tool_results, other_content): (Vec<_>, Vec<_>) =
784 content.into_iter().partition(|content| {
785 matches!(content, crate::message::UserContent::ToolResult(_))
786 });
787
788 if !tool_results.is_empty() {
789 tool_results
790 .into_iter()
791 .map(|content| match content {
792 crate::message::UserContent::ToolResult(
793 crate::message::ToolResult { id, content, .. },
794 ) => {
795 let content_string = content
797 .into_iter()
798 .map(|content| match content {
799 crate::message::ToolResultContent::Text(text) => text.text,
800 _ => "[Non-text content]".to_string(),
801 })
802 .collect::<Vec<_>>()
803 .join("\n");
804
805 Ok::<_, crate::message::MessageError>(Message::ToolResult {
806 name: id,
807 content: content_string,
808 })
809 }
810 _ => unreachable!(),
811 })
812 .collect::<Result<Vec<_>, _>>()
813 } else {
814 let (texts, images) = other_content.into_iter().fold(
816 (Vec::new(), Vec::new()),
817 |(mut texts, mut images), content| {
818 match content {
819 crate::message::UserContent::Text(crate::message::Text {
820 text,
821 }) => texts.push(text),
822 crate::message::UserContent::Image(crate::message::Image {
823 data: DocumentSourceKind::Base64(data),
824 ..
825 }) => images.push(data),
826 crate::message::UserContent::Document(
827 crate::message::Document {
828 data:
829 DocumentSourceKind::Base64(data)
830 | DocumentSourceKind::String(data),
831 ..
832 },
833 ) => texts.push(data),
834 _ => {} }
836 (texts, images)
837 },
838 );
839
840 Ok(vec![Message::User {
841 content: texts.join(" "),
842 images: if images.is_empty() {
843 None
844 } else {
845 Some(
846 images
847 .into_iter()
848 .map(|x| x.to_string())
849 .collect::<Vec<String>>(),
850 )
851 },
852 name: None,
853 }])
854 }
855 }
856 InternalMessage::Assistant { content, .. } => {
857 let mut thinking: Option<String> = None;
858 let mut text_content = Vec::new();
859 let mut tool_calls = Vec::new();
860
861 for content in content.into_iter() {
862 match content {
863 crate::message::AssistantContent::Text(text) => {
864 text_content.push(text.text)
865 }
866 crate::message::AssistantContent::ToolCall(tool_call) => {
867 tool_calls.push(tool_call)
868 }
869 crate::message::AssistantContent::Reasoning(
870 crate::message::Reasoning { reasoning, .. },
871 ) => {
872 thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
873 }
874 crate::message::AssistantContent::Image(_) => {
875 return Err(crate::message::MessageError::ConversionError(
876 "Ollama currently doesn't support images.".into(),
877 ));
878 }
879 }
880 }
881
882 Ok(vec![Message::Assistant {
885 content: text_content.join(" "),
886 thinking,
887 images: None,
888 name: None,
889 tool_calls: tool_calls
890 .into_iter()
891 .map(|tool_call| tool_call.into())
892 .collect::<Vec<_>>(),
893 }])
894 }
895 }
896 }
897}
898
899impl From<Message> for crate::completion::Message {
902 fn from(msg: Message) -> Self {
903 match msg {
904 Message::User { content, .. } => crate::completion::Message::User {
905 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
906 text: content,
907 })),
908 },
909 Message::Assistant {
910 content,
911 tool_calls,
912 ..
913 } => {
914 let mut assistant_contents =
915 vec![crate::completion::message::AssistantContent::Text(Text {
916 text: content,
917 })];
918 for tc in tool_calls {
919 assistant_contents.push(
920 crate::completion::message::AssistantContent::tool_call(
921 tc.function.name.clone(),
922 tc.function.name,
923 tc.function.arguments,
924 ),
925 );
926 }
927 crate::completion::Message::Assistant {
928 id: None,
929 content: OneOrMany::many(assistant_contents).unwrap(),
930 }
931 }
932 Message::System { content, .. } => crate::completion::Message::User {
934 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
935 text: content,
936 })),
937 },
938 Message::ToolResult { name, content } => crate::completion::Message::User {
939 content: OneOrMany::one(message::UserContent::tool_result(
940 name,
941 OneOrMany::one(message::ToolResultContent::text(content)),
942 )),
943 },
944 }
945 }
946}
947
948impl Message {
949 pub fn system(content: &str) -> Self {
951 Message::System {
952 content: content.to_owned(),
953 images: None,
954 name: None,
955 }
956 }
957}
958
959impl From<crate::message::ToolCall> for ToolCall {
962 fn from(tool_call: crate::message::ToolCall) -> Self {
963 Self {
964 r#type: ToolType::Function,
965 function: Function {
966 name: tool_call.function.name,
967 arguments: tool_call.function.arguments,
968 },
969 }
970 }
971}
972
973#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
974pub struct SystemContent {
975 #[serde(default)]
976 r#type: SystemContentType,
977 text: String,
978}
979
980#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
981#[serde(rename_all = "lowercase")]
982pub enum SystemContentType {
983 #[default]
984 Text,
985}
986
987impl From<String> for SystemContent {
988 fn from(s: String) -> Self {
989 SystemContent {
990 r#type: SystemContentType::default(),
991 text: s,
992 }
993 }
994}
995
996impl FromStr for SystemContent {
997 type Err = std::convert::Infallible;
998 fn from_str(s: &str) -> Result<Self, Self::Err> {
999 Ok(SystemContent {
1000 r#type: SystemContentType::default(),
1001 text: s.to_string(),
1002 })
1003 }
1004}
1005
1006#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1007pub struct AssistantContent {
1008 pub text: String,
1009}
1010
1011impl FromStr for AssistantContent {
1012 type Err = std::convert::Infallible;
1013 fn from_str(s: &str) -> Result<Self, Self::Err> {
1014 Ok(AssistantContent { text: s.to_owned() })
1015 }
1016}
1017
1018#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1019#[serde(tag = "type", rename_all = "lowercase")]
1020pub enum UserContent {
1021 Text { text: String },
1022 Image { image_url: ImageUrl },
1023 }
1025
1026impl FromStr for UserContent {
1027 type Err = std::convert::Infallible;
1028 fn from_str(s: &str) -> Result<Self, Self::Err> {
1029 Ok(UserContent::Text { text: s.to_owned() })
1030 }
1031}
1032
1033#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1034pub struct ImageUrl {
1035 pub url: String,
1036 #[serde(default)]
1037 pub detail: ImageDetail,
1038}
1039
1040#[cfg(test)]
1045mod tests {
1046 use super::*;
1047 use serde_json::json;
1048
1049 #[tokio::test]
1051 async fn test_chat_completion() {
1052 let sample_chat_response = json!({
1054 "model": "llama3.2",
1055 "created_at": "2023-08-04T19:22:45.499127Z",
1056 "message": {
1057 "role": "assistant",
1058 "content": "The sky is blue because of Rayleigh scattering.",
1059 "images": null,
1060 "tool_calls": [
1061 {
1062 "type": "function",
1063 "function": {
1064 "name": "get_current_weather",
1065 "arguments": {
1066 "location": "San Francisco, CA",
1067 "format": "celsius"
1068 }
1069 }
1070 }
1071 ]
1072 },
1073 "done": true,
1074 "total_duration": 8000000000u64,
1075 "load_duration": 6000000u64,
1076 "prompt_eval_count": 61u64,
1077 "prompt_eval_duration": 400000000u64,
1078 "eval_count": 468u64,
1079 "eval_duration": 7700000000u64
1080 });
1081 let sample_text = sample_chat_response.to_string();
1082
1083 let chat_resp: CompletionResponse =
1084 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1085 let conv: completion::CompletionResponse<CompletionResponse> =
1086 chat_resp.try_into().unwrap();
1087 assert!(
1088 !conv.choice.is_empty(),
1089 "Expected non-empty choice in chat response"
1090 );
1091 }
1092
1093 #[test]
1095 fn test_message_conversion() {
1096 let provider_msg = Message::User {
1098 content: "Test message".to_owned(),
1099 images: None,
1100 name: None,
1101 };
1102 let comp_msg: crate::completion::Message = provider_msg.into();
1104 match comp_msg {
1105 crate::completion::Message::User { content } => {
1106 let first_content = content.first();
1108 match first_content {
1110 crate::completion::message::UserContent::Text(text_struct) => {
1111 assert_eq!(text_struct.text, "Test message");
1112 }
1113 _ => panic!("Expected text content in conversion"),
1114 }
1115 }
1116 _ => panic!("Conversion from provider Message to completion Message failed"),
1117 }
1118 }
1119
1120 #[test]
1122 fn test_tool_definition_conversion() {
1123 let internal_tool = crate::completion::ToolDefinition {
1125 name: "get_current_weather".to_owned(),
1126 description: "Get the current weather for a location".to_owned(),
1127 parameters: json!({
1128 "type": "object",
1129 "properties": {
1130 "location": {
1131 "type": "string",
1132 "description": "The location to get the weather for, e.g. San Francisco, CA"
1133 },
1134 "format": {
1135 "type": "string",
1136 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1137 "enum": ["celsius", "fahrenheit"]
1138 }
1139 },
1140 "required": ["location", "format"]
1141 }),
1142 };
1143 let ollama_tool: ToolDefinition = internal_tool.into();
1145 assert_eq!(ollama_tool.type_field, "function");
1146 assert_eq!(ollama_tool.function.name, "get_current_weather");
1147 assert_eq!(
1148 ollama_tool.function.description,
1149 "Get the current weather for a location"
1150 );
1151 let params = &ollama_tool.function.parameters;
1153 assert_eq!(params["properties"]["location"]["type"], "string");
1154 }
1155
1156 #[tokio::test]
1158 async fn test_chat_completion_with_thinking() {
1159 let sample_response = json!({
1160 "model": "qwen-thinking",
1161 "created_at": "2023-08-04T19:22:45.499127Z",
1162 "message": {
1163 "role": "assistant",
1164 "content": "The answer is 42.",
1165 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1166 "images": null,
1167 "tool_calls": []
1168 },
1169 "done": true,
1170 "total_duration": 8000000000u64,
1171 "load_duration": 6000000u64,
1172 "prompt_eval_count": 61u64,
1173 "prompt_eval_duration": 400000000u64,
1174 "eval_count": 468u64,
1175 "eval_duration": 7700000000u64
1176 });
1177
1178 let chat_resp: CompletionResponse =
1179 serde_json::from_value(sample_response).expect("Failed to deserialize");
1180
1181 if let Message::Assistant {
1183 thinking, content, ..
1184 } = &chat_resp.message
1185 {
1186 assert_eq!(
1187 thinking.as_ref().unwrap(),
1188 "Let me think about this carefully. The question asks for the meaning of life..."
1189 );
1190 assert_eq!(content, "The answer is 42.");
1191 } else {
1192 panic!("Expected Assistant message");
1193 }
1194 }
1195
1196 #[tokio::test]
1198 async fn test_chat_completion_without_thinking() {
1199 let sample_response = json!({
1200 "model": "llama3.2",
1201 "created_at": "2023-08-04T19:22:45.499127Z",
1202 "message": {
1203 "role": "assistant",
1204 "content": "Hello!",
1205 "images": null,
1206 "tool_calls": []
1207 },
1208 "done": true,
1209 "total_duration": 8000000000u64,
1210 "load_duration": 6000000u64,
1211 "prompt_eval_count": 10u64,
1212 "prompt_eval_duration": 400000000u64,
1213 "eval_count": 5u64,
1214 "eval_duration": 7700000000u64
1215 });
1216
1217 let chat_resp: CompletionResponse =
1218 serde_json::from_value(sample_response).expect("Failed to deserialize");
1219
1220 if let Message::Assistant {
1222 thinking, content, ..
1223 } = &chat_resp.message
1224 {
1225 assert!(thinking.is_none());
1226 assert_eq!(content, "Hello!");
1227 } else {
1228 panic!("Expected Assistant message");
1229 }
1230 }
1231
1232 #[test]
1234 fn test_streaming_response_with_thinking() {
1235 let sample_chunk = json!({
1236 "model": "qwen-thinking",
1237 "created_at": "2023-08-04T19:22:45.499127Z",
1238 "message": {
1239 "role": "assistant",
1240 "content": "",
1241 "thinking": "Analyzing the problem...",
1242 "images": null,
1243 "tool_calls": []
1244 },
1245 "done": false
1246 });
1247
1248 let chunk: CompletionResponse =
1249 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1250
1251 if let Message::Assistant {
1252 thinking, content, ..
1253 } = &chunk.message
1254 {
1255 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1256 assert_eq!(content, "");
1257 } else {
1258 panic!("Expected Assistant message");
1259 }
1260 }
1261
1262 #[test]
1264 fn test_message_conversion_with_thinking() {
1265 let reasoning_content = crate::message::Reasoning {
1267 id: None,
1268 reasoning: vec!["Step 1: Consider the problem".to_string()],
1269 signature: None,
1270 };
1271
1272 let internal_msg = crate::message::Message::Assistant {
1273 id: None,
1274 content: crate::OneOrMany::many(vec![
1275 crate::message::AssistantContent::Reasoning(reasoning_content),
1276 crate::message::AssistantContent::Text(crate::message::Text {
1277 text: "The answer is X".to_string(),
1278 }),
1279 ])
1280 .unwrap(),
1281 };
1282
1283 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1285 assert_eq!(provider_msgs.len(), 1);
1286
1287 if let Message::Assistant {
1288 thinking, content, ..
1289 } = &provider_msgs[0]
1290 {
1291 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1292 assert_eq!(content, "The answer is X");
1293 } else {
1294 panic!("Expected Assistant message with thinking");
1295 }
1296 }
1297
1298 #[test]
1300 fn test_empty_thinking_content() {
1301 let sample_response = json!({
1302 "model": "llama3.2",
1303 "created_at": "2023-08-04T19:22:45.499127Z",
1304 "message": {
1305 "role": "assistant",
1306 "content": "Response",
1307 "thinking": "",
1308 "images": null,
1309 "tool_calls": []
1310 },
1311 "done": true,
1312 "total_duration": 8000000000u64,
1313 "load_duration": 6000000u64,
1314 "prompt_eval_count": 10u64,
1315 "prompt_eval_duration": 400000000u64,
1316 "eval_count": 5u64,
1317 "eval_duration": 7700000000u64
1318 });
1319
1320 let chat_resp: CompletionResponse =
1321 serde_json::from_value(sample_response).expect("Failed to deserialize");
1322
1323 if let Message::Assistant {
1324 thinking, content, ..
1325 } = &chat_resp.message
1326 {
1327 assert_eq!(thinking.as_ref().unwrap(), "");
1329 assert_eq!(content, "Response");
1330 } else {
1331 panic!("Expected Assistant message");
1332 }
1333 }
1334
1335 #[test]
1337 fn test_thinking_with_tool_calls() {
1338 let sample_response = json!({
1339 "model": "qwen-thinking",
1340 "created_at": "2023-08-04T19:22:45.499127Z",
1341 "message": {
1342 "role": "assistant",
1343 "content": "Let me check the weather.",
1344 "thinking": "User wants weather info, I should use the weather tool",
1345 "images": null,
1346 "tool_calls": [
1347 {
1348 "type": "function",
1349 "function": {
1350 "name": "get_weather",
1351 "arguments": {
1352 "location": "San Francisco"
1353 }
1354 }
1355 }
1356 ]
1357 },
1358 "done": true,
1359 "total_duration": 8000000000u64,
1360 "load_duration": 6000000u64,
1361 "prompt_eval_count": 30u64,
1362 "prompt_eval_duration": 400000000u64,
1363 "eval_count": 50u64,
1364 "eval_duration": 7700000000u64
1365 });
1366
1367 let chat_resp: CompletionResponse =
1368 serde_json::from_value(sample_response).expect("Failed to deserialize");
1369
1370 if let Message::Assistant {
1371 thinking,
1372 content,
1373 tool_calls,
1374 ..
1375 } = &chat_resp.message
1376 {
1377 assert_eq!(
1378 thinking.as_ref().unwrap(),
1379 "User wants weather info, I should use the weather tool"
1380 );
1381 assert_eq!(content, "Let me check the weather.");
1382 assert_eq!(tool_calls.len(), 1);
1383 assert_eq!(tool_calls[0].function.name, "get_weather");
1384 } else {
1385 panic!("Expected Assistant message with thinking and tool calls");
1386 }
1387 }
1388}