1use crate::client::{
35 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42 OneOrMany,
43 completion::{self, CompletionError, CompletionRequest},
44 embeddings::{self, EmbeddingError},
45 json_utils, message,
46 message::{ImageDetail, Text},
47 streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68 type Builder = OllamaBuilder;
69 const VERIFY_PATH: &'static str = "api/tags";
70}
71
72impl<H> Capabilities<H> for OllamaExt {
73 type Completion = Capable<CompletionModel<H>>;
74 type Transcription = Nothing;
75 type Embeddings = Capable<EmbeddingModel<H>>;
76 type ModelListing = Nothing;
77 #[cfg(feature = "image")]
78 type ImageGeneration = Nothing;
79
80 #[cfg(feature = "audio")]
81 type AudioGeneration = Nothing;
82}
83
84impl DebugExt for OllamaExt {}
85
86impl ProviderBuilder for OllamaBuilder {
87 type Extension<H>
88 = OllamaExt
89 where
90 H: HttpClientExt;
91 type ApiKey = Nothing;
92
93 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
94
95 fn build<H>(
96 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97 ) -> http_client::Result<Self::Extension<H>>
98 where
99 H: HttpClientExt,
100 {
101 Ok(OllamaExt)
102 }
103}
104
105pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
106pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
107
108impl ProviderClient for Client {
109 type Input = Nothing;
110
111 fn from_env() -> Self {
112 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
113
114 Self::builder()
115 .api_key(Nothing)
116 .base_url(&api_base)
117 .build()
118 .unwrap()
119 }
120
121 fn from_val(_: Self::Input) -> Self {
122 Self::builder().api_key(Nothing).build().unwrap()
123 }
124}
125
126#[derive(Debug, Deserialize)]
129struct ApiErrorResponse {
130 message: String,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum ApiResponse<T> {
136 Ok(T),
137 Err(ApiErrorResponse),
138}
139
140pub const ALL_MINILM: &str = "all-minilm";
143pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
144
145fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
146 match identifier {
147 ALL_MINILM => Some(384),
148 NOMIC_EMBED_TEXT => Some(768),
149 _ => None,
150 }
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct EmbeddingResponse {
155 pub model: String,
156 pub embeddings: Vec<Vec<f64>>,
157 #[serde(default)]
158 pub total_duration: Option<u64>,
159 #[serde(default)]
160 pub load_duration: Option<u64>,
161 #[serde(default)]
162 pub prompt_eval_count: Option<u64>,
163}
164
165impl From<ApiErrorResponse> for EmbeddingError {
166 fn from(err: ApiErrorResponse) -> Self {
167 EmbeddingError::ProviderError(err.message)
168 }
169}
170
171impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
172 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
173 match value {
174 ApiResponse::Ok(response) => Ok(response),
175 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
176 }
177 }
178}
179
180#[derive(Clone)]
183pub struct EmbeddingModel<T = reqwest::Client> {
184 client: Client<T>,
185 pub model: String,
186 ndims: usize,
187}
188
189impl<T> EmbeddingModel<T> {
190 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
191 Self {
192 client,
193 model: model.into(),
194 ndims,
195 }
196 }
197
198 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
199 Self {
200 client,
201 model: model.into(),
202 ndims,
203 }
204 }
205}
206
207impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
208where
209 T: HttpClientExt + Clone + 'static,
210{
211 type Client = Client<T>;
212
213 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
214 let model = model.into();
215 let dims = dims
216 .or(model_dimensions_from_identifier(&model))
217 .unwrap_or_default();
218 Self::new(client.clone(), model, dims)
219 }
220
221 const MAX_DOCUMENTS: usize = 1024;
222 fn ndims(&self) -> usize {
223 self.ndims
224 }
225
226 async fn embed_texts(
227 &self,
228 documents: impl IntoIterator<Item = String>,
229 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
230 let docs: Vec<String> = documents.into_iter().collect();
231
232 let body = serde_json::to_vec(&json!({
233 "model": self.model,
234 "input": docs
235 }))?;
236
237 let req = self
238 .client
239 .post("api/embed")?
240 .body(body)
241 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
242
243 let response = self.client.send(req).await?;
244
245 if !response.status().is_success() {
246 let text = http_client::text(response).await?;
247 return Err(EmbeddingError::ProviderError(text));
248 }
249
250 let bytes: Vec<u8> = response.into_body().await?;
251
252 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
253
254 if api_resp.embeddings.len() != docs.len() {
255 return Err(EmbeddingError::ResponseError(
256 "Number of returned embeddings does not match input".into(),
257 ));
258 }
259 Ok(api_resp
260 .embeddings
261 .into_iter()
262 .zip(docs.into_iter())
263 .map(|(vec, document)| embeddings::Embedding { document, vec })
264 .collect())
265 }
266}
267
268pub const LLAMA3_2: &str = "llama3.2";
271pub const LLAVA: &str = "llava";
272pub const MISTRAL: &str = "mistral";
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct CompletionResponse {
276 pub model: String,
277 pub created_at: String,
278 pub message: Message,
279 pub done: bool,
280 #[serde(default)]
281 pub done_reason: Option<String>,
282 #[serde(default)]
283 pub total_duration: Option<u64>,
284 #[serde(default)]
285 pub load_duration: Option<u64>,
286 #[serde(default)]
287 pub prompt_eval_count: Option<u64>,
288 #[serde(default)]
289 pub prompt_eval_duration: Option<u64>,
290 #[serde(default)]
291 pub eval_count: Option<u64>,
292 #[serde(default)]
293 pub eval_duration: Option<u64>,
294}
295impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
296 type Error = CompletionError;
297 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
298 match resp.message {
299 Message::Assistant {
301 content,
302 thinking,
303 tool_calls,
304 ..
305 } => {
306 let mut assistant_contents = Vec::new();
307 if !content.is_empty() {
309 assistant_contents.push(completion::AssistantContent::text(&content));
310 }
311 for tc in tool_calls.iter() {
314 assistant_contents.push(completion::AssistantContent::tool_call(
315 tc.function.name.clone(),
316 tc.function.name.clone(),
317 tc.function.arguments.clone(),
318 ));
319 }
320 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
321 CompletionError::ResponseError("No content provided".to_owned())
322 })?;
323 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
324 let completion_tokens = resp.eval_count.unwrap_or(0);
325
326 let raw_response = CompletionResponse {
327 model: resp.model,
328 created_at: resp.created_at,
329 done: resp.done,
330 done_reason: resp.done_reason,
331 total_duration: resp.total_duration,
332 load_duration: resp.load_duration,
333 prompt_eval_count: resp.prompt_eval_count,
334 prompt_eval_duration: resp.prompt_eval_duration,
335 eval_count: resp.eval_count,
336 eval_duration: resp.eval_duration,
337 message: Message::Assistant {
338 content,
339 thinking,
340 images: None,
341 name: None,
342 tool_calls,
343 },
344 };
345
346 Ok(completion::CompletionResponse {
347 choice,
348 usage: Usage {
349 input_tokens: prompt_tokens,
350 output_tokens: completion_tokens,
351 total_tokens: prompt_tokens + completion_tokens,
352 cached_input_tokens: 0,
353 },
354 raw_response,
355 message_id: None,
356 })
357 }
358 _ => Err(CompletionError::ResponseError(
359 "Chat response does not include an assistant message".into(),
360 )),
361 }
362 }
363}
364
365#[derive(Debug, Serialize, Deserialize)]
366pub(super) struct OllamaCompletionRequest {
367 model: String,
368 pub messages: Vec<Message>,
369 #[serde(skip_serializing_if = "Option::is_none")]
370 temperature: Option<f64>,
371 #[serde(skip_serializing_if = "Vec::is_empty")]
372 tools: Vec<ToolDefinition>,
373 pub stream: bool,
374 think: bool,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 max_tokens: Option<u64>,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 keep_alive: Option<String>,
379 #[serde(skip_serializing_if = "Option::is_none")]
380 format: Option<schemars::Schema>,
381 options: serde_json::Value,
382}
383
384impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
385 type Error = CompletionError;
386
387 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
388 let model = req.model.clone().unwrap_or_else(|| model.to_string());
389 if req.tool_choice.is_some() {
390 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
391 }
392 let mut partial_history = vec![];
394 if let Some(docs) = req.normalized_documents() {
395 partial_history.push(docs);
396 }
397 partial_history.extend(req.chat_history);
398
399 let mut full_history: Vec<Message> = match &req.preamble {
401 Some(preamble) => vec![Message::system(preamble)],
402 None => vec![],
403 };
404
405 full_history.extend(
407 partial_history
408 .into_iter()
409 .map(message::Message::try_into)
410 .collect::<Result<Vec<Vec<Message>>, _>>()?
411 .into_iter()
412 .flatten()
413 .collect::<Vec<_>>(),
414 );
415
416 let mut think = false;
417 let mut keep_alive: Option<String> = None;
418
419 let options = if let Some(mut extra) = req.additional_params {
420 if let Some(obj) = extra.as_object_mut() {
422 if let Some(think_val) = obj.remove("think") {
424 think = think_val.as_bool().ok_or_else(|| {
425 CompletionError::RequestError("`think` must be a bool".into())
426 })?;
427 }
428
429 if let Some(keep_alive_val) = obj.remove("keep_alive") {
431 keep_alive = Some(
432 keep_alive_val
433 .as_str()
434 .ok_or_else(|| {
435 CompletionError::RequestError(
436 "`keep_alive` must be a string".into(),
437 )
438 })?
439 .to_string(),
440 );
441 }
442 }
443
444 json_utils::merge(json!({ "temperature": req.temperature }), extra)
445 } else {
446 json!({ "temperature": req.temperature })
447 };
448
449 Ok(Self {
450 model: model.to_string(),
451 messages: full_history,
452 temperature: req.temperature,
453 max_tokens: req.max_tokens,
454 stream: false,
455 think,
456 keep_alive,
457 format: req.output_schema,
458 tools: req
459 .tools
460 .clone()
461 .into_iter()
462 .map(ToolDefinition::from)
463 .collect::<Vec<_>>(),
464 options,
465 })
466 }
467}
468
469#[derive(Clone)]
470pub struct CompletionModel<T = reqwest::Client> {
471 client: Client<T>,
472 pub model: String,
473}
474
475impl<T> CompletionModel<T> {
476 pub fn new(client: Client<T>, model: &str) -> Self {
477 Self {
478 client,
479 model: model.to_owned(),
480 }
481 }
482}
483
484#[derive(Clone, Serialize, Deserialize, Debug)]
487pub struct StreamingCompletionResponse {
488 pub done_reason: Option<String>,
489 pub total_duration: Option<u64>,
490 pub load_duration: Option<u64>,
491 pub prompt_eval_count: Option<u64>,
492 pub prompt_eval_duration: Option<u64>,
493 pub eval_count: Option<u64>,
494 pub eval_duration: Option<u64>,
495}
496
497impl GetTokenUsage for StreamingCompletionResponse {
498 fn token_usage(&self) -> Option<crate::completion::Usage> {
499 let mut usage = crate::completion::Usage::new();
500 let input_tokens = self.prompt_eval_count.unwrap_or_default();
501 let output_tokens = self.eval_count.unwrap_or_default();
502 usage.input_tokens = input_tokens;
503 usage.output_tokens = output_tokens;
504 usage.total_tokens = input_tokens + output_tokens;
505
506 Some(usage)
507 }
508}
509
510impl<T> completion::CompletionModel for CompletionModel<T>
511where
512 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
513{
514 type Response = CompletionResponse;
515 type StreamingResponse = StreamingCompletionResponse;
516
517 type Client = Client<T>;
518
519 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
520 Self::new(client.clone(), model.into().as_str())
521 }
522
523 async fn completion(
524 &self,
525 completion_request: CompletionRequest,
526 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
527 let span = if tracing::Span::current().is_disabled() {
528 info_span!(
529 target: "rig::completions",
530 "chat",
531 gen_ai.operation.name = "chat",
532 gen_ai.provider.name = "ollama",
533 gen_ai.request.model = self.model,
534 gen_ai.system_instructions = tracing::field::Empty,
535 gen_ai.response.id = tracing::field::Empty,
536 gen_ai.response.model = tracing::field::Empty,
537 gen_ai.usage.output_tokens = tracing::field::Empty,
538 gen_ai.usage.input_tokens = tracing::field::Empty,
539 )
540 } else {
541 tracing::Span::current()
542 };
543
544 span.record("gen_ai.system_instructions", &completion_request.preamble);
545 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
546
547 if tracing::enabled!(tracing::Level::TRACE) {
548 tracing::trace!(target: "rig::completions",
549 "Ollama completion request: {}",
550 serde_json::to_string_pretty(&request)?
551 );
552 }
553
554 let body = serde_json::to_vec(&request)?;
555
556 let req = self
557 .client
558 .post("api/chat")?
559 .body(body)
560 .map_err(http_client::Error::from)?;
561
562 let async_block = async move {
563 let response = self.client.send::<_, Bytes>(req).await?;
564 let status = response.status();
565 let response_body = response.into_body().into_future().await?.to_vec();
566
567 if !status.is_success() {
568 return Err(CompletionError::ProviderError(
569 String::from_utf8_lossy(&response_body).to_string(),
570 ));
571 }
572
573 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
574 let span = tracing::Span::current();
575 span.record("gen_ai.response.model_name", &response.model);
576 span.record(
577 "gen_ai.usage.input_tokens",
578 response.prompt_eval_count.unwrap_or_default(),
579 );
580 span.record(
581 "gen_ai.usage.output_tokens",
582 response.eval_count.unwrap_or_default(),
583 );
584
585 if tracing::enabled!(tracing::Level::TRACE) {
586 tracing::trace!(target: "rig::completions",
587 "Ollama completion response: {}",
588 serde_json::to_string_pretty(&response)?
589 );
590 }
591
592 let response: completion::CompletionResponse<CompletionResponse> =
593 response.try_into()?;
594
595 Ok(response)
596 };
597
598 tracing::Instrument::instrument(async_block, span).await
599 }
600
601 async fn stream(
602 &self,
603 request: CompletionRequest,
604 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
605 {
606 let span = if tracing::Span::current().is_disabled() {
607 info_span!(
608 target: "rig::completions",
609 "chat_streaming",
610 gen_ai.operation.name = "chat_streaming",
611 gen_ai.provider.name = "ollama",
612 gen_ai.request.model = self.model,
613 gen_ai.system_instructions = tracing::field::Empty,
614 gen_ai.response.id = tracing::field::Empty,
615 gen_ai.response.model = self.model,
616 gen_ai.usage.output_tokens = tracing::field::Empty,
617 gen_ai.usage.input_tokens = tracing::field::Empty,
618 )
619 } else {
620 tracing::Span::current()
621 };
622
623 span.record("gen_ai.system_instructions", &request.preamble);
624
625 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
626 request.stream = true;
627
628 if tracing::enabled!(tracing::Level::TRACE) {
629 tracing::trace!(target: "rig::completions",
630 "Ollama streaming completion request: {}",
631 serde_json::to_string_pretty(&request)?
632 );
633 }
634
635 let body = serde_json::to_vec(&request)?;
636
637 let req = self
638 .client
639 .post("api/chat")?
640 .body(body)
641 .map_err(http_client::Error::from)?;
642
643 let response = self.client.send_streaming(req).await?;
644 let status = response.status();
645 let mut byte_stream = response.into_body();
646
647 if !status.is_success() {
648 return Err(CompletionError::ProviderError(format!(
649 "Got error status code trying to send a request to Ollama: {status}"
650 )));
651 }
652
653 let stream = try_stream! {
654 let span = tracing::Span::current();
655 let mut tool_calls_final = Vec::new();
656 let mut text_response = String::new();
657 let mut thinking_response = String::new();
658
659 while let Some(chunk) = byte_stream.next().await {
660 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
661
662 for line in bytes.split(|&b| b == b'\n') {
663 if line.is_empty() {
664 continue;
665 }
666
667 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
668
669 let response: CompletionResponse = serde_json::from_slice(line)?;
670
671 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
672 if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
673 thinking_response += &thinking_content;
674 yield RawStreamingChoice::ReasoningDelta {
675 id: None,
676 reasoning: thinking_content,
677 };
678 }
679
680 if !content.is_empty() {
681 text_response += &content;
682 yield RawStreamingChoice::Message(content);
683 }
684
685 for tool_call in tool_calls {
686 tool_calls_final.push(tool_call.clone());
687 yield RawStreamingChoice::ToolCall(
688 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
689 );
690 }
691 }
692
693 if response.done {
694 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
695 span.record("gen_ai.usage.output_tokens", response.eval_count);
696 let message = Message::Assistant {
697 content: text_response.clone(),
698 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
699 images: None,
700 name: None,
701 tool_calls: tool_calls_final.clone()
702 };
703 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
704 yield RawStreamingChoice::FinalResponse(
705 StreamingCompletionResponse {
706 total_duration: response.total_duration,
707 load_duration: response.load_duration,
708 prompt_eval_count: response.prompt_eval_count,
709 prompt_eval_duration: response.prompt_eval_duration,
710 eval_count: response.eval_count,
711 eval_duration: response.eval_duration,
712 done_reason: response.done_reason,
713 }
714 );
715 break;
716 }
717 }
718 }
719 }.instrument(span);
720
721 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
722 stream,
723 )))
724 }
725}
726
727#[derive(Clone, Debug, Deserialize, Serialize)]
731pub struct ToolDefinition {
732 #[serde(rename = "type")]
733 pub type_field: String, pub function: completion::ToolDefinition,
735}
736
737impl From<crate::completion::ToolDefinition> for ToolDefinition {
739 fn from(tool: crate::completion::ToolDefinition) -> Self {
740 ToolDefinition {
741 type_field: "function".to_owned(),
742 function: completion::ToolDefinition {
743 name: tool.name,
744 description: tool.description,
745 parameters: tool.parameters,
746 },
747 }
748 }
749}
750
751#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
752pub struct ToolCall {
753 #[serde(default, rename = "type")]
754 pub r#type: ToolType,
755 pub function: Function,
756}
757#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
758#[serde(rename_all = "lowercase")]
759pub enum ToolType {
760 #[default]
761 Function,
762}
763#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
764pub struct Function {
765 pub name: String,
766 pub arguments: Value,
767}
768
769#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
772#[serde(tag = "role", rename_all = "lowercase")]
773pub enum Message {
774 User {
775 content: String,
776 #[serde(skip_serializing_if = "Option::is_none")]
777 images: Option<Vec<String>>,
778 #[serde(skip_serializing_if = "Option::is_none")]
779 name: Option<String>,
780 },
781 Assistant {
782 #[serde(default)]
783 content: String,
784 #[serde(skip_serializing_if = "Option::is_none")]
785 thinking: Option<String>,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 images: Option<Vec<String>>,
788 #[serde(skip_serializing_if = "Option::is_none")]
789 name: Option<String>,
790 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
791 tool_calls: Vec<ToolCall>,
792 },
793 System {
794 content: String,
795 #[serde(skip_serializing_if = "Option::is_none")]
796 images: Option<Vec<String>>,
797 #[serde(skip_serializing_if = "Option::is_none")]
798 name: Option<String>,
799 },
800 #[serde(rename = "tool")]
801 ToolResult {
802 #[serde(rename = "tool_name")]
803 name: String,
804 content: String,
805 },
806}
807
808impl TryFrom<crate::message::Message> for Vec<Message> {
814 type Error = crate::message::MessageError;
815 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
816 use crate::message::Message as InternalMessage;
817 match internal_msg {
818 InternalMessage::User { content, .. } => {
819 let (tool_results, other_content): (Vec<_>, Vec<_>) =
820 content.into_iter().partition(|content| {
821 matches!(content, crate::message::UserContent::ToolResult(_))
822 });
823
824 if !tool_results.is_empty() {
825 tool_results
826 .into_iter()
827 .map(|content| match content {
828 crate::message::UserContent::ToolResult(
829 crate::message::ToolResult { id, content, .. },
830 ) => {
831 let content_string = content
833 .into_iter()
834 .map(|content| match content {
835 crate::message::ToolResultContent::Text(text) => text.text,
836 _ => "[Non-text content]".to_string(),
837 })
838 .collect::<Vec<_>>()
839 .join("\n");
840
841 Ok::<_, crate::message::MessageError>(Message::ToolResult {
842 name: id,
843 content: content_string,
844 })
845 }
846 _ => unreachable!(),
847 })
848 .collect::<Result<Vec<_>, _>>()
849 } else {
850 let (texts, images) = other_content.into_iter().fold(
852 (Vec::new(), Vec::new()),
853 |(mut texts, mut images), content| {
854 match content {
855 crate::message::UserContent::Text(crate::message::Text {
856 text,
857 }) => texts.push(text),
858 crate::message::UserContent::Image(crate::message::Image {
859 data: DocumentSourceKind::Base64(data),
860 ..
861 }) => images.push(data),
862 crate::message::UserContent::Document(
863 crate::message::Document {
864 data:
865 DocumentSourceKind::Base64(data)
866 | DocumentSourceKind::String(data),
867 ..
868 },
869 ) => texts.push(data),
870 _ => {} }
872 (texts, images)
873 },
874 );
875
876 Ok(vec![Message::User {
877 content: texts.join(" "),
878 images: if images.is_empty() {
879 None
880 } else {
881 Some(
882 images
883 .into_iter()
884 .map(|x| x.to_string())
885 .collect::<Vec<String>>(),
886 )
887 },
888 name: None,
889 }])
890 }
891 }
892 InternalMessage::Assistant { content, .. } => {
893 let mut thinking: Option<String> = None;
894 let mut text_content = Vec::new();
895 let mut tool_calls = Vec::new();
896
897 for content in content.into_iter() {
898 match content {
899 crate::message::AssistantContent::Text(text) => {
900 text_content.push(text.text)
901 }
902 crate::message::AssistantContent::ToolCall(tool_call) => {
903 tool_calls.push(tool_call)
904 }
905 crate::message::AssistantContent::Reasoning(reasoning) => {
906 let display = reasoning.display_text();
907 if !display.is_empty() {
908 thinking = Some(display);
909 }
910 }
911 crate::message::AssistantContent::Image(_) => {
912 return Err(crate::message::MessageError::ConversionError(
913 "Ollama currently doesn't support images.".into(),
914 ));
915 }
916 }
917 }
918
919 Ok(vec![Message::Assistant {
922 content: text_content.join(" "),
923 thinking,
924 images: None,
925 name: None,
926 tool_calls: tool_calls
927 .into_iter()
928 .map(|tool_call| tool_call.into())
929 .collect::<Vec<_>>(),
930 }])
931 }
932 }
933 }
934}
935
936impl From<Message> for crate::completion::Message {
939 fn from(msg: Message) -> Self {
940 match msg {
941 Message::User { content, .. } => crate::completion::Message::User {
942 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
943 text: content,
944 })),
945 },
946 Message::Assistant {
947 content,
948 tool_calls,
949 ..
950 } => {
951 let mut assistant_contents =
952 vec![crate::completion::message::AssistantContent::Text(Text {
953 text: content,
954 })];
955 for tc in tool_calls {
956 assistant_contents.push(
957 crate::completion::message::AssistantContent::tool_call(
958 tc.function.name.clone(),
959 tc.function.name,
960 tc.function.arguments,
961 ),
962 );
963 }
964 crate::completion::Message::Assistant {
965 id: None,
966 content: OneOrMany::many(assistant_contents).unwrap(),
967 }
968 }
969 Message::System { content, .. } => crate::completion::Message::User {
971 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
972 text: content,
973 })),
974 },
975 Message::ToolResult { name, content } => crate::completion::Message::User {
976 content: OneOrMany::one(message::UserContent::tool_result(
977 name,
978 OneOrMany::one(message::ToolResultContent::text(content)),
979 )),
980 },
981 }
982 }
983}
984
985impl Message {
986 pub fn system(content: &str) -> Self {
988 Message::System {
989 content: content.to_owned(),
990 images: None,
991 name: None,
992 }
993 }
994}
995
996impl From<crate::message::ToolCall> for ToolCall {
999 fn from(tool_call: crate::message::ToolCall) -> Self {
1000 Self {
1001 r#type: ToolType::Function,
1002 function: Function {
1003 name: tool_call.function.name,
1004 arguments: tool_call.function.arguments,
1005 },
1006 }
1007 }
1008}
1009
1010#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1011pub struct SystemContent {
1012 #[serde(default)]
1013 r#type: SystemContentType,
1014 text: String,
1015}
1016
1017#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1018#[serde(rename_all = "lowercase")]
1019pub enum SystemContentType {
1020 #[default]
1021 Text,
1022}
1023
1024impl From<String> for SystemContent {
1025 fn from(s: String) -> Self {
1026 SystemContent {
1027 r#type: SystemContentType::default(),
1028 text: s,
1029 }
1030 }
1031}
1032
1033impl FromStr for SystemContent {
1034 type Err = std::convert::Infallible;
1035 fn from_str(s: &str) -> Result<Self, Self::Err> {
1036 Ok(SystemContent {
1037 r#type: SystemContentType::default(),
1038 text: s.to_string(),
1039 })
1040 }
1041}
1042
1043#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1044pub struct AssistantContent {
1045 pub text: String,
1046}
1047
1048impl FromStr for AssistantContent {
1049 type Err = std::convert::Infallible;
1050 fn from_str(s: &str) -> Result<Self, Self::Err> {
1051 Ok(AssistantContent { text: s.to_owned() })
1052 }
1053}
1054
1055#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1056#[serde(tag = "type", rename_all = "lowercase")]
1057pub enum UserContent {
1058 Text { text: String },
1059 Image { image_url: ImageUrl },
1060 }
1062
1063impl FromStr for UserContent {
1064 type Err = std::convert::Infallible;
1065 fn from_str(s: &str) -> Result<Self, Self::Err> {
1066 Ok(UserContent::Text { text: s.to_owned() })
1067 }
1068}
1069
1070#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1071pub struct ImageUrl {
1072 pub url: String,
1073 #[serde(default)]
1074 pub detail: ImageDetail,
1075}
1076
1077#[cfg(test)]
1082mod tests {
1083 use super::*;
1084 use serde_json::json;
1085
1086 #[tokio::test]
1088 async fn test_chat_completion() {
1089 let sample_chat_response = json!({
1091 "model": "llama3.2",
1092 "created_at": "2023-08-04T19:22:45.499127Z",
1093 "message": {
1094 "role": "assistant",
1095 "content": "The sky is blue because of Rayleigh scattering.",
1096 "images": null,
1097 "tool_calls": [
1098 {
1099 "type": "function",
1100 "function": {
1101 "name": "get_current_weather",
1102 "arguments": {
1103 "location": "San Francisco, CA",
1104 "format": "celsius"
1105 }
1106 }
1107 }
1108 ]
1109 },
1110 "done": true,
1111 "total_duration": 8000000000u64,
1112 "load_duration": 6000000u64,
1113 "prompt_eval_count": 61u64,
1114 "prompt_eval_duration": 400000000u64,
1115 "eval_count": 468u64,
1116 "eval_duration": 7700000000u64
1117 });
1118 let sample_text = sample_chat_response.to_string();
1119
1120 let chat_resp: CompletionResponse =
1121 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1122 let conv: completion::CompletionResponse<CompletionResponse> =
1123 chat_resp.try_into().unwrap();
1124 assert!(
1125 !conv.choice.is_empty(),
1126 "Expected non-empty choice in chat response"
1127 );
1128 }
1129
1130 #[test]
1132 fn test_message_conversion() {
1133 let provider_msg = Message::User {
1135 content: "Test message".to_owned(),
1136 images: None,
1137 name: None,
1138 };
1139 let comp_msg: crate::completion::Message = provider_msg.into();
1141 match comp_msg {
1142 crate::completion::Message::User { content } => {
1143 let first_content = content.first();
1145 match first_content {
1147 crate::completion::message::UserContent::Text(text_struct) => {
1148 assert_eq!(text_struct.text, "Test message");
1149 }
1150 _ => panic!("Expected text content in conversion"),
1151 }
1152 }
1153 _ => panic!("Conversion from provider Message to completion Message failed"),
1154 }
1155 }
1156
1157 #[test]
1159 fn test_tool_definition_conversion() {
1160 let internal_tool = crate::completion::ToolDefinition {
1162 name: "get_current_weather".to_owned(),
1163 description: "Get the current weather for a location".to_owned(),
1164 parameters: json!({
1165 "type": "object",
1166 "properties": {
1167 "location": {
1168 "type": "string",
1169 "description": "The location to get the weather for, e.g. San Francisco, CA"
1170 },
1171 "format": {
1172 "type": "string",
1173 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1174 "enum": ["celsius", "fahrenheit"]
1175 }
1176 },
1177 "required": ["location", "format"]
1178 }),
1179 };
1180 let ollama_tool: ToolDefinition = internal_tool.into();
1182 assert_eq!(ollama_tool.type_field, "function");
1183 assert_eq!(ollama_tool.function.name, "get_current_weather");
1184 assert_eq!(
1185 ollama_tool.function.description,
1186 "Get the current weather for a location"
1187 );
1188 let params = &ollama_tool.function.parameters;
1190 assert_eq!(params["properties"]["location"]["type"], "string");
1191 }
1192
1193 #[tokio::test]
1195 async fn test_chat_completion_with_thinking() {
1196 let sample_response = json!({
1197 "model": "qwen-thinking",
1198 "created_at": "2023-08-04T19:22:45.499127Z",
1199 "message": {
1200 "role": "assistant",
1201 "content": "The answer is 42.",
1202 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1203 "images": null,
1204 "tool_calls": []
1205 },
1206 "done": true,
1207 "total_duration": 8000000000u64,
1208 "load_duration": 6000000u64,
1209 "prompt_eval_count": 61u64,
1210 "prompt_eval_duration": 400000000u64,
1211 "eval_count": 468u64,
1212 "eval_duration": 7700000000u64
1213 });
1214
1215 let chat_resp: CompletionResponse =
1216 serde_json::from_value(sample_response).expect("Failed to deserialize");
1217
1218 if let Message::Assistant {
1220 thinking, content, ..
1221 } = &chat_resp.message
1222 {
1223 assert_eq!(
1224 thinking.as_ref().unwrap(),
1225 "Let me think about this carefully. The question asks for the meaning of life..."
1226 );
1227 assert_eq!(content, "The answer is 42.");
1228 } else {
1229 panic!("Expected Assistant message");
1230 }
1231 }
1232
1233 #[tokio::test]
1235 async fn test_chat_completion_without_thinking() {
1236 let sample_response = json!({
1237 "model": "llama3.2",
1238 "created_at": "2023-08-04T19:22:45.499127Z",
1239 "message": {
1240 "role": "assistant",
1241 "content": "Hello!",
1242 "images": null,
1243 "tool_calls": []
1244 },
1245 "done": true,
1246 "total_duration": 8000000000u64,
1247 "load_duration": 6000000u64,
1248 "prompt_eval_count": 10u64,
1249 "prompt_eval_duration": 400000000u64,
1250 "eval_count": 5u64,
1251 "eval_duration": 7700000000u64
1252 });
1253
1254 let chat_resp: CompletionResponse =
1255 serde_json::from_value(sample_response).expect("Failed to deserialize");
1256
1257 if let Message::Assistant {
1259 thinking, content, ..
1260 } = &chat_resp.message
1261 {
1262 assert!(thinking.is_none());
1263 assert_eq!(content, "Hello!");
1264 } else {
1265 panic!("Expected Assistant message");
1266 }
1267 }
1268
1269 #[test]
1271 fn test_streaming_response_with_thinking() {
1272 let sample_chunk = json!({
1273 "model": "qwen-thinking",
1274 "created_at": "2023-08-04T19:22:45.499127Z",
1275 "message": {
1276 "role": "assistant",
1277 "content": "",
1278 "thinking": "Analyzing the problem...",
1279 "images": null,
1280 "tool_calls": []
1281 },
1282 "done": false
1283 });
1284
1285 let chunk: CompletionResponse =
1286 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1287
1288 if let Message::Assistant {
1289 thinking, content, ..
1290 } = &chunk.message
1291 {
1292 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1293 assert_eq!(content, "");
1294 } else {
1295 panic!("Expected Assistant message");
1296 }
1297 }
1298
1299 #[test]
1301 fn test_message_conversion_with_thinking() {
1302 let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1304
1305 let internal_msg = crate::message::Message::Assistant {
1306 id: None,
1307 content: crate::OneOrMany::many(vec![
1308 crate::message::AssistantContent::Reasoning(reasoning_content),
1309 crate::message::AssistantContent::Text(crate::message::Text {
1310 text: "The answer is X".to_string(),
1311 }),
1312 ])
1313 .unwrap(),
1314 };
1315
1316 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1318 assert_eq!(provider_msgs.len(), 1);
1319
1320 if let Message::Assistant {
1321 thinking, content, ..
1322 } = &provider_msgs[0]
1323 {
1324 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1325 assert_eq!(content, "The answer is X");
1326 } else {
1327 panic!("Expected Assistant message with thinking");
1328 }
1329 }
1330
1331 #[test]
1333 fn test_empty_thinking_content() {
1334 let sample_response = json!({
1335 "model": "llama3.2",
1336 "created_at": "2023-08-04T19:22:45.499127Z",
1337 "message": {
1338 "role": "assistant",
1339 "content": "Response",
1340 "thinking": "",
1341 "images": null,
1342 "tool_calls": []
1343 },
1344 "done": true,
1345 "total_duration": 8000000000u64,
1346 "load_duration": 6000000u64,
1347 "prompt_eval_count": 10u64,
1348 "prompt_eval_duration": 400000000u64,
1349 "eval_count": 5u64,
1350 "eval_duration": 7700000000u64
1351 });
1352
1353 let chat_resp: CompletionResponse =
1354 serde_json::from_value(sample_response).expect("Failed to deserialize");
1355
1356 if let Message::Assistant {
1357 thinking, content, ..
1358 } = &chat_resp.message
1359 {
1360 assert_eq!(thinking.as_ref().unwrap(), "");
1362 assert_eq!(content, "Response");
1363 } else {
1364 panic!("Expected Assistant message");
1365 }
1366 }
1367
1368 #[test]
1370 fn test_thinking_with_tool_calls() {
1371 let sample_response = json!({
1372 "model": "qwen-thinking",
1373 "created_at": "2023-08-04T19:22:45.499127Z",
1374 "message": {
1375 "role": "assistant",
1376 "content": "Let me check the weather.",
1377 "thinking": "User wants weather info, I should use the weather tool",
1378 "images": null,
1379 "tool_calls": [
1380 {
1381 "type": "function",
1382 "function": {
1383 "name": "get_weather",
1384 "arguments": {
1385 "location": "San Francisco"
1386 }
1387 }
1388 }
1389 ]
1390 },
1391 "done": true,
1392 "total_duration": 8000000000u64,
1393 "load_duration": 6000000u64,
1394 "prompt_eval_count": 30u64,
1395 "prompt_eval_duration": 400000000u64,
1396 "eval_count": 50u64,
1397 "eval_duration": 7700000000u64
1398 });
1399
1400 let chat_resp: CompletionResponse =
1401 serde_json::from_value(sample_response).expect("Failed to deserialize");
1402
1403 if let Message::Assistant {
1404 thinking,
1405 content,
1406 tool_calls,
1407 ..
1408 } = &chat_resp.message
1409 {
1410 assert_eq!(
1411 thinking.as_ref().unwrap(),
1412 "User wants weather info, I should use the weather tool"
1413 );
1414 assert_eq!(content, "Let me check the weather.");
1415 assert_eq!(tool_calls.len(), 1);
1416 assert_eq!(tool_calls[0].function.name, "get_weather");
1417 } else {
1418 panic!("Expected Assistant message with thinking and tool calls");
1419 }
1420 }
1421
1422 #[test]
1424 fn test_completion_request_with_think_param() {
1425 use crate::OneOrMany;
1426 use crate::completion::Message as CompletionMessage;
1427 use crate::message::{Text, UserContent};
1428
1429 let completion_request = CompletionRequest {
1431 model: None,
1432 preamble: Some("You are a helpful assistant.".to_string()),
1433 chat_history: OneOrMany::one(CompletionMessage::User {
1434 content: OneOrMany::one(UserContent::Text(Text {
1435 text: "What is 2 + 2?".to_string(),
1436 })),
1437 }),
1438 documents: vec![],
1439 tools: vec![],
1440 temperature: Some(0.7),
1441 max_tokens: Some(1024),
1442 tool_choice: None,
1443 additional_params: Some(json!({
1444 "think": true,
1445 "keep_alive": "-1m",
1446 "num_ctx": 4096
1447 })),
1448 output_schema: None,
1449 };
1450
1451 let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1453 .expect("Failed to create Ollama request");
1454
1455 let serialized =
1457 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1458
1459 let expected = json!({
1465 "model": "qwen3:8b",
1466 "messages": [
1467 {
1468 "role": "system",
1469 "content": "You are a helpful assistant."
1470 },
1471 {
1472 "role": "user",
1473 "content": "What is 2 + 2?"
1474 }
1475 ],
1476 "temperature": 0.7,
1477 "stream": false,
1478 "think": true,
1479 "max_tokens": 1024,
1480 "keep_alive": "-1m",
1481 "options": {
1482 "temperature": 0.7,
1483 "num_ctx": 4096
1484 }
1485 });
1486
1487 assert_eq!(serialized, expected);
1488 }
1489
1490 #[test]
1492 fn test_completion_request_with_think_false_default() {
1493 use crate::OneOrMany;
1494 use crate::completion::Message as CompletionMessage;
1495 use crate::message::{Text, UserContent};
1496
1497 let completion_request = CompletionRequest {
1499 model: None,
1500 preamble: Some("You are a helpful assistant.".to_string()),
1501 chat_history: OneOrMany::one(CompletionMessage::User {
1502 content: OneOrMany::one(UserContent::Text(Text {
1503 text: "Hello!".to_string(),
1504 })),
1505 }),
1506 documents: vec![],
1507 tools: vec![],
1508 temperature: Some(0.5),
1509 max_tokens: None,
1510 tool_choice: None,
1511 additional_params: None,
1512 output_schema: None,
1513 };
1514
1515 let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1517 .expect("Failed to create Ollama request");
1518
1519 let serialized =
1521 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1522
1523 let expected = json!({
1525 "model": "llama3.2",
1526 "messages": [
1527 {
1528 "role": "system",
1529 "content": "You are a helpful assistant."
1530 },
1531 {
1532 "role": "user",
1533 "content": "Hello!"
1534 }
1535 ],
1536 "temperature": 0.5,
1537 "stream": false,
1538 "think": false,
1539 "options": {
1540 "temperature": 0.5
1541 }
1542 });
1543
1544 assert_eq!(serialized, expected);
1545 }
1546
1547 #[test]
1548 fn test_completion_request_with_output_schema() {
1549 use crate::OneOrMany;
1550 use crate::completion::Message as CompletionMessage;
1551 use crate::message::{Text, UserContent};
1552
1553 let schema: schemars::Schema = serde_json::from_value(json!({
1554 "type": "object",
1555 "properties": {
1556 "age": { "type": "integer" },
1557 "available": { "type": "boolean" }
1558 },
1559 "required": ["age", "available"]
1560 }))
1561 .expect("Failed to parse schema");
1562
1563 let completion_request = CompletionRequest {
1564 model: Some("llama3.1".to_string()),
1565 preamble: None,
1566 chat_history: OneOrMany::one(CompletionMessage::User {
1567 content: OneOrMany::one(UserContent::Text(Text {
1568 text: "How old is Ollama?".to_string(),
1569 })),
1570 }),
1571 documents: vec![],
1572 tools: vec![],
1573 temperature: None,
1574 max_tokens: None,
1575 tool_choice: None,
1576 additional_params: None,
1577 output_schema: Some(schema),
1578 };
1579
1580 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1581 .expect("Failed to create Ollama request");
1582
1583 let serialized =
1584 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1585
1586 let format = serialized
1587 .get("format")
1588 .expect("format field should be present");
1589 assert_eq!(
1590 *format,
1591 json!({
1592 "type": "object",
1593 "properties": {
1594 "age": { "type": "integer" },
1595 "available": { "type": "boolean" }
1596 },
1597 "required": ["age", "available"]
1598 })
1599 );
1600 }
1601
1602 #[test]
1603 fn test_completion_request_without_output_schema() {
1604 use crate::OneOrMany;
1605 use crate::completion::Message as CompletionMessage;
1606 use crate::message::{Text, UserContent};
1607
1608 let completion_request = CompletionRequest {
1609 model: Some("llama3.1".to_string()),
1610 preamble: None,
1611 chat_history: OneOrMany::one(CompletionMessage::User {
1612 content: OneOrMany::one(UserContent::Text(Text {
1613 text: "Hello!".to_string(),
1614 })),
1615 }),
1616 documents: vec![],
1617 tools: vec![],
1618 temperature: None,
1619 max_tokens: None,
1620 tool_choice: None,
1621 additional_params: None,
1622 output_schema: None,
1623 };
1624
1625 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1626 .expect("Failed to create Ollama request");
1627
1628 let serialized =
1629 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1630
1631 assert!(
1632 serialized.get("format").is_none(),
1633 "format field should be absent when output_schema is None"
1634 );
1635 }
1636
1637 #[test]
1638 fn test_client_initialization() {
1639 let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1640 let _client_from_builder = crate::providers::ollama::Client::builder()
1641 .api_key(Nothing)
1642 .build()
1643 .expect("Client::builder() failed");
1644 }
1645}