1use crate::client::{
35 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42 OneOrMany,
43 completion::{self, CompletionError, CompletionRequest},
44 embeddings::{self, EmbeddingError},
45 json_utils, message,
46 message::{ImageDetail, Text},
47 streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68 type Builder = OllamaBuilder;
69
70 const VERIFY_PATH: &'static str = "api/tags";
71
72 fn build<H>(
73 _: &crate::client::ClientBuilder<
74 Self::Builder,
75 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
76 H,
77 >,
78 ) -> http_client::Result<Self> {
79 Ok(Self)
80 }
81}
82
83impl<H> Capabilities<H> for OllamaExt {
84 type Completion = Capable<CompletionModel<H>>;
85 type Transcription = Nothing;
86 type Embeddings = Capable<EmbeddingModel<H>>;
87 type ModelListing = Nothing;
88 #[cfg(feature = "image")]
89 type ImageGeneration = Nothing;
90
91 #[cfg(feature = "audio")]
92 type AudioGeneration = Nothing;
93}
94
95impl DebugExt for OllamaExt {}
96
97impl ProviderBuilder for OllamaBuilder {
98 type Output = OllamaExt;
99 type ApiKey = Nothing;
100
101 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
102}
103
104pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
105pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
106
107impl ProviderClient for Client {
108 type Input = Nothing;
109
110 fn from_env() -> Self {
111 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
112
113 Self::builder()
114 .api_key(Nothing)
115 .base_url(&api_base)
116 .build()
117 .unwrap()
118 }
119
120 fn from_val(_: Self::Input) -> Self {
121 Self::builder().api_key(Nothing).build().unwrap()
122 }
123}
124
125#[derive(Debug, Deserialize)]
128struct ApiErrorResponse {
129 message: String,
130}
131
132#[derive(Debug, Deserialize)]
133#[serde(untagged)]
134enum ApiResponse<T> {
135 Ok(T),
136 Err(ApiErrorResponse),
137}
138
139pub const ALL_MINILM: &str = "all-minilm";
142pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
143
144fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
145 match identifier {
146 ALL_MINILM => Some(384),
147 NOMIC_EMBED_TEXT => Some(768),
148 _ => None,
149 }
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153pub struct EmbeddingResponse {
154 pub model: String,
155 pub embeddings: Vec<Vec<f64>>,
156 #[serde(default)]
157 pub total_duration: Option<u64>,
158 #[serde(default)]
159 pub load_duration: Option<u64>,
160 #[serde(default)]
161 pub prompt_eval_count: Option<u64>,
162}
163
164impl From<ApiErrorResponse> for EmbeddingError {
165 fn from(err: ApiErrorResponse) -> Self {
166 EmbeddingError::ProviderError(err.message)
167 }
168}
169
170impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
171 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
172 match value {
173 ApiResponse::Ok(response) => Ok(response),
174 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
175 }
176 }
177}
178
179#[derive(Clone)]
182pub struct EmbeddingModel<T = reqwest::Client> {
183 client: Client<T>,
184 pub model: String,
185 ndims: usize,
186}
187
188impl<T> EmbeddingModel<T> {
189 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
190 Self {
191 client,
192 model: model.into(),
193 ndims,
194 }
195 }
196
197 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
198 Self {
199 client,
200 model: model.into(),
201 ndims,
202 }
203 }
204}
205
206impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
207where
208 T: HttpClientExt + Clone + 'static,
209{
210 type Client = Client<T>;
211
212 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
213 let model = model.into();
214 let dims = dims
215 .or(model_dimensions_from_identifier(&model))
216 .unwrap_or_default();
217 Self::new(client.clone(), model, dims)
218 }
219
220 const MAX_DOCUMENTS: usize = 1024;
221 fn ndims(&self) -> usize {
222 self.ndims
223 }
224
225 async fn embed_texts(
226 &self,
227 documents: impl IntoIterator<Item = String>,
228 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
229 let docs: Vec<String> = documents.into_iter().collect();
230
231 let body = serde_json::to_vec(&json!({
232 "model": self.model,
233 "input": docs
234 }))?;
235
236 let req = self
237 .client
238 .post("api/embed")?
239 .body(body)
240 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
241
242 let response = self.client.send(req).await?;
243
244 if !response.status().is_success() {
245 let text = http_client::text(response).await?;
246 return Err(EmbeddingError::ProviderError(text));
247 }
248
249 let bytes: Vec<u8> = response.into_body().await?;
250
251 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
252
253 if api_resp.embeddings.len() != docs.len() {
254 return Err(EmbeddingError::ResponseError(
255 "Number of returned embeddings does not match input".into(),
256 ));
257 }
258 Ok(api_resp
259 .embeddings
260 .into_iter()
261 .zip(docs.into_iter())
262 .map(|(vec, document)| embeddings::Embedding { document, vec })
263 .collect())
264 }
265}
266
267pub const LLAMA3_2: &str = "llama3.2";
270pub const LLAVA: &str = "llava";
271pub const MISTRAL: &str = "mistral";
272
273#[derive(Debug, Serialize, Deserialize)]
274pub struct CompletionResponse {
275 pub model: String,
276 pub created_at: String,
277 pub message: Message,
278 pub done: bool,
279 #[serde(default)]
280 pub done_reason: Option<String>,
281 #[serde(default)]
282 pub total_duration: Option<u64>,
283 #[serde(default)]
284 pub load_duration: Option<u64>,
285 #[serde(default)]
286 pub prompt_eval_count: Option<u64>,
287 #[serde(default)]
288 pub prompt_eval_duration: Option<u64>,
289 #[serde(default)]
290 pub eval_count: Option<u64>,
291 #[serde(default)]
292 pub eval_duration: Option<u64>,
293}
294impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
295 type Error = CompletionError;
296 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
297 match resp.message {
298 Message::Assistant {
300 content,
301 thinking,
302 tool_calls,
303 ..
304 } => {
305 let mut assistant_contents = Vec::new();
306 if !content.is_empty() {
308 assistant_contents.push(completion::AssistantContent::text(&content));
309 }
310 for tc in tool_calls.iter() {
313 assistant_contents.push(completion::AssistantContent::tool_call(
314 tc.function.name.clone(),
315 tc.function.name.clone(),
316 tc.function.arguments.clone(),
317 ));
318 }
319 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
320 CompletionError::ResponseError("No content provided".to_owned())
321 })?;
322 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
323 let completion_tokens = resp.eval_count.unwrap_or(0);
324
325 let raw_response = CompletionResponse {
326 model: resp.model,
327 created_at: resp.created_at,
328 done: resp.done,
329 done_reason: resp.done_reason,
330 total_duration: resp.total_duration,
331 load_duration: resp.load_duration,
332 prompt_eval_count: resp.prompt_eval_count,
333 prompt_eval_duration: resp.prompt_eval_duration,
334 eval_count: resp.eval_count,
335 eval_duration: resp.eval_duration,
336 message: Message::Assistant {
337 content,
338 thinking,
339 images: None,
340 name: None,
341 tool_calls,
342 },
343 };
344
345 Ok(completion::CompletionResponse {
346 choice,
347 usage: Usage {
348 input_tokens: prompt_tokens,
349 output_tokens: completion_tokens,
350 total_tokens: prompt_tokens + completion_tokens,
351 cached_input_tokens: 0,
352 },
353 raw_response,
354 message_id: None,
355 })
356 }
357 _ => Err(CompletionError::ResponseError(
358 "Chat response does not include an assistant message".into(),
359 )),
360 }
361 }
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365pub(super) struct OllamaCompletionRequest {
366 model: String,
367 pub messages: Vec<Message>,
368 #[serde(skip_serializing_if = "Option::is_none")]
369 temperature: Option<f64>,
370 #[serde(skip_serializing_if = "Vec::is_empty")]
371 tools: Vec<ToolDefinition>,
372 pub stream: bool,
373 think: bool,
374 #[serde(skip_serializing_if = "Option::is_none")]
375 max_tokens: Option<u64>,
376 #[serde(skip_serializing_if = "Option::is_none")]
377 keep_alive: Option<String>,
378 #[serde(skip_serializing_if = "Option::is_none")]
379 format: Option<schemars::Schema>,
380 options: serde_json::Value,
381}
382
383impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
384 type Error = CompletionError;
385
386 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
387 let model = req.model.clone().unwrap_or_else(|| model.to_string());
388 if req.tool_choice.is_some() {
389 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
390 }
391 let mut partial_history = vec![];
393 if let Some(docs) = req.normalized_documents() {
394 partial_history.push(docs);
395 }
396 partial_history.extend(req.chat_history);
397
398 let mut full_history: Vec<Message> = match &req.preamble {
400 Some(preamble) => vec![Message::system(preamble)],
401 None => vec![],
402 };
403
404 full_history.extend(
406 partial_history
407 .into_iter()
408 .map(message::Message::try_into)
409 .collect::<Result<Vec<Vec<Message>>, _>>()?
410 .into_iter()
411 .flatten()
412 .collect::<Vec<_>>(),
413 );
414
415 let mut think = false;
416 let mut keep_alive: Option<String> = None;
417
418 let options = if let Some(mut extra) = req.additional_params {
419 if let Some(obj) = extra.as_object_mut() {
421 if let Some(think_val) = obj.remove("think") {
423 think = think_val.as_bool().ok_or_else(|| {
424 CompletionError::RequestError("`think` must be a bool".into())
425 })?;
426 }
427
428 if let Some(keep_alive_val) = obj.remove("keep_alive") {
430 keep_alive = Some(
431 keep_alive_val
432 .as_str()
433 .ok_or_else(|| {
434 CompletionError::RequestError(
435 "`keep_alive` must be a string".into(),
436 )
437 })?
438 .to_string(),
439 );
440 }
441 }
442
443 json_utils::merge(json!({ "temperature": req.temperature }), extra)
444 } else {
445 json!({ "temperature": req.temperature })
446 };
447
448 Ok(Self {
449 model: model.to_string(),
450 messages: full_history,
451 temperature: req.temperature,
452 max_tokens: req.max_tokens,
453 stream: false,
454 think,
455 keep_alive,
456 format: req.output_schema,
457 tools: req
458 .tools
459 .clone()
460 .into_iter()
461 .map(ToolDefinition::from)
462 .collect::<Vec<_>>(),
463 options,
464 })
465 }
466}
467
468#[derive(Clone)]
469pub struct CompletionModel<T = reqwest::Client> {
470 client: Client<T>,
471 pub model: String,
472}
473
474impl<T> CompletionModel<T> {
475 pub fn new(client: Client<T>, model: &str) -> Self {
476 Self {
477 client,
478 model: model.to_owned(),
479 }
480 }
481}
482
483#[derive(Clone, Serialize, Deserialize, Debug)]
486pub struct StreamingCompletionResponse {
487 pub done_reason: Option<String>,
488 pub total_duration: Option<u64>,
489 pub load_duration: Option<u64>,
490 pub prompt_eval_count: Option<u64>,
491 pub prompt_eval_duration: Option<u64>,
492 pub eval_count: Option<u64>,
493 pub eval_duration: Option<u64>,
494}
495
496impl GetTokenUsage for StreamingCompletionResponse {
497 fn token_usage(&self) -> Option<crate::completion::Usage> {
498 let mut usage = crate::completion::Usage::new();
499 let input_tokens = self.prompt_eval_count.unwrap_or_default();
500 let output_tokens = self.eval_count.unwrap_or_default();
501 usage.input_tokens = input_tokens;
502 usage.output_tokens = output_tokens;
503 usage.total_tokens = input_tokens + output_tokens;
504
505 Some(usage)
506 }
507}
508
509impl<T> completion::CompletionModel for CompletionModel<T>
510where
511 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
512{
513 type Response = CompletionResponse;
514 type StreamingResponse = StreamingCompletionResponse;
515
516 type Client = Client<T>;
517
518 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
519 Self::new(client.clone(), model.into().as_str())
520 }
521
522 async fn completion(
523 &self,
524 completion_request: CompletionRequest,
525 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
526 let span = if tracing::Span::current().is_disabled() {
527 info_span!(
528 target: "rig::completions",
529 "chat",
530 gen_ai.operation.name = "chat",
531 gen_ai.provider.name = "ollama",
532 gen_ai.request.model = self.model,
533 gen_ai.system_instructions = tracing::field::Empty,
534 gen_ai.response.id = tracing::field::Empty,
535 gen_ai.response.model = tracing::field::Empty,
536 gen_ai.usage.output_tokens = tracing::field::Empty,
537 gen_ai.usage.input_tokens = tracing::field::Empty,
538 )
539 } else {
540 tracing::Span::current()
541 };
542
543 span.record("gen_ai.system_instructions", &completion_request.preamble);
544 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
545
546 if tracing::enabled!(tracing::Level::TRACE) {
547 tracing::trace!(target: "rig::completions",
548 "Ollama completion request: {}",
549 serde_json::to_string_pretty(&request)?
550 );
551 }
552
553 let body = serde_json::to_vec(&request)?;
554
555 let req = self
556 .client
557 .post("api/chat")?
558 .body(body)
559 .map_err(http_client::Error::from)?;
560
561 let async_block = async move {
562 let response = self.client.send::<_, Bytes>(req).await?;
563 let status = response.status();
564 let response_body = response.into_body().into_future().await?.to_vec();
565
566 if !status.is_success() {
567 return Err(CompletionError::ProviderError(
568 String::from_utf8_lossy(&response_body).to_string(),
569 ));
570 }
571
572 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
573 let span = tracing::Span::current();
574 span.record("gen_ai.response.model_name", &response.model);
575 span.record(
576 "gen_ai.usage.input_tokens",
577 response.prompt_eval_count.unwrap_or_default(),
578 );
579 span.record(
580 "gen_ai.usage.output_tokens",
581 response.eval_count.unwrap_or_default(),
582 );
583
584 if tracing::enabled!(tracing::Level::TRACE) {
585 tracing::trace!(target: "rig::completions",
586 "Ollama completion response: {}",
587 serde_json::to_string_pretty(&response)?
588 );
589 }
590
591 let response: completion::CompletionResponse<CompletionResponse> =
592 response.try_into()?;
593
594 Ok(response)
595 };
596
597 tracing::Instrument::instrument(async_block, span).await
598 }
599
600 async fn stream(
601 &self,
602 request: CompletionRequest,
603 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
604 {
605 let span = if tracing::Span::current().is_disabled() {
606 info_span!(
607 target: "rig::completions",
608 "chat_streaming",
609 gen_ai.operation.name = "chat_streaming",
610 gen_ai.provider.name = "ollama",
611 gen_ai.request.model = self.model,
612 gen_ai.system_instructions = tracing::field::Empty,
613 gen_ai.response.id = tracing::field::Empty,
614 gen_ai.response.model = self.model,
615 gen_ai.usage.output_tokens = tracing::field::Empty,
616 gen_ai.usage.input_tokens = tracing::field::Empty,
617 )
618 } else {
619 tracing::Span::current()
620 };
621
622 span.record("gen_ai.system_instructions", &request.preamble);
623
624 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
625 request.stream = true;
626
627 if tracing::enabled!(tracing::Level::TRACE) {
628 tracing::trace!(target: "rig::completions",
629 "Ollama streaming completion request: {}",
630 serde_json::to_string_pretty(&request)?
631 );
632 }
633
634 let body = serde_json::to_vec(&request)?;
635
636 let req = self
637 .client
638 .post("api/chat")?
639 .body(body)
640 .map_err(http_client::Error::from)?;
641
642 let response = self.client.send_streaming(req).await?;
643 let status = response.status();
644 let mut byte_stream = response.into_body();
645
646 if !status.is_success() {
647 return Err(CompletionError::ProviderError(format!(
648 "Got error status code trying to send a request to Ollama: {status}"
649 )));
650 }
651
652 let stream = try_stream! {
653 let span = tracing::Span::current();
654 let mut tool_calls_final = Vec::new();
655 let mut text_response = String::new();
656 let mut thinking_response = String::new();
657
658 while let Some(chunk) = byte_stream.next().await {
659 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
660
661 for line in bytes.split(|&b| b == b'\n') {
662 if line.is_empty() {
663 continue;
664 }
665
666 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
667
668 let response: CompletionResponse = serde_json::from_slice(line)?;
669
670 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
671 if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
672 thinking_response += &thinking_content;
673 yield RawStreamingChoice::ReasoningDelta {
674 id: None,
675 reasoning: thinking_content,
676 };
677 }
678
679 if !content.is_empty() {
680 text_response += &content;
681 yield RawStreamingChoice::Message(content);
682 }
683
684 for tool_call in tool_calls {
685 tool_calls_final.push(tool_call.clone());
686 yield RawStreamingChoice::ToolCall(
687 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
688 );
689 }
690 }
691
692 if response.done {
693 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
694 span.record("gen_ai.usage.output_tokens", response.eval_count);
695 let message = Message::Assistant {
696 content: text_response.clone(),
697 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
698 images: None,
699 name: None,
700 tool_calls: tool_calls_final.clone()
701 };
702 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
703 yield RawStreamingChoice::FinalResponse(
704 StreamingCompletionResponse {
705 total_duration: response.total_duration,
706 load_duration: response.load_duration,
707 prompt_eval_count: response.prompt_eval_count,
708 prompt_eval_duration: response.prompt_eval_duration,
709 eval_count: response.eval_count,
710 eval_duration: response.eval_duration,
711 done_reason: response.done_reason,
712 }
713 );
714 break;
715 }
716 }
717 }
718 }.instrument(span);
719
720 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
721 stream,
722 )))
723 }
724}
725
726#[derive(Clone, Debug, Deserialize, Serialize)]
730pub struct ToolDefinition {
731 #[serde(rename = "type")]
732 pub type_field: String, pub function: completion::ToolDefinition,
734}
735
736impl From<crate::completion::ToolDefinition> for ToolDefinition {
738 fn from(tool: crate::completion::ToolDefinition) -> Self {
739 ToolDefinition {
740 type_field: "function".to_owned(),
741 function: completion::ToolDefinition {
742 name: tool.name,
743 description: tool.description,
744 parameters: tool.parameters,
745 },
746 }
747 }
748}
749
750#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
751pub struct ToolCall {
752 #[serde(default, rename = "type")]
753 pub r#type: ToolType,
754 pub function: Function,
755}
756#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
757#[serde(rename_all = "lowercase")]
758pub enum ToolType {
759 #[default]
760 Function,
761}
762#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
763pub struct Function {
764 pub name: String,
765 pub arguments: Value,
766}
767
768#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
771#[serde(tag = "role", rename_all = "lowercase")]
772pub enum Message {
773 User {
774 content: String,
775 #[serde(skip_serializing_if = "Option::is_none")]
776 images: Option<Vec<String>>,
777 #[serde(skip_serializing_if = "Option::is_none")]
778 name: Option<String>,
779 },
780 Assistant {
781 #[serde(default)]
782 content: String,
783 #[serde(skip_serializing_if = "Option::is_none")]
784 thinking: Option<String>,
785 #[serde(skip_serializing_if = "Option::is_none")]
786 images: Option<Vec<String>>,
787 #[serde(skip_serializing_if = "Option::is_none")]
788 name: Option<String>,
789 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
790 tool_calls: Vec<ToolCall>,
791 },
792 System {
793 content: String,
794 #[serde(skip_serializing_if = "Option::is_none")]
795 images: Option<Vec<String>>,
796 #[serde(skip_serializing_if = "Option::is_none")]
797 name: Option<String>,
798 },
799 #[serde(rename = "tool")]
800 ToolResult {
801 #[serde(rename = "tool_name")]
802 name: String,
803 content: String,
804 },
805}
806
807impl TryFrom<crate::message::Message> for Vec<Message> {
813 type Error = crate::message::MessageError;
814 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
815 use crate::message::Message as InternalMessage;
816 match internal_msg {
817 InternalMessage::User { content, .. } => {
818 let (tool_results, other_content): (Vec<_>, Vec<_>) =
819 content.into_iter().partition(|content| {
820 matches!(content, crate::message::UserContent::ToolResult(_))
821 });
822
823 if !tool_results.is_empty() {
824 tool_results
825 .into_iter()
826 .map(|content| match content {
827 crate::message::UserContent::ToolResult(
828 crate::message::ToolResult { id, content, .. },
829 ) => {
830 let content_string = content
832 .into_iter()
833 .map(|content| match content {
834 crate::message::ToolResultContent::Text(text) => text.text,
835 _ => "[Non-text content]".to_string(),
836 })
837 .collect::<Vec<_>>()
838 .join("\n");
839
840 Ok::<_, crate::message::MessageError>(Message::ToolResult {
841 name: id,
842 content: content_string,
843 })
844 }
845 _ => unreachable!(),
846 })
847 .collect::<Result<Vec<_>, _>>()
848 } else {
849 let (texts, images) = other_content.into_iter().fold(
851 (Vec::new(), Vec::new()),
852 |(mut texts, mut images), content| {
853 match content {
854 crate::message::UserContent::Text(crate::message::Text {
855 text,
856 }) => texts.push(text),
857 crate::message::UserContent::Image(crate::message::Image {
858 data: DocumentSourceKind::Base64(data),
859 ..
860 }) => images.push(data),
861 crate::message::UserContent::Document(
862 crate::message::Document {
863 data:
864 DocumentSourceKind::Base64(data)
865 | DocumentSourceKind::String(data),
866 ..
867 },
868 ) => texts.push(data),
869 _ => {} }
871 (texts, images)
872 },
873 );
874
875 Ok(vec![Message::User {
876 content: texts.join(" "),
877 images: if images.is_empty() {
878 None
879 } else {
880 Some(
881 images
882 .into_iter()
883 .map(|x| x.to_string())
884 .collect::<Vec<String>>(),
885 )
886 },
887 name: None,
888 }])
889 }
890 }
891 InternalMessage::Assistant { content, .. } => {
892 let mut thinking: Option<String> = None;
893 let mut text_content = Vec::new();
894 let mut tool_calls = Vec::new();
895
896 for content in content.into_iter() {
897 match content {
898 crate::message::AssistantContent::Text(text) => {
899 text_content.push(text.text)
900 }
901 crate::message::AssistantContent::ToolCall(tool_call) => {
902 tool_calls.push(tool_call)
903 }
904 crate::message::AssistantContent::Reasoning(reasoning) => {
905 let display = reasoning.display_text();
906 if !display.is_empty() {
907 thinking = Some(display);
908 }
909 }
910 crate::message::AssistantContent::Image(_) => {
911 return Err(crate::message::MessageError::ConversionError(
912 "Ollama currently doesn't support images.".into(),
913 ));
914 }
915 }
916 }
917
918 Ok(vec![Message::Assistant {
921 content: text_content.join(" "),
922 thinking,
923 images: None,
924 name: None,
925 tool_calls: tool_calls
926 .into_iter()
927 .map(|tool_call| tool_call.into())
928 .collect::<Vec<_>>(),
929 }])
930 }
931 }
932 }
933}
934
935impl From<Message> for crate::completion::Message {
938 fn from(msg: Message) -> Self {
939 match msg {
940 Message::User { content, .. } => crate::completion::Message::User {
941 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
942 text: content,
943 })),
944 },
945 Message::Assistant {
946 content,
947 tool_calls,
948 ..
949 } => {
950 let mut assistant_contents =
951 vec![crate::completion::message::AssistantContent::Text(Text {
952 text: content,
953 })];
954 for tc in tool_calls {
955 assistant_contents.push(
956 crate::completion::message::AssistantContent::tool_call(
957 tc.function.name.clone(),
958 tc.function.name,
959 tc.function.arguments,
960 ),
961 );
962 }
963 crate::completion::Message::Assistant {
964 id: None,
965 content: OneOrMany::many(assistant_contents).unwrap(),
966 }
967 }
968 Message::System { content, .. } => crate::completion::Message::User {
970 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
971 text: content,
972 })),
973 },
974 Message::ToolResult { name, content } => crate::completion::Message::User {
975 content: OneOrMany::one(message::UserContent::tool_result(
976 name,
977 OneOrMany::one(message::ToolResultContent::text(content)),
978 )),
979 },
980 }
981 }
982}
983
984impl Message {
985 pub fn system(content: &str) -> Self {
987 Message::System {
988 content: content.to_owned(),
989 images: None,
990 name: None,
991 }
992 }
993}
994
995impl From<crate::message::ToolCall> for ToolCall {
998 fn from(tool_call: crate::message::ToolCall) -> Self {
999 Self {
1000 r#type: ToolType::Function,
1001 function: Function {
1002 name: tool_call.function.name,
1003 arguments: tool_call.function.arguments,
1004 },
1005 }
1006 }
1007}
1008
1009#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1010pub struct SystemContent {
1011 #[serde(default)]
1012 r#type: SystemContentType,
1013 text: String,
1014}
1015
1016#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1017#[serde(rename_all = "lowercase")]
1018pub enum SystemContentType {
1019 #[default]
1020 Text,
1021}
1022
1023impl From<String> for SystemContent {
1024 fn from(s: String) -> Self {
1025 SystemContent {
1026 r#type: SystemContentType::default(),
1027 text: s,
1028 }
1029 }
1030}
1031
1032impl FromStr for SystemContent {
1033 type Err = std::convert::Infallible;
1034 fn from_str(s: &str) -> Result<Self, Self::Err> {
1035 Ok(SystemContent {
1036 r#type: SystemContentType::default(),
1037 text: s.to_string(),
1038 })
1039 }
1040}
1041
1042#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1043pub struct AssistantContent {
1044 pub text: String,
1045}
1046
1047impl FromStr for AssistantContent {
1048 type Err = std::convert::Infallible;
1049 fn from_str(s: &str) -> Result<Self, Self::Err> {
1050 Ok(AssistantContent { text: s.to_owned() })
1051 }
1052}
1053
1054#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1055#[serde(tag = "type", rename_all = "lowercase")]
1056pub enum UserContent {
1057 Text { text: String },
1058 Image { image_url: ImageUrl },
1059 }
1061
1062impl FromStr for UserContent {
1063 type Err = std::convert::Infallible;
1064 fn from_str(s: &str) -> Result<Self, Self::Err> {
1065 Ok(UserContent::Text { text: s.to_owned() })
1066 }
1067}
1068
1069#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1070pub struct ImageUrl {
1071 pub url: String,
1072 #[serde(default)]
1073 pub detail: ImageDetail,
1074}
1075
1076#[cfg(test)]
1081mod tests {
1082 use super::*;
1083 use serde_json::json;
1084
1085 #[tokio::test]
1087 async fn test_chat_completion() {
1088 let sample_chat_response = json!({
1090 "model": "llama3.2",
1091 "created_at": "2023-08-04T19:22:45.499127Z",
1092 "message": {
1093 "role": "assistant",
1094 "content": "The sky is blue because of Rayleigh scattering.",
1095 "images": null,
1096 "tool_calls": [
1097 {
1098 "type": "function",
1099 "function": {
1100 "name": "get_current_weather",
1101 "arguments": {
1102 "location": "San Francisco, CA",
1103 "format": "celsius"
1104 }
1105 }
1106 }
1107 ]
1108 },
1109 "done": true,
1110 "total_duration": 8000000000u64,
1111 "load_duration": 6000000u64,
1112 "prompt_eval_count": 61u64,
1113 "prompt_eval_duration": 400000000u64,
1114 "eval_count": 468u64,
1115 "eval_duration": 7700000000u64
1116 });
1117 let sample_text = sample_chat_response.to_string();
1118
1119 let chat_resp: CompletionResponse =
1120 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1121 let conv: completion::CompletionResponse<CompletionResponse> =
1122 chat_resp.try_into().unwrap();
1123 assert!(
1124 !conv.choice.is_empty(),
1125 "Expected non-empty choice in chat response"
1126 );
1127 }
1128
1129 #[test]
1131 fn test_message_conversion() {
1132 let provider_msg = Message::User {
1134 content: "Test message".to_owned(),
1135 images: None,
1136 name: None,
1137 };
1138 let comp_msg: crate::completion::Message = provider_msg.into();
1140 match comp_msg {
1141 crate::completion::Message::User { content } => {
1142 let first_content = content.first();
1144 match first_content {
1146 crate::completion::message::UserContent::Text(text_struct) => {
1147 assert_eq!(text_struct.text, "Test message");
1148 }
1149 _ => panic!("Expected text content in conversion"),
1150 }
1151 }
1152 _ => panic!("Conversion from provider Message to completion Message failed"),
1153 }
1154 }
1155
1156 #[test]
1158 fn test_tool_definition_conversion() {
1159 let internal_tool = crate::completion::ToolDefinition {
1161 name: "get_current_weather".to_owned(),
1162 description: "Get the current weather for a location".to_owned(),
1163 parameters: json!({
1164 "type": "object",
1165 "properties": {
1166 "location": {
1167 "type": "string",
1168 "description": "The location to get the weather for, e.g. San Francisco, CA"
1169 },
1170 "format": {
1171 "type": "string",
1172 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1173 "enum": ["celsius", "fahrenheit"]
1174 }
1175 },
1176 "required": ["location", "format"]
1177 }),
1178 };
1179 let ollama_tool: ToolDefinition = internal_tool.into();
1181 assert_eq!(ollama_tool.type_field, "function");
1182 assert_eq!(ollama_tool.function.name, "get_current_weather");
1183 assert_eq!(
1184 ollama_tool.function.description,
1185 "Get the current weather for a location"
1186 );
1187 let params = &ollama_tool.function.parameters;
1189 assert_eq!(params["properties"]["location"]["type"], "string");
1190 }
1191
1192 #[tokio::test]
1194 async fn test_chat_completion_with_thinking() {
1195 let sample_response = json!({
1196 "model": "qwen-thinking",
1197 "created_at": "2023-08-04T19:22:45.499127Z",
1198 "message": {
1199 "role": "assistant",
1200 "content": "The answer is 42.",
1201 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1202 "images": null,
1203 "tool_calls": []
1204 },
1205 "done": true,
1206 "total_duration": 8000000000u64,
1207 "load_duration": 6000000u64,
1208 "prompt_eval_count": 61u64,
1209 "prompt_eval_duration": 400000000u64,
1210 "eval_count": 468u64,
1211 "eval_duration": 7700000000u64
1212 });
1213
1214 let chat_resp: CompletionResponse =
1215 serde_json::from_value(sample_response).expect("Failed to deserialize");
1216
1217 if let Message::Assistant {
1219 thinking, content, ..
1220 } = &chat_resp.message
1221 {
1222 assert_eq!(
1223 thinking.as_ref().unwrap(),
1224 "Let me think about this carefully. The question asks for the meaning of life..."
1225 );
1226 assert_eq!(content, "The answer is 42.");
1227 } else {
1228 panic!("Expected Assistant message");
1229 }
1230 }
1231
1232 #[tokio::test]
1234 async fn test_chat_completion_without_thinking() {
1235 let sample_response = json!({
1236 "model": "llama3.2",
1237 "created_at": "2023-08-04T19:22:45.499127Z",
1238 "message": {
1239 "role": "assistant",
1240 "content": "Hello!",
1241 "images": null,
1242 "tool_calls": []
1243 },
1244 "done": true,
1245 "total_duration": 8000000000u64,
1246 "load_duration": 6000000u64,
1247 "prompt_eval_count": 10u64,
1248 "prompt_eval_duration": 400000000u64,
1249 "eval_count": 5u64,
1250 "eval_duration": 7700000000u64
1251 });
1252
1253 let chat_resp: CompletionResponse =
1254 serde_json::from_value(sample_response).expect("Failed to deserialize");
1255
1256 if let Message::Assistant {
1258 thinking, content, ..
1259 } = &chat_resp.message
1260 {
1261 assert!(thinking.is_none());
1262 assert_eq!(content, "Hello!");
1263 } else {
1264 panic!("Expected Assistant message");
1265 }
1266 }
1267
1268 #[test]
1270 fn test_streaming_response_with_thinking() {
1271 let sample_chunk = json!({
1272 "model": "qwen-thinking",
1273 "created_at": "2023-08-04T19:22:45.499127Z",
1274 "message": {
1275 "role": "assistant",
1276 "content": "",
1277 "thinking": "Analyzing the problem...",
1278 "images": null,
1279 "tool_calls": []
1280 },
1281 "done": false
1282 });
1283
1284 let chunk: CompletionResponse =
1285 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1286
1287 if let Message::Assistant {
1288 thinking, content, ..
1289 } = &chunk.message
1290 {
1291 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1292 assert_eq!(content, "");
1293 } else {
1294 panic!("Expected Assistant message");
1295 }
1296 }
1297
1298 #[test]
1300 fn test_message_conversion_with_thinking() {
1301 let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1303
1304 let internal_msg = crate::message::Message::Assistant {
1305 id: None,
1306 content: crate::OneOrMany::many(vec![
1307 crate::message::AssistantContent::Reasoning(reasoning_content),
1308 crate::message::AssistantContent::Text(crate::message::Text {
1309 text: "The answer is X".to_string(),
1310 }),
1311 ])
1312 .unwrap(),
1313 };
1314
1315 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1317 assert_eq!(provider_msgs.len(), 1);
1318
1319 if let Message::Assistant {
1320 thinking, content, ..
1321 } = &provider_msgs[0]
1322 {
1323 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1324 assert_eq!(content, "The answer is X");
1325 } else {
1326 panic!("Expected Assistant message with thinking");
1327 }
1328 }
1329
1330 #[test]
1332 fn test_empty_thinking_content() {
1333 let sample_response = json!({
1334 "model": "llama3.2",
1335 "created_at": "2023-08-04T19:22:45.499127Z",
1336 "message": {
1337 "role": "assistant",
1338 "content": "Response",
1339 "thinking": "",
1340 "images": null,
1341 "tool_calls": []
1342 },
1343 "done": true,
1344 "total_duration": 8000000000u64,
1345 "load_duration": 6000000u64,
1346 "prompt_eval_count": 10u64,
1347 "prompt_eval_duration": 400000000u64,
1348 "eval_count": 5u64,
1349 "eval_duration": 7700000000u64
1350 });
1351
1352 let chat_resp: CompletionResponse =
1353 serde_json::from_value(sample_response).expect("Failed to deserialize");
1354
1355 if let Message::Assistant {
1356 thinking, content, ..
1357 } = &chat_resp.message
1358 {
1359 assert_eq!(thinking.as_ref().unwrap(), "");
1361 assert_eq!(content, "Response");
1362 } else {
1363 panic!("Expected Assistant message");
1364 }
1365 }
1366
1367 #[test]
1369 fn test_thinking_with_tool_calls() {
1370 let sample_response = json!({
1371 "model": "qwen-thinking",
1372 "created_at": "2023-08-04T19:22:45.499127Z",
1373 "message": {
1374 "role": "assistant",
1375 "content": "Let me check the weather.",
1376 "thinking": "User wants weather info, I should use the weather tool",
1377 "images": null,
1378 "tool_calls": [
1379 {
1380 "type": "function",
1381 "function": {
1382 "name": "get_weather",
1383 "arguments": {
1384 "location": "San Francisco"
1385 }
1386 }
1387 }
1388 ]
1389 },
1390 "done": true,
1391 "total_duration": 8000000000u64,
1392 "load_duration": 6000000u64,
1393 "prompt_eval_count": 30u64,
1394 "prompt_eval_duration": 400000000u64,
1395 "eval_count": 50u64,
1396 "eval_duration": 7700000000u64
1397 });
1398
1399 let chat_resp: CompletionResponse =
1400 serde_json::from_value(sample_response).expect("Failed to deserialize");
1401
1402 if let Message::Assistant {
1403 thinking,
1404 content,
1405 tool_calls,
1406 ..
1407 } = &chat_resp.message
1408 {
1409 assert_eq!(
1410 thinking.as_ref().unwrap(),
1411 "User wants weather info, I should use the weather tool"
1412 );
1413 assert_eq!(content, "Let me check the weather.");
1414 assert_eq!(tool_calls.len(), 1);
1415 assert_eq!(tool_calls[0].function.name, "get_weather");
1416 } else {
1417 panic!("Expected Assistant message with thinking and tool calls");
1418 }
1419 }
1420
1421 #[test]
1423 fn test_completion_request_with_think_param() {
1424 use crate::OneOrMany;
1425 use crate::completion::Message as CompletionMessage;
1426 use crate::message::{Text, UserContent};
1427
1428 let completion_request = CompletionRequest {
1430 model: None,
1431 preamble: Some("You are a helpful assistant.".to_string()),
1432 chat_history: OneOrMany::one(CompletionMessage::User {
1433 content: OneOrMany::one(UserContent::Text(Text {
1434 text: "What is 2 + 2?".to_string(),
1435 })),
1436 }),
1437 documents: vec![],
1438 tools: vec![],
1439 temperature: Some(0.7),
1440 max_tokens: Some(1024),
1441 tool_choice: None,
1442 additional_params: Some(json!({
1443 "think": true,
1444 "keep_alive": "-1m",
1445 "num_ctx": 4096
1446 })),
1447 output_schema: None,
1448 };
1449
1450 let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1452 .expect("Failed to create Ollama request");
1453
1454 let serialized =
1456 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1457
1458 let expected = json!({
1464 "model": "qwen3:8b",
1465 "messages": [
1466 {
1467 "role": "system",
1468 "content": "You are a helpful assistant."
1469 },
1470 {
1471 "role": "user",
1472 "content": "What is 2 + 2?"
1473 }
1474 ],
1475 "temperature": 0.7,
1476 "stream": false,
1477 "think": true,
1478 "max_tokens": 1024,
1479 "keep_alive": "-1m",
1480 "options": {
1481 "temperature": 0.7,
1482 "num_ctx": 4096
1483 }
1484 });
1485
1486 assert_eq!(serialized, expected);
1487 }
1488
1489 #[test]
1491 fn test_completion_request_with_think_false_default() {
1492 use crate::OneOrMany;
1493 use crate::completion::Message as CompletionMessage;
1494 use crate::message::{Text, UserContent};
1495
1496 let completion_request = CompletionRequest {
1498 model: None,
1499 preamble: Some("You are a helpful assistant.".to_string()),
1500 chat_history: OneOrMany::one(CompletionMessage::User {
1501 content: OneOrMany::one(UserContent::Text(Text {
1502 text: "Hello!".to_string(),
1503 })),
1504 }),
1505 documents: vec![],
1506 tools: vec![],
1507 temperature: Some(0.5),
1508 max_tokens: None,
1509 tool_choice: None,
1510 additional_params: None,
1511 output_schema: None,
1512 };
1513
1514 let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1516 .expect("Failed to create Ollama request");
1517
1518 let serialized =
1520 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1521
1522 let expected = json!({
1524 "model": "llama3.2",
1525 "messages": [
1526 {
1527 "role": "system",
1528 "content": "You are a helpful assistant."
1529 },
1530 {
1531 "role": "user",
1532 "content": "Hello!"
1533 }
1534 ],
1535 "temperature": 0.5,
1536 "stream": false,
1537 "think": false,
1538 "options": {
1539 "temperature": 0.5
1540 }
1541 });
1542
1543 assert_eq!(serialized, expected);
1544 }
1545
1546 #[test]
1547 fn test_completion_request_with_output_schema() {
1548 use crate::OneOrMany;
1549 use crate::completion::Message as CompletionMessage;
1550 use crate::message::{Text, UserContent};
1551
1552 let schema: schemars::Schema = serde_json::from_value(json!({
1553 "type": "object",
1554 "properties": {
1555 "age": { "type": "integer" },
1556 "available": { "type": "boolean" }
1557 },
1558 "required": ["age", "available"]
1559 }))
1560 .expect("Failed to parse schema");
1561
1562 let completion_request = CompletionRequest {
1563 model: Some("llama3.1".to_string()),
1564 preamble: None,
1565 chat_history: OneOrMany::one(CompletionMessage::User {
1566 content: OneOrMany::one(UserContent::Text(Text {
1567 text: "How old is Ollama?".to_string(),
1568 })),
1569 }),
1570 documents: vec![],
1571 tools: vec![],
1572 temperature: None,
1573 max_tokens: None,
1574 tool_choice: None,
1575 additional_params: None,
1576 output_schema: Some(schema),
1577 };
1578
1579 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1580 .expect("Failed to create Ollama request");
1581
1582 let serialized =
1583 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1584
1585 let format = serialized
1586 .get("format")
1587 .expect("format field should be present");
1588 assert_eq!(
1589 *format,
1590 json!({
1591 "type": "object",
1592 "properties": {
1593 "age": { "type": "integer" },
1594 "available": { "type": "boolean" }
1595 },
1596 "required": ["age", "available"]
1597 })
1598 );
1599 }
1600
1601 #[test]
1602 fn test_completion_request_without_output_schema() {
1603 use crate::OneOrMany;
1604 use crate::completion::Message as CompletionMessage;
1605 use crate::message::{Text, UserContent};
1606
1607 let completion_request = CompletionRequest {
1608 model: Some("llama3.1".to_string()),
1609 preamble: None,
1610 chat_history: OneOrMany::one(CompletionMessage::User {
1611 content: OneOrMany::one(UserContent::Text(Text {
1612 text: "Hello!".to_string(),
1613 })),
1614 }),
1615 documents: vec![],
1616 tools: vec![],
1617 temperature: None,
1618 max_tokens: None,
1619 tool_choice: None,
1620 additional_params: None,
1621 output_schema: None,
1622 };
1623
1624 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1625 .expect("Failed to create Ollama request");
1626
1627 let serialized =
1628 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1629
1630 assert!(
1631 serialized.get("format").is_none(),
1632 "format field should be absent when output_schema is None"
1633 );
1634 }
1635
1636 #[test]
1637 fn test_client_initialization() {
1638 let _client: crate::providers::ollama::Client =
1639 crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1640 let _client_from_builder: crate::providers::ollama::Client =
1641 crate::providers::ollama::Client::builder()
1642 .api_key(Nothing)
1643 .build()
1644 .expect("Client::builder() failed");
1645 }
1646}