1use crate::client::{
35 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42 OneOrMany,
43 completion::{self, CompletionError, CompletionRequest},
44 embeddings::{self, EmbeddingError},
45 json_utils, message,
46 message::{ImageDetail, Text},
47 streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68 type Builder = OllamaBuilder;
69 const VERIFY_PATH: &'static str = "api/tags";
70}
71
72impl<H> Capabilities<H> for OllamaExt {
73 type Completion = Capable<CompletionModel<H>>;
74 type Transcription = Nothing;
75 type Embeddings = Capable<EmbeddingModel<H>>;
76 type ModelListing = Nothing;
77 #[cfg(feature = "image")]
78 type ImageGeneration = Nothing;
79
80 #[cfg(feature = "audio")]
81 type AudioGeneration = Nothing;
82}
83
84impl DebugExt for OllamaExt {}
85
86impl ProviderBuilder for OllamaBuilder {
87 type Extension<H>
88 = OllamaExt
89 where
90 H: HttpClientExt;
91 type ApiKey = Nothing;
92
93 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
94
95 fn build<H>(
96 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97 ) -> http_client::Result<Self::Extension<H>>
98 where
99 H: HttpClientExt,
100 {
101 Ok(OllamaExt)
102 }
103}
104
105pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
106pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
107
108impl ProviderClient for Client {
109 type Input = Nothing;
110
111 fn from_env() -> Self {
112 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
113
114 Self::builder()
115 .api_key(Nothing)
116 .base_url(&api_base)
117 .build()
118 .unwrap()
119 }
120
121 fn from_val(_: Self::Input) -> Self {
122 Self::builder().api_key(Nothing).build().unwrap()
123 }
124}
125
126#[derive(Debug, Deserialize)]
129struct ApiErrorResponse {
130 message: String,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum ApiResponse<T> {
136 Ok(T),
137 Err(ApiErrorResponse),
138}
139
140pub const ALL_MINILM: &str = "all-minilm";
143pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
144
145fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
146 match identifier {
147 ALL_MINILM => Some(384),
148 NOMIC_EMBED_TEXT => Some(768),
149 _ => None,
150 }
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct EmbeddingResponse {
155 pub model: String,
156 pub embeddings: Vec<Vec<f64>>,
157 #[serde(default)]
158 pub total_duration: Option<u64>,
159 #[serde(default)]
160 pub load_duration: Option<u64>,
161 #[serde(default)]
162 pub prompt_eval_count: Option<u64>,
163}
164
165impl From<ApiErrorResponse> for EmbeddingError {
166 fn from(err: ApiErrorResponse) -> Self {
167 EmbeddingError::ProviderError(err.message)
168 }
169}
170
171impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
172 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
173 match value {
174 ApiResponse::Ok(response) => Ok(response),
175 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
176 }
177 }
178}
179
180#[derive(Clone)]
183pub struct EmbeddingModel<T = reqwest::Client> {
184 client: Client<T>,
185 pub model: String,
186 ndims: usize,
187}
188
189impl<T> EmbeddingModel<T> {
190 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
191 Self {
192 client,
193 model: model.into(),
194 ndims,
195 }
196 }
197
198 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
199 Self {
200 client,
201 model: model.into(),
202 ndims,
203 }
204 }
205}
206
207impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
208where
209 T: HttpClientExt + Clone + 'static,
210{
211 type Client = Client<T>;
212
213 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
214 let model = model.into();
215 let dims = dims
216 .or(model_dimensions_from_identifier(&model))
217 .unwrap_or_default();
218 Self::new(client.clone(), model, dims)
219 }
220
221 const MAX_DOCUMENTS: usize = 1024;
222 fn ndims(&self) -> usize {
223 self.ndims
224 }
225
226 async fn embed_texts(
227 &self,
228 documents: impl IntoIterator<Item = String>,
229 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
230 let docs: Vec<String> = documents.into_iter().collect();
231
232 let body = serde_json::to_vec(&json!({
233 "model": self.model,
234 "input": docs
235 }))?;
236
237 let req = self
238 .client
239 .post("api/embed")?
240 .body(body)
241 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
242
243 let response = self.client.send(req).await?;
244
245 if !response.status().is_success() {
246 let text = http_client::text(response).await?;
247 return Err(EmbeddingError::ProviderError(text));
248 }
249
250 let bytes: Vec<u8> = response.into_body().await?;
251
252 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
253
254 if api_resp.embeddings.len() != docs.len() {
255 return Err(EmbeddingError::ResponseError(
256 "Number of returned embeddings does not match input".into(),
257 ));
258 }
259 Ok(api_resp
260 .embeddings
261 .into_iter()
262 .zip(docs.into_iter())
263 .map(|(vec, document)| embeddings::Embedding { document, vec })
264 .collect())
265 }
266}
267
268pub const LLAMA3_2: &str = "llama3.2";
271pub const LLAVA: &str = "llava";
272pub const MISTRAL: &str = "mistral";
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct CompletionResponse {
276 pub model: String,
277 pub created_at: String,
278 pub message: Message,
279 pub done: bool,
280 #[serde(default)]
281 pub done_reason: Option<String>,
282 #[serde(default)]
283 pub total_duration: Option<u64>,
284 #[serde(default)]
285 pub load_duration: Option<u64>,
286 #[serde(default)]
287 pub prompt_eval_count: Option<u64>,
288 #[serde(default)]
289 pub prompt_eval_duration: Option<u64>,
290 #[serde(default)]
291 pub eval_count: Option<u64>,
292 #[serde(default)]
293 pub eval_duration: Option<u64>,
294}
295impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
296 type Error = CompletionError;
297 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
298 match resp.message {
299 Message::Assistant {
301 content,
302 thinking,
303 tool_calls,
304 ..
305 } => {
306 let mut assistant_contents = Vec::new();
307 if !content.is_empty() {
309 assistant_contents.push(completion::AssistantContent::text(&content));
310 }
311 for tc in tool_calls.iter() {
314 assistant_contents.push(completion::AssistantContent::tool_call(
315 tc.function.name.clone(),
316 tc.function.name.clone(),
317 tc.function.arguments.clone(),
318 ));
319 }
320 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
321 CompletionError::ResponseError("No content provided".to_owned())
322 })?;
323 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
324 let completion_tokens = resp.eval_count.unwrap_or(0);
325
326 let raw_response = CompletionResponse {
327 model: resp.model,
328 created_at: resp.created_at,
329 done: resp.done,
330 done_reason: resp.done_reason,
331 total_duration: resp.total_duration,
332 load_duration: resp.load_duration,
333 prompt_eval_count: resp.prompt_eval_count,
334 prompt_eval_duration: resp.prompt_eval_duration,
335 eval_count: resp.eval_count,
336 eval_duration: resp.eval_duration,
337 message: Message::Assistant {
338 content,
339 thinking,
340 images: None,
341 name: None,
342 tool_calls,
343 },
344 };
345
346 Ok(completion::CompletionResponse {
347 choice,
348 usage: Usage {
349 input_tokens: prompt_tokens,
350 output_tokens: completion_tokens,
351 total_tokens: prompt_tokens + completion_tokens,
352 cached_input_tokens: 0,
353 },
354 raw_response,
355 message_id: None,
356 })
357 }
358 _ => Err(CompletionError::ResponseError(
359 "Chat response does not include an assistant message".into(),
360 )),
361 }
362 }
363}
364
365#[derive(Debug, Serialize, Deserialize)]
366pub(super) struct OllamaCompletionRequest {
367 model: String,
368 pub messages: Vec<Message>,
369 #[serde(skip_serializing_if = "Option::is_none")]
370 temperature: Option<f64>,
371 #[serde(skip_serializing_if = "Vec::is_empty")]
372 tools: Vec<ToolDefinition>,
373 pub stream: bool,
374 think: bool,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 max_tokens: Option<u64>,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 keep_alive: Option<String>,
379 #[serde(skip_serializing_if = "Option::is_none")]
380 format: Option<schemars::Schema>,
381 options: serde_json::Value,
382}
383
384impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
385 type Error = CompletionError;
386
387 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
388 let model = req.model.clone().unwrap_or_else(|| model.to_string());
389 if req.tool_choice.is_some() {
390 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
391 }
392 let mut partial_history = vec![];
394 if let Some(docs) = req.normalized_documents() {
395 partial_history.push(docs);
396 }
397 partial_history.extend(req.chat_history);
398
399 let mut full_history: Vec<Message> = match &req.preamble {
401 Some(preamble) => vec![Message::system(preamble)],
402 None => vec![],
403 };
404
405 full_history.extend(
407 partial_history
408 .into_iter()
409 .map(message::Message::try_into)
410 .collect::<Result<Vec<Vec<Message>>, _>>()?
411 .into_iter()
412 .flatten()
413 .collect::<Vec<_>>(),
414 );
415
416 let mut think = false;
417 let mut keep_alive: Option<String> = None;
418
419 let options = if let Some(mut extra) = req.additional_params {
420 if let Some(obj) = extra.as_object_mut() {
422 if let Some(think_val) = obj.remove("think") {
424 think = think_val.as_bool().ok_or_else(|| {
425 CompletionError::RequestError("`think` must be a bool".into())
426 })?;
427 }
428
429 if let Some(keep_alive_val) = obj.remove("keep_alive") {
431 keep_alive = Some(
432 keep_alive_val
433 .as_str()
434 .ok_or_else(|| {
435 CompletionError::RequestError(
436 "`keep_alive` must be a string".into(),
437 )
438 })?
439 .to_string(),
440 );
441 }
442 }
443
444 json_utils::merge(json!({ "temperature": req.temperature }), extra)
445 } else {
446 json!({ "temperature": req.temperature })
447 };
448
449 Ok(Self {
450 model: model.to_string(),
451 messages: full_history,
452 temperature: req.temperature,
453 max_tokens: req.max_tokens,
454 stream: false,
455 think,
456 keep_alive,
457 format: req.output_schema,
458 tools: req
459 .tools
460 .clone()
461 .into_iter()
462 .map(ToolDefinition::from)
463 .collect::<Vec<_>>(),
464 options,
465 })
466 }
467}
468
469#[derive(Clone)]
470pub struct CompletionModel<T = reqwest::Client> {
471 client: Client<T>,
472 pub model: String,
473}
474
475impl<T> CompletionModel<T> {
476 pub fn new(client: Client<T>, model: &str) -> Self {
477 Self {
478 client,
479 model: model.to_owned(),
480 }
481 }
482}
483
484#[derive(Clone, Serialize, Deserialize, Debug)]
487pub struct StreamingCompletionResponse {
488 pub done_reason: Option<String>,
489 pub total_duration: Option<u64>,
490 pub load_duration: Option<u64>,
491 pub prompt_eval_count: Option<u64>,
492 pub prompt_eval_duration: Option<u64>,
493 pub eval_count: Option<u64>,
494 pub eval_duration: Option<u64>,
495}
496
497impl GetTokenUsage for StreamingCompletionResponse {
498 fn token_usage(&self) -> Option<crate::completion::Usage> {
499 let mut usage = crate::completion::Usage::new();
500 let input_tokens = self.prompt_eval_count.unwrap_or_default();
501 let output_tokens = self.eval_count.unwrap_or_default();
502 usage.input_tokens = input_tokens;
503 usage.output_tokens = output_tokens;
504 usage.total_tokens = input_tokens + output_tokens;
505
506 Some(usage)
507 }
508}
509
510impl<T> completion::CompletionModel for CompletionModel<T>
511where
512 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
513{
514 type Response = CompletionResponse;
515 type StreamingResponse = StreamingCompletionResponse;
516
517 type Client = Client<T>;
518
519 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
520 Self::new(client.clone(), model.into().as_str())
521 }
522
523 async fn completion(
524 &self,
525 completion_request: CompletionRequest,
526 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
527 let span = if tracing::Span::current().is_disabled() {
528 info_span!(
529 target: "rig::completions",
530 "chat",
531 gen_ai.operation.name = "chat",
532 gen_ai.provider.name = "ollama",
533 gen_ai.request.model = self.model,
534 gen_ai.system_instructions = tracing::field::Empty,
535 gen_ai.response.id = tracing::field::Empty,
536 gen_ai.response.model = tracing::field::Empty,
537 gen_ai.usage.output_tokens = tracing::field::Empty,
538 gen_ai.usage.input_tokens = tracing::field::Empty,
539 gen_ai.usage.cached_tokens = tracing::field::Empty,
540 )
541 } else {
542 tracing::Span::current()
543 };
544
545 span.record("gen_ai.system_instructions", &completion_request.preamble);
546 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
547
548 if tracing::enabled!(tracing::Level::TRACE) {
549 tracing::trace!(target: "rig::completions",
550 "Ollama completion request: {}",
551 serde_json::to_string_pretty(&request)?
552 );
553 }
554
555 let body = serde_json::to_vec(&request)?;
556
557 let req = self
558 .client
559 .post("api/chat")?
560 .body(body)
561 .map_err(http_client::Error::from)?;
562
563 let async_block = async move {
564 let response = self.client.send::<_, Bytes>(req).await?;
565 let status = response.status();
566 let response_body = response.into_body().into_future().await?.to_vec();
567
568 if !status.is_success() {
569 return Err(CompletionError::ProviderError(
570 String::from_utf8_lossy(&response_body).to_string(),
571 ));
572 }
573
574 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
575 let span = tracing::Span::current();
576 span.record("gen_ai.response.model_name", &response.model);
577 span.record(
578 "gen_ai.usage.input_tokens",
579 response.prompt_eval_count.unwrap_or_default(),
580 );
581 span.record(
582 "gen_ai.usage.output_tokens",
583 response.eval_count.unwrap_or_default(),
584 );
585
586 if tracing::enabled!(tracing::Level::TRACE) {
587 tracing::trace!(target: "rig::completions",
588 "Ollama completion response: {}",
589 serde_json::to_string_pretty(&response)?
590 );
591 }
592
593 let response: completion::CompletionResponse<CompletionResponse> =
594 response.try_into()?;
595
596 Ok(response)
597 };
598
599 tracing::Instrument::instrument(async_block, span).await
600 }
601
602 async fn stream(
603 &self,
604 request: CompletionRequest,
605 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
606 {
607 let span = if tracing::Span::current().is_disabled() {
608 info_span!(
609 target: "rig::completions",
610 "chat_streaming",
611 gen_ai.operation.name = "chat_streaming",
612 gen_ai.provider.name = "ollama",
613 gen_ai.request.model = self.model,
614 gen_ai.system_instructions = tracing::field::Empty,
615 gen_ai.response.id = tracing::field::Empty,
616 gen_ai.response.model = self.model,
617 gen_ai.usage.output_tokens = tracing::field::Empty,
618 gen_ai.usage.input_tokens = tracing::field::Empty,
619 gen_ai.usage.cached_tokens = tracing::field::Empty,
620 )
621 } else {
622 tracing::Span::current()
623 };
624
625 span.record("gen_ai.system_instructions", &request.preamble);
626
627 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
628 request.stream = true;
629
630 if tracing::enabled!(tracing::Level::TRACE) {
631 tracing::trace!(target: "rig::completions",
632 "Ollama streaming completion request: {}",
633 serde_json::to_string_pretty(&request)?
634 );
635 }
636
637 let body = serde_json::to_vec(&request)?;
638
639 let req = self
640 .client
641 .post("api/chat")?
642 .body(body)
643 .map_err(http_client::Error::from)?;
644
645 let response = self.client.send_streaming(req).await?;
646 let status = response.status();
647 let mut byte_stream = response.into_body();
648
649 if !status.is_success() {
650 return Err(CompletionError::ProviderError(format!(
651 "Got error status code trying to send a request to Ollama: {status}"
652 )));
653 }
654
655 let stream = try_stream! {
656 let span = tracing::Span::current();
657 let mut tool_calls_final = Vec::new();
658 let mut text_response = String::new();
659 let mut thinking_response = String::new();
660
661 while let Some(chunk) = byte_stream.next().await {
662 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
663
664 for line in bytes.split(|&b| b == b'\n') {
665 if line.is_empty() {
666 continue;
667 }
668
669 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
670
671 let response: CompletionResponse = serde_json::from_slice(line)?;
672
673 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
674 if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
675 thinking_response += &thinking_content;
676 yield RawStreamingChoice::ReasoningDelta {
677 id: None,
678 reasoning: thinking_content,
679 };
680 }
681
682 if !content.is_empty() {
683 text_response += &content;
684 yield RawStreamingChoice::Message(content);
685 }
686
687 for tool_call in tool_calls {
688 tool_calls_final.push(tool_call.clone());
689 yield RawStreamingChoice::ToolCall(
690 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
691 );
692 }
693 }
694
695 if response.done {
696 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
697 span.record("gen_ai.usage.output_tokens", response.eval_count);
698 let message = Message::Assistant {
699 content: text_response.clone(),
700 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
701 images: None,
702 name: None,
703 tool_calls: tool_calls_final.clone()
704 };
705 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
706 yield RawStreamingChoice::FinalResponse(
707 StreamingCompletionResponse {
708 total_duration: response.total_duration,
709 load_duration: response.load_duration,
710 prompt_eval_count: response.prompt_eval_count,
711 prompt_eval_duration: response.prompt_eval_duration,
712 eval_count: response.eval_count,
713 eval_duration: response.eval_duration,
714 done_reason: response.done_reason,
715 }
716 );
717 break;
718 }
719 }
720 }
721 }.instrument(span);
722
723 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
724 stream,
725 )))
726 }
727}
728
729#[derive(Clone, Debug, Deserialize, Serialize)]
733pub struct ToolDefinition {
734 #[serde(rename = "type")]
735 pub type_field: String, pub function: completion::ToolDefinition,
737}
738
739impl From<crate::completion::ToolDefinition> for ToolDefinition {
741 fn from(tool: crate::completion::ToolDefinition) -> Self {
742 ToolDefinition {
743 type_field: "function".to_owned(),
744 function: completion::ToolDefinition {
745 name: tool.name,
746 description: tool.description,
747 parameters: tool.parameters,
748 },
749 }
750 }
751}
752
753#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
754pub struct ToolCall {
755 #[serde(default, rename = "type")]
756 pub r#type: ToolType,
757 pub function: Function,
758}
759#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
760#[serde(rename_all = "lowercase")]
761pub enum ToolType {
762 #[default]
763 Function,
764}
765#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
766pub struct Function {
767 pub name: String,
768 pub arguments: Value,
769}
770
771#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
774#[serde(tag = "role", rename_all = "lowercase")]
775pub enum Message {
776 User {
777 content: String,
778 #[serde(skip_serializing_if = "Option::is_none")]
779 images: Option<Vec<String>>,
780 #[serde(skip_serializing_if = "Option::is_none")]
781 name: Option<String>,
782 },
783 Assistant {
784 #[serde(default)]
785 content: String,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 thinking: Option<String>,
788 #[serde(skip_serializing_if = "Option::is_none")]
789 images: Option<Vec<String>>,
790 #[serde(skip_serializing_if = "Option::is_none")]
791 name: Option<String>,
792 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
793 tool_calls: Vec<ToolCall>,
794 },
795 System {
796 content: String,
797 #[serde(skip_serializing_if = "Option::is_none")]
798 images: Option<Vec<String>>,
799 #[serde(skip_serializing_if = "Option::is_none")]
800 name: Option<String>,
801 },
802 #[serde(rename = "tool")]
803 ToolResult {
804 #[serde(rename = "tool_name")]
805 name: String,
806 content: String,
807 },
808}
809
810impl TryFrom<crate::message::Message> for Vec<Message> {
816 type Error = crate::message::MessageError;
817 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
818 use crate::message::Message as InternalMessage;
819 match internal_msg {
820 InternalMessage::System { content } => Ok(vec![Message::System {
821 content,
822 images: None,
823 name: None,
824 }]),
825 InternalMessage::User { content, .. } => {
826 let (tool_results, other_content): (Vec<_>, Vec<_>) =
827 content.into_iter().partition(|content| {
828 matches!(content, crate::message::UserContent::ToolResult(_))
829 });
830
831 if !tool_results.is_empty() {
832 tool_results
833 .into_iter()
834 .map(|content| match content {
835 crate::message::UserContent::ToolResult(
836 crate::message::ToolResult { id, content, .. },
837 ) => {
838 let content_string = content
840 .into_iter()
841 .map(|content| match content {
842 crate::message::ToolResultContent::Text(text) => text.text,
843 _ => "[Non-text content]".to_string(),
844 })
845 .collect::<Vec<_>>()
846 .join("\n");
847
848 Ok::<_, crate::message::MessageError>(Message::ToolResult {
849 name: id,
850 content: content_string,
851 })
852 }
853 _ => unreachable!(),
854 })
855 .collect::<Result<Vec<_>, _>>()
856 } else {
857 let (texts, images) = other_content.into_iter().fold(
859 (Vec::new(), Vec::new()),
860 |(mut texts, mut images), content| {
861 match content {
862 crate::message::UserContent::Text(crate::message::Text {
863 text,
864 }) => texts.push(text),
865 crate::message::UserContent::Image(crate::message::Image {
866 data: DocumentSourceKind::Base64(data),
867 ..
868 }) => images.push(data),
869 crate::message::UserContent::Document(
870 crate::message::Document {
871 data:
872 DocumentSourceKind::Base64(data)
873 | DocumentSourceKind::String(data),
874 ..
875 },
876 ) => texts.push(data),
877 _ => {} }
879 (texts, images)
880 },
881 );
882
883 Ok(vec![Message::User {
884 content: texts.join(" "),
885 images: if images.is_empty() {
886 None
887 } else {
888 Some(
889 images
890 .into_iter()
891 .map(|x| x.to_string())
892 .collect::<Vec<String>>(),
893 )
894 },
895 name: None,
896 }])
897 }
898 }
899 InternalMessage::Assistant { content, .. } => {
900 let mut thinking: Option<String> = None;
901 let mut text_content = Vec::new();
902 let mut tool_calls = Vec::new();
903
904 for content in content.into_iter() {
905 match content {
906 crate::message::AssistantContent::Text(text) => {
907 text_content.push(text.text)
908 }
909 crate::message::AssistantContent::ToolCall(tool_call) => {
910 tool_calls.push(tool_call)
911 }
912 crate::message::AssistantContent::Reasoning(reasoning) => {
913 let display = reasoning.display_text();
914 if !display.is_empty() {
915 thinking = Some(display);
916 }
917 }
918 crate::message::AssistantContent::Image(_) => {
919 return Err(crate::message::MessageError::ConversionError(
920 "Ollama currently doesn't support images.".into(),
921 ));
922 }
923 }
924 }
925
926 Ok(vec![Message::Assistant {
929 content: text_content.join(" "),
930 thinking,
931 images: None,
932 name: None,
933 tool_calls: tool_calls
934 .into_iter()
935 .map(|tool_call| tool_call.into())
936 .collect::<Vec<_>>(),
937 }])
938 }
939 }
940 }
941}
942
943impl From<Message> for crate::completion::Message {
946 fn from(msg: Message) -> Self {
947 match msg {
948 Message::User { content, .. } => crate::completion::Message::User {
949 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
950 text: content,
951 })),
952 },
953 Message::Assistant {
954 content,
955 tool_calls,
956 ..
957 } => {
958 let mut assistant_contents =
959 vec![crate::completion::message::AssistantContent::Text(Text {
960 text: content,
961 })];
962 for tc in tool_calls {
963 assistant_contents.push(
964 crate::completion::message::AssistantContent::tool_call(
965 tc.function.name.clone(),
966 tc.function.name,
967 tc.function.arguments,
968 ),
969 );
970 }
971 crate::completion::Message::Assistant {
972 id: None,
973 content: OneOrMany::many(assistant_contents).unwrap(),
974 }
975 }
976 Message::System { content, .. } => crate::completion::Message::User {
978 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
979 text: content,
980 })),
981 },
982 Message::ToolResult { name, content } => crate::completion::Message::User {
983 content: OneOrMany::one(message::UserContent::tool_result(
984 name,
985 OneOrMany::one(message::ToolResultContent::text(content)),
986 )),
987 },
988 }
989 }
990}
991
992impl Message {
993 pub fn system(content: &str) -> Self {
995 Message::System {
996 content: content.to_owned(),
997 images: None,
998 name: None,
999 }
1000 }
1001}
1002
1003impl From<crate::message::ToolCall> for ToolCall {
1006 fn from(tool_call: crate::message::ToolCall) -> Self {
1007 Self {
1008 r#type: ToolType::Function,
1009 function: Function {
1010 name: tool_call.function.name,
1011 arguments: tool_call.function.arguments,
1012 },
1013 }
1014 }
1015}
1016
1017#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1018pub struct SystemContent {
1019 #[serde(default)]
1020 r#type: SystemContentType,
1021 text: String,
1022}
1023
1024#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1025#[serde(rename_all = "lowercase")]
1026pub enum SystemContentType {
1027 #[default]
1028 Text,
1029}
1030
1031impl From<String> for SystemContent {
1032 fn from(s: String) -> Self {
1033 SystemContent {
1034 r#type: SystemContentType::default(),
1035 text: s,
1036 }
1037 }
1038}
1039
1040impl FromStr for SystemContent {
1041 type Err = std::convert::Infallible;
1042 fn from_str(s: &str) -> Result<Self, Self::Err> {
1043 Ok(SystemContent {
1044 r#type: SystemContentType::default(),
1045 text: s.to_string(),
1046 })
1047 }
1048}
1049
1050#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1051pub struct AssistantContent {
1052 pub text: String,
1053}
1054
1055impl FromStr for AssistantContent {
1056 type Err = std::convert::Infallible;
1057 fn from_str(s: &str) -> Result<Self, Self::Err> {
1058 Ok(AssistantContent { text: s.to_owned() })
1059 }
1060}
1061
1062#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1063#[serde(tag = "type", rename_all = "lowercase")]
1064pub enum UserContent {
1065 Text { text: String },
1066 Image { image_url: ImageUrl },
1067 }
1069
1070impl FromStr for UserContent {
1071 type Err = std::convert::Infallible;
1072 fn from_str(s: &str) -> Result<Self, Self::Err> {
1073 Ok(UserContent::Text { text: s.to_owned() })
1074 }
1075}
1076
1077#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1078pub struct ImageUrl {
1079 pub url: String,
1080 #[serde(default)]
1081 pub detail: ImageDetail,
1082}
1083
1084#[cfg(test)]
1089mod tests {
1090 use super::*;
1091 use serde_json::json;
1092
1093 #[tokio::test]
1095 async fn test_chat_completion() {
1096 let sample_chat_response = json!({
1098 "model": "llama3.2",
1099 "created_at": "2023-08-04T19:22:45.499127Z",
1100 "message": {
1101 "role": "assistant",
1102 "content": "The sky is blue because of Rayleigh scattering.",
1103 "images": null,
1104 "tool_calls": [
1105 {
1106 "type": "function",
1107 "function": {
1108 "name": "get_current_weather",
1109 "arguments": {
1110 "location": "San Francisco, CA",
1111 "format": "celsius"
1112 }
1113 }
1114 }
1115 ]
1116 },
1117 "done": true,
1118 "total_duration": 8000000000u64,
1119 "load_duration": 6000000u64,
1120 "prompt_eval_count": 61u64,
1121 "prompt_eval_duration": 400000000u64,
1122 "eval_count": 468u64,
1123 "eval_duration": 7700000000u64
1124 });
1125 let sample_text = sample_chat_response.to_string();
1126
1127 let chat_resp: CompletionResponse =
1128 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1129 let conv: completion::CompletionResponse<CompletionResponse> =
1130 chat_resp.try_into().unwrap();
1131 assert!(
1132 !conv.choice.is_empty(),
1133 "Expected non-empty choice in chat response"
1134 );
1135 }
1136
1137 #[test]
1139 fn test_message_conversion() {
1140 let provider_msg = Message::User {
1142 content: "Test message".to_owned(),
1143 images: None,
1144 name: None,
1145 };
1146 let comp_msg: crate::completion::Message = provider_msg.into();
1148 match comp_msg {
1149 crate::completion::Message::User { content } => {
1150 let first_content = content.first();
1152 match first_content {
1154 crate::completion::message::UserContent::Text(text_struct) => {
1155 assert_eq!(text_struct.text, "Test message");
1156 }
1157 _ => panic!("Expected text content in conversion"),
1158 }
1159 }
1160 _ => panic!("Conversion from provider Message to completion Message failed"),
1161 }
1162 }
1163
1164 #[test]
1166 fn test_tool_definition_conversion() {
1167 let internal_tool = crate::completion::ToolDefinition {
1169 name: "get_current_weather".to_owned(),
1170 description: "Get the current weather for a location".to_owned(),
1171 parameters: json!({
1172 "type": "object",
1173 "properties": {
1174 "location": {
1175 "type": "string",
1176 "description": "The location to get the weather for, e.g. San Francisco, CA"
1177 },
1178 "format": {
1179 "type": "string",
1180 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1181 "enum": ["celsius", "fahrenheit"]
1182 }
1183 },
1184 "required": ["location", "format"]
1185 }),
1186 };
1187 let ollama_tool: ToolDefinition = internal_tool.into();
1189 assert_eq!(ollama_tool.type_field, "function");
1190 assert_eq!(ollama_tool.function.name, "get_current_weather");
1191 assert_eq!(
1192 ollama_tool.function.description,
1193 "Get the current weather for a location"
1194 );
1195 let params = &ollama_tool.function.parameters;
1197 assert_eq!(params["properties"]["location"]["type"], "string");
1198 }
1199
1200 #[tokio::test]
1202 async fn test_chat_completion_with_thinking() {
1203 let sample_response = json!({
1204 "model": "qwen-thinking",
1205 "created_at": "2023-08-04T19:22:45.499127Z",
1206 "message": {
1207 "role": "assistant",
1208 "content": "The answer is 42.",
1209 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1210 "images": null,
1211 "tool_calls": []
1212 },
1213 "done": true,
1214 "total_duration": 8000000000u64,
1215 "load_duration": 6000000u64,
1216 "prompt_eval_count": 61u64,
1217 "prompt_eval_duration": 400000000u64,
1218 "eval_count": 468u64,
1219 "eval_duration": 7700000000u64
1220 });
1221
1222 let chat_resp: CompletionResponse =
1223 serde_json::from_value(sample_response).expect("Failed to deserialize");
1224
1225 if let Message::Assistant {
1227 thinking, content, ..
1228 } = &chat_resp.message
1229 {
1230 assert_eq!(
1231 thinking.as_ref().unwrap(),
1232 "Let me think about this carefully. The question asks for the meaning of life..."
1233 );
1234 assert_eq!(content, "The answer is 42.");
1235 } else {
1236 panic!("Expected Assistant message");
1237 }
1238 }
1239
1240 #[tokio::test]
1242 async fn test_chat_completion_without_thinking() {
1243 let sample_response = json!({
1244 "model": "llama3.2",
1245 "created_at": "2023-08-04T19:22:45.499127Z",
1246 "message": {
1247 "role": "assistant",
1248 "content": "Hello!",
1249 "images": null,
1250 "tool_calls": []
1251 },
1252 "done": true,
1253 "total_duration": 8000000000u64,
1254 "load_duration": 6000000u64,
1255 "prompt_eval_count": 10u64,
1256 "prompt_eval_duration": 400000000u64,
1257 "eval_count": 5u64,
1258 "eval_duration": 7700000000u64
1259 });
1260
1261 let chat_resp: CompletionResponse =
1262 serde_json::from_value(sample_response).expect("Failed to deserialize");
1263
1264 if let Message::Assistant {
1266 thinking, content, ..
1267 } = &chat_resp.message
1268 {
1269 assert!(thinking.is_none());
1270 assert_eq!(content, "Hello!");
1271 } else {
1272 panic!("Expected Assistant message");
1273 }
1274 }
1275
1276 #[test]
1278 fn test_streaming_response_with_thinking() {
1279 let sample_chunk = json!({
1280 "model": "qwen-thinking",
1281 "created_at": "2023-08-04T19:22:45.499127Z",
1282 "message": {
1283 "role": "assistant",
1284 "content": "",
1285 "thinking": "Analyzing the problem...",
1286 "images": null,
1287 "tool_calls": []
1288 },
1289 "done": false
1290 });
1291
1292 let chunk: CompletionResponse =
1293 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1294
1295 if let Message::Assistant {
1296 thinking, content, ..
1297 } = &chunk.message
1298 {
1299 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1300 assert_eq!(content, "");
1301 } else {
1302 panic!("Expected Assistant message");
1303 }
1304 }
1305
1306 #[test]
1308 fn test_message_conversion_with_thinking() {
1309 let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1311
1312 let internal_msg = crate::message::Message::Assistant {
1313 id: None,
1314 content: crate::OneOrMany::many(vec![
1315 crate::message::AssistantContent::Reasoning(reasoning_content),
1316 crate::message::AssistantContent::Text(crate::message::Text {
1317 text: "The answer is X".to_string(),
1318 }),
1319 ])
1320 .unwrap(),
1321 };
1322
1323 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1325 assert_eq!(provider_msgs.len(), 1);
1326
1327 if let Message::Assistant {
1328 thinking, content, ..
1329 } = &provider_msgs[0]
1330 {
1331 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1332 assert_eq!(content, "The answer is X");
1333 } else {
1334 panic!("Expected Assistant message with thinking");
1335 }
1336 }
1337
1338 #[test]
1340 fn test_empty_thinking_content() {
1341 let sample_response = json!({
1342 "model": "llama3.2",
1343 "created_at": "2023-08-04T19:22:45.499127Z",
1344 "message": {
1345 "role": "assistant",
1346 "content": "Response",
1347 "thinking": "",
1348 "images": null,
1349 "tool_calls": []
1350 },
1351 "done": true,
1352 "total_duration": 8000000000u64,
1353 "load_duration": 6000000u64,
1354 "prompt_eval_count": 10u64,
1355 "prompt_eval_duration": 400000000u64,
1356 "eval_count": 5u64,
1357 "eval_duration": 7700000000u64
1358 });
1359
1360 let chat_resp: CompletionResponse =
1361 serde_json::from_value(sample_response).expect("Failed to deserialize");
1362
1363 if let Message::Assistant {
1364 thinking, content, ..
1365 } = &chat_resp.message
1366 {
1367 assert_eq!(thinking.as_ref().unwrap(), "");
1369 assert_eq!(content, "Response");
1370 } else {
1371 panic!("Expected Assistant message");
1372 }
1373 }
1374
1375 #[test]
1377 fn test_thinking_with_tool_calls() {
1378 let sample_response = json!({
1379 "model": "qwen-thinking",
1380 "created_at": "2023-08-04T19:22:45.499127Z",
1381 "message": {
1382 "role": "assistant",
1383 "content": "Let me check the weather.",
1384 "thinking": "User wants weather info, I should use the weather tool",
1385 "images": null,
1386 "tool_calls": [
1387 {
1388 "type": "function",
1389 "function": {
1390 "name": "get_weather",
1391 "arguments": {
1392 "location": "San Francisco"
1393 }
1394 }
1395 }
1396 ]
1397 },
1398 "done": true,
1399 "total_duration": 8000000000u64,
1400 "load_duration": 6000000u64,
1401 "prompt_eval_count": 30u64,
1402 "prompt_eval_duration": 400000000u64,
1403 "eval_count": 50u64,
1404 "eval_duration": 7700000000u64
1405 });
1406
1407 let chat_resp: CompletionResponse =
1408 serde_json::from_value(sample_response).expect("Failed to deserialize");
1409
1410 if let Message::Assistant {
1411 thinking,
1412 content,
1413 tool_calls,
1414 ..
1415 } = &chat_resp.message
1416 {
1417 assert_eq!(
1418 thinking.as_ref().unwrap(),
1419 "User wants weather info, I should use the weather tool"
1420 );
1421 assert_eq!(content, "Let me check the weather.");
1422 assert_eq!(tool_calls.len(), 1);
1423 assert_eq!(tool_calls[0].function.name, "get_weather");
1424 } else {
1425 panic!("Expected Assistant message with thinking and tool calls");
1426 }
1427 }
1428
1429 #[test]
1431 fn test_completion_request_with_think_param() {
1432 use crate::OneOrMany;
1433 use crate::completion::Message as CompletionMessage;
1434 use crate::message::{Text, UserContent};
1435
1436 let completion_request = CompletionRequest {
1438 model: None,
1439 preamble: Some("You are a helpful assistant.".to_string()),
1440 chat_history: OneOrMany::one(CompletionMessage::User {
1441 content: OneOrMany::one(UserContent::Text(Text {
1442 text: "What is 2 + 2?".to_string(),
1443 })),
1444 }),
1445 documents: vec![],
1446 tools: vec![],
1447 temperature: Some(0.7),
1448 max_tokens: Some(1024),
1449 tool_choice: None,
1450 additional_params: Some(json!({
1451 "think": true,
1452 "keep_alive": "-1m",
1453 "num_ctx": 4096
1454 })),
1455 output_schema: None,
1456 };
1457
1458 let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1460 .expect("Failed to create Ollama request");
1461
1462 let serialized =
1464 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1465
1466 let expected = json!({
1472 "model": "qwen3:8b",
1473 "messages": [
1474 {
1475 "role": "system",
1476 "content": "You are a helpful assistant."
1477 },
1478 {
1479 "role": "user",
1480 "content": "What is 2 + 2?"
1481 }
1482 ],
1483 "temperature": 0.7,
1484 "stream": false,
1485 "think": true,
1486 "max_tokens": 1024,
1487 "keep_alive": "-1m",
1488 "options": {
1489 "temperature": 0.7,
1490 "num_ctx": 4096
1491 }
1492 });
1493
1494 assert_eq!(serialized, expected);
1495 }
1496
1497 #[test]
1499 fn test_completion_request_with_think_false_default() {
1500 use crate::OneOrMany;
1501 use crate::completion::Message as CompletionMessage;
1502 use crate::message::{Text, UserContent};
1503
1504 let completion_request = CompletionRequest {
1506 model: None,
1507 preamble: Some("You are a helpful assistant.".to_string()),
1508 chat_history: OneOrMany::one(CompletionMessage::User {
1509 content: OneOrMany::one(UserContent::Text(Text {
1510 text: "Hello!".to_string(),
1511 })),
1512 }),
1513 documents: vec![],
1514 tools: vec![],
1515 temperature: Some(0.5),
1516 max_tokens: None,
1517 tool_choice: None,
1518 additional_params: None,
1519 output_schema: None,
1520 };
1521
1522 let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1524 .expect("Failed to create Ollama request");
1525
1526 let serialized =
1528 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1529
1530 let expected = json!({
1532 "model": "llama3.2",
1533 "messages": [
1534 {
1535 "role": "system",
1536 "content": "You are a helpful assistant."
1537 },
1538 {
1539 "role": "user",
1540 "content": "Hello!"
1541 }
1542 ],
1543 "temperature": 0.5,
1544 "stream": false,
1545 "think": false,
1546 "options": {
1547 "temperature": 0.5
1548 }
1549 });
1550
1551 assert_eq!(serialized, expected);
1552 }
1553
1554 #[test]
1555 fn test_completion_request_with_output_schema() {
1556 use crate::OneOrMany;
1557 use crate::completion::Message as CompletionMessage;
1558 use crate::message::{Text, UserContent};
1559
1560 let schema: schemars::Schema = serde_json::from_value(json!({
1561 "type": "object",
1562 "properties": {
1563 "age": { "type": "integer" },
1564 "available": { "type": "boolean" }
1565 },
1566 "required": ["age", "available"]
1567 }))
1568 .expect("Failed to parse schema");
1569
1570 let completion_request = CompletionRequest {
1571 model: Some("llama3.1".to_string()),
1572 preamble: None,
1573 chat_history: OneOrMany::one(CompletionMessage::User {
1574 content: OneOrMany::one(UserContent::Text(Text {
1575 text: "How old is Ollama?".to_string(),
1576 })),
1577 }),
1578 documents: vec![],
1579 tools: vec![],
1580 temperature: None,
1581 max_tokens: None,
1582 tool_choice: None,
1583 additional_params: None,
1584 output_schema: Some(schema),
1585 };
1586
1587 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1588 .expect("Failed to create Ollama request");
1589
1590 let serialized =
1591 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1592
1593 let format = serialized
1594 .get("format")
1595 .expect("format field should be present");
1596 assert_eq!(
1597 *format,
1598 json!({
1599 "type": "object",
1600 "properties": {
1601 "age": { "type": "integer" },
1602 "available": { "type": "boolean" }
1603 },
1604 "required": ["age", "available"]
1605 })
1606 );
1607 }
1608
1609 #[test]
1610 fn test_completion_request_without_output_schema() {
1611 use crate::OneOrMany;
1612 use crate::completion::Message as CompletionMessage;
1613 use crate::message::{Text, UserContent};
1614
1615 let completion_request = CompletionRequest {
1616 model: Some("llama3.1".to_string()),
1617 preamble: None,
1618 chat_history: OneOrMany::one(CompletionMessage::User {
1619 content: OneOrMany::one(UserContent::Text(Text {
1620 text: "Hello!".to_string(),
1621 })),
1622 }),
1623 documents: vec![],
1624 tools: vec![],
1625 temperature: None,
1626 max_tokens: None,
1627 tool_choice: None,
1628 additional_params: None,
1629 output_schema: None,
1630 };
1631
1632 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1633 .expect("Failed to create Ollama request");
1634
1635 let serialized =
1636 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1637
1638 assert!(
1639 serialized.get("format").is_none(),
1640 "format field should be absent when output_schema is None"
1641 );
1642 }
1643
1644 #[test]
1645 fn test_client_initialization() {
1646 let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1647 let _client_from_builder = crate::providers::ollama::Client::builder()
1648 .api_key(Nothing)
1649 .build()
1650 .expect("Client::builder() failed");
1651 }
1652}