1use crate::json_utils::merge_inplace;
42use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
43use crate::{
44 agent::AgentBuilder,
45 completion::{self, CompletionError, CompletionRequest},
46 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
47 extractor::ExtractorBuilder,
48 json_utils, message,
49 message::{ImageDetail, Text},
50 Embed, OneOrMany,
51};
52use async_stream::stream;
53use futures::StreamExt;
54use reqwest;
55use schemars::JsonSchema;
56use serde::{Deserialize, Serialize};
57use serde_json::{json, Value};
58use std::convert::Infallible;
59use std::{convert::TryFrom, str::FromStr};
60const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64#[derive(Clone)]
65pub struct Client {
66 base_url: String,
67 http_client: reqwest::Client,
68}
69
70impl Default for Client {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Client {
77 pub fn new() -> Self {
78 Self::from_url(OLLAMA_API_BASE_URL)
79 }
80 pub fn from_url(base_url: &str) -> Self {
81 Self {
82 base_url: base_url.to_owned(),
83 http_client: reqwest::Client::builder()
84 .build()
85 .expect("Ollama reqwest client should build"),
86 }
87 }
88 fn post(&self, path: &str) -> reqwest::RequestBuilder {
89 let url = format!("{}/{}", self.base_url, path);
90 self.http_client.post(url)
91 }
92 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
93 EmbeddingModel::new(self.clone(), model, 0)
94 }
95 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
96 EmbeddingModel::new(self.clone(), model, ndims)
97 }
98 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
99 EmbeddingsBuilder::new(self.embedding_model(model))
100 }
101 pub fn completion_model(&self, model: &str) -> CompletionModel {
102 CompletionModel::new(self.clone(), model)
103 }
104 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
105 AgentBuilder::new(self.completion_model(model))
106 }
107 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
108 &self,
109 model: &str,
110 ) -> ExtractorBuilder<T, CompletionModel> {
111 ExtractorBuilder::new(self.completion_model(model))
112 }
113}
114
115#[derive(Debug, Deserialize)]
118struct ApiErrorResponse {
119 message: String,
120}
121
122#[derive(Debug, Deserialize)]
123#[serde(untagged)]
124enum ApiResponse<T> {
125 Ok(T),
126 Err(ApiErrorResponse),
127}
128
129pub const ALL_MINILM: &str = "all-minilm";
132pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
133
134#[derive(Debug, Serialize, Deserialize)]
135pub struct EmbeddingResponse {
136 pub model: String,
137 pub embeddings: Vec<Vec<f64>>,
138 #[serde(default)]
139 pub total_duration: Option<u64>,
140 #[serde(default)]
141 pub load_duration: Option<u64>,
142 #[serde(default)]
143 pub prompt_eval_count: Option<u64>,
144}
145
146impl From<ApiErrorResponse> for EmbeddingError {
147 fn from(err: ApiErrorResponse) -> Self {
148 EmbeddingError::ProviderError(err.message)
149 }
150}
151
152impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
153 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
154 match value {
155 ApiResponse::Ok(response) => Ok(response),
156 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
157 }
158 }
159}
160
161#[derive(Clone)]
164pub struct EmbeddingModel {
165 client: Client,
166 pub model: String,
167 ndims: usize,
168}
169
170impl EmbeddingModel {
171 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
172 Self {
173 client,
174 model: model.to_owned(),
175 ndims,
176 }
177 }
178}
179
180impl embeddings::EmbeddingModel for EmbeddingModel {
181 const MAX_DOCUMENTS: usize = 1024;
182 fn ndims(&self) -> usize {
183 self.ndims
184 }
185 #[cfg_attr(feature = "worker", worker::send)]
186 async fn embed_texts(
187 &self,
188 documents: impl IntoIterator<Item = String>,
189 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
190 let docs: Vec<String> = documents.into_iter().collect();
191 let payload = json!({
192 "model": self.model,
193 "input": docs,
194 });
195 let response = self
196 .client
197 .post("api/embed")
198 .json(&payload)
199 .send()
200 .await
201 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
202 if response.status().is_success() {
203 let api_resp: EmbeddingResponse = response
204 .json()
205 .await
206 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
207 if api_resp.embeddings.len() != docs.len() {
208 return Err(EmbeddingError::ResponseError(
209 "Number of returned embeddings does not match input".into(),
210 ));
211 }
212 Ok(api_resp
213 .embeddings
214 .into_iter()
215 .zip(docs.into_iter())
216 .map(|(vec, document)| embeddings::Embedding { document, vec })
217 .collect())
218 } else {
219 Err(EmbeddingError::ProviderError(response.text().await?))
220 }
221 }
222}
223
224pub const LLAMA3_2: &str = "llama3.2";
227pub const LLAVA: &str = "llava";
228pub const MISTRAL: &str = "mistral";
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionResponse {
232 pub model: String,
233 pub created_at: String,
234 pub message: Message,
235 pub done: bool,
236 #[serde(default)]
237 pub done_reason: Option<String>,
238 #[serde(default)]
239 pub total_duration: Option<u64>,
240 #[serde(default)]
241 pub load_duration: Option<u64>,
242 #[serde(default)]
243 pub prompt_eval_count: Option<u64>,
244 #[serde(default)]
245 pub prompt_eval_duration: Option<u64>,
246 #[serde(default)]
247 pub eval_count: Option<u64>,
248 #[serde(default)]
249 pub eval_duration: Option<u64>,
250}
251impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
252 type Error = CompletionError;
253 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
254 match resp.message {
255 Message::Assistant {
257 content,
258 tool_calls,
259 ..
260 } => {
261 let mut assistant_contents = Vec::new();
262 if !content.is_empty() {
264 assistant_contents.push(completion::AssistantContent::text(&content));
265 }
266 for tc in tool_calls.iter() {
269 assistant_contents.push(completion::AssistantContent::tool_call(
270 tc.function.name.clone(),
271 tc.function.name.clone(),
272 tc.function.arguments.clone(),
273 ));
274 }
275 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
276 CompletionError::ResponseError("No content provided".to_owned())
277 })?;
278 let raw_response = CompletionResponse {
279 model: resp.model,
280 created_at: resp.created_at,
281 done: resp.done,
282 done_reason: resp.done_reason,
283 total_duration: resp.total_duration,
284 load_duration: resp.load_duration,
285 prompt_eval_count: resp.prompt_eval_count,
286 prompt_eval_duration: resp.prompt_eval_duration,
287 eval_count: resp.eval_count,
288 eval_duration: resp.eval_duration,
289 message: Message::Assistant {
290 content,
291 images: None,
292 name: None,
293 tool_calls,
294 },
295 };
296 Ok(completion::CompletionResponse {
297 choice,
298 raw_response,
299 })
300 }
301 _ => Err(CompletionError::ResponseError(
302 "Chat response does not include an assistant message".into(),
303 )),
304 }
305 }
306}
307
308#[derive(Clone)]
311pub struct CompletionModel {
312 client: Client,
313 pub model: String,
314}
315
316impl CompletionModel {
317 pub fn new(client: Client, model: &str) -> Self {
318 Self {
319 client,
320 model: model.to_owned(),
321 }
322 }
323
324 fn create_completion_request(
325 &self,
326 completion_request: CompletionRequest,
327 ) -> Result<Value, CompletionError> {
328 let mut partial_history = vec![];
330 if let Some(docs) = completion_request.normalized_documents() {
331 partial_history.push(docs);
332 }
333 partial_history.extend(completion_request.chat_history);
334
335 let mut full_history: Vec<Message> = completion_request
337 .preamble
338 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
339
340 full_history.extend(
342 partial_history
343 .into_iter()
344 .map(|msg| msg.try_into())
345 .collect::<Result<Vec<Message>, _>>()?,
346 );
347
348 let options = if let Some(extra) = completion_request.additional_params {
350 json_utils::merge(
351 json!({ "temperature": completion_request.temperature }),
352 extra,
353 )
354 } else {
355 json!({ "temperature": completion_request.temperature })
356 };
357
358 let mut request_payload = json!({
359 "model": self.model,
360 "messages": full_history,
361 "options": options,
362 "stream": false,
363 });
364 if !completion_request.tools.is_empty() {
365 request_payload["tools"] = json!(completion_request
366 .tools
367 .into_iter()
368 .map(|tool| tool.into())
369 .collect::<Vec<ToolDefinition>>());
370 }
371
372 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
373
374 Ok(request_payload)
375 }
376}
377
378impl completion::CompletionModel for CompletionModel {
381 type Response = CompletionResponse;
382
383 #[cfg_attr(feature = "worker", worker::send)]
384 async fn completion(
385 &self,
386 completion_request: CompletionRequest,
387 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
388 let request_payload = self.create_completion_request(completion_request)?;
389
390 let response = self
391 .client
392 .post("api/chat")
393 .json(&request_payload)
394 .send()
395 .await
396 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
397 if response.status().is_success() {
398 let text = response
399 .text()
400 .await
401 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
402 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
403 let chat_resp: CompletionResponse = serde_json::from_str(&text)
404 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
405 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
406 Ok(conv)
407 } else {
408 let err_text = response
409 .text()
410 .await
411 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
412 Err(CompletionError::ProviderError(err_text))
413 }
414 }
415}
416
417impl StreamingCompletionModel for CompletionModel {
418 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
419 let mut request_payload = self.create_completion_request(request)?;
420 merge_inplace(&mut request_payload, json!({"stream": true}));
421
422 let response = self
423 .client
424 .post("api/chat")
425 .json(&request_payload)
426 .send()
427 .await
428 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
429
430 if !response.status().is_success() {
431 let err_text = response
432 .text()
433 .await
434 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
435 return Err(CompletionError::ProviderError(err_text));
436 }
437
438 Ok(Box::pin(stream! {
439 let mut stream = response.bytes_stream();
440 while let Some(chunk_result) = stream.next().await {
441 let chunk = match chunk_result {
442 Ok(c) => c,
443 Err(e) => {
444 yield Err(CompletionError::from(e));
445 break;
446 }
447 };
448
449 let text = match String::from_utf8(chunk.to_vec()) {
450 Ok(t) => t,
451 Err(e) => {
452 yield Err(CompletionError::ResponseError(e.to_string()));
453 break;
454 }
455 };
456
457
458 for line in text.lines() {
459 let line = line.to_string();
460
461 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
462 continue;
463 };
464
465 match response.message {
466 Message::Assistant{ content, tool_calls, .. } => {
467 if !content.is_empty() {
468 yield Ok(StreamingChoice::Message(content))
469 }
470
471 for tool_call in tool_calls.iter() {
472 let function = tool_call.function.clone();
473
474 yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
475 }
476 }
477 _ => {
478 continue;
479 }
480 }
481 }
482 }
483 }))
484 }
485}
486
487#[derive(Clone, Debug, Deserialize, Serialize)]
491pub struct ToolDefinition {
492 #[serde(rename = "type")]
493 pub type_field: String, pub function: completion::ToolDefinition,
495}
496
497impl From<crate::completion::ToolDefinition> for ToolDefinition {
499 fn from(tool: crate::completion::ToolDefinition) -> Self {
500 ToolDefinition {
501 type_field: "function".to_owned(),
502 function: completion::ToolDefinition {
503 name: tool.name,
504 description: tool.description,
505 parameters: tool.parameters,
506 },
507 }
508 }
509}
510
511#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
512pub struct ToolCall {
513 #[serde(default, rename = "type")]
515 pub r#type: ToolType,
516 pub function: Function,
517}
518#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
519#[serde(rename_all = "lowercase")]
520pub enum ToolType {
521 #[default]
522 Function,
523}
524#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
525pub struct Function {
526 pub name: String,
527 pub arguments: Value,
528}
529
530#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
533#[serde(tag = "role", rename_all = "lowercase")]
534pub enum Message {
535 User {
536 content: String,
537 #[serde(skip_serializing_if = "Option::is_none")]
538 images: Option<Vec<String>>,
539 #[serde(skip_serializing_if = "Option::is_none")]
540 name: Option<String>,
541 },
542 Assistant {
543 #[serde(default)]
544 content: String,
545 #[serde(skip_serializing_if = "Option::is_none")]
546 images: Option<Vec<String>>,
547 #[serde(skip_serializing_if = "Option::is_none")]
548 name: Option<String>,
549 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
550 tool_calls: Vec<ToolCall>,
551 },
552 System {
553 content: String,
554 #[serde(skip_serializing_if = "Option::is_none")]
555 images: Option<Vec<String>>,
556 #[serde(skip_serializing_if = "Option::is_none")]
557 name: Option<String>,
558 },
559 #[serde(rename = "Tool")]
560 ToolResult {
561 tool_call_id: String,
562 content: OneOrMany<ToolResultContent>,
563 },
564}
565
566impl TryFrom<crate::message::Message> for Message {
572 type Error = crate::message::MessageError;
573 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
574 use crate::message::Message as InternalMessage;
575 match internal_msg {
576 InternalMessage::User { content, .. } => {
577 let mut texts = Vec::new();
578 let mut images = Vec::new();
579 for uc in content.into_iter() {
580 match uc {
581 crate::message::UserContent::Text(t) => texts.push(t.text),
582 crate::message::UserContent::Image(img) => images.push(img.data),
583 _ => {} }
585 }
586 let content_str = texts.join(" ");
587 let images_opt = if images.is_empty() {
588 None
589 } else {
590 Some(images)
591 };
592 Ok(Message::User {
593 content: content_str,
594 images: images_opt,
595 name: None,
596 })
597 }
598 InternalMessage::Assistant { content, .. } => {
599 let mut texts = Vec::new();
600 let mut tool_calls = Vec::new();
601 for ac in content.into_iter() {
602 match ac {
603 crate::message::AssistantContent::Text(t) => texts.push(t.text),
604 crate::message::AssistantContent::ToolCall(tc) => {
605 tool_calls.push(ToolCall {
606 r#type: ToolType::Function, function: Function {
608 name: tc.function.name,
609 arguments: tc.function.arguments,
610 },
611 });
612 }
613 }
614 }
615 let content_str = texts.join(" ");
616 Ok(Message::Assistant {
617 content: content_str,
618 images: None,
619 name: None,
620 tool_calls,
621 })
622 }
623 }
624 }
625}
626
627impl From<Message> for crate::completion::Message {
630 fn from(msg: Message) -> Self {
631 match msg {
632 Message::User { content, .. } => crate::completion::Message::User {
633 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
634 text: content,
635 })),
636 },
637 Message::Assistant {
638 content,
639 tool_calls,
640 ..
641 } => {
642 let mut assistant_contents =
643 vec![crate::completion::message::AssistantContent::Text(Text {
644 text: content,
645 })];
646 for tc in tool_calls {
647 assistant_contents.push(
648 crate::completion::message::AssistantContent::tool_call(
649 tc.function.name.clone(),
650 tc.function.name,
651 tc.function.arguments,
652 ),
653 );
654 }
655 crate::completion::Message::Assistant {
656 content: OneOrMany::many(assistant_contents).unwrap(),
657 }
658 }
659 Message::System { content, .. } => crate::completion::Message::User {
661 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
662 text: content,
663 })),
664 },
665 Message::ToolResult {
666 tool_call_id,
667 content,
668 } => crate::completion::Message::User {
669 content: OneOrMany::one(message::UserContent::tool_result(
670 tool_call_id,
671 content.map(|content| message::ToolResultContent::text(content.text)),
672 )),
673 },
674 }
675 }
676}
677
678impl Message {
679 pub fn system(content: &str) -> Self {
681 Message::System {
682 content: content.to_owned(),
683 images: None,
684 name: None,
685 }
686 }
687}
688
689#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
692pub struct ToolResultContent {
693 text: String,
694}
695
696impl FromStr for ToolResultContent {
697 type Err = Infallible;
698
699 fn from_str(s: &str) -> Result<Self, Self::Err> {
700 Ok(s.to_owned().into())
701 }
702}
703
704impl From<String> for ToolResultContent {
705 fn from(s: String) -> Self {
706 ToolResultContent { text: s }
707 }
708}
709
710#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
711pub struct SystemContent {
712 #[serde(default)]
713 r#type: SystemContentType,
714 text: String,
715}
716
717#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
718#[serde(rename_all = "lowercase")]
719pub enum SystemContentType {
720 #[default]
721 Text,
722}
723
724impl From<String> for SystemContent {
725 fn from(s: String) -> Self {
726 SystemContent {
727 r#type: SystemContentType::default(),
728 text: s,
729 }
730 }
731}
732
733impl FromStr for SystemContent {
734 type Err = std::convert::Infallible;
735 fn from_str(s: &str) -> Result<Self, Self::Err> {
736 Ok(SystemContent {
737 r#type: SystemContentType::default(),
738 text: s.to_string(),
739 })
740 }
741}
742
743#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
744pub struct AssistantContent {
745 pub text: String,
746}
747
748impl FromStr for AssistantContent {
749 type Err = std::convert::Infallible;
750 fn from_str(s: &str) -> Result<Self, Self::Err> {
751 Ok(AssistantContent { text: s.to_owned() })
752 }
753}
754
755#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
756#[serde(tag = "type", rename_all = "lowercase")]
757pub enum UserContent {
758 Text { text: String },
759 Image { image_url: ImageUrl },
760 }
762
763impl FromStr for UserContent {
764 type Err = std::convert::Infallible;
765 fn from_str(s: &str) -> Result<Self, Self::Err> {
766 Ok(UserContent::Text { text: s.to_owned() })
767 }
768}
769
770#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
771pub struct ImageUrl {
772 pub url: String,
773 #[serde(default)]
774 pub detail: ImageDetail,
775}
776
777#[cfg(test)]
782mod tests {
783 use super::*;
784 use serde_json::json;
785
786 #[tokio::test]
788 async fn test_chat_completion() {
789 let sample_chat_response = json!({
791 "model": "llama3.2",
792 "created_at": "2023-08-04T19:22:45.499127Z",
793 "message": {
794 "role": "assistant",
795 "content": "The sky is blue because of Rayleigh scattering.",
796 "images": null,
797 "tool_calls": [
798 {
799 "type": "function",
800 "function": {
801 "name": "get_current_weather",
802 "arguments": {
803 "location": "San Francisco, CA",
804 "format": "celsius"
805 }
806 }
807 }
808 ]
809 },
810 "done": true,
811 "total_duration": 8000000000u64,
812 "load_duration": 6000000u64,
813 "prompt_eval_count": 61u64,
814 "prompt_eval_duration": 400000000u64,
815 "eval_count": 468u64,
816 "eval_duration": 7700000000u64
817 });
818 let sample_text = sample_chat_response.to_string();
819
820 let chat_resp: CompletionResponse =
821 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
822 let conv: completion::CompletionResponse<CompletionResponse> =
823 chat_resp.try_into().unwrap();
824 assert!(
825 !conv.choice.is_empty(),
826 "Expected non-empty choice in chat response"
827 );
828 }
829
830 #[test]
832 fn test_message_conversion() {
833 let provider_msg = Message::User {
835 content: "Test message".to_owned(),
836 images: None,
837 name: None,
838 };
839 let comp_msg: crate::completion::Message = provider_msg.into();
841 match comp_msg {
842 crate::completion::Message::User { content } => {
843 let first_content = content.first();
845 match first_content {
847 crate::completion::message::UserContent::Text(text_struct) => {
848 assert_eq!(text_struct.text, "Test message");
849 }
850 _ => panic!("Expected text content in conversion"),
851 }
852 }
853 _ => panic!("Conversion from provider Message to completion Message failed"),
854 }
855 }
856
857 #[test]
859 fn test_tool_definition_conversion() {
860 let internal_tool = crate::completion::ToolDefinition {
862 name: "get_current_weather".to_owned(),
863 description: "Get the current weather for a location".to_owned(),
864 parameters: json!({
865 "type": "object",
866 "properties": {
867 "location": {
868 "type": "string",
869 "description": "The location to get the weather for, e.g. San Francisco, CA"
870 },
871 "format": {
872 "type": "string",
873 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
874 "enum": ["celsius", "fahrenheit"]
875 }
876 },
877 "required": ["location", "format"]
878 }),
879 };
880 let ollama_tool: ToolDefinition = internal_tool.into();
882 assert_eq!(ollama_tool.type_field, "function");
883 assert_eq!(ollama_tool.function.name, "get_current_weather");
884 assert_eq!(
885 ollama_tool.function.description,
886 "Get the current weather for a location"
887 );
888 let params = &ollama_tool.function.parameters;
890 assert_eq!(params["properties"]["location"]["type"], "string");
891 }
892}