1use crate::client::{
35 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
36};
37use crate::completion::{GetTokenUsage, Usage};
38use crate::http_client::{self, HttpClientExt};
39use crate::message::DocumentSourceKind;
40use crate::streaming::RawStreamingChoice;
41use crate::{
42 OneOrMany,
43 completion::{self, CompletionError, CompletionRequest},
44 embeddings::{self, EmbeddingError},
45 json_utils, message,
46 message::{ImageDetail, Text},
47 streaming,
48};
49use async_stream::try_stream;
50use bytes::Bytes;
51use futures::StreamExt;
52use serde::{Deserialize, Serialize};
53use serde_json::{Value, json};
54use std::{convert::TryFrom, str::FromStr};
55use tracing::info_span;
56use tracing_futures::Instrument;
57const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Debug, Default, Clone, Copy)]
62pub struct OllamaExt;
63
64#[derive(Debug, Default, Clone, Copy)]
65pub struct OllamaBuilder;
66
67impl Provider for OllamaExt {
68 type Builder = OllamaBuilder;
69
70 const VERIFY_PATH: &'static str = "api/tags";
71
72 fn build<H>(
73 _: &crate::client::ClientBuilder<
74 Self::Builder,
75 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
76 H,
77 >,
78 ) -> http_client::Result<Self> {
79 Ok(Self)
80 }
81}
82
83impl<H> Capabilities<H> for OllamaExt {
84 type Completion = Capable<CompletionModel<H>>;
85 type Transcription = Nothing;
86 type Embeddings = Capable<EmbeddingModel<H>>;
87 #[cfg(feature = "image")]
88 type ImageGeneration = Nothing;
89
90 #[cfg(feature = "audio")]
91 type AudioGeneration = Nothing;
92}
93
94impl DebugExt for OllamaExt {}
95
96impl ProviderBuilder for OllamaBuilder {
97 type Output = OllamaExt;
98 type ApiKey = Nothing;
99
100 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
101}
102
103pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
104pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
105
106impl ProviderClient for Client {
107 type Input = Nothing;
108
109 fn from_env() -> Self {
110 let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
111
112 Self::builder()
113 .api_key(Nothing)
114 .base_url(&api_base)
115 .build()
116 .unwrap()
117 }
118
119 fn from_val(_: Self::Input) -> Self {
120 Self::builder().api_key(Nothing).build().unwrap()
121 }
122}
123
124#[derive(Debug, Deserialize)]
127struct ApiErrorResponse {
128 message: String,
129}
130
131#[derive(Debug, Deserialize)]
132#[serde(untagged)]
133enum ApiResponse<T> {
134 Ok(T),
135 Err(ApiErrorResponse),
136}
137
138pub const ALL_MINILM: &str = "all-minilm";
141pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
142
143#[derive(Debug, Serialize, Deserialize)]
144pub struct EmbeddingResponse {
145 pub model: String,
146 pub embeddings: Vec<Vec<f64>>,
147 #[serde(default)]
148 pub total_duration: Option<u64>,
149 #[serde(default)]
150 pub load_duration: Option<u64>,
151 #[serde(default)]
152 pub prompt_eval_count: Option<u64>,
153}
154
155impl From<ApiErrorResponse> for EmbeddingError {
156 fn from(err: ApiErrorResponse) -> Self {
157 EmbeddingError::ProviderError(err.message)
158 }
159}
160
161impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
162 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
163 match value {
164 ApiResponse::Ok(response) => Ok(response),
165 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
166 }
167 }
168}
169
170#[derive(Clone)]
173pub struct EmbeddingModel<T = reqwest::Client> {
174 client: Client<T>,
175 pub model: String,
176 ndims: usize,
177}
178
179impl<T> EmbeddingModel<T> {
180 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
181 Self {
182 client,
183 model: model.into(),
184 ndims,
185 }
186 }
187
188 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
189 Self {
190 client,
191 model: model.into(),
192 ndims,
193 }
194 }
195}
196
197impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
198where
199 T: HttpClientExt + Clone + 'static,
200{
201 type Client = Client<T>;
202
203 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
204 Self::new(client.clone(), model, dims.unwrap())
205 }
206
207 const MAX_DOCUMENTS: usize = 1024;
208 fn ndims(&self) -> usize {
209 self.ndims
210 }
211
212 async fn embed_texts(
213 &self,
214 documents: impl IntoIterator<Item = String>,
215 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
216 let docs: Vec<String> = documents.into_iter().collect();
217
218 let body = serde_json::to_vec(&json!({
219 "model": self.model,
220 "input": docs
221 }))?;
222
223 let req = self
224 .client
225 .post("api/embed")?
226 .body(body)
227 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
228
229 let response = self.client.send(req).await?;
230
231 if !response.status().is_success() {
232 let text = http_client::text(response).await?;
233 return Err(EmbeddingError::ProviderError(text));
234 }
235
236 let bytes: Vec<u8> = response.into_body().await?;
237
238 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
239
240 if api_resp.embeddings.len() != docs.len() {
241 return Err(EmbeddingError::ResponseError(
242 "Number of returned embeddings does not match input".into(),
243 ));
244 }
245 Ok(api_resp
246 .embeddings
247 .into_iter()
248 .zip(docs.into_iter())
249 .map(|(vec, document)| embeddings::Embedding { document, vec })
250 .collect())
251 }
252}
253
254pub const LLAMA3_2: &str = "llama3.2";
257pub const LLAVA: &str = "llava";
258pub const MISTRAL: &str = "mistral";
259
260#[derive(Debug, Serialize, Deserialize)]
261pub struct CompletionResponse {
262 pub model: String,
263 pub created_at: String,
264 pub message: Message,
265 pub done: bool,
266 #[serde(default)]
267 pub done_reason: Option<String>,
268 #[serde(default)]
269 pub total_duration: Option<u64>,
270 #[serde(default)]
271 pub load_duration: Option<u64>,
272 #[serde(default)]
273 pub prompt_eval_count: Option<u64>,
274 #[serde(default)]
275 pub prompt_eval_duration: Option<u64>,
276 #[serde(default)]
277 pub eval_count: Option<u64>,
278 #[serde(default)]
279 pub eval_duration: Option<u64>,
280}
281impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
282 type Error = CompletionError;
283 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
284 match resp.message {
285 Message::Assistant {
287 content,
288 thinking,
289 tool_calls,
290 ..
291 } => {
292 let mut assistant_contents = Vec::new();
293 if !content.is_empty() {
295 assistant_contents.push(completion::AssistantContent::text(&content));
296 }
297 for tc in tool_calls.iter() {
300 assistant_contents.push(completion::AssistantContent::tool_call(
301 tc.function.name.clone(),
302 tc.function.name.clone(),
303 tc.function.arguments.clone(),
304 ));
305 }
306 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
307 CompletionError::ResponseError("No content provided".to_owned())
308 })?;
309 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
310 let completion_tokens = resp.eval_count.unwrap_or(0);
311
312 let raw_response = CompletionResponse {
313 model: resp.model,
314 created_at: resp.created_at,
315 done: resp.done,
316 done_reason: resp.done_reason,
317 total_duration: resp.total_duration,
318 load_duration: resp.load_duration,
319 prompt_eval_count: resp.prompt_eval_count,
320 prompt_eval_duration: resp.prompt_eval_duration,
321 eval_count: resp.eval_count,
322 eval_duration: resp.eval_duration,
323 message: Message::Assistant {
324 content,
325 thinking,
326 images: None,
327 name: None,
328 tool_calls,
329 },
330 };
331
332 Ok(completion::CompletionResponse {
333 choice,
334 usage: Usage {
335 input_tokens: prompt_tokens,
336 output_tokens: completion_tokens,
337 total_tokens: prompt_tokens + completion_tokens,
338 cached_input_tokens: 0,
339 },
340 raw_response,
341 })
342 }
343 _ => Err(CompletionError::ResponseError(
344 "Chat response does not include an assistant message".into(),
345 )),
346 }
347 }
348}
349
350#[derive(Debug, Serialize, Deserialize)]
351pub(super) struct OllamaCompletionRequest {
352 model: String,
353 pub messages: Vec<Message>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 temperature: Option<f64>,
356 #[serde(skip_serializing_if = "Vec::is_empty")]
357 tools: Vec<ToolDefinition>,
358 pub stream: bool,
359 think: bool,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 max_tokens: Option<u64>,
362 options: serde_json::Value,
363}
364
365impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
366 type Error = CompletionError;
367
368 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
369 if req.tool_choice.is_some() {
370 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
371 }
372 let mut partial_history = vec![];
374 if let Some(docs) = req.normalized_documents() {
375 partial_history.push(docs);
376 }
377 partial_history.extend(req.chat_history);
378
379 let mut full_history: Vec<Message> = match &req.preamble {
381 Some(preamble) => vec![Message::system(preamble)],
382 None => vec![],
383 };
384
385 full_history.extend(
387 partial_history
388 .into_iter()
389 .map(message::Message::try_into)
390 .collect::<Result<Vec<Vec<Message>>, _>>()?
391 .into_iter()
392 .flatten()
393 .collect::<Vec<_>>(),
394 );
395
396 let mut think = false;
397
398 let options = if let Some(mut extra) = req.additional_params {
400 if extra.get("think").is_some() {
401 think = extra["think"].take().as_bool().ok_or_else(|| {
402 CompletionError::RequestError("`think` must be a bool".into())
403 })?;
404 }
405 json_utils::merge(json!({ "temperature": req.temperature }), extra)
406 } else {
407 json!({ "temperature": req.temperature })
408 };
409
410 Ok(Self {
411 model: model.to_string(),
412 messages: full_history,
413 temperature: req.temperature,
414 max_tokens: req.max_tokens,
415 stream: false,
416 think,
417 tools: req
418 .tools
419 .clone()
420 .into_iter()
421 .map(ToolDefinition::from)
422 .collect::<Vec<_>>(),
423 options,
424 })
425 }
426}
427
428#[derive(Clone)]
429pub struct CompletionModel<T = reqwest::Client> {
430 client: Client<T>,
431 pub model: String,
432}
433
434impl<T> CompletionModel<T> {
435 pub fn new(client: Client<T>, model: &str) -> Self {
436 Self {
437 client,
438 model: model.to_owned(),
439 }
440 }
441}
442
443#[derive(Clone, Serialize, Deserialize, Debug)]
446pub struct StreamingCompletionResponse {
447 pub done_reason: Option<String>,
448 pub total_duration: Option<u64>,
449 pub load_duration: Option<u64>,
450 pub prompt_eval_count: Option<u64>,
451 pub prompt_eval_duration: Option<u64>,
452 pub eval_count: Option<u64>,
453 pub eval_duration: Option<u64>,
454}
455
456impl GetTokenUsage for StreamingCompletionResponse {
457 fn token_usage(&self) -> Option<crate::completion::Usage> {
458 let mut usage = crate::completion::Usage::new();
459 let input_tokens = self.prompt_eval_count.unwrap_or_default();
460 let output_tokens = self.eval_count.unwrap_or_default();
461 usage.input_tokens = input_tokens;
462 usage.output_tokens = output_tokens;
463 usage.total_tokens = input_tokens + output_tokens;
464
465 Some(usage)
466 }
467}
468
469impl<T> completion::CompletionModel for CompletionModel<T>
470where
471 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
472{
473 type Response = CompletionResponse;
474 type StreamingResponse = StreamingCompletionResponse;
475
476 type Client = Client<T>;
477
478 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
479 Self::new(client.clone(), model.into().as_str())
480 }
481
482 async fn completion(
483 &self,
484 completion_request: CompletionRequest,
485 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
486 let span = if tracing::Span::current().is_disabled() {
487 info_span!(
488 target: "rig::completions",
489 "chat",
490 gen_ai.operation.name = "chat",
491 gen_ai.provider.name = "ollama",
492 gen_ai.request.model = self.model,
493 gen_ai.system_instructions = tracing::field::Empty,
494 gen_ai.response.id = tracing::field::Empty,
495 gen_ai.response.model = tracing::field::Empty,
496 gen_ai.usage.output_tokens = tracing::field::Empty,
497 gen_ai.usage.input_tokens = tracing::field::Empty,
498 )
499 } else {
500 tracing::Span::current()
501 };
502
503 span.record("gen_ai.system_instructions", &completion_request.preamble);
504 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
505
506 if tracing::enabled!(tracing::Level::TRACE) {
507 tracing::trace!(target: "rig::completions",
508 "Ollama completion request: {}",
509 serde_json::to_string_pretty(&request)?
510 );
511 }
512
513 let body = serde_json::to_vec(&request)?;
514
515 let req = self
516 .client
517 .post("api/chat")?
518 .body(body)
519 .map_err(http_client::Error::from)?;
520
521 let async_block = async move {
522 let response = self.client.send::<_, Bytes>(req).await?;
523 let status = response.status();
524 let response_body = response.into_body().into_future().await?.to_vec();
525
526 if !status.is_success() {
527 return Err(CompletionError::ProviderError(
528 String::from_utf8_lossy(&response_body).to_string(),
529 ));
530 }
531
532 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
533 let span = tracing::Span::current();
534 span.record("gen_ai.response.model_name", &response.model);
535 span.record(
536 "gen_ai.usage.input_tokens",
537 response.prompt_eval_count.unwrap_or_default(),
538 );
539 span.record(
540 "gen_ai.usage.output_tokens",
541 response.eval_count.unwrap_or_default(),
542 );
543
544 if tracing::enabled!(tracing::Level::TRACE) {
545 tracing::trace!(target: "rig::completions",
546 "Ollama completion response: {}",
547 serde_json::to_string_pretty(&response)?
548 );
549 }
550
551 let response: completion::CompletionResponse<CompletionResponse> =
552 response.try_into()?;
553
554 Ok(response)
555 };
556
557 tracing::Instrument::instrument(async_block, span).await
558 }
559
560 async fn stream(
561 &self,
562 request: CompletionRequest,
563 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
564 {
565 let span = if tracing::Span::current().is_disabled() {
566 info_span!(
567 target: "rig::completions",
568 "chat_streaming",
569 gen_ai.operation.name = "chat_streaming",
570 gen_ai.provider.name = "ollama",
571 gen_ai.request.model = self.model,
572 gen_ai.system_instructions = tracing::field::Empty,
573 gen_ai.response.id = tracing::field::Empty,
574 gen_ai.response.model = self.model,
575 gen_ai.usage.output_tokens = tracing::field::Empty,
576 gen_ai.usage.input_tokens = tracing::field::Empty,
577 )
578 } else {
579 tracing::Span::current()
580 };
581
582 span.record("gen_ai.system_instructions", &request.preamble);
583
584 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
585 request.stream = true;
586
587 if tracing::enabled!(tracing::Level::TRACE) {
588 tracing::trace!(target: "rig::completions",
589 "Ollama streaming completion request: {}",
590 serde_json::to_string_pretty(&request)?
591 );
592 }
593
594 let body = serde_json::to_vec(&request)?;
595
596 let req = self
597 .client
598 .post("api/chat")?
599 .body(body)
600 .map_err(http_client::Error::from)?;
601
602 let response = self.client.send_streaming(req).await?;
603 let status = response.status();
604 let mut byte_stream = response.into_body();
605
606 if !status.is_success() {
607 return Err(CompletionError::ProviderError(format!(
608 "Got error status code trying to send a request to Ollama: {status}"
609 )));
610 }
611
612 let stream = try_stream! {
613 let span = tracing::Span::current();
614 let mut tool_calls_final = Vec::new();
615 let mut text_response = String::new();
616 let mut thinking_response = String::new();
617
618 while let Some(chunk) = byte_stream.next().await {
619 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
620
621 for line in bytes.split(|&b| b == b'\n') {
622 if line.is_empty() {
623 continue;
624 }
625
626 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
627
628 let response: CompletionResponse = serde_json::from_slice(line)?;
629
630 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
631 if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
632 thinking_response += &thinking_content;
633 yield RawStreamingChoice::ReasoningDelta {
634 id: None,
635 reasoning: thinking_content,
636 };
637 }
638
639 if !content.is_empty() {
640 text_response += &content;
641 yield RawStreamingChoice::Message(content);
642 }
643
644 for tool_call in tool_calls {
645 tool_calls_final.push(tool_call.clone());
646 yield RawStreamingChoice::ToolCall(
647 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
648 );
649 }
650 }
651
652 if response.done {
653 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
654 span.record("gen_ai.usage.output_tokens", response.eval_count);
655 let message = Message::Assistant {
656 content: text_response.clone(),
657 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
658 images: None,
659 name: None,
660 tool_calls: tool_calls_final.clone()
661 };
662 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap());
663 yield RawStreamingChoice::FinalResponse(
664 StreamingCompletionResponse {
665 total_duration: response.total_duration,
666 load_duration: response.load_duration,
667 prompt_eval_count: response.prompt_eval_count,
668 prompt_eval_duration: response.prompt_eval_duration,
669 eval_count: response.eval_count,
670 eval_duration: response.eval_duration,
671 done_reason: response.done_reason,
672 }
673 );
674 break;
675 }
676 }
677 }
678 }.instrument(span);
679
680 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
681 stream,
682 )))
683 }
684}
685
686#[derive(Clone, Debug, Deserialize, Serialize)]
690pub struct ToolDefinition {
691 #[serde(rename = "type")]
692 pub type_field: String, pub function: completion::ToolDefinition,
694}
695
696impl From<crate::completion::ToolDefinition> for ToolDefinition {
698 fn from(tool: crate::completion::ToolDefinition) -> Self {
699 ToolDefinition {
700 type_field: "function".to_owned(),
701 function: completion::ToolDefinition {
702 name: tool.name,
703 description: tool.description,
704 parameters: tool.parameters,
705 },
706 }
707 }
708}
709
710#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
711pub struct ToolCall {
712 #[serde(default, rename = "type")]
713 pub r#type: ToolType,
714 pub function: Function,
715}
716#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
717#[serde(rename_all = "lowercase")]
718pub enum ToolType {
719 #[default]
720 Function,
721}
722#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
723pub struct Function {
724 pub name: String,
725 pub arguments: Value,
726}
727
728#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
731#[serde(tag = "role", rename_all = "lowercase")]
732pub enum Message {
733 User {
734 content: String,
735 #[serde(skip_serializing_if = "Option::is_none")]
736 images: Option<Vec<String>>,
737 #[serde(skip_serializing_if = "Option::is_none")]
738 name: Option<String>,
739 },
740 Assistant {
741 #[serde(default)]
742 content: String,
743 #[serde(skip_serializing_if = "Option::is_none")]
744 thinking: Option<String>,
745 #[serde(skip_serializing_if = "Option::is_none")]
746 images: Option<Vec<String>>,
747 #[serde(skip_serializing_if = "Option::is_none")]
748 name: Option<String>,
749 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
750 tool_calls: Vec<ToolCall>,
751 },
752 System {
753 content: String,
754 #[serde(skip_serializing_if = "Option::is_none")]
755 images: Option<Vec<String>>,
756 #[serde(skip_serializing_if = "Option::is_none")]
757 name: Option<String>,
758 },
759 #[serde(rename = "tool")]
760 ToolResult {
761 #[serde(rename = "tool_name")]
762 name: String,
763 content: String,
764 },
765}
766
767impl TryFrom<crate::message::Message> for Vec<Message> {
773 type Error = crate::message::MessageError;
774 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
775 use crate::message::Message as InternalMessage;
776 match internal_msg {
777 InternalMessage::User { content, .. } => {
778 let (tool_results, other_content): (Vec<_>, Vec<_>) =
779 content.into_iter().partition(|content| {
780 matches!(content, crate::message::UserContent::ToolResult(_))
781 });
782
783 if !tool_results.is_empty() {
784 tool_results
785 .into_iter()
786 .map(|content| match content {
787 crate::message::UserContent::ToolResult(
788 crate::message::ToolResult { id, content, .. },
789 ) => {
790 let content_string = content
792 .into_iter()
793 .map(|content| match content {
794 crate::message::ToolResultContent::Text(text) => text.text,
795 _ => "[Non-text content]".to_string(),
796 })
797 .collect::<Vec<_>>()
798 .join("\n");
799
800 Ok::<_, crate::message::MessageError>(Message::ToolResult {
801 name: id,
802 content: content_string,
803 })
804 }
805 _ => unreachable!(),
806 })
807 .collect::<Result<Vec<_>, _>>()
808 } else {
809 let (texts, images) = other_content.into_iter().fold(
811 (Vec::new(), Vec::new()),
812 |(mut texts, mut images), content| {
813 match content {
814 crate::message::UserContent::Text(crate::message::Text {
815 text,
816 }) => texts.push(text),
817 crate::message::UserContent::Image(crate::message::Image {
818 data: DocumentSourceKind::Base64(data),
819 ..
820 }) => images.push(data),
821 crate::message::UserContent::Document(
822 crate::message::Document {
823 data:
824 DocumentSourceKind::Base64(data)
825 | DocumentSourceKind::String(data),
826 ..
827 },
828 ) => texts.push(data),
829 _ => {} }
831 (texts, images)
832 },
833 );
834
835 Ok(vec![Message::User {
836 content: texts.join(" "),
837 images: if images.is_empty() {
838 None
839 } else {
840 Some(
841 images
842 .into_iter()
843 .map(|x| x.to_string())
844 .collect::<Vec<String>>(),
845 )
846 },
847 name: None,
848 }])
849 }
850 }
851 InternalMessage::Assistant { content, .. } => {
852 let mut thinking: Option<String> = None;
853 let mut text_content = Vec::new();
854 let mut tool_calls = Vec::new();
855
856 for content in content.into_iter() {
857 match content {
858 crate::message::AssistantContent::Text(text) => {
859 text_content.push(text.text)
860 }
861 crate::message::AssistantContent::ToolCall(tool_call) => {
862 tool_calls.push(tool_call)
863 }
864 crate::message::AssistantContent::Reasoning(
865 crate::message::Reasoning { reasoning, .. },
866 ) => {
867 thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
868 }
869 crate::message::AssistantContent::Image(_) => {
870 return Err(crate::message::MessageError::ConversionError(
871 "Ollama currently doesn't support images.".into(),
872 ));
873 }
874 }
875 }
876
877 Ok(vec![Message::Assistant {
880 content: text_content.join(" "),
881 thinking,
882 images: None,
883 name: None,
884 tool_calls: tool_calls
885 .into_iter()
886 .map(|tool_call| tool_call.into())
887 .collect::<Vec<_>>(),
888 }])
889 }
890 }
891 }
892}
893
894impl From<Message> for crate::completion::Message {
897 fn from(msg: Message) -> Self {
898 match msg {
899 Message::User { content, .. } => crate::completion::Message::User {
900 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
901 text: content,
902 })),
903 },
904 Message::Assistant {
905 content,
906 tool_calls,
907 ..
908 } => {
909 let mut assistant_contents =
910 vec![crate::completion::message::AssistantContent::Text(Text {
911 text: content,
912 })];
913 for tc in tool_calls {
914 assistant_contents.push(
915 crate::completion::message::AssistantContent::tool_call(
916 tc.function.name.clone(),
917 tc.function.name,
918 tc.function.arguments,
919 ),
920 );
921 }
922 crate::completion::Message::Assistant {
923 id: None,
924 content: OneOrMany::many(assistant_contents).unwrap(),
925 }
926 }
927 Message::System { content, .. } => crate::completion::Message::User {
929 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
930 text: content,
931 })),
932 },
933 Message::ToolResult { name, content } => crate::completion::Message::User {
934 content: OneOrMany::one(message::UserContent::tool_result(
935 name,
936 OneOrMany::one(message::ToolResultContent::text(content)),
937 )),
938 },
939 }
940 }
941}
942
943impl Message {
944 pub fn system(content: &str) -> Self {
946 Message::System {
947 content: content.to_owned(),
948 images: None,
949 name: None,
950 }
951 }
952}
953
954impl From<crate::message::ToolCall> for ToolCall {
957 fn from(tool_call: crate::message::ToolCall) -> Self {
958 Self {
959 r#type: ToolType::Function,
960 function: Function {
961 name: tool_call.function.name,
962 arguments: tool_call.function.arguments,
963 },
964 }
965 }
966}
967
968#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
969pub struct SystemContent {
970 #[serde(default)]
971 r#type: SystemContentType,
972 text: String,
973}
974
975#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
976#[serde(rename_all = "lowercase")]
977pub enum SystemContentType {
978 #[default]
979 Text,
980}
981
982impl From<String> for SystemContent {
983 fn from(s: String) -> Self {
984 SystemContent {
985 r#type: SystemContentType::default(),
986 text: s,
987 }
988 }
989}
990
991impl FromStr for SystemContent {
992 type Err = std::convert::Infallible;
993 fn from_str(s: &str) -> Result<Self, Self::Err> {
994 Ok(SystemContent {
995 r#type: SystemContentType::default(),
996 text: s.to_string(),
997 })
998 }
999}
1000
1001#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1002pub struct AssistantContent {
1003 pub text: String,
1004}
1005
1006impl FromStr for AssistantContent {
1007 type Err = std::convert::Infallible;
1008 fn from_str(s: &str) -> Result<Self, Self::Err> {
1009 Ok(AssistantContent { text: s.to_owned() })
1010 }
1011}
1012
1013#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1014#[serde(tag = "type", rename_all = "lowercase")]
1015pub enum UserContent {
1016 Text { text: String },
1017 Image { image_url: ImageUrl },
1018 }
1020
1021impl FromStr for UserContent {
1022 type Err = std::convert::Infallible;
1023 fn from_str(s: &str) -> Result<Self, Self::Err> {
1024 Ok(UserContent::Text { text: s.to_owned() })
1025 }
1026}
1027
1028#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1029pub struct ImageUrl {
1030 pub url: String,
1031 #[serde(default)]
1032 pub detail: ImageDetail,
1033}
1034
1035#[cfg(test)]
1040mod tests {
1041 use super::*;
1042 use serde_json::json;
1043
1044 #[tokio::test]
1046 async fn test_chat_completion() {
1047 let sample_chat_response = json!({
1049 "model": "llama3.2",
1050 "created_at": "2023-08-04T19:22:45.499127Z",
1051 "message": {
1052 "role": "assistant",
1053 "content": "The sky is blue because of Rayleigh scattering.",
1054 "images": null,
1055 "tool_calls": [
1056 {
1057 "type": "function",
1058 "function": {
1059 "name": "get_current_weather",
1060 "arguments": {
1061 "location": "San Francisco, CA",
1062 "format": "celsius"
1063 }
1064 }
1065 }
1066 ]
1067 },
1068 "done": true,
1069 "total_duration": 8000000000u64,
1070 "load_duration": 6000000u64,
1071 "prompt_eval_count": 61u64,
1072 "prompt_eval_duration": 400000000u64,
1073 "eval_count": 468u64,
1074 "eval_duration": 7700000000u64
1075 });
1076 let sample_text = sample_chat_response.to_string();
1077
1078 let chat_resp: CompletionResponse =
1079 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1080 let conv: completion::CompletionResponse<CompletionResponse> =
1081 chat_resp.try_into().unwrap();
1082 assert!(
1083 !conv.choice.is_empty(),
1084 "Expected non-empty choice in chat response"
1085 );
1086 }
1087
1088 #[test]
1090 fn test_message_conversion() {
1091 let provider_msg = Message::User {
1093 content: "Test message".to_owned(),
1094 images: None,
1095 name: None,
1096 };
1097 let comp_msg: crate::completion::Message = provider_msg.into();
1099 match comp_msg {
1100 crate::completion::Message::User { content } => {
1101 let first_content = content.first();
1103 match first_content {
1105 crate::completion::message::UserContent::Text(text_struct) => {
1106 assert_eq!(text_struct.text, "Test message");
1107 }
1108 _ => panic!("Expected text content in conversion"),
1109 }
1110 }
1111 _ => panic!("Conversion from provider Message to completion Message failed"),
1112 }
1113 }
1114
1115 #[test]
1117 fn test_tool_definition_conversion() {
1118 let internal_tool = crate::completion::ToolDefinition {
1120 name: "get_current_weather".to_owned(),
1121 description: "Get the current weather for a location".to_owned(),
1122 parameters: json!({
1123 "type": "object",
1124 "properties": {
1125 "location": {
1126 "type": "string",
1127 "description": "The location to get the weather for, e.g. San Francisco, CA"
1128 },
1129 "format": {
1130 "type": "string",
1131 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1132 "enum": ["celsius", "fahrenheit"]
1133 }
1134 },
1135 "required": ["location", "format"]
1136 }),
1137 };
1138 let ollama_tool: ToolDefinition = internal_tool.into();
1140 assert_eq!(ollama_tool.type_field, "function");
1141 assert_eq!(ollama_tool.function.name, "get_current_weather");
1142 assert_eq!(
1143 ollama_tool.function.description,
1144 "Get the current weather for a location"
1145 );
1146 let params = &ollama_tool.function.parameters;
1148 assert_eq!(params["properties"]["location"]["type"], "string");
1149 }
1150
1151 #[tokio::test]
1153 async fn test_chat_completion_with_thinking() {
1154 let sample_response = json!({
1155 "model": "qwen-thinking",
1156 "created_at": "2023-08-04T19:22:45.499127Z",
1157 "message": {
1158 "role": "assistant",
1159 "content": "The answer is 42.",
1160 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1161 "images": null,
1162 "tool_calls": []
1163 },
1164 "done": true,
1165 "total_duration": 8000000000u64,
1166 "load_duration": 6000000u64,
1167 "prompt_eval_count": 61u64,
1168 "prompt_eval_duration": 400000000u64,
1169 "eval_count": 468u64,
1170 "eval_duration": 7700000000u64
1171 });
1172
1173 let chat_resp: CompletionResponse =
1174 serde_json::from_value(sample_response).expect("Failed to deserialize");
1175
1176 if let Message::Assistant {
1178 thinking, content, ..
1179 } = &chat_resp.message
1180 {
1181 assert_eq!(
1182 thinking.as_ref().unwrap(),
1183 "Let me think about this carefully. The question asks for the meaning of life..."
1184 );
1185 assert_eq!(content, "The answer is 42.");
1186 } else {
1187 panic!("Expected Assistant message");
1188 }
1189 }
1190
1191 #[tokio::test]
1193 async fn test_chat_completion_without_thinking() {
1194 let sample_response = json!({
1195 "model": "llama3.2",
1196 "created_at": "2023-08-04T19:22:45.499127Z",
1197 "message": {
1198 "role": "assistant",
1199 "content": "Hello!",
1200 "images": null,
1201 "tool_calls": []
1202 },
1203 "done": true,
1204 "total_duration": 8000000000u64,
1205 "load_duration": 6000000u64,
1206 "prompt_eval_count": 10u64,
1207 "prompt_eval_duration": 400000000u64,
1208 "eval_count": 5u64,
1209 "eval_duration": 7700000000u64
1210 });
1211
1212 let chat_resp: CompletionResponse =
1213 serde_json::from_value(sample_response).expect("Failed to deserialize");
1214
1215 if let Message::Assistant {
1217 thinking, content, ..
1218 } = &chat_resp.message
1219 {
1220 assert!(thinking.is_none());
1221 assert_eq!(content, "Hello!");
1222 } else {
1223 panic!("Expected Assistant message");
1224 }
1225 }
1226
1227 #[test]
1229 fn test_streaming_response_with_thinking() {
1230 let sample_chunk = json!({
1231 "model": "qwen-thinking",
1232 "created_at": "2023-08-04T19:22:45.499127Z",
1233 "message": {
1234 "role": "assistant",
1235 "content": "",
1236 "thinking": "Analyzing the problem...",
1237 "images": null,
1238 "tool_calls": []
1239 },
1240 "done": false
1241 });
1242
1243 let chunk: CompletionResponse =
1244 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1245
1246 if let Message::Assistant {
1247 thinking, content, ..
1248 } = &chunk.message
1249 {
1250 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1251 assert_eq!(content, "");
1252 } else {
1253 panic!("Expected Assistant message");
1254 }
1255 }
1256
1257 #[test]
1259 fn test_message_conversion_with_thinking() {
1260 let reasoning_content = crate::message::Reasoning {
1262 id: None,
1263 reasoning: vec!["Step 1: Consider the problem".to_string()],
1264 signature: None,
1265 };
1266
1267 let internal_msg = crate::message::Message::Assistant {
1268 id: None,
1269 content: crate::OneOrMany::many(vec![
1270 crate::message::AssistantContent::Reasoning(reasoning_content),
1271 crate::message::AssistantContent::Text(crate::message::Text {
1272 text: "The answer is X".to_string(),
1273 }),
1274 ])
1275 .unwrap(),
1276 };
1277
1278 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1280 assert_eq!(provider_msgs.len(), 1);
1281
1282 if let Message::Assistant {
1283 thinking, content, ..
1284 } = &provider_msgs[0]
1285 {
1286 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1287 assert_eq!(content, "The answer is X");
1288 } else {
1289 panic!("Expected Assistant message with thinking");
1290 }
1291 }
1292
1293 #[test]
1295 fn test_empty_thinking_content() {
1296 let sample_response = json!({
1297 "model": "llama3.2",
1298 "created_at": "2023-08-04T19:22:45.499127Z",
1299 "message": {
1300 "role": "assistant",
1301 "content": "Response",
1302 "thinking": "",
1303 "images": null,
1304 "tool_calls": []
1305 },
1306 "done": true,
1307 "total_duration": 8000000000u64,
1308 "load_duration": 6000000u64,
1309 "prompt_eval_count": 10u64,
1310 "prompt_eval_duration": 400000000u64,
1311 "eval_count": 5u64,
1312 "eval_duration": 7700000000u64
1313 });
1314
1315 let chat_resp: CompletionResponse =
1316 serde_json::from_value(sample_response).expect("Failed to deserialize");
1317
1318 if let Message::Assistant {
1319 thinking, content, ..
1320 } = &chat_resp.message
1321 {
1322 assert_eq!(thinking.as_ref().unwrap(), "");
1324 assert_eq!(content, "Response");
1325 } else {
1326 panic!("Expected Assistant message");
1327 }
1328 }
1329
1330 #[test]
1332 fn test_thinking_with_tool_calls() {
1333 let sample_response = json!({
1334 "model": "qwen-thinking",
1335 "created_at": "2023-08-04T19:22:45.499127Z",
1336 "message": {
1337 "role": "assistant",
1338 "content": "Let me check the weather.",
1339 "thinking": "User wants weather info, I should use the weather tool",
1340 "images": null,
1341 "tool_calls": [
1342 {
1343 "type": "function",
1344 "function": {
1345 "name": "get_weather",
1346 "arguments": {
1347 "location": "San Francisco"
1348 }
1349 }
1350 }
1351 ]
1352 },
1353 "done": true,
1354 "total_duration": 8000000000u64,
1355 "load_duration": 6000000u64,
1356 "prompt_eval_count": 30u64,
1357 "prompt_eval_duration": 400000000u64,
1358 "eval_count": 50u64,
1359 "eval_duration": 7700000000u64
1360 });
1361
1362 let chat_resp: CompletionResponse =
1363 serde_json::from_value(sample_response).expect("Failed to deserialize");
1364
1365 if let Message::Assistant {
1366 thinking,
1367 content,
1368 tool_calls,
1369 ..
1370 } = &chat_resp.message
1371 {
1372 assert_eq!(
1373 thinking.as_ref().unwrap(),
1374 "User wants weather info, I should use the weather tool"
1375 );
1376 assert_eq!(content, "Let me check the weather.");
1377 assert_eq!(tool_calls.len(), 1);
1378 assert_eq!(tool_calls[0].function.name, "get_weather");
1379 } else {
1380 panic!("Expected Assistant message with thinking and tool calls");
1381 }
1382 }
1383}