1use crate::{
42 agent::AgentBuilder,
43 completion::{self, CompletionError, CompletionRequest},
44 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
45 extractor::ExtractorBuilder,
46 json_utils, message,
47 message::{ImageDetail, Text},
48 Embed, OneOrMany,
49};
50use reqwest;
51use schemars::JsonSchema;
52use serde::{Deserialize, Serialize};
53use serde_json::{json, Value};
54use std::convert::Infallible;
55use std::{convert::TryFrom, str::FromStr};
56
57const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
60
61#[derive(Clone)]
62pub struct Client {
63 base_url: String,
64 http_client: reqwest::Client,
65}
66
67impl Default for Client {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl Client {
74 pub fn new() -> Self {
75 Self::from_url(OLLAMA_API_BASE_URL)
76 }
77 pub fn from_url(base_url: &str) -> Self {
78 Self {
79 base_url: base_url.to_owned(),
80 http_client: reqwest::Client::builder()
81 .build()
82 .expect("Ollama reqwest client should build"),
83 }
84 }
85 fn post(&self, path: &str) -> reqwest::RequestBuilder {
86 let url = format!("{}/{}", self.base_url, path);
87 self.http_client.post(url)
88 }
89 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
90 EmbeddingModel::new(self.clone(), model, 0)
91 }
92 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
93 EmbeddingModel::new(self.clone(), model, ndims)
94 }
95 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
96 EmbeddingsBuilder::new(self.embedding_model(model))
97 }
98 pub fn completion_model(&self, model: &str) -> CompletionModel {
99 CompletionModel::new(self.clone(), model)
100 }
101 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
102 AgentBuilder::new(self.completion_model(model))
103 }
104 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
105 &self,
106 model: &str,
107 ) -> ExtractorBuilder<T, CompletionModel> {
108 ExtractorBuilder::new(self.completion_model(model))
109 }
110}
111
112#[derive(Debug, Deserialize)]
115struct ApiErrorResponse {
116 message: String,
117}
118
119#[derive(Debug, Deserialize)]
120#[serde(untagged)]
121enum ApiResponse<T> {
122 Ok(T),
123 Err(ApiErrorResponse),
124}
125
126pub const ALL_MINILM: &str = "all-minilm";
129pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
130
131#[derive(Debug, Serialize, Deserialize)]
132pub struct EmbeddingResponse {
133 pub model: String,
134 pub embeddings: Vec<Vec<f64>>,
135 #[serde(default)]
136 pub total_duration: Option<u64>,
137 #[serde(default)]
138 pub load_duration: Option<u64>,
139 #[serde(default)]
140 pub prompt_eval_count: Option<u64>,
141}
142
143impl From<ApiErrorResponse> for EmbeddingError {
144 fn from(err: ApiErrorResponse) -> Self {
145 EmbeddingError::ProviderError(err.message)
146 }
147}
148
149impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
150 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
151 match value {
152 ApiResponse::Ok(response) => Ok(response),
153 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
154 }
155 }
156}
157
158#[derive(Clone)]
161pub struct EmbeddingModel {
162 client: Client,
163 pub model: String,
164 ndims: usize,
165}
166
167impl EmbeddingModel {
168 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
169 Self {
170 client,
171 model: model.to_owned(),
172 ndims,
173 }
174 }
175}
176
177impl embeddings::EmbeddingModel for EmbeddingModel {
178 const MAX_DOCUMENTS: usize = 1024;
179 fn ndims(&self) -> usize {
180 self.ndims
181 }
182 #[cfg_attr(feature = "worker", worker::send)]
183 async fn embed_texts(
184 &self,
185 documents: impl IntoIterator<Item = String>,
186 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
187 let docs: Vec<String> = documents.into_iter().collect();
188 let payload = json!({
189 "model": self.model,
190 "input": docs,
191 });
192 let response = self
193 .client
194 .post("api/embed")
195 .json(&payload)
196 .send()
197 .await
198 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
199 if response.status().is_success() {
200 let api_resp: EmbeddingResponse = response
201 .json()
202 .await
203 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
204 if api_resp.embeddings.len() != docs.len() {
205 return Err(EmbeddingError::ResponseError(
206 "Number of returned embeddings does not match input".into(),
207 ));
208 }
209 Ok(api_resp
210 .embeddings
211 .into_iter()
212 .zip(docs.into_iter())
213 .map(|(vec, document)| embeddings::Embedding { document, vec })
214 .collect())
215 } else {
216 Err(EmbeddingError::ProviderError(response.text().await?))
217 }
218 }
219}
220
221pub const LLAMA3_2: &str = "llama3.2";
224pub const LLAVA: &str = "llava";
225pub const MISTRAL: &str = "mistral";
226
227#[derive(Debug, Serialize, Deserialize)]
228pub struct CompletionResponse {
229 pub model: String,
230 pub created_at: String,
231 pub message: Message,
232 pub done: bool,
233 #[serde(default)]
234 pub done_reason: Option<String>,
235 #[serde(default)]
236 pub total_duration: Option<u64>,
237 #[serde(default)]
238 pub load_duration: Option<u64>,
239 #[serde(default)]
240 pub prompt_eval_count: Option<u64>,
241 #[serde(default)]
242 pub prompt_eval_duration: Option<u64>,
243 #[serde(default)]
244 pub eval_count: Option<u64>,
245 #[serde(default)]
246 pub eval_duration: Option<u64>,
247}
248impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
249 type Error = CompletionError;
250 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
251 match resp.message {
252 Message::Assistant {
254 content,
255 tool_calls,
256 ..
257 } => {
258 let mut assistant_contents = Vec::new();
259 if !content.is_empty() {
261 assistant_contents.push(completion::AssistantContent::text(&content));
262 }
263 for tc in tool_calls.iter() {
266 assistant_contents.push(completion::AssistantContent::tool_call(
267 tc.function.name.clone(),
268 tc.function.name.clone(),
269 tc.function.arguments.clone(),
270 ));
271 }
272 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
273 CompletionError::ResponseError("No content provided".to_owned())
274 })?;
275 let raw_response = CompletionResponse {
276 model: resp.model,
277 created_at: resp.created_at,
278 done: resp.done,
279 done_reason: resp.done_reason,
280 total_duration: resp.total_duration,
281 load_duration: resp.load_duration,
282 prompt_eval_count: resp.prompt_eval_count,
283 prompt_eval_duration: resp.prompt_eval_duration,
284 eval_count: resp.eval_count,
285 eval_duration: resp.eval_duration,
286 message: Message::Assistant {
287 content,
288 images: None,
289 name: None,
290 tool_calls,
291 },
292 };
293 Ok(completion::CompletionResponse {
294 choice,
295 raw_response,
296 })
297 }
298 _ => Err(CompletionError::ResponseError(
299 "Chat response does not include an assistant message".into(),
300 )),
301 }
302 }
303}
304
305#[derive(Clone)]
308pub struct CompletionModel {
309 client: Client,
310 pub model: String,
311}
312
313impl CompletionModel {
314 pub fn new(client: Client, model: &str) -> Self {
315 Self {
316 client,
317 model: model.to_owned(),
318 }
319 }
320}
321
322impl completion::CompletionModel for CompletionModel {
325 type Response = CompletionResponse;
326
327 #[cfg_attr(feature = "worker", worker::send)]
328 async fn completion(
329 &self,
330 completion_request: CompletionRequest,
331 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
332 let prompt: Message = completion_request.prompt_with_context().try_into()?;
334 let options = if let Some(extra) = completion_request.additional_params {
335 json_utils::merge(
336 json!({ "temperature": completion_request.temperature }),
337 extra,
338 )
339 } else {
340 json!({ "temperature": completion_request.temperature })
341 };
342
343 let mut full_history = Vec::new();
345 if let Some(preamble) = completion_request.preamble {
346 full_history.push(Message::system(&preamble));
347 }
348 for msg in completion_request.chat_history.into_iter() {
349 full_history.push(Message::try_from(msg)?);
350 }
351 full_history.push(prompt);
352
353 let mut request_payload = json!({
354 "model": self.model,
355 "messages": full_history,
356 "options": options,
357 "stream": false,
358 });
359 if !completion_request.tools.is_empty() {
360 request_payload["tools"] = json!(completion_request
361 .tools
362 .into_iter()
363 .map(|tool| tool.into())
364 .collect::<Vec<ToolDefinition>>());
365 }
366
367 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
368 let response = self
369 .client
370 .post("api/chat")
371 .json(&request_payload)
372 .send()
373 .await
374 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
375 if response.status().is_success() {
376 let text = response
377 .text()
378 .await
379 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
380 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
381 let chat_resp: CompletionResponse = serde_json::from_str(&text)
382 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
383 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
384 Ok(conv)
385 } else {
386 let err_text = response
387 .text()
388 .await
389 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
390 Err(CompletionError::ProviderError(err_text))
391 }
392 }
393}
394
395#[derive(Clone, Debug, Deserialize, Serialize)]
399pub struct ToolDefinition {
400 #[serde(rename = "type")]
401 pub type_field: String, pub function: completion::ToolDefinition,
403}
404
405impl From<crate::completion::ToolDefinition> for ToolDefinition {
407 fn from(tool: crate::completion::ToolDefinition) -> Self {
408 ToolDefinition {
409 type_field: "function".to_owned(),
410 function: completion::ToolDefinition {
411 name: tool.name,
412 description: tool.description,
413 parameters: tool.parameters,
414 },
415 }
416 }
417}
418
419#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
420pub struct ToolCall {
421 #[serde(default, rename = "type")]
423 pub r#type: ToolType,
424 pub function: Function,
425}
426#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
427#[serde(rename_all = "lowercase")]
428pub enum ToolType {
429 #[default]
430 Function,
431}
432#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
433pub struct Function {
434 pub name: String,
435 pub arguments: Value,
436}
437
438#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
441#[serde(tag = "role", rename_all = "lowercase")]
442pub enum Message {
443 User {
444 content: String,
445 #[serde(skip_serializing_if = "Option::is_none")]
446 images: Option<Vec<String>>,
447 #[serde(skip_serializing_if = "Option::is_none")]
448 name: Option<String>,
449 },
450 Assistant {
451 #[serde(default)]
452 content: String,
453 #[serde(skip_serializing_if = "Option::is_none")]
454 images: Option<Vec<String>>,
455 #[serde(skip_serializing_if = "Option::is_none")]
456 name: Option<String>,
457 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
458 tool_calls: Vec<ToolCall>,
459 },
460 System {
461 content: String,
462 #[serde(skip_serializing_if = "Option::is_none")]
463 images: Option<Vec<String>>,
464 #[serde(skip_serializing_if = "Option::is_none")]
465 name: Option<String>,
466 },
467 #[serde(rename = "Tool")]
468 ToolResult {
469 tool_call_id: String,
470 content: OneOrMany<ToolResultContent>,
471 },
472}
473
474impl TryFrom<crate::message::Message> for Message {
480 type Error = crate::message::MessageError;
481 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
482 use crate::message::Message as InternalMessage;
483 match internal_msg {
484 InternalMessage::User { content, .. } => {
485 let mut texts = Vec::new();
486 let mut images = Vec::new();
487 for uc in content.into_iter() {
488 match uc {
489 crate::message::UserContent::Text(t) => texts.push(t.text),
490 crate::message::UserContent::Image(img) => images.push(img.data),
491 _ => {} }
493 }
494 let content_str = texts.join(" ");
495 let images_opt = if images.is_empty() {
496 None
497 } else {
498 Some(images)
499 };
500 Ok(Message::User {
501 content: content_str,
502 images: images_opt,
503 name: None,
504 })
505 }
506 InternalMessage::Assistant { content, .. } => {
507 let mut texts = Vec::new();
508 let mut tool_calls = Vec::new();
509 for ac in content.into_iter() {
510 match ac {
511 crate::message::AssistantContent::Text(t) => texts.push(t.text),
512 crate::message::AssistantContent::ToolCall(tc) => {
513 tool_calls.push(ToolCall {
514 r#type: ToolType::Function, function: Function {
516 name: tc.function.name,
517 arguments: tc.function.arguments,
518 },
519 });
520 }
521 }
522 }
523 let content_str = texts.join(" ");
524 Ok(Message::Assistant {
525 content: content_str,
526 images: None,
527 name: None,
528 tool_calls,
529 })
530 }
531 }
532 }
533}
534
535impl From<Message> for crate::completion::Message {
538 fn from(msg: Message) -> Self {
539 match msg {
540 Message::User { content, .. } => crate::completion::Message::User {
541 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
542 text: content,
543 })),
544 },
545 Message::Assistant {
546 content,
547 tool_calls,
548 ..
549 } => {
550 let mut assistant_contents =
551 vec![crate::completion::message::AssistantContent::Text(Text {
552 text: content,
553 })];
554 for tc in tool_calls {
555 assistant_contents.push(
556 crate::completion::message::AssistantContent::tool_call(
557 tc.function.name.clone(),
558 tc.function.name,
559 tc.function.arguments,
560 ),
561 );
562 }
563 crate::completion::Message::Assistant {
564 content: OneOrMany::many(assistant_contents).unwrap(),
565 }
566 }
567 Message::System { content, .. } => crate::completion::Message::User {
569 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
570 text: content,
571 })),
572 },
573 Message::ToolResult {
574 tool_call_id,
575 content,
576 } => crate::completion::Message::User {
577 content: OneOrMany::one(message::UserContent::tool_result(
578 tool_call_id,
579 content.map(|content| message::ToolResultContent::text(content.text)),
580 )),
581 },
582 }
583 }
584}
585
586impl Message {
587 pub fn system(content: &str) -> Self {
589 Message::System {
590 content: content.to_owned(),
591 images: None,
592 name: None,
593 }
594 }
595}
596
597#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
600pub struct ToolResultContent {
601 text: String,
602}
603
604impl FromStr for ToolResultContent {
605 type Err = Infallible;
606
607 fn from_str(s: &str) -> Result<Self, Self::Err> {
608 Ok(s.to_owned().into())
609 }
610}
611
612impl From<String> for ToolResultContent {
613 fn from(s: String) -> Self {
614 ToolResultContent { text: s }
615 }
616}
617
618#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
619pub struct SystemContent {
620 #[serde(default)]
621 r#type: SystemContentType,
622 text: String,
623}
624
625#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
626#[serde(rename_all = "lowercase")]
627pub enum SystemContentType {
628 #[default]
629 Text,
630}
631
632impl From<String> for SystemContent {
633 fn from(s: String) -> Self {
634 SystemContent {
635 r#type: SystemContentType::default(),
636 text: s,
637 }
638 }
639}
640
641impl FromStr for SystemContent {
642 type Err = std::convert::Infallible;
643 fn from_str(s: &str) -> Result<Self, Self::Err> {
644 Ok(SystemContent {
645 r#type: SystemContentType::default(),
646 text: s.to_string(),
647 })
648 }
649}
650
651#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
652pub struct AssistantContent {
653 pub text: String,
654}
655
656impl FromStr for AssistantContent {
657 type Err = std::convert::Infallible;
658 fn from_str(s: &str) -> Result<Self, Self::Err> {
659 Ok(AssistantContent { text: s.to_owned() })
660 }
661}
662
663#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
664#[serde(tag = "type", rename_all = "lowercase")]
665pub enum UserContent {
666 Text { text: String },
667 Image { image_url: ImageUrl },
668 }
670
671impl FromStr for UserContent {
672 type Err = std::convert::Infallible;
673 fn from_str(s: &str) -> Result<Self, Self::Err> {
674 Ok(UserContent::Text { text: s.to_owned() })
675 }
676}
677
678#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
679pub struct ImageUrl {
680 pub url: String,
681 #[serde(default)]
682 pub detail: ImageDetail,
683}
684
685#[cfg(test)]
690mod tests {
691 use super::*;
692 use serde_json::json;
693
694 #[tokio::test]
696 async fn test_chat_completion() {
697 let sample_chat_response = json!({
699 "model": "llama3.2",
700 "created_at": "2023-08-04T19:22:45.499127Z",
701 "message": {
702 "role": "assistant",
703 "content": "The sky is blue because of Rayleigh scattering.",
704 "images": null,
705 "tool_calls": [
706 {
707 "type": "function",
708 "function": {
709 "name": "get_current_weather",
710 "arguments": {
711 "location": "San Francisco, CA",
712 "format": "celsius"
713 }
714 }
715 }
716 ]
717 },
718 "done": true,
719 "total_duration": 8000000000u64,
720 "load_duration": 6000000u64,
721 "prompt_eval_count": 61u64,
722 "prompt_eval_duration": 400000000u64,
723 "eval_count": 468u64,
724 "eval_duration": 7700000000u64
725 });
726 let sample_text = sample_chat_response.to_string();
727
728 let chat_resp: CompletionResponse =
729 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
730 let conv: completion::CompletionResponse<CompletionResponse> =
731 chat_resp.try_into().unwrap();
732 assert!(
733 !conv.choice.is_empty(),
734 "Expected non-empty choice in chat response"
735 );
736 }
737
738 #[test]
740 fn test_message_conversion() {
741 let provider_msg = Message::User {
743 content: "Test message".to_owned(),
744 images: None,
745 name: None,
746 };
747 let comp_msg: crate::completion::Message = provider_msg.into();
749 match comp_msg {
750 crate::completion::Message::User { content } => {
751 let first_content = content.first();
753 match first_content {
755 crate::completion::message::UserContent::Text(text_struct) => {
756 assert_eq!(text_struct.text, "Test message");
757 }
758 _ => panic!("Expected text content in conversion"),
759 }
760 }
761 _ => panic!("Conversion from provider Message to completion Message failed"),
762 }
763 }
764
765 #[test]
767 fn test_tool_definition_conversion() {
768 let internal_tool = crate::completion::ToolDefinition {
770 name: "get_current_weather".to_owned(),
771 description: "Get the current weather for a location".to_owned(),
772 parameters: json!({
773 "type": "object",
774 "properties": {
775 "location": {
776 "type": "string",
777 "description": "The location to get the weather for, e.g. San Francisco, CA"
778 },
779 "format": {
780 "type": "string",
781 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
782 "enum": ["celsius", "fahrenheit"]
783 }
784 },
785 "required": ["location", "format"]
786 }),
787 };
788 let ollama_tool: ToolDefinition = internal_tool.into();
790 assert_eq!(ollama_tool.type_field, "function");
791 assert_eq!(ollama_tool.function.name, "get_current_weather");
792 assert_eq!(
793 ollama_tool.function.description,
794 "Get the current weather for a location"
795 );
796 let params = &ollama_tool.function.parameters;
798 assert_eq!(params["properties"]["location"]["type"], "string");
799 }
800}