1use crate::client::{
35 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42 OneOrMany,
43 completion::{self, CompletionError, CompletionRequest},
44 embeddings::{self, EmbeddingError},
45 json_utils, message,
46 message::{ImageDetail, Text},
47 streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68 type Builder = OllamaBuilder;
69 const VERIFY_PATH: &'static str = "api/tags";
70}
71
72impl<H> Capabilities<H> for OllamaExt {
73 type Completion = Capable<CompletionModel<H>>;
74 type Transcription = Nothing;
75 type Embeddings = Capable<EmbeddingModel<H>>;
76 type ModelListing = Nothing;
77 #[cfg(feature = "image")]
78 type ImageGeneration = Nothing;
79
80 #[cfg(feature = "audio")]
81 type AudioGeneration = Nothing;
82}
83
84impl DebugExt for OllamaExt {}
85
86impl ProviderBuilder for OllamaBuilder {
87 type Extension<H>
88 = OllamaExt
89 where
90 H: HttpClientExt;
91 type ApiKey = Nothing;
92
93 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
94
95 fn build<H>(
96 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97 ) -> http_client::Result<Self::Extension<H>>
98 where
99 H: HttpClientExt,
100 {
101 Ok(OllamaExt)
102 }
103}
104
105pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
106pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
107
108impl ProviderClient for Client {
109 type Input = Nothing;
110
111 fn from_env() -> Self {
112 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
113
114 Self::builder()
115 .api_key(Nothing)
116 .base_url(&api_base)
117 .build()
118 .unwrap()
119 }
120
121 fn from_val(_: Self::Input) -> Self {
122 Self::builder().api_key(Nothing).build().unwrap()
123 }
124}
125
126#[derive(Debug, Deserialize)]
129struct ApiErrorResponse {
130 message: String,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum ApiResponse<T> {
136 Ok(T),
137 Err(ApiErrorResponse),
138}
139
140pub const ALL_MINILM: &str = "all-minilm";
143pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
144
145fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
146 match identifier {
147 ALL_MINILM => Some(384),
148 NOMIC_EMBED_TEXT => Some(768),
149 _ => None,
150 }
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct EmbeddingResponse {
155 pub model: String,
156 pub embeddings: Vec<Vec<f64>>,
157 #[serde(default)]
158 pub total_duration: Option<u64>,
159 #[serde(default)]
160 pub load_duration: Option<u64>,
161 #[serde(default)]
162 pub prompt_eval_count: Option<u64>,
163}
164
165impl From<ApiErrorResponse> for EmbeddingError {
166 fn from(err: ApiErrorResponse) -> Self {
167 EmbeddingError::ProviderError(err.message)
168 }
169}
170
171impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
172 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
173 match value {
174 ApiResponse::Ok(response) => Ok(response),
175 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
176 }
177 }
178}
179
180#[derive(Clone)]
183pub struct EmbeddingModel<T = reqwest::Client> {
184 client: Client<T>,
185 pub model: String,
186 ndims: usize,
187}
188
189impl<T> EmbeddingModel<T> {
190 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
191 Self {
192 client,
193 model: model.into(),
194 ndims,
195 }
196 }
197
198 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
199 Self {
200 client,
201 model: model.into(),
202 ndims,
203 }
204 }
205}
206
207impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
208where
209 T: HttpClientExt + Clone + 'static,
210{
211 type Client = Client<T>;
212
213 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
214 let model = model.into();
215 let dims = dims
216 .or(model_dimensions_from_identifier(&model))
217 .unwrap_or_default();
218 Self::new(client.clone(), model, dims)
219 }
220
221 const MAX_DOCUMENTS: usize = 1024;
222 fn ndims(&self) -> usize {
223 self.ndims
224 }
225
226 async fn embed_texts(
227 &self,
228 documents: impl IntoIterator<Item = String>,
229 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
230 let docs: Vec<String> = documents.into_iter().collect();
231
232 let body = serde_json::to_vec(&json!({
233 "model": self.model,
234 "input": docs
235 }))?;
236
237 let req = self
238 .client
239 .post("api/embed")?
240 .body(body)
241 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
242
243 let response = self.client.send(req).await?;
244
245 if !response.status().is_success() {
246 let text = http_client::text(response).await?;
247 return Err(EmbeddingError::ProviderError(text));
248 }
249
250 let bytes: Vec<u8> = response.into_body().await?;
251
252 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
253
254 if api_resp.embeddings.len() != docs.len() {
255 return Err(EmbeddingError::ResponseError(
256 "Number of returned embeddings does not match input".into(),
257 ));
258 }
259 Ok(api_resp
260 .embeddings
261 .into_iter()
262 .zip(docs.into_iter())
263 .map(|(vec, document)| embeddings::Embedding { document, vec })
264 .collect())
265 }
266}
267
268pub const LLAMA3_2: &str = "llama3.2";
271pub const LLAVA: &str = "llava";
272pub const MISTRAL: &str = "mistral";
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct CompletionResponse {
276 pub model: String,
277 pub created_at: String,
278 pub message: Message,
279 pub done: bool,
280 #[serde(default)]
281 pub done_reason: Option<String>,
282 #[serde(default)]
283 pub total_duration: Option<u64>,
284 #[serde(default)]
285 pub load_duration: Option<u64>,
286 #[serde(default)]
287 pub prompt_eval_count: Option<u64>,
288 #[serde(default)]
289 pub prompt_eval_duration: Option<u64>,
290 #[serde(default)]
291 pub eval_count: Option<u64>,
292 #[serde(default)]
293 pub eval_duration: Option<u64>,
294}
295impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
296 type Error = CompletionError;
297 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
298 match resp.message {
299 Message::Assistant {
301 content,
302 thinking,
303 tool_calls,
304 ..
305 } => {
306 let mut assistant_contents = Vec::new();
307 if !content.is_empty() {
309 assistant_contents.push(completion::AssistantContent::text(&content));
310 }
311 for tc in tool_calls.iter() {
314 assistant_contents.push(completion::AssistantContent::tool_call(
315 tc.function.name.clone(),
316 tc.function.name.clone(),
317 tc.function.arguments.clone(),
318 ));
319 }
320 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
321 CompletionError::ResponseError("No content provided".to_owned())
322 })?;
323 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
324 let completion_tokens = resp.eval_count.unwrap_or(0);
325
326 let raw_response = CompletionResponse {
327 model: resp.model,
328 created_at: resp.created_at,
329 done: resp.done,
330 done_reason: resp.done_reason,
331 total_duration: resp.total_duration,
332 load_duration: resp.load_duration,
333 prompt_eval_count: resp.prompt_eval_count,
334 prompt_eval_duration: resp.prompt_eval_duration,
335 eval_count: resp.eval_count,
336 eval_duration: resp.eval_duration,
337 message: Message::Assistant {
338 content,
339 thinking,
340 images: None,
341 name: None,
342 tool_calls,
343 },
344 };
345
346 Ok(completion::CompletionResponse {
347 choice,
348 usage: Usage {
349 input_tokens: prompt_tokens,
350 output_tokens: completion_tokens,
351 total_tokens: prompt_tokens + completion_tokens,
352 cached_input_tokens: 0,
353 cache_creation_input_tokens: 0,
354 },
355 raw_response,
356 message_id: None,
357 })
358 }
359 _ => Err(CompletionError::ResponseError(
360 "Chat response does not include an assistant message".into(),
361 )),
362 }
363 }
364}
365
366#[derive(Debug, Serialize, Deserialize)]
367pub(super) struct OllamaCompletionRequest {
368 model: String,
369 pub messages: Vec<Message>,
370 #[serde(skip_serializing_if = "Option::is_none")]
371 temperature: Option<f64>,
372 #[serde(skip_serializing_if = "Vec::is_empty")]
373 tools: Vec<ToolDefinition>,
374 pub stream: bool,
375 think: bool,
376 #[serde(skip_serializing_if = "Option::is_none")]
377 max_tokens: Option<u64>,
378 #[serde(skip_serializing_if = "Option::is_none")]
379 keep_alive: Option<String>,
380 #[serde(skip_serializing_if = "Option::is_none")]
381 format: Option<schemars::Schema>,
382 options: serde_json::Value,
383}
384
385impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
386 type Error = CompletionError;
387
388 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
389 let model = req.model.clone().unwrap_or_else(|| model.to_string());
390 if req.tool_choice.is_some() {
391 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
392 }
393 let mut partial_history = vec![];
395 if let Some(docs) = req.normalized_documents() {
396 partial_history.push(docs);
397 }
398 partial_history.extend(req.chat_history);
399
400 let mut full_history: Vec<Message> = match &req.preamble {
402 Some(preamble) => vec![Message::system(preamble)],
403 None => vec![],
404 };
405
406 full_history.extend(
408 partial_history
409 .into_iter()
410 .map(message::Message::try_into)
411 .collect::<Result<Vec<Vec<Message>>, _>>()?
412 .into_iter()
413 .flatten()
414 .collect::<Vec<_>>(),
415 );
416
417 let mut think = false;
418 let mut keep_alive: Option<String> = None;
419
420 let options = if let Some(mut extra) = req.additional_params {
421 if let Some(obj) = extra.as_object_mut() {
423 if let Some(think_val) = obj.remove("think") {
425 think = think_val.as_bool().ok_or_else(|| {
426 CompletionError::RequestError("`think` must be a bool".into())
427 })?;
428 }
429
430 if let Some(keep_alive_val) = obj.remove("keep_alive") {
432 keep_alive = Some(
433 keep_alive_val
434 .as_str()
435 .ok_or_else(|| {
436 CompletionError::RequestError(
437 "`keep_alive` must be a string".into(),
438 )
439 })?
440 .to_string(),
441 );
442 }
443 }
444
445 json_utils::merge(json!({ "temperature": req.temperature }), extra)
446 } else {
447 json!({ "temperature": req.temperature })
448 };
449
450 Ok(Self {
451 model: model.to_string(),
452 messages: full_history,
453 temperature: req.temperature,
454 max_tokens: req.max_tokens,
455 stream: false,
456 think,
457 keep_alive,
458 format: req.output_schema,
459 tools: req
460 .tools
461 .clone()
462 .into_iter()
463 .map(ToolDefinition::from)
464 .collect::<Vec<_>>(),
465 options,
466 })
467 }
468}
469
470#[derive(Clone)]
471pub struct CompletionModel<T = reqwest::Client> {
472 client: Client<T>,
473 pub model: String,
474}
475
476impl<T> CompletionModel<T> {
477 pub fn new(client: Client<T>, model: &str) -> Self {
478 Self {
479 client,
480 model: model.to_owned(),
481 }
482 }
483}
484
485#[derive(Clone, Serialize, Deserialize, Debug)]
488pub struct StreamingCompletionResponse {
489 pub done_reason: Option<String>,
490 pub total_duration: Option<u64>,
491 pub load_duration: Option<u64>,
492 pub prompt_eval_count: Option<u64>,
493 pub prompt_eval_duration: Option<u64>,
494 pub eval_count: Option<u64>,
495 pub eval_duration: Option<u64>,
496}
497
498impl GetTokenUsage for StreamingCompletionResponse {
499 fn token_usage(&self) -> Option<crate::completion::Usage> {
500 let mut usage = crate::completion::Usage::new();
501 let input_tokens = self.prompt_eval_count.unwrap_or_default();
502 let output_tokens = self.eval_count.unwrap_or_default();
503 usage.input_tokens = input_tokens;
504 usage.output_tokens = output_tokens;
505 usage.total_tokens = input_tokens + output_tokens;
506
507 Some(usage)
508 }
509}
510
511impl<T> completion::CompletionModel for CompletionModel<T>
512where
513 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
514{
515 type Response = CompletionResponse;
516 type StreamingResponse = StreamingCompletionResponse;
517
518 type Client = Client<T>;
519
520 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
521 Self::new(client.clone(), model.into().as_str())
522 }
523
524 async fn completion(
525 &self,
526 completion_request: CompletionRequest,
527 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
528 let span = if tracing::Span::current().is_disabled() {
529 info_span!(
530 target: "rig::completions",
531 "chat",
532 gen_ai.operation.name = "chat",
533 gen_ai.provider.name = "ollama",
534 gen_ai.request.model = self.model,
535 gen_ai.system_instructions = tracing::field::Empty,
536 gen_ai.response.id = tracing::field::Empty,
537 gen_ai.response.model = tracing::field::Empty,
538 gen_ai.usage.output_tokens = tracing::field::Empty,
539 gen_ai.usage.input_tokens = tracing::field::Empty,
540 gen_ai.usage.cached_tokens = tracing::field::Empty,
541 )
542 } else {
543 tracing::Span::current()
544 };
545
546 span.record("gen_ai.system_instructions", &completion_request.preamble);
547 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
548
549 if tracing::enabled!(tracing::Level::TRACE) {
550 tracing::trace!(target: "rig::completions",
551 "Ollama completion request: {}",
552 serde_json::to_string_pretty(&request)?
553 );
554 }
555
556 let body = serde_json::to_vec(&request)?;
557
558 let req = self
559 .client
560 .post("api/chat")?
561 .body(body)
562 .map_err(http_client::Error::from)?;
563
564 let async_block = async move {
565 let response = self.client.send::<_, Bytes>(req).await?;
566 let status = response.status();
567 let response_body = response.into_body().into_future().await?.to_vec();
568
569 if !status.is_success() {
570 return Err(CompletionError::ProviderError(
571 String::from_utf8_lossy(&response_body).to_string(),
572 ));
573 }
574
575 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
576 let span = tracing::Span::current();
577 span.record("gen_ai.response.model_name", &response.model);
578 span.record(
579 "gen_ai.usage.input_tokens",
580 response.prompt_eval_count.unwrap_or_default(),
581 );
582 span.record(
583 "gen_ai.usage.output_tokens",
584 response.eval_count.unwrap_or_default(),
585 );
586
587 if tracing::enabled!(tracing::Level::TRACE) {
588 tracing::trace!(target: "rig::completions",
589 "Ollama completion response: {}",
590 serde_json::to_string_pretty(&response)?
591 );
592 }
593
594 let response: completion::CompletionResponse<CompletionResponse> =
595 response.try_into()?;
596
597 Ok(response)
598 };
599
600 tracing::Instrument::instrument(async_block, span).await
601 }
602
603 async fn stream(
604 &self,
605 request: CompletionRequest,
606 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
607 {
608 let span = if tracing::Span::current().is_disabled() {
609 info_span!(
610 target: "rig::completions",
611 "chat_streaming",
612 gen_ai.operation.name = "chat_streaming",
613 gen_ai.provider.name = "ollama",
614 gen_ai.request.model = self.model,
615 gen_ai.system_instructions = tracing::field::Empty,
616 gen_ai.response.id = tracing::field::Empty,
617 gen_ai.response.model = self.model,
618 gen_ai.usage.output_tokens = tracing::field::Empty,
619 gen_ai.usage.input_tokens = tracing::field::Empty,
620 gen_ai.usage.cached_tokens = tracing::field::Empty,
621 )
622 } else {
623 tracing::Span::current()
624 };
625
626 span.record("gen_ai.system_instructions", &request.preamble);
627
628 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
629 request.stream = true;
630
631 if tracing::enabled!(tracing::Level::TRACE) {
632 tracing::trace!(target: "rig::completions",
633 "Ollama streaming completion request: {}",
634 serde_json::to_string_pretty(&request)?
635 );
636 }
637
638 let body = serde_json::to_vec(&request)?;
639
640 let req = self
641 .client
642 .post("api/chat")?
643 .body(body)
644 .map_err(http_client::Error::from)?;
645
646 let response = self.client.send_streaming(req).await?;
647 let status = response.status();
648 let mut byte_stream = response.into_body();
649
650 if !status.is_success() {
651 return Err(CompletionError::ProviderError(format!(
652 "Got error status code trying to send a request to Ollama: {status}"
653 )));
654 }
655
656 let stream = try_stream! {
657 let span = tracing::Span::current();
658 let mut tool_calls_final = Vec::new();
659 let mut text_response = String::new();
660 let mut thinking_response = String::new();
661
662 while let Some(chunk) = byte_stream.next().await {
663 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
664
665 for line in bytes.split(|&b| b == b'\n') {
666 if line.is_empty() {
667 continue;
668 }
669
670 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
671
672 let response: CompletionResponse = serde_json::from_slice(line)?;
673
674 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
675 if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
676 thinking_response += &thinking_content;
677 yield RawStreamingChoice::ReasoningDelta {
678 id: None,
679 reasoning: thinking_content,
680 };
681 }
682
683 if !content.is_empty() {
684 text_response += &content;
685 yield RawStreamingChoice::Message(content);
686 }
687
688 for tool_call in tool_calls {
689 tool_calls_final.push(tool_call.clone());
690 yield RawStreamingChoice::ToolCall(
691 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
692 );
693 }
694 }
695
696 if response.done {
697 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
698 span.record("gen_ai.usage.output_tokens", response.eval_count);
699 let message = Message::Assistant {
700 content: text_response.clone(),
701 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
702 images: None,
703 name: None,
704 tool_calls: tool_calls_final.clone()
705 };
706 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
707 yield RawStreamingChoice::FinalResponse(
708 StreamingCompletionResponse {
709 total_duration: response.total_duration,
710 load_duration: response.load_duration,
711 prompt_eval_count: response.prompt_eval_count,
712 prompt_eval_duration: response.prompt_eval_duration,
713 eval_count: response.eval_count,
714 eval_duration: response.eval_duration,
715 done_reason: response.done_reason,
716 }
717 );
718 break;
719 }
720 }
721 }
722 }.instrument(span);
723
724 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
725 stream,
726 )))
727 }
728}
729
730#[derive(Clone, Debug, Deserialize, Serialize)]
734pub struct ToolDefinition {
735 #[serde(rename = "type")]
736 pub type_field: String, pub function: completion::ToolDefinition,
738}
739
740impl From<crate::completion::ToolDefinition> for ToolDefinition {
742 fn from(tool: crate::completion::ToolDefinition) -> Self {
743 ToolDefinition {
744 type_field: "function".to_owned(),
745 function: completion::ToolDefinition {
746 name: tool.name,
747 description: tool.description,
748 parameters: tool.parameters,
749 },
750 }
751 }
752}
753
754#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
755pub struct ToolCall {
756 #[serde(default, rename = "type")]
757 pub r#type: ToolType,
758 pub function: Function,
759}
760#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
761#[serde(rename_all = "lowercase")]
762pub enum ToolType {
763 #[default]
764 Function,
765}
766#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
767pub struct Function {
768 pub name: String,
769 pub arguments: Value,
770}
771
772#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
775#[serde(tag = "role", rename_all = "lowercase")]
776pub enum Message {
777 User {
778 content: String,
779 #[serde(skip_serializing_if = "Option::is_none")]
780 images: Option<Vec<String>>,
781 #[serde(skip_serializing_if = "Option::is_none")]
782 name: Option<String>,
783 },
784 Assistant {
785 #[serde(default)]
786 content: String,
787 #[serde(skip_serializing_if = "Option::is_none")]
788 thinking: Option<String>,
789 #[serde(skip_serializing_if = "Option::is_none")]
790 images: Option<Vec<String>>,
791 #[serde(skip_serializing_if = "Option::is_none")]
792 name: Option<String>,
793 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
794 tool_calls: Vec<ToolCall>,
795 },
796 System {
797 content: String,
798 #[serde(skip_serializing_if = "Option::is_none")]
799 images: Option<Vec<String>>,
800 #[serde(skip_serializing_if = "Option::is_none")]
801 name: Option<String>,
802 },
803 #[serde(rename = "tool")]
804 ToolResult {
805 #[serde(rename = "tool_name")]
806 name: String,
807 content: String,
808 },
809}
810
811impl TryFrom<crate::message::Message> for Vec<Message> {
817 type Error = crate::message::MessageError;
818 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
819 use crate::message::Message as InternalMessage;
820 match internal_msg {
821 InternalMessage::System { content } => Ok(vec![Message::System {
822 content,
823 images: None,
824 name: None,
825 }]),
826 InternalMessage::User { content, .. } => {
827 let (tool_results, other_content): (Vec<_>, Vec<_>) =
828 content.into_iter().partition(|content| {
829 matches!(content, crate::message::UserContent::ToolResult(_))
830 });
831
832 if !tool_results.is_empty() {
833 tool_results
834 .into_iter()
835 .map(|content| match content {
836 crate::message::UserContent::ToolResult(
837 crate::message::ToolResult { id, content, .. },
838 ) => {
839 let content_string = content
841 .into_iter()
842 .map(|content| match content {
843 crate::message::ToolResultContent::Text(text) => text.text,
844 _ => "[Non-text content]".to_string(),
845 })
846 .collect::<Vec<_>>()
847 .join("\n");
848
849 Ok::<_, crate::message::MessageError>(Message::ToolResult {
850 name: id,
851 content: content_string,
852 })
853 }
854 _ => unreachable!(),
855 })
856 .collect::<Result<Vec<_>, _>>()
857 } else {
858 let (texts, images) = other_content.into_iter().fold(
860 (Vec::new(), Vec::new()),
861 |(mut texts, mut images), content| {
862 match content {
863 crate::message::UserContent::Text(crate::message::Text {
864 text,
865 }) => texts.push(text),
866 crate::message::UserContent::Image(crate::message::Image {
867 data: DocumentSourceKind::Base64(data),
868 ..
869 }) => images.push(data),
870 crate::message::UserContent::Document(
871 crate::message::Document {
872 data:
873 DocumentSourceKind::Base64(data)
874 | DocumentSourceKind::String(data),
875 ..
876 },
877 ) => texts.push(data),
878 _ => {} }
880 (texts, images)
881 },
882 );
883
884 Ok(vec![Message::User {
885 content: texts.join(" "),
886 images: if images.is_empty() {
887 None
888 } else {
889 Some(
890 images
891 .into_iter()
892 .map(|x| x.to_string())
893 .collect::<Vec<String>>(),
894 )
895 },
896 name: None,
897 }])
898 }
899 }
900 InternalMessage::Assistant { content, .. } => {
901 let mut thinking: Option<String> = None;
902 let mut text_content = Vec::new();
903 let mut tool_calls = Vec::new();
904
905 for content in content.into_iter() {
906 match content {
907 crate::message::AssistantContent::Text(text) => {
908 text_content.push(text.text)
909 }
910 crate::message::AssistantContent::ToolCall(tool_call) => {
911 tool_calls.push(tool_call)
912 }
913 crate::message::AssistantContent::Reasoning(reasoning) => {
914 let display = reasoning.display_text();
915 if !display.is_empty() {
916 thinking = Some(display);
917 }
918 }
919 crate::message::AssistantContent::Image(_) => {
920 return Err(crate::message::MessageError::ConversionError(
921 "Ollama currently doesn't support images.".into(),
922 ));
923 }
924 }
925 }
926
927 Ok(vec![Message::Assistant {
930 content: text_content.join(" "),
931 thinking,
932 images: None,
933 name: None,
934 tool_calls: tool_calls
935 .into_iter()
936 .map(|tool_call| tool_call.into())
937 .collect::<Vec<_>>(),
938 }])
939 }
940 }
941 }
942}
943
944impl From<Message> for crate::completion::Message {
947 fn from(msg: Message) -> Self {
948 match msg {
949 Message::User { content, .. } => crate::completion::Message::User {
950 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
951 text: content,
952 })),
953 },
954 Message::Assistant {
955 content,
956 tool_calls,
957 ..
958 } => {
959 let mut assistant_contents =
960 vec![crate::completion::message::AssistantContent::Text(Text {
961 text: content,
962 })];
963 for tc in tool_calls {
964 assistant_contents.push(
965 crate::completion::message::AssistantContent::tool_call(
966 tc.function.name.clone(),
967 tc.function.name,
968 tc.function.arguments,
969 ),
970 );
971 }
972 crate::completion::Message::Assistant {
973 id: None,
974 content: OneOrMany::many(assistant_contents).unwrap(),
975 }
976 }
977 Message::System { content, .. } => crate::completion::Message::User {
979 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
980 text: content,
981 })),
982 },
983 Message::ToolResult { name, content } => crate::completion::Message::User {
984 content: OneOrMany::one(message::UserContent::tool_result(
985 name,
986 OneOrMany::one(message::ToolResultContent::text(content)),
987 )),
988 },
989 }
990 }
991}
992
993impl Message {
994 pub fn system(content: &str) -> Self {
996 Message::System {
997 content: content.to_owned(),
998 images: None,
999 name: None,
1000 }
1001 }
1002}
1003
1004impl From<crate::message::ToolCall> for ToolCall {
1007 fn from(tool_call: crate::message::ToolCall) -> Self {
1008 Self {
1009 r#type: ToolType::Function,
1010 function: Function {
1011 name: tool_call.function.name,
1012 arguments: tool_call.function.arguments,
1013 },
1014 }
1015 }
1016}
1017
1018#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1019pub struct SystemContent {
1020 #[serde(default)]
1021 r#type: SystemContentType,
1022 text: String,
1023}
1024
1025#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1026#[serde(rename_all = "lowercase")]
1027pub enum SystemContentType {
1028 #[default]
1029 Text,
1030}
1031
1032impl From<String> for SystemContent {
1033 fn from(s: String) -> Self {
1034 SystemContent {
1035 r#type: SystemContentType::default(),
1036 text: s,
1037 }
1038 }
1039}
1040
1041impl FromStr for SystemContent {
1042 type Err = std::convert::Infallible;
1043 fn from_str(s: &str) -> Result<Self, Self::Err> {
1044 Ok(SystemContent {
1045 r#type: SystemContentType::default(),
1046 text: s.to_string(),
1047 })
1048 }
1049}
1050
1051#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1052pub struct AssistantContent {
1053 pub text: String,
1054}
1055
1056impl FromStr for AssistantContent {
1057 type Err = std::convert::Infallible;
1058 fn from_str(s: &str) -> Result<Self, Self::Err> {
1059 Ok(AssistantContent { text: s.to_owned() })
1060 }
1061}
1062
1063#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1064#[serde(tag = "type", rename_all = "lowercase")]
1065pub enum UserContent {
1066 Text { text: String },
1067 Image { image_url: ImageUrl },
1068 }
1070
1071impl FromStr for UserContent {
1072 type Err = std::convert::Infallible;
1073 fn from_str(s: &str) -> Result<Self, Self::Err> {
1074 Ok(UserContent::Text { text: s.to_owned() })
1075 }
1076}
1077
1078#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1079pub struct ImageUrl {
1080 pub url: String,
1081 #[serde(default)]
1082 pub detail: ImageDetail,
1083}
1084
1085#[cfg(test)]
1090mod tests {
1091 use super::*;
1092 use serde_json::json;
1093
1094 #[tokio::test]
1096 async fn test_chat_completion() {
1097 let sample_chat_response = json!({
1099 "model": "llama3.2",
1100 "created_at": "2023-08-04T19:22:45.499127Z",
1101 "message": {
1102 "role": "assistant",
1103 "content": "The sky is blue because of Rayleigh scattering.",
1104 "images": null,
1105 "tool_calls": [
1106 {
1107 "type": "function",
1108 "function": {
1109 "name": "get_current_weather",
1110 "arguments": {
1111 "location": "San Francisco, CA",
1112 "format": "celsius"
1113 }
1114 }
1115 }
1116 ]
1117 },
1118 "done": true,
1119 "total_duration": 8000000000u64,
1120 "load_duration": 6000000u64,
1121 "prompt_eval_count": 61u64,
1122 "prompt_eval_duration": 400000000u64,
1123 "eval_count": 468u64,
1124 "eval_duration": 7700000000u64
1125 });
1126 let sample_text = sample_chat_response.to_string();
1127
1128 let chat_resp: CompletionResponse =
1129 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1130 let conv: completion::CompletionResponse<CompletionResponse> =
1131 chat_resp.try_into().unwrap();
1132 assert!(
1133 !conv.choice.is_empty(),
1134 "Expected non-empty choice in chat response"
1135 );
1136 }
1137
1138 #[test]
1140 fn test_message_conversion() {
1141 let provider_msg = Message::User {
1143 content: "Test message".to_owned(),
1144 images: None,
1145 name: None,
1146 };
1147 let comp_msg: crate::completion::Message = provider_msg.into();
1149 match comp_msg {
1150 crate::completion::Message::User { content } => {
1151 let first_content = content.first();
1153 match first_content {
1155 crate::completion::message::UserContent::Text(text_struct) => {
1156 assert_eq!(text_struct.text, "Test message");
1157 }
1158 _ => panic!("Expected text content in conversion"),
1159 }
1160 }
1161 _ => panic!("Conversion from provider Message to completion Message failed"),
1162 }
1163 }
1164
1165 #[test]
1167 fn test_tool_definition_conversion() {
1168 let internal_tool = crate::completion::ToolDefinition {
1170 name: "get_current_weather".to_owned(),
1171 description: "Get the current weather for a location".to_owned(),
1172 parameters: json!({
1173 "type": "object",
1174 "properties": {
1175 "location": {
1176 "type": "string",
1177 "description": "The location to get the weather for, e.g. San Francisco, CA"
1178 },
1179 "format": {
1180 "type": "string",
1181 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1182 "enum": ["celsius", "fahrenheit"]
1183 }
1184 },
1185 "required": ["location", "format"]
1186 }),
1187 };
1188 let ollama_tool: ToolDefinition = internal_tool.into();
1190 assert_eq!(ollama_tool.type_field, "function");
1191 assert_eq!(ollama_tool.function.name, "get_current_weather");
1192 assert_eq!(
1193 ollama_tool.function.description,
1194 "Get the current weather for a location"
1195 );
1196 let params = &ollama_tool.function.parameters;
1198 assert_eq!(params["properties"]["location"]["type"], "string");
1199 }
1200
1201 #[tokio::test]
1203 async fn test_chat_completion_with_thinking() {
1204 let sample_response = json!({
1205 "model": "qwen-thinking",
1206 "created_at": "2023-08-04T19:22:45.499127Z",
1207 "message": {
1208 "role": "assistant",
1209 "content": "The answer is 42.",
1210 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1211 "images": null,
1212 "tool_calls": []
1213 },
1214 "done": true,
1215 "total_duration": 8000000000u64,
1216 "load_duration": 6000000u64,
1217 "prompt_eval_count": 61u64,
1218 "prompt_eval_duration": 400000000u64,
1219 "eval_count": 468u64,
1220 "eval_duration": 7700000000u64
1221 });
1222
1223 let chat_resp: CompletionResponse =
1224 serde_json::from_value(sample_response).expect("Failed to deserialize");
1225
1226 if let Message::Assistant {
1228 thinking, content, ..
1229 } = &chat_resp.message
1230 {
1231 assert_eq!(
1232 thinking.as_ref().unwrap(),
1233 "Let me think about this carefully. The question asks for the meaning of life..."
1234 );
1235 assert_eq!(content, "The answer is 42.");
1236 } else {
1237 panic!("Expected Assistant message");
1238 }
1239 }
1240
1241 #[tokio::test]
1243 async fn test_chat_completion_without_thinking() {
1244 let sample_response = json!({
1245 "model": "llama3.2",
1246 "created_at": "2023-08-04T19:22:45.499127Z",
1247 "message": {
1248 "role": "assistant",
1249 "content": "Hello!",
1250 "images": null,
1251 "tool_calls": []
1252 },
1253 "done": true,
1254 "total_duration": 8000000000u64,
1255 "load_duration": 6000000u64,
1256 "prompt_eval_count": 10u64,
1257 "prompt_eval_duration": 400000000u64,
1258 "eval_count": 5u64,
1259 "eval_duration": 7700000000u64
1260 });
1261
1262 let chat_resp: CompletionResponse =
1263 serde_json::from_value(sample_response).expect("Failed to deserialize");
1264
1265 if let Message::Assistant {
1267 thinking, content, ..
1268 } = &chat_resp.message
1269 {
1270 assert!(thinking.is_none());
1271 assert_eq!(content, "Hello!");
1272 } else {
1273 panic!("Expected Assistant message");
1274 }
1275 }
1276
1277 #[test]
1279 fn test_streaming_response_with_thinking() {
1280 let sample_chunk = json!({
1281 "model": "qwen-thinking",
1282 "created_at": "2023-08-04T19:22:45.499127Z",
1283 "message": {
1284 "role": "assistant",
1285 "content": "",
1286 "thinking": "Analyzing the problem...",
1287 "images": null,
1288 "tool_calls": []
1289 },
1290 "done": false
1291 });
1292
1293 let chunk: CompletionResponse =
1294 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1295
1296 if let Message::Assistant {
1297 thinking, content, ..
1298 } = &chunk.message
1299 {
1300 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1301 assert_eq!(content, "");
1302 } else {
1303 panic!("Expected Assistant message");
1304 }
1305 }
1306
1307 #[test]
1309 fn test_message_conversion_with_thinking() {
1310 let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1312
1313 let internal_msg = crate::message::Message::Assistant {
1314 id: None,
1315 content: crate::OneOrMany::many(vec![
1316 crate::message::AssistantContent::Reasoning(reasoning_content),
1317 crate::message::AssistantContent::Text(crate::message::Text {
1318 text: "The answer is X".to_string(),
1319 }),
1320 ])
1321 .unwrap(),
1322 };
1323
1324 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1326 assert_eq!(provider_msgs.len(), 1);
1327
1328 if let Message::Assistant {
1329 thinking, content, ..
1330 } = &provider_msgs[0]
1331 {
1332 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1333 assert_eq!(content, "The answer is X");
1334 } else {
1335 panic!("Expected Assistant message with thinking");
1336 }
1337 }
1338
1339 #[test]
1341 fn test_empty_thinking_content() {
1342 let sample_response = json!({
1343 "model": "llama3.2",
1344 "created_at": "2023-08-04T19:22:45.499127Z",
1345 "message": {
1346 "role": "assistant",
1347 "content": "Response",
1348 "thinking": "",
1349 "images": null,
1350 "tool_calls": []
1351 },
1352 "done": true,
1353 "total_duration": 8000000000u64,
1354 "load_duration": 6000000u64,
1355 "prompt_eval_count": 10u64,
1356 "prompt_eval_duration": 400000000u64,
1357 "eval_count": 5u64,
1358 "eval_duration": 7700000000u64
1359 });
1360
1361 let chat_resp: CompletionResponse =
1362 serde_json::from_value(sample_response).expect("Failed to deserialize");
1363
1364 if let Message::Assistant {
1365 thinking, content, ..
1366 } = &chat_resp.message
1367 {
1368 assert_eq!(thinking.as_ref().unwrap(), "");
1370 assert_eq!(content, "Response");
1371 } else {
1372 panic!("Expected Assistant message");
1373 }
1374 }
1375
1376 #[test]
1378 fn test_thinking_with_tool_calls() {
1379 let sample_response = json!({
1380 "model": "qwen-thinking",
1381 "created_at": "2023-08-04T19:22:45.499127Z",
1382 "message": {
1383 "role": "assistant",
1384 "content": "Let me check the weather.",
1385 "thinking": "User wants weather info, I should use the weather tool",
1386 "images": null,
1387 "tool_calls": [
1388 {
1389 "type": "function",
1390 "function": {
1391 "name": "get_weather",
1392 "arguments": {
1393 "location": "San Francisco"
1394 }
1395 }
1396 }
1397 ]
1398 },
1399 "done": true,
1400 "total_duration": 8000000000u64,
1401 "load_duration": 6000000u64,
1402 "prompt_eval_count": 30u64,
1403 "prompt_eval_duration": 400000000u64,
1404 "eval_count": 50u64,
1405 "eval_duration": 7700000000u64
1406 });
1407
1408 let chat_resp: CompletionResponse =
1409 serde_json::from_value(sample_response).expect("Failed to deserialize");
1410
1411 if let Message::Assistant {
1412 thinking,
1413 content,
1414 tool_calls,
1415 ..
1416 } = &chat_resp.message
1417 {
1418 assert_eq!(
1419 thinking.as_ref().unwrap(),
1420 "User wants weather info, I should use the weather tool"
1421 );
1422 assert_eq!(content, "Let me check the weather.");
1423 assert_eq!(tool_calls.len(), 1);
1424 assert_eq!(tool_calls[0].function.name, "get_weather");
1425 } else {
1426 panic!("Expected Assistant message with thinking and tool calls");
1427 }
1428 }
1429
1430 #[test]
1432 fn test_completion_request_with_think_param() {
1433 use crate::OneOrMany;
1434 use crate::completion::Message as CompletionMessage;
1435 use crate::message::{Text, UserContent};
1436
1437 let completion_request = CompletionRequest {
1439 model: None,
1440 preamble: Some("You are a helpful assistant.".to_string()),
1441 chat_history: OneOrMany::one(CompletionMessage::User {
1442 content: OneOrMany::one(UserContent::Text(Text {
1443 text: "What is 2 + 2?".to_string(),
1444 })),
1445 }),
1446 documents: vec![],
1447 tools: vec![],
1448 temperature: Some(0.7),
1449 max_tokens: Some(1024),
1450 tool_choice: None,
1451 additional_params: Some(json!({
1452 "think": true,
1453 "keep_alive": "-1m",
1454 "num_ctx": 4096
1455 })),
1456 output_schema: None,
1457 };
1458
1459 let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1461 .expect("Failed to create Ollama request");
1462
1463 let serialized =
1465 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1466
1467 let expected = json!({
1473 "model": "qwen3:8b",
1474 "messages": [
1475 {
1476 "role": "system",
1477 "content": "You are a helpful assistant."
1478 },
1479 {
1480 "role": "user",
1481 "content": "What is 2 + 2?"
1482 }
1483 ],
1484 "temperature": 0.7,
1485 "stream": false,
1486 "think": true,
1487 "max_tokens": 1024,
1488 "keep_alive": "-1m",
1489 "options": {
1490 "temperature": 0.7,
1491 "num_ctx": 4096
1492 }
1493 });
1494
1495 assert_eq!(serialized, expected);
1496 }
1497
1498 #[test]
1500 fn test_completion_request_with_think_false_default() {
1501 use crate::OneOrMany;
1502 use crate::completion::Message as CompletionMessage;
1503 use crate::message::{Text, UserContent};
1504
1505 let completion_request = CompletionRequest {
1507 model: None,
1508 preamble: Some("You are a helpful assistant.".to_string()),
1509 chat_history: OneOrMany::one(CompletionMessage::User {
1510 content: OneOrMany::one(UserContent::Text(Text {
1511 text: "Hello!".to_string(),
1512 })),
1513 }),
1514 documents: vec![],
1515 tools: vec![],
1516 temperature: Some(0.5),
1517 max_tokens: None,
1518 tool_choice: None,
1519 additional_params: None,
1520 output_schema: None,
1521 };
1522
1523 let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1525 .expect("Failed to create Ollama request");
1526
1527 let serialized =
1529 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1530
1531 let expected = json!({
1533 "model": "llama3.2",
1534 "messages": [
1535 {
1536 "role": "system",
1537 "content": "You are a helpful assistant."
1538 },
1539 {
1540 "role": "user",
1541 "content": "Hello!"
1542 }
1543 ],
1544 "temperature": 0.5,
1545 "stream": false,
1546 "think": false,
1547 "options": {
1548 "temperature": 0.5
1549 }
1550 });
1551
1552 assert_eq!(serialized, expected);
1553 }
1554
1555 #[test]
1556 fn test_completion_request_with_output_schema() {
1557 use crate::OneOrMany;
1558 use crate::completion::Message as CompletionMessage;
1559 use crate::message::{Text, UserContent};
1560
1561 let schema: schemars::Schema = serde_json::from_value(json!({
1562 "type": "object",
1563 "properties": {
1564 "age": { "type": "integer" },
1565 "available": { "type": "boolean" }
1566 },
1567 "required": ["age", "available"]
1568 }))
1569 .expect("Failed to parse schema");
1570
1571 let completion_request = CompletionRequest {
1572 model: Some("llama3.1".to_string()),
1573 preamble: None,
1574 chat_history: OneOrMany::one(CompletionMessage::User {
1575 content: OneOrMany::one(UserContent::Text(Text {
1576 text: "How old is Ollama?".to_string(),
1577 })),
1578 }),
1579 documents: vec![],
1580 tools: vec![],
1581 temperature: None,
1582 max_tokens: None,
1583 tool_choice: None,
1584 additional_params: None,
1585 output_schema: Some(schema),
1586 };
1587
1588 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1589 .expect("Failed to create Ollama request");
1590
1591 let serialized =
1592 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1593
1594 let format = serialized
1595 .get("format")
1596 .expect("format field should be present");
1597 assert_eq!(
1598 *format,
1599 json!({
1600 "type": "object",
1601 "properties": {
1602 "age": { "type": "integer" },
1603 "available": { "type": "boolean" }
1604 },
1605 "required": ["age", "available"]
1606 })
1607 );
1608 }
1609
1610 #[test]
1611 fn test_completion_request_without_output_schema() {
1612 use crate::OneOrMany;
1613 use crate::completion::Message as CompletionMessage;
1614 use crate::message::{Text, UserContent};
1615
1616 let completion_request = CompletionRequest {
1617 model: Some("llama3.1".to_string()),
1618 preamble: None,
1619 chat_history: OneOrMany::one(CompletionMessage::User {
1620 content: OneOrMany::one(UserContent::Text(Text {
1621 text: "Hello!".to_string(),
1622 })),
1623 }),
1624 documents: vec![],
1625 tools: vec![],
1626 temperature: None,
1627 max_tokens: None,
1628 tool_choice: None,
1629 additional_params: None,
1630 output_schema: None,
1631 };
1632
1633 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1634 .expect("Failed to create Ollama request");
1635
1636 let serialized =
1637 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1638
1639 assert!(
1640 serialized.get("format").is_none(),
1641 "format field should be absent when output_schema is None"
1642 );
1643 }
1644
1645 #[test]
1646 fn test_client_initialization() {
1647 let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1648 let _client_from_builder = crate::providers::ollama::Client::builder()
1649 .api_key(Nothing)
1650 .build()
1651 .expect("Client::builder() failed");
1652 }
1653}