1use crate::client::{
41 self, ApiKey, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider, ProviderBuilder,
42 ProviderClient,
43};
44use crate::completion::{GetTokenUsage, Usage};
45use crate::http_client::{self, HttpClientExt};
46use crate::message::DocumentSourceKind;
47use crate::model::{Model, ModelList, ModelListingError};
48use crate::streaming::RawStreamingChoice;
49use crate::{
50 OneOrMany,
51 completion::{self, CompletionError, CompletionRequest},
52 embeddings::{self, EmbeddingError},
53 json_utils, message,
54 message::{ImageDetail, Text},
55 streaming,
56 wasm_compat::{WasmCompatSend, WasmCompatSync},
57};
58use async_stream::try_stream;
59use bytes::Bytes;
60use futures::StreamExt;
61use serde::{Deserialize, Serialize};
62use serde_json::{Value, json};
63use std::{convert::TryFrom, str::FromStr};
64use tracing::info_span;
65use tracing_futures::Instrument;
66const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
69
70#[derive(Debug, Default, Clone)]
73pub struct OllamaApiKey(Option<String>);
74
75impl ApiKey for OllamaApiKey {
76 fn into_header(
77 self,
78 ) -> Option<http_client::Result<(http::header::HeaderName, http::header::HeaderValue)>> {
79 self.0.map(http_client::make_auth_header)
80 }
81}
82
83impl From<Nothing> for OllamaApiKey {
84 fn from(_: Nothing) -> Self {
85 Self(None)
86 }
87}
88
89impl From<String> for OllamaApiKey {
90 fn from(key: String) -> Self {
91 if key.is_empty() {
92 Self(None)
93 } else {
94 Self(Some(key))
95 }
96 }
97}
98
99impl From<&str> for OllamaApiKey {
100 fn from(key: &str) -> Self {
101 if key.is_empty() {
102 Self(None)
103 } else {
104 Self(Some(key.to_owned()))
105 }
106 }
107}
108
109#[derive(Debug, Default, Clone, Copy)]
110pub struct OllamaExt;
111
112#[derive(Debug, Default, Clone, Copy)]
113pub struct OllamaBuilder;
114
115impl Provider for OllamaExt {
116 type Builder = OllamaBuilder;
117 const VERIFY_PATH: &'static str = "api/tags";
118}
119
120impl<H> Capabilities<H> for OllamaExt {
121 type Completion = Capable<CompletionModel<H>>;
122 type Transcription = Nothing;
123 type Embeddings = Capable<EmbeddingModel<H>>;
124 type ModelListing = Capable<OllamaModelLister<H>>;
125 #[cfg(feature = "image")]
126 type ImageGeneration = Nothing;
127
128 #[cfg(feature = "audio")]
129 type AudioGeneration = Nothing;
130}
131
132impl DebugExt for OllamaExt {}
133
134impl ProviderBuilder for OllamaBuilder {
135 type Extension<H>
136 = OllamaExt
137 where
138 H: HttpClientExt;
139 type ApiKey = OllamaApiKey;
140
141 const BASE_URL: &'static str = OLLAMA_API_BASE_URL;
142
143 fn build<H>(
144 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
145 ) -> http_client::Result<Self::Extension<H>>
146 where
147 H: HttpClientExt,
148 {
149 Ok(OllamaExt)
150 }
151}
152
153pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
154pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, OllamaApiKey, H>;
155
156impl ProviderClient for Client {
157 type Input = OllamaApiKey;
158 type Error = crate::client::ProviderClientError;
159
160 fn from_env() -> Result<Self, Self::Error> {
161 let api_base = crate::client::optional_env_var("OLLAMA_API_BASE_URL")?
162 .unwrap_or_else(|| OLLAMA_API_BASE_URL.to_string());
163
164 let api_key = crate::client::optional_env_var("OLLAMA_API_KEY")?
165 .map(OllamaApiKey::from)
166 .unwrap_or_default();
167
168 Self::builder()
169 .api_key(api_key)
170 .base_url(&api_base)
171 .build()
172 .map_err(Into::into)
173 }
174
175 fn from_val(api_key: Self::Input) -> Result<Self, Self::Error> {
176 Self::builder().api_key(api_key).build().map_err(Into::into)
177 }
178}
179
180#[derive(Debug, Deserialize)]
183struct ApiErrorResponse {
184 message: String,
185}
186
187#[derive(Debug, Deserialize)]
188#[serde(untagged)]
189enum ApiResponse<T> {
190 Ok(T),
191 Err(ApiErrorResponse),
192}
193
194pub const ALL_MINILM: &str = "all-minilm";
197pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
198
199fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
200 match identifier {
201 ALL_MINILM => Some(384),
202 NOMIC_EMBED_TEXT => Some(768),
203 _ => None,
204 }
205}
206
207#[derive(Debug, Serialize, Deserialize)]
208pub struct EmbeddingResponse {
209 pub model: String,
210 pub embeddings: Vec<Vec<f64>>,
211 #[serde(default)]
212 pub total_duration: Option<u64>,
213 #[serde(default)]
214 pub load_duration: Option<u64>,
215 #[serde(default)]
216 pub prompt_eval_count: Option<u64>,
217}
218
219impl From<ApiErrorResponse> for EmbeddingError {
220 fn from(err: ApiErrorResponse) -> Self {
221 EmbeddingError::ProviderError(err.message)
222 }
223}
224
225impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
226 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
227 match value {
228 ApiResponse::Ok(response) => Ok(response),
229 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
230 }
231 }
232}
233
234#[derive(Clone)]
237pub struct EmbeddingModel<T = reqwest::Client> {
238 client: Client<T>,
239 pub model: String,
240 ndims: usize,
241}
242
243impl<T> EmbeddingModel<T> {
244 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
245 Self {
246 client,
247 model: model.into(),
248 ndims,
249 }
250 }
251
252 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
253 Self {
254 client,
255 model: model.into(),
256 ndims,
257 }
258 }
259}
260
261impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
262where
263 T: HttpClientExt + Clone + 'static,
264{
265 type Client = Client<T>;
266
267 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
268 let model = model.into();
269 let dims = dims
270 .or(model_dimensions_from_identifier(&model))
271 .unwrap_or_default();
272 Self::new(client.clone(), model, dims)
273 }
274
275 const MAX_DOCUMENTS: usize = 1024;
276 fn ndims(&self) -> usize {
277 self.ndims
278 }
279
280 async fn embed_texts(
281 &self,
282 documents: impl IntoIterator<Item = String>,
283 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
284 let docs: Vec<String> = documents.into_iter().collect();
285
286 let body = serde_json::to_vec(&json!({
287 "model": self.model,
288 "input": docs
289 }))?;
290
291 let req = self
292 .client
293 .post("api/embed")?
294 .body(body)
295 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
296
297 let response = self.client.send::<_, Vec<u8>>(req).await?;
298
299 if !response.status().is_success() {
300 let text = http_client::text(response).await?;
301 return Err(EmbeddingError::ProviderError(text));
302 }
303
304 let bytes: Vec<u8> = response.into_body().await?;
305
306 let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?;
307
308 if api_resp.embeddings.len() != docs.len() {
309 return Err(EmbeddingError::ResponseError(
310 "Number of returned embeddings does not match input".into(),
311 ));
312 }
313 Ok(api_resp
314 .embeddings
315 .into_iter()
316 .zip(docs.into_iter())
317 .map(|(vec, document)| embeddings::Embedding { document, vec })
318 .collect())
319 }
320}
321
322pub const LLAMA3_2: &str = "llama3.2";
325pub const LLAVA: &str = "llava";
326pub const MISTRAL: &str = "mistral";
327
328#[derive(Debug, Serialize, Deserialize)]
329pub struct CompletionResponse {
330 pub model: String,
331 pub created_at: String,
332 pub message: Message,
333 pub done: bool,
334 #[serde(default)]
335 pub done_reason: Option<String>,
336 #[serde(default)]
337 pub total_duration: Option<u64>,
338 #[serde(default)]
339 pub load_duration: Option<u64>,
340 #[serde(default)]
341 pub prompt_eval_count: Option<u64>,
342 #[serde(default)]
343 pub prompt_eval_duration: Option<u64>,
344 #[serde(default)]
345 pub eval_count: Option<u64>,
346 #[serde(default)]
347 pub eval_duration: Option<u64>,
348}
349impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
350 type Error = CompletionError;
351 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
352 match resp.message {
353 Message::Assistant {
355 content,
356 thinking,
357 tool_calls,
358 ..
359 } => {
360 let mut assistant_contents = Vec::new();
361 if !content.is_empty() {
363 assistant_contents.push(completion::AssistantContent::text(&content));
364 }
365 for tc in tool_calls.iter() {
368 assistant_contents.push(completion::AssistantContent::tool_call(
369 tc.function.name.clone(),
370 tc.function.name.clone(),
371 tc.function.arguments.clone(),
372 ));
373 }
374 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
375 CompletionError::ResponseError("No content provided".to_owned())
376 })?;
377 let prompt_tokens = resp.prompt_eval_count.unwrap_or(0);
378 let completion_tokens = resp.eval_count.unwrap_or(0);
379
380 let raw_response = CompletionResponse {
381 model: resp.model,
382 created_at: resp.created_at,
383 done: resp.done,
384 done_reason: resp.done_reason,
385 total_duration: resp.total_duration,
386 load_duration: resp.load_duration,
387 prompt_eval_count: resp.prompt_eval_count,
388 prompt_eval_duration: resp.prompt_eval_duration,
389 eval_count: resp.eval_count,
390 eval_duration: resp.eval_duration,
391 message: Message::Assistant {
392 content,
393 thinking,
394 images: None,
395 name: None,
396 tool_calls,
397 },
398 };
399
400 Ok(completion::CompletionResponse {
401 choice,
402 usage: Usage {
403 input_tokens: prompt_tokens,
404 output_tokens: completion_tokens,
405 total_tokens: prompt_tokens + completion_tokens,
406 cached_input_tokens: 0,
407 cache_creation_input_tokens: 0,
408 },
409 raw_response,
410 message_id: None,
411 })
412 }
413 _ => Err(CompletionError::ResponseError(
414 "Chat response does not include an assistant message".into(),
415 )),
416 }
417 }
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421pub(super) struct OllamaCompletionRequest {
422 model: String,
423 pub messages: Vec<Message>,
424 #[serde(skip_serializing_if = "Option::is_none")]
425 temperature: Option<f64>,
426 #[serde(skip_serializing_if = "Vec::is_empty")]
427 tools: Vec<ToolDefinition>,
428 pub stream: bool,
429 think: bool,
430 #[serde(skip_serializing_if = "Option::is_none")]
431 max_tokens: Option<u64>,
432 #[serde(skip_serializing_if = "Option::is_none")]
433 keep_alive: Option<String>,
434 #[serde(skip_serializing_if = "Option::is_none")]
435 format: Option<schemars::Schema>,
436 options: serde_json::Value,
437}
438
439impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest {
440 type Error = CompletionError;
441
442 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
443 let model = req.model.clone().unwrap_or_else(|| model.to_string());
444 if req.tool_choice.is_some() {
445 tracing::warn!("WARNING: `tool_choice` not supported for Ollama");
446 }
447 let mut partial_history = vec![];
449 if let Some(docs) = req.normalized_documents() {
450 partial_history.push(docs);
451 }
452 partial_history.extend(req.chat_history);
453
454 let mut full_history: Vec<Message> = match &req.preamble {
456 Some(preamble) => vec![Message::system(preamble)],
457 None => vec![],
458 };
459
460 full_history.extend(
462 partial_history
463 .into_iter()
464 .map(message::Message::try_into)
465 .collect::<Result<Vec<Vec<Message>>, _>>()?
466 .into_iter()
467 .flatten()
468 .collect::<Vec<_>>(),
469 );
470
471 let mut think = false;
472 let mut keep_alive: Option<String> = None;
473
474 let options = if let Some(mut extra) = req.additional_params {
475 if let Some(obj) = extra.as_object_mut() {
477 if let Some(think_val) = obj.remove("think") {
479 think = think_val.as_bool().ok_or_else(|| {
480 CompletionError::RequestError("`think` must be a bool".into())
481 })?;
482 }
483
484 if let Some(keep_alive_val) = obj.remove("keep_alive") {
486 keep_alive = Some(
487 keep_alive_val
488 .as_str()
489 .ok_or_else(|| {
490 CompletionError::RequestError(
491 "`keep_alive` must be a string".into(),
492 )
493 })?
494 .to_string(),
495 );
496 }
497 }
498
499 json_utils::merge(json!({ "temperature": req.temperature }), extra)
500 } else {
501 json!({ "temperature": req.temperature })
502 };
503
504 Ok(Self {
505 model: model.to_string(),
506 messages: full_history,
507 temperature: req.temperature,
508 max_tokens: req.max_tokens,
509 stream: false,
510 think,
511 keep_alive,
512 format: req.output_schema,
513 tools: req
514 .tools
515 .clone()
516 .into_iter()
517 .map(ToolDefinition::from)
518 .collect::<Vec<_>>(),
519 options,
520 })
521 }
522}
523
524#[derive(Clone)]
525pub struct CompletionModel<T = reqwest::Client> {
526 client: Client<T>,
527 pub model: String,
528}
529
530impl<T> CompletionModel<T> {
531 pub fn new(client: Client<T>, model: &str) -> Self {
532 Self {
533 client,
534 model: model.to_owned(),
535 }
536 }
537}
538
539#[derive(Clone, Serialize, Deserialize, Debug)]
542pub struct StreamingCompletionResponse {
543 pub done_reason: Option<String>,
544 pub total_duration: Option<u64>,
545 pub load_duration: Option<u64>,
546 pub prompt_eval_count: Option<u64>,
547 pub prompt_eval_duration: Option<u64>,
548 pub eval_count: Option<u64>,
549 pub eval_duration: Option<u64>,
550}
551
552impl GetTokenUsage for StreamingCompletionResponse {
553 fn token_usage(&self) -> Option<crate::completion::Usage> {
554 let mut usage = crate::completion::Usage::new();
555 let input_tokens = self.prompt_eval_count.unwrap_or_default();
556 let output_tokens = self.eval_count.unwrap_or_default();
557 usage.input_tokens = input_tokens;
558 usage.output_tokens = output_tokens;
559 usage.total_tokens = input_tokens + output_tokens;
560
561 Some(usage)
562 }
563}
564
565impl<T> completion::CompletionModel for CompletionModel<T>
566where
567 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
568{
569 type Response = CompletionResponse;
570 type StreamingResponse = StreamingCompletionResponse;
571
572 type Client = Client<T>;
573
574 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
575 Self::new(client.clone(), model.into().as_str())
576 }
577
578 async fn completion(
579 &self,
580 completion_request: CompletionRequest,
581 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
582 let span = if tracing::Span::current().is_disabled() {
583 info_span!(
584 target: "rig::completions",
585 "chat",
586 gen_ai.operation.name = "chat",
587 gen_ai.provider.name = "ollama",
588 gen_ai.request.model = self.model,
589 gen_ai.system_instructions = tracing::field::Empty,
590 gen_ai.response.id = tracing::field::Empty,
591 gen_ai.response.model = tracing::field::Empty,
592 gen_ai.usage.output_tokens = tracing::field::Empty,
593 gen_ai.usage.input_tokens = tracing::field::Empty,
594 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
595 )
596 } else {
597 tracing::Span::current()
598 };
599
600 span.record("gen_ai.system_instructions", &completion_request.preamble);
601 let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
602
603 if tracing::enabled!(tracing::Level::TRACE) {
604 tracing::trace!(target: "rig::completions",
605 "Ollama completion request: {}",
606 serde_json::to_string_pretty(&request)?
607 );
608 }
609
610 let body = serde_json::to_vec(&request)?;
611
612 let req = self
613 .client
614 .post("api/chat")?
615 .body(body)
616 .map_err(http_client::Error::from)?;
617
618 let async_block = async move {
619 let response = self.client.send::<_, Bytes>(req).await?;
620 let status = response.status();
621 let response_body = response.into_body().into_future().await?.to_vec();
622
623 if !status.is_success() {
624 return Err(CompletionError::ProviderError(
625 String::from_utf8_lossy(&response_body).to_string(),
626 ));
627 }
628
629 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
630 let span = tracing::Span::current();
631 span.record("gen_ai.response.model", &response.model);
632 span.record(
633 "gen_ai.usage.input_tokens",
634 response.prompt_eval_count.unwrap_or_default(),
635 );
636 span.record(
637 "gen_ai.usage.output_tokens",
638 response.eval_count.unwrap_or_default(),
639 );
640
641 if tracing::enabled!(tracing::Level::TRACE) {
642 tracing::trace!(target: "rig::completions",
643 "Ollama completion response: {}",
644 serde_json::to_string_pretty(&response)?
645 );
646 }
647
648 let response: completion::CompletionResponse<CompletionResponse> =
649 response.try_into()?;
650
651 Ok(response)
652 };
653
654 tracing::Instrument::instrument(async_block, span).await
655 }
656
657 async fn stream(
658 &self,
659 request: CompletionRequest,
660 ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
661 {
662 let span = if tracing::Span::current().is_disabled() {
663 info_span!(
664 target: "rig::completions",
665 "chat_streaming",
666 gen_ai.operation.name = "chat_streaming",
667 gen_ai.provider.name = "ollama",
668 gen_ai.request.model = self.model,
669 gen_ai.system_instructions = tracing::field::Empty,
670 gen_ai.response.id = tracing::field::Empty,
671 gen_ai.response.model = self.model,
672 gen_ai.usage.output_tokens = tracing::field::Empty,
673 gen_ai.usage.input_tokens = tracing::field::Empty,
674 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
675 )
676 } else {
677 tracing::Span::current()
678 };
679
680 span.record("gen_ai.system_instructions", &request.preamble);
681
682 let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?;
683 request.stream = true;
684
685 if tracing::enabled!(tracing::Level::TRACE) {
686 tracing::trace!(target: "rig::completions",
687 "Ollama streaming completion request: {}",
688 serde_json::to_string_pretty(&request)?
689 );
690 }
691
692 let body = serde_json::to_vec(&request)?;
693
694 let req = self
695 .client
696 .post("api/chat")?
697 .body(body)
698 .map_err(http_client::Error::from)?;
699
700 let response = self.client.send_streaming(req).await?;
701 let status = response.status();
702 let mut byte_stream = response.into_body();
703
704 if !status.is_success() {
705 return Err(CompletionError::ProviderError(format!(
706 "Got error status code trying to send a request to Ollama: {status}"
707 )));
708 }
709
710 let stream = try_stream! {
711 let span = tracing::Span::current();
712 let mut tool_calls_final = Vec::new();
713 let mut text_response = String::new();
714 let mut thinking_response = String::new();
715
716 while let Some(chunk) = byte_stream.next().await {
717 let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?;
718
719 for line in bytes.split(|&b| b == b'\n') {
720 if line.is_empty() {
721 continue;
722 }
723
724 tracing::debug!(target: "rig", "Received NDJSON line from Ollama: {}", String::from_utf8_lossy(line));
725
726 let response: CompletionResponse = serde_json::from_slice(line)?;
727
728 if let Message::Assistant { content, thinking, tool_calls, .. } = response.message {
729 if let Some(thinking_content) = thinking && !thinking_content.is_empty() {
730 thinking_response += &thinking_content;
731 yield RawStreamingChoice::ReasoningDelta {
732 id: None,
733 reasoning: thinking_content,
734 };
735 }
736
737 if !content.is_empty() {
738 text_response += &content;
739 yield RawStreamingChoice::Message(content);
740 }
741
742 for tool_call in tool_calls {
743 tool_calls_final.push(tool_call.clone());
744 yield RawStreamingChoice::ToolCall(
745 crate::streaming::RawStreamingToolCall::new(String::new(), tool_call.function.name, tool_call.function.arguments)
746 );
747 }
748 }
749
750 if response.done {
751 span.record("gen_ai.usage.input_tokens", response.prompt_eval_count);
752 span.record("gen_ai.usage.output_tokens", response.eval_count);
753 let message = Message::Assistant {
754 content: text_response.clone(),
755 thinking: if thinking_response.is_empty() { None } else { Some(thinking_response.clone()) },
756 images: None,
757 name: None,
758 tool_calls: tool_calls_final.clone()
759 };
760 if let Ok(serialized_message) = serde_json::to_string(&vec![message]) {
761 span.record("gen_ai.output.messages", serialized_message);
762 }
763 yield RawStreamingChoice::FinalResponse(
764 StreamingCompletionResponse {
765 total_duration: response.total_duration,
766 load_duration: response.load_duration,
767 prompt_eval_count: response.prompt_eval_count,
768 prompt_eval_duration: response.prompt_eval_duration,
769 eval_count: response.eval_count,
770 eval_duration: response.eval_duration,
771 done_reason: response.done_reason,
772 }
773 );
774 break;
775 }
776 }
777 }
778 }.instrument(span);
779
780 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
781 stream,
782 )))
783 }
784}
785
786#[derive(Debug, Deserialize)]
789struct ListModelsResponse {
790 models: Vec<ListModelEntry>,
791}
792
793#[derive(Debug, Deserialize)]
794struct ListModelEntry {
795 name: String,
796 model: String,
797}
798
799impl From<ListModelEntry> for Model {
800 fn from(value: ListModelEntry) -> Self {
801 Model::new(value.model, value.name)
802 }
803}
804
805#[derive(Clone)]
807pub struct OllamaModelLister<H = reqwest::Client> {
808 client: Client<H>,
809}
810
811impl<H> ModelLister<H> for OllamaModelLister<H>
812where
813 H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
814{
815 type Client = Client<H>;
816
817 fn new(client: Self::Client) -> Self {
818 Self { client }
819 }
820
821 async fn list_all(&self) -> Result<ModelList, ModelListingError> {
822 let path = "/api/tags";
823 let req = self.client.get(path)?.body(http_client::NoBody)?;
824 let response = self.client.send::<_, Vec<u8>>(req).await?;
825
826 if !response.status().is_success() {
827 let status_code = response.status().as_u16();
828 let body = response.into_body().await?;
829 return Err(ModelListingError::api_error_with_context(
830 "Ollama",
831 path,
832 status_code,
833 &body,
834 ));
835 }
836
837 let body = response.into_body().await?;
838 let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
839 ModelListingError::parse_error_with_context("Ollama", path, &error, &body)
840 })?;
841 let models = api_resp.models.into_iter().map(Model::from).collect();
842
843 Ok(ModelList::new(models))
844 }
845}
846
847#[derive(Clone, Debug, Deserialize, Serialize)]
851pub struct ToolDefinition {
852 #[serde(rename = "type")]
853 pub type_field: String, pub function: completion::ToolDefinition,
855}
856
857impl From<crate::completion::ToolDefinition> for ToolDefinition {
859 fn from(tool: crate::completion::ToolDefinition) -> Self {
860 ToolDefinition {
861 type_field: "function".to_owned(),
862 function: completion::ToolDefinition {
863 name: tool.name,
864 description: tool.description,
865 parameters: tool.parameters,
866 },
867 }
868 }
869}
870
871#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
872pub struct ToolCall {
873 #[serde(default, rename = "type")]
874 pub r#type: ToolType,
875 pub function: Function,
876}
877#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
878#[serde(rename_all = "lowercase")]
879pub enum ToolType {
880 #[default]
881 Function,
882}
883#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
884pub struct Function {
885 pub name: String,
886 pub arguments: Value,
887}
888
889#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
892#[serde(tag = "role", rename_all = "lowercase")]
893pub enum Message {
894 User {
895 content: String,
896 #[serde(skip_serializing_if = "Option::is_none")]
897 images: Option<Vec<String>>,
898 #[serde(skip_serializing_if = "Option::is_none")]
899 name: Option<String>,
900 },
901 Assistant {
902 #[serde(default)]
903 content: String,
904 #[serde(skip_serializing_if = "Option::is_none")]
905 thinking: Option<String>,
906 #[serde(skip_serializing_if = "Option::is_none")]
907 images: Option<Vec<String>>,
908 #[serde(skip_serializing_if = "Option::is_none")]
909 name: Option<String>,
910 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
911 tool_calls: Vec<ToolCall>,
912 },
913 System {
914 content: String,
915 #[serde(skip_serializing_if = "Option::is_none")]
916 images: Option<Vec<String>>,
917 #[serde(skip_serializing_if = "Option::is_none")]
918 name: Option<String>,
919 },
920 #[serde(rename = "tool")]
921 ToolResult {
922 #[serde(rename = "tool_name")]
923 name: String,
924 content: String,
925 },
926}
927
928impl TryFrom<crate::message::Message> for Vec<Message> {
934 type Error = crate::message::MessageError;
935 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
936 use crate::message::Message as InternalMessage;
937 match internal_msg {
938 InternalMessage::System { content } => Ok(vec![Message::System {
939 content,
940 images: None,
941 name: None,
942 }]),
943 InternalMessage::User { content, .. } => {
944 let (tool_results, other_content): (Vec<_>, Vec<_>) =
945 content.into_iter().partition(|content| {
946 matches!(content, crate::message::UserContent::ToolResult(_))
947 });
948
949 if !tool_results.is_empty() {
950 tool_results
951 .into_iter()
952 .map(|content| match content {
953 crate::message::UserContent::ToolResult(
954 crate::message::ToolResult { id, content, .. },
955 ) => {
956 let content_string = content
958 .into_iter()
959 .map(|content| match content {
960 crate::message::ToolResultContent::Text(text) => text.text,
961 _ => "[Non-text content]".to_string(),
962 })
963 .collect::<Vec<_>>()
964 .join("\n");
965
966 Ok::<_, crate::message::MessageError>(Message::ToolResult {
967 name: id,
968 content: content_string,
969 })
970 }
971 _ => Err(crate::message::MessageError::ConversionError(
972 "expected tool result content while converting Ollama input".into(),
973 )),
974 })
975 .collect::<Result<Vec<_>, _>>()
976 } else {
977 let (texts, images) = other_content.into_iter().fold(
979 (Vec::new(), Vec::new()),
980 |(mut texts, mut images), content| {
981 match content {
982 crate::message::UserContent::Text(crate::message::Text {
983 text,
984 }) => texts.push(text),
985 crate::message::UserContent::Image(crate::message::Image {
986 data: DocumentSourceKind::Base64(data),
987 ..
988 }) => images.push(data),
989 crate::message::UserContent::Document(
990 crate::message::Document {
991 data:
992 DocumentSourceKind::Base64(data)
993 | DocumentSourceKind::String(data),
994 ..
995 },
996 ) => texts.push(data),
997 _ => {} }
999 (texts, images)
1000 },
1001 );
1002
1003 Ok(vec![Message::User {
1004 content: texts.join(" "),
1005 images: if images.is_empty() {
1006 None
1007 } else {
1008 Some(
1009 images
1010 .into_iter()
1011 .map(|x| x.to_string())
1012 .collect::<Vec<String>>(),
1013 )
1014 },
1015 name: None,
1016 }])
1017 }
1018 }
1019 InternalMessage::Assistant { content, .. } => {
1020 let mut thinking: Option<String> = None;
1021 let mut text_content = Vec::new();
1022 let mut tool_calls = Vec::new();
1023
1024 for content in content.into_iter() {
1025 match content {
1026 crate::message::AssistantContent::Text(text) => {
1027 text_content.push(text.text)
1028 }
1029 crate::message::AssistantContent::ToolCall(tool_call) => {
1030 tool_calls.push(tool_call)
1031 }
1032 crate::message::AssistantContent::Reasoning(reasoning) => {
1033 let display = reasoning.display_text();
1034 if !display.is_empty() {
1035 thinking = Some(display);
1036 }
1037 }
1038 crate::message::AssistantContent::Image(_) => {
1039 return Err(crate::message::MessageError::ConversionError(
1040 "Ollama currently doesn't support images.".into(),
1041 ));
1042 }
1043 }
1044 }
1045
1046 Ok(vec![Message::Assistant {
1049 content: text_content.join(" "),
1050 thinking,
1051 images: None,
1052 name: None,
1053 tool_calls: tool_calls
1054 .into_iter()
1055 .map(|tool_call| tool_call.into())
1056 .collect::<Vec<_>>(),
1057 }])
1058 }
1059 }
1060 }
1061}
1062
1063impl From<Message> for crate::completion::Message {
1066 fn from(msg: Message) -> Self {
1067 match msg {
1068 Message::User { content, .. } => crate::completion::Message::User {
1069 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
1070 text: content,
1071 })),
1072 },
1073 Message::Assistant {
1074 content,
1075 tool_calls,
1076 ..
1077 } => {
1078 let mut assistant_contents =
1079 vec![crate::completion::message::AssistantContent::Text(Text {
1080 text: content,
1081 })];
1082 for tc in tool_calls {
1083 assistant_contents.push(
1084 crate::completion::message::AssistantContent::tool_call(
1085 tc.function.name.clone(),
1086 tc.function.name,
1087 tc.function.arguments,
1088 ),
1089 );
1090 }
1091 let content =
1092 OneOrMany::from_iter_optional(assistant_contents).unwrap_or_else(|| {
1093 OneOrMany::one(crate::completion::message::AssistantContent::Text(Text {
1094 text: String::new(),
1095 }))
1096 });
1097
1098 crate::completion::Message::Assistant { id: None, content }
1099 }
1100 Message::System { content, .. } => crate::completion::Message::User {
1102 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
1103 text: content,
1104 })),
1105 },
1106 Message::ToolResult { name, content } => crate::completion::Message::User {
1107 content: OneOrMany::one(message::UserContent::tool_result(
1108 name,
1109 OneOrMany::one(message::ToolResultContent::text(content)),
1110 )),
1111 },
1112 }
1113 }
1114}
1115
1116impl Message {
1117 pub fn system(content: &str) -> Self {
1119 Message::System {
1120 content: content.to_owned(),
1121 images: None,
1122 name: None,
1123 }
1124 }
1125}
1126
1127impl From<crate::message::ToolCall> for ToolCall {
1130 fn from(tool_call: crate::message::ToolCall) -> Self {
1131 Self {
1132 r#type: ToolType::Function,
1133 function: Function {
1134 name: tool_call.function.name,
1135 arguments: tool_call.function.arguments,
1136 },
1137 }
1138 }
1139}
1140
1141#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1142pub struct SystemContent {
1143 #[serde(default)]
1144 r#type: SystemContentType,
1145 text: String,
1146}
1147
1148#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
1149#[serde(rename_all = "lowercase")]
1150pub enum SystemContentType {
1151 #[default]
1152 Text,
1153}
1154
1155impl From<String> for SystemContent {
1156 fn from(s: String) -> Self {
1157 SystemContent {
1158 r#type: SystemContentType::default(),
1159 text: s,
1160 }
1161 }
1162}
1163
1164impl FromStr for SystemContent {
1165 type Err = std::convert::Infallible;
1166 fn from_str(s: &str) -> Result<Self, Self::Err> {
1167 Ok(SystemContent {
1168 r#type: SystemContentType::default(),
1169 text: s.to_string(),
1170 })
1171 }
1172}
1173
1174#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1175pub struct AssistantContent {
1176 pub text: String,
1177}
1178
1179impl FromStr for AssistantContent {
1180 type Err = std::convert::Infallible;
1181 fn from_str(s: &str) -> Result<Self, Self::Err> {
1182 Ok(AssistantContent { text: s.to_owned() })
1183 }
1184}
1185
1186#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1187#[serde(tag = "type", rename_all = "lowercase")]
1188pub enum UserContent {
1189 Text { text: String },
1190 Image { image_url: ImageUrl },
1191 }
1193
1194impl FromStr for UserContent {
1195 type Err = std::convert::Infallible;
1196 fn from_str(s: &str) -> Result<Self, Self::Err> {
1197 Ok(UserContent::Text { text: s.to_owned() })
1198 }
1199}
1200
1201#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1202pub struct ImageUrl {
1203 pub url: String,
1204 #[serde(default)]
1205 pub detail: ImageDetail,
1206}
1207
1208#[cfg(test)]
1213mod tests {
1214 use super::*;
1215 use serde_json::json;
1216
1217 #[tokio::test]
1219 async fn test_chat_completion() {
1220 let sample_chat_response = json!({
1222 "model": "llama3.2",
1223 "created_at": "2023-08-04T19:22:45.499127Z",
1224 "message": {
1225 "role": "assistant",
1226 "content": "The sky is blue because of Rayleigh scattering.",
1227 "images": null,
1228 "tool_calls": [
1229 {
1230 "type": "function",
1231 "function": {
1232 "name": "get_current_weather",
1233 "arguments": {
1234 "location": "San Francisco, CA",
1235 "format": "celsius"
1236 }
1237 }
1238 }
1239 ]
1240 },
1241 "done": true,
1242 "total_duration": 8000000000u64,
1243 "load_duration": 6000000u64,
1244 "prompt_eval_count": 61u64,
1245 "prompt_eval_duration": 400000000u64,
1246 "eval_count": 468u64,
1247 "eval_duration": 7700000000u64
1248 });
1249 let sample_text = sample_chat_response.to_string();
1250
1251 let chat_resp: CompletionResponse =
1252 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
1253 let conv: completion::CompletionResponse<CompletionResponse> =
1254 chat_resp.try_into().unwrap();
1255 assert!(
1256 !conv.choice.is_empty(),
1257 "Expected non-empty choice in chat response"
1258 );
1259 }
1260
1261 #[test]
1263 fn test_message_conversion() {
1264 let provider_msg = Message::User {
1266 content: "Test message".to_owned(),
1267 images: None,
1268 name: None,
1269 };
1270 let comp_msg: crate::completion::Message = provider_msg.into();
1272 match comp_msg {
1273 crate::completion::Message::User { content } => {
1274 let first_content = content.first();
1276 match first_content {
1278 crate::completion::message::UserContent::Text(text_struct) => {
1279 assert_eq!(text_struct.text, "Test message");
1280 }
1281 _ => panic!("Expected text content in conversion"),
1282 }
1283 }
1284 _ => panic!("Conversion from provider Message to completion Message failed"),
1285 }
1286 }
1287
1288 #[test]
1290 fn test_tool_definition_conversion() {
1291 let internal_tool = crate::completion::ToolDefinition {
1293 name: "get_current_weather".to_owned(),
1294 description: "Get the current weather for a location".to_owned(),
1295 parameters: json!({
1296 "type": "object",
1297 "properties": {
1298 "location": {
1299 "type": "string",
1300 "description": "The location to get the weather for, e.g. San Francisco, CA"
1301 },
1302 "format": {
1303 "type": "string",
1304 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
1305 "enum": ["celsius", "fahrenheit"]
1306 }
1307 },
1308 "required": ["location", "format"]
1309 }),
1310 };
1311 let ollama_tool: ToolDefinition = internal_tool.into();
1313 assert_eq!(ollama_tool.type_field, "function");
1314 assert_eq!(ollama_tool.function.name, "get_current_weather");
1315 assert_eq!(
1316 ollama_tool.function.description,
1317 "Get the current weather for a location"
1318 );
1319 let params = &ollama_tool.function.parameters;
1321 assert_eq!(params["properties"]["location"]["type"], "string");
1322 }
1323
1324 #[tokio::test]
1326 async fn test_chat_completion_with_thinking() {
1327 let sample_response = json!({
1328 "model": "qwen-thinking",
1329 "created_at": "2023-08-04T19:22:45.499127Z",
1330 "message": {
1331 "role": "assistant",
1332 "content": "The answer is 42.",
1333 "thinking": "Let me think about this carefully. The question asks for the meaning of life...",
1334 "images": null,
1335 "tool_calls": []
1336 },
1337 "done": true,
1338 "total_duration": 8000000000u64,
1339 "load_duration": 6000000u64,
1340 "prompt_eval_count": 61u64,
1341 "prompt_eval_duration": 400000000u64,
1342 "eval_count": 468u64,
1343 "eval_duration": 7700000000u64
1344 });
1345
1346 let chat_resp: CompletionResponse =
1347 serde_json::from_value(sample_response).expect("Failed to deserialize");
1348
1349 if let Message::Assistant {
1351 thinking, content, ..
1352 } = &chat_resp.message
1353 {
1354 assert_eq!(
1355 thinking.as_ref().unwrap(),
1356 "Let me think about this carefully. The question asks for the meaning of life..."
1357 );
1358 assert_eq!(content, "The answer is 42.");
1359 } else {
1360 panic!("Expected Assistant message");
1361 }
1362 }
1363
1364 #[tokio::test]
1366 async fn test_chat_completion_without_thinking() {
1367 let sample_response = json!({
1368 "model": "llama3.2",
1369 "created_at": "2023-08-04T19:22:45.499127Z",
1370 "message": {
1371 "role": "assistant",
1372 "content": "Hello!",
1373 "images": null,
1374 "tool_calls": []
1375 },
1376 "done": true,
1377 "total_duration": 8000000000u64,
1378 "load_duration": 6000000u64,
1379 "prompt_eval_count": 10u64,
1380 "prompt_eval_duration": 400000000u64,
1381 "eval_count": 5u64,
1382 "eval_duration": 7700000000u64
1383 });
1384
1385 let chat_resp: CompletionResponse =
1386 serde_json::from_value(sample_response).expect("Failed to deserialize");
1387
1388 if let Message::Assistant {
1390 thinking, content, ..
1391 } = &chat_resp.message
1392 {
1393 assert!(thinking.is_none());
1394 assert_eq!(content, "Hello!");
1395 } else {
1396 panic!("Expected Assistant message");
1397 }
1398 }
1399
1400 #[test]
1402 fn test_streaming_response_with_thinking() {
1403 let sample_chunk = json!({
1404 "model": "qwen-thinking",
1405 "created_at": "2023-08-04T19:22:45.499127Z",
1406 "message": {
1407 "role": "assistant",
1408 "content": "",
1409 "thinking": "Analyzing the problem...",
1410 "images": null,
1411 "tool_calls": []
1412 },
1413 "done": false
1414 });
1415
1416 let chunk: CompletionResponse =
1417 serde_json::from_value(sample_chunk).expect("Failed to deserialize");
1418
1419 if let Message::Assistant {
1420 thinking, content, ..
1421 } = &chunk.message
1422 {
1423 assert_eq!(thinking.as_ref().unwrap(), "Analyzing the problem...");
1424 assert_eq!(content, "");
1425 } else {
1426 panic!("Expected Assistant message");
1427 }
1428 }
1429
1430 #[test]
1432 fn test_message_conversion_with_thinking() {
1433 let reasoning_content = crate::message::Reasoning::new("Step 1: Consider the problem");
1435
1436 let internal_msg = crate::message::Message::Assistant {
1437 id: None,
1438 content: crate::OneOrMany::many(vec![
1439 crate::message::AssistantContent::Reasoning(reasoning_content),
1440 crate::message::AssistantContent::Text(crate::message::Text {
1441 text: "The answer is X".to_string(),
1442 }),
1443 ])
1444 .unwrap(),
1445 };
1446
1447 let provider_msgs: Vec<Message> = internal_msg.try_into().unwrap();
1449 assert_eq!(provider_msgs.len(), 1);
1450
1451 if let Message::Assistant {
1452 thinking, content, ..
1453 } = &provider_msgs[0]
1454 {
1455 assert_eq!(thinking.as_ref().unwrap(), "Step 1: Consider the problem");
1456 assert_eq!(content, "The answer is X");
1457 } else {
1458 panic!("Expected Assistant message with thinking");
1459 }
1460 }
1461
1462 #[test]
1464 fn test_empty_thinking_content() {
1465 let sample_response = json!({
1466 "model": "llama3.2",
1467 "created_at": "2023-08-04T19:22:45.499127Z",
1468 "message": {
1469 "role": "assistant",
1470 "content": "Response",
1471 "thinking": "",
1472 "images": null,
1473 "tool_calls": []
1474 },
1475 "done": true,
1476 "total_duration": 8000000000u64,
1477 "load_duration": 6000000u64,
1478 "prompt_eval_count": 10u64,
1479 "prompt_eval_duration": 400000000u64,
1480 "eval_count": 5u64,
1481 "eval_duration": 7700000000u64
1482 });
1483
1484 let chat_resp: CompletionResponse =
1485 serde_json::from_value(sample_response).expect("Failed to deserialize");
1486
1487 if let Message::Assistant {
1488 thinking, content, ..
1489 } = &chat_resp.message
1490 {
1491 assert_eq!(thinking.as_ref().unwrap(), "");
1493 assert_eq!(content, "Response");
1494 } else {
1495 panic!("Expected Assistant message");
1496 }
1497 }
1498
1499 #[test]
1501 fn test_thinking_with_tool_calls() {
1502 let sample_response = json!({
1503 "model": "qwen-thinking",
1504 "created_at": "2023-08-04T19:22:45.499127Z",
1505 "message": {
1506 "role": "assistant",
1507 "content": "Let me check the weather.",
1508 "thinking": "User wants weather info, I should use the weather tool",
1509 "images": null,
1510 "tool_calls": [
1511 {
1512 "type": "function",
1513 "function": {
1514 "name": "get_weather",
1515 "arguments": {
1516 "location": "San Francisco"
1517 }
1518 }
1519 }
1520 ]
1521 },
1522 "done": true,
1523 "total_duration": 8000000000u64,
1524 "load_duration": 6000000u64,
1525 "prompt_eval_count": 30u64,
1526 "prompt_eval_duration": 400000000u64,
1527 "eval_count": 50u64,
1528 "eval_duration": 7700000000u64
1529 });
1530
1531 let chat_resp: CompletionResponse =
1532 serde_json::from_value(sample_response).expect("Failed to deserialize");
1533
1534 if let Message::Assistant {
1535 thinking,
1536 content,
1537 tool_calls,
1538 ..
1539 } = &chat_resp.message
1540 {
1541 assert_eq!(
1542 thinking.as_ref().unwrap(),
1543 "User wants weather info, I should use the weather tool"
1544 );
1545 assert_eq!(content, "Let me check the weather.");
1546 assert_eq!(tool_calls.len(), 1);
1547 assert_eq!(tool_calls[0].function.name, "get_weather");
1548 } else {
1549 panic!("Expected Assistant message with thinking and tool calls");
1550 }
1551 }
1552
1553 #[test]
1555 fn test_completion_request_with_think_param() {
1556 use crate::OneOrMany;
1557 use crate::completion::Message as CompletionMessage;
1558 use crate::message::{Text, UserContent};
1559
1560 let completion_request = CompletionRequest {
1562 model: None,
1563 preamble: Some("You are a helpful assistant.".to_string()),
1564 chat_history: OneOrMany::one(CompletionMessage::User {
1565 content: OneOrMany::one(UserContent::Text(Text {
1566 text: "What is 2 + 2?".to_string(),
1567 })),
1568 }),
1569 documents: vec![],
1570 tools: vec![],
1571 temperature: Some(0.7),
1572 max_tokens: Some(1024),
1573 tool_choice: None,
1574 additional_params: Some(json!({
1575 "think": true,
1576 "keep_alive": "-1m",
1577 "num_ctx": 4096
1578 })),
1579 output_schema: None,
1580 };
1581
1582 let ollama_request = OllamaCompletionRequest::try_from(("qwen3:8b", completion_request))
1584 .expect("Failed to create Ollama request");
1585
1586 let serialized =
1588 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1589
1590 let expected = json!({
1596 "model": "qwen3:8b",
1597 "messages": [
1598 {
1599 "role": "system",
1600 "content": "You are a helpful assistant."
1601 },
1602 {
1603 "role": "user",
1604 "content": "What is 2 + 2?"
1605 }
1606 ],
1607 "temperature": 0.7,
1608 "stream": false,
1609 "think": true,
1610 "max_tokens": 1024,
1611 "keep_alive": "-1m",
1612 "options": {
1613 "temperature": 0.7,
1614 "num_ctx": 4096
1615 }
1616 });
1617
1618 assert_eq!(serialized, expected);
1619 }
1620
1621 #[test]
1623 fn test_completion_request_with_think_false_default() {
1624 use crate::OneOrMany;
1625 use crate::completion::Message as CompletionMessage;
1626 use crate::message::{Text, UserContent};
1627
1628 let completion_request = CompletionRequest {
1630 model: None,
1631 preamble: Some("You are a helpful assistant.".to_string()),
1632 chat_history: OneOrMany::one(CompletionMessage::User {
1633 content: OneOrMany::one(UserContent::Text(Text {
1634 text: "Hello!".to_string(),
1635 })),
1636 }),
1637 documents: vec![],
1638 tools: vec![],
1639 temperature: Some(0.5),
1640 max_tokens: None,
1641 tool_choice: None,
1642 additional_params: None,
1643 output_schema: None,
1644 };
1645
1646 let ollama_request = OllamaCompletionRequest::try_from(("llama3.2", completion_request))
1648 .expect("Failed to create Ollama request");
1649
1650 let serialized =
1652 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1653
1654 let expected = json!({
1656 "model": "llama3.2",
1657 "messages": [
1658 {
1659 "role": "system",
1660 "content": "You are a helpful assistant."
1661 },
1662 {
1663 "role": "user",
1664 "content": "Hello!"
1665 }
1666 ],
1667 "temperature": 0.5,
1668 "stream": false,
1669 "think": false,
1670 "options": {
1671 "temperature": 0.5
1672 }
1673 });
1674
1675 assert_eq!(serialized, expected);
1676 }
1677
1678 #[test]
1679 fn test_completion_request_with_output_schema() {
1680 use crate::OneOrMany;
1681 use crate::completion::Message as CompletionMessage;
1682 use crate::message::{Text, UserContent};
1683
1684 let schema: schemars::Schema = serde_json::from_value(json!({
1685 "type": "object",
1686 "properties": {
1687 "age": { "type": "integer" },
1688 "available": { "type": "boolean" }
1689 },
1690 "required": ["age", "available"]
1691 }))
1692 .expect("Failed to parse schema");
1693
1694 let completion_request = CompletionRequest {
1695 model: Some("llama3.1".to_string()),
1696 preamble: None,
1697 chat_history: OneOrMany::one(CompletionMessage::User {
1698 content: OneOrMany::one(UserContent::Text(Text {
1699 text: "How old is Ollama?".to_string(),
1700 })),
1701 }),
1702 documents: vec![],
1703 tools: vec![],
1704 temperature: None,
1705 max_tokens: None,
1706 tool_choice: None,
1707 additional_params: None,
1708 output_schema: Some(schema),
1709 };
1710
1711 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1712 .expect("Failed to create Ollama request");
1713
1714 let serialized =
1715 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1716
1717 let format = serialized
1718 .get("format")
1719 .expect("format field should be present");
1720 assert_eq!(
1721 *format,
1722 json!({
1723 "type": "object",
1724 "properties": {
1725 "age": { "type": "integer" },
1726 "available": { "type": "boolean" }
1727 },
1728 "required": ["age", "available"]
1729 })
1730 );
1731 }
1732
1733 #[test]
1734 fn test_completion_request_without_output_schema() {
1735 use crate::OneOrMany;
1736 use crate::completion::Message as CompletionMessage;
1737 use crate::message::{Text, UserContent};
1738
1739 let completion_request = CompletionRequest {
1740 model: Some("llama3.1".to_string()),
1741 preamble: None,
1742 chat_history: OneOrMany::one(CompletionMessage::User {
1743 content: OneOrMany::one(UserContent::Text(Text {
1744 text: "Hello!".to_string(),
1745 })),
1746 }),
1747 documents: vec![],
1748 tools: vec![],
1749 temperature: None,
1750 max_tokens: None,
1751 tool_choice: None,
1752 additional_params: None,
1753 output_schema: None,
1754 };
1755
1756 let ollama_request = OllamaCompletionRequest::try_from(("llama3.1", completion_request))
1757 .expect("Failed to create Ollama request");
1758
1759 let serialized =
1760 serde_json::to_value(&ollama_request).expect("Failed to serialize request");
1761
1762 assert!(
1763 serialized.get("format").is_none(),
1764 "format field should be absent when output_schema is None"
1765 );
1766 }
1767
1768 #[test]
1769 fn test_client_initialization() {
1770 let _client = crate::providers::ollama::Client::new(Nothing).expect("Client::new() failed");
1771 let _client_from_builder = crate::providers::ollama::Client::builder()
1772 .api_key(Nothing)
1773 .build()
1774 .expect("Client::builder() failed");
1775 }
1776}