1use crate::json_utils::merge_inplace;
42use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
43use crate::{
44 agent::AgentBuilder,
45 completion::{self, CompletionError, CompletionRequest},
46 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
47 extractor::ExtractorBuilder,
48 json_utils, message,
49 message::{ImageDetail, Text},
50 Embed, OneOrMany,
51};
52use async_stream::stream;
53use futures::StreamExt;
54use reqwest;
55use schemars::JsonSchema;
56use serde::{Deserialize, Serialize};
57use serde_json::{json, Value};
58use std::convert::Infallible;
59use std::{convert::TryFrom, str::FromStr};
60const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
63
64#[derive(Clone)]
65pub struct Client {
66 base_url: String,
67 http_client: reqwest::Client,
68}
69
70impl Default for Client {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Client {
77 pub fn new() -> Self {
78 Self::from_url(OLLAMA_API_BASE_URL)
79 }
80 pub fn from_url(base_url: &str) -> Self {
81 Self {
82 base_url: base_url.to_owned(),
83 http_client: reqwest::Client::builder()
84 .build()
85 .expect("Ollama reqwest client should build"),
86 }
87 }
88 fn post(&self, path: &str) -> reqwest::RequestBuilder {
89 let url = format!("{}/{}", self.base_url, path);
90 self.http_client.post(url)
91 }
92 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
93 EmbeddingModel::new(self.clone(), model, 0)
94 }
95 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
96 EmbeddingModel::new(self.clone(), model, ndims)
97 }
98 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
99 EmbeddingsBuilder::new(self.embedding_model(model))
100 }
101 pub fn completion_model(&self, model: &str) -> CompletionModel {
102 CompletionModel::new(self.clone(), model)
103 }
104 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
105 AgentBuilder::new(self.completion_model(model))
106 }
107 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
108 &self,
109 model: &str,
110 ) -> ExtractorBuilder<T, CompletionModel> {
111 ExtractorBuilder::new(self.completion_model(model))
112 }
113}
114
115#[derive(Debug, Deserialize)]
118struct ApiErrorResponse {
119 message: String,
120}
121
122#[derive(Debug, Deserialize)]
123#[serde(untagged)]
124enum ApiResponse<T> {
125 Ok(T),
126 Err(ApiErrorResponse),
127}
128
129pub const ALL_MINILM: &str = "all-minilm";
132pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
133
134#[derive(Debug, Serialize, Deserialize)]
135pub struct EmbeddingResponse {
136 pub model: String,
137 pub embeddings: Vec<Vec<f64>>,
138 #[serde(default)]
139 pub total_duration: Option<u64>,
140 #[serde(default)]
141 pub load_duration: Option<u64>,
142 #[serde(default)]
143 pub prompt_eval_count: Option<u64>,
144}
145
146impl From<ApiErrorResponse> for EmbeddingError {
147 fn from(err: ApiErrorResponse) -> Self {
148 EmbeddingError::ProviderError(err.message)
149 }
150}
151
152impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
153 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
154 match value {
155 ApiResponse::Ok(response) => Ok(response),
156 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
157 }
158 }
159}
160
161#[derive(Clone)]
164pub struct EmbeddingModel {
165 client: Client,
166 pub model: String,
167 ndims: usize,
168}
169
170impl EmbeddingModel {
171 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
172 Self {
173 client,
174 model: model.to_owned(),
175 ndims,
176 }
177 }
178}
179
180impl embeddings::EmbeddingModel for EmbeddingModel {
181 const MAX_DOCUMENTS: usize = 1024;
182 fn ndims(&self) -> usize {
183 self.ndims
184 }
185 #[cfg_attr(feature = "worker", worker::send)]
186 async fn embed_texts(
187 &self,
188 documents: impl IntoIterator<Item = String>,
189 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
190 let docs: Vec<String> = documents.into_iter().collect();
191 let payload = json!({
192 "model": self.model,
193 "input": docs,
194 });
195 let response = self
196 .client
197 .post("api/embed")
198 .json(&payload)
199 .send()
200 .await
201 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
202 if response.status().is_success() {
203 let api_resp: EmbeddingResponse = response
204 .json()
205 .await
206 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
207 if api_resp.embeddings.len() != docs.len() {
208 return Err(EmbeddingError::ResponseError(
209 "Number of returned embeddings does not match input".into(),
210 ));
211 }
212 Ok(api_resp
213 .embeddings
214 .into_iter()
215 .zip(docs.into_iter())
216 .map(|(vec, document)| embeddings::Embedding { document, vec })
217 .collect())
218 } else {
219 Err(EmbeddingError::ProviderError(response.text().await?))
220 }
221 }
222}
223
224pub const LLAMA3_2: &str = "llama3.2";
227pub const LLAVA: &str = "llava";
228pub const MISTRAL: &str = "mistral";
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionResponse {
232 pub model: String,
233 pub created_at: String,
234 pub message: Message,
235 pub done: bool,
236 #[serde(default)]
237 pub done_reason: Option<String>,
238 #[serde(default)]
239 pub total_duration: Option<u64>,
240 #[serde(default)]
241 pub load_duration: Option<u64>,
242 #[serde(default)]
243 pub prompt_eval_count: Option<u64>,
244 #[serde(default)]
245 pub prompt_eval_duration: Option<u64>,
246 #[serde(default)]
247 pub eval_count: Option<u64>,
248 #[serde(default)]
249 pub eval_duration: Option<u64>,
250}
251impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
252 type Error = CompletionError;
253 fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
254 match resp.message {
255 Message::Assistant {
257 content,
258 tool_calls,
259 ..
260 } => {
261 let mut assistant_contents = Vec::new();
262 if !content.is_empty() {
264 assistant_contents.push(completion::AssistantContent::text(&content));
265 }
266 for tc in tool_calls.iter() {
269 assistant_contents.push(completion::AssistantContent::tool_call(
270 tc.function.name.clone(),
271 tc.function.name.clone(),
272 tc.function.arguments.clone(),
273 ));
274 }
275 let choice = OneOrMany::many(assistant_contents).map_err(|_| {
276 CompletionError::ResponseError("No content provided".to_owned())
277 })?;
278 let raw_response = CompletionResponse {
279 model: resp.model,
280 created_at: resp.created_at,
281 done: resp.done,
282 done_reason: resp.done_reason,
283 total_duration: resp.total_duration,
284 load_duration: resp.load_duration,
285 prompt_eval_count: resp.prompt_eval_count,
286 prompt_eval_duration: resp.prompt_eval_duration,
287 eval_count: resp.eval_count,
288 eval_duration: resp.eval_duration,
289 message: Message::Assistant {
290 content,
291 images: None,
292 name: None,
293 tool_calls,
294 },
295 };
296 Ok(completion::CompletionResponse {
297 choice,
298 raw_response,
299 })
300 }
301 _ => Err(CompletionError::ResponseError(
302 "Chat response does not include an assistant message".into(),
303 )),
304 }
305 }
306}
307
308#[derive(Clone)]
311pub struct CompletionModel {
312 client: Client,
313 pub model: String,
314}
315
316impl CompletionModel {
317 pub fn new(client: Client, model: &str) -> Self {
318 Self {
319 client,
320 model: model.to_owned(),
321 }
322 }
323
324 fn create_completion_request(
325 &self,
326 completion_request: CompletionRequest,
327 ) -> Result<Value, CompletionError> {
328 let prompt: Message = completion_request.prompt_with_context().try_into()?;
330 let options = if let Some(extra) = completion_request.additional_params {
331 json_utils::merge(
332 json!({ "temperature": completion_request.temperature }),
333 extra,
334 )
335 } else {
336 json!({ "temperature": completion_request.temperature })
337 };
338
339 let mut full_history = Vec::new();
341 if let Some(preamble) = completion_request.preamble {
342 full_history.push(Message::system(&preamble));
343 }
344 for msg in completion_request.chat_history.into_iter() {
345 full_history.push(Message::try_from(msg)?);
346 }
347 full_history.push(prompt);
348
349 let mut request_payload = json!({
350 "model": self.model,
351 "messages": full_history,
352 "options": options,
353 "stream": false,
354 });
355 if !completion_request.tools.is_empty() {
356 request_payload["tools"] = json!(completion_request
357 .tools
358 .into_iter()
359 .map(|tool| tool.into())
360 .collect::<Vec<ToolDefinition>>());
361 }
362
363 tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
364
365 Ok(request_payload)
366 }
367}
368
369impl completion::CompletionModel for CompletionModel {
372 type Response = CompletionResponse;
373
374 #[cfg_attr(feature = "worker", worker::send)]
375 async fn completion(
376 &self,
377 completion_request: CompletionRequest,
378 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
379 let request_payload = self.create_completion_request(completion_request)?;
380
381 let response = self
382 .client
383 .post("api/chat")
384 .json(&request_payload)
385 .send()
386 .await
387 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
388 if response.status().is_success() {
389 let text = response
390 .text()
391 .await
392 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
393 tracing::debug!(target: "rig", "Ollama chat response: {}", text);
394 let chat_resp: CompletionResponse = serde_json::from_str(&text)
395 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
396 let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
397 Ok(conv)
398 } else {
399 let err_text = response
400 .text()
401 .await
402 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
403 Err(CompletionError::ProviderError(err_text))
404 }
405 }
406}
407
408impl StreamingCompletionModel for CompletionModel {
409 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
410 let mut request_payload = self.create_completion_request(request)?;
411 merge_inplace(&mut request_payload, json!({"stream": true}));
412
413 let response = self
414 .client
415 .post("api/chat")
416 .json(&request_payload)
417 .send()
418 .await
419 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
420
421 if !response.status().is_success() {
422 let err_text = response
423 .text()
424 .await
425 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
426 return Err(CompletionError::ProviderError(err_text));
427 }
428
429 Ok(Box::pin(stream! {
430 let mut stream = response.bytes_stream();
431 while let Some(chunk_result) = stream.next().await {
432 let chunk = match chunk_result {
433 Ok(c) => c,
434 Err(e) => {
435 yield Err(CompletionError::from(e));
436 break;
437 }
438 };
439
440 let text = match String::from_utf8(chunk.to_vec()) {
441 Ok(t) => t,
442 Err(e) => {
443 yield Err(CompletionError::ResponseError(e.to_string()));
444 break;
445 }
446 };
447
448
449 for line in text.lines() {
450 let line = line.to_string();
451
452 let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
453 continue;
454 };
455
456 match response.message {
457 Message::Assistant{ content, tool_calls, .. } => {
458 if !content.is_empty() {
459 yield Ok(StreamingChoice::Message(content))
460 }
461
462 for tool_call in tool_calls.iter() {
463 let function = tool_call.function.clone();
464
465 yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
466 }
467 }
468 _ => {
469 continue;
470 }
471 }
472 }
473 }
474 }))
475 }
476}
477
478#[derive(Clone, Debug, Deserialize, Serialize)]
482pub struct ToolDefinition {
483 #[serde(rename = "type")]
484 pub type_field: String, pub function: completion::ToolDefinition,
486}
487
488impl From<crate::completion::ToolDefinition> for ToolDefinition {
490 fn from(tool: crate::completion::ToolDefinition) -> Self {
491 ToolDefinition {
492 type_field: "function".to_owned(),
493 function: completion::ToolDefinition {
494 name: tool.name,
495 description: tool.description,
496 parameters: tool.parameters,
497 },
498 }
499 }
500}
501
502#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
503pub struct ToolCall {
504 #[serde(default, rename = "type")]
506 pub r#type: ToolType,
507 pub function: Function,
508}
509#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
510#[serde(rename_all = "lowercase")]
511pub enum ToolType {
512 #[default]
513 Function,
514}
515#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
516pub struct Function {
517 pub name: String,
518 pub arguments: Value,
519}
520
521#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
524#[serde(tag = "role", rename_all = "lowercase")]
525pub enum Message {
526 User {
527 content: String,
528 #[serde(skip_serializing_if = "Option::is_none")]
529 images: Option<Vec<String>>,
530 #[serde(skip_serializing_if = "Option::is_none")]
531 name: Option<String>,
532 },
533 Assistant {
534 #[serde(default)]
535 content: String,
536 #[serde(skip_serializing_if = "Option::is_none")]
537 images: Option<Vec<String>>,
538 #[serde(skip_serializing_if = "Option::is_none")]
539 name: Option<String>,
540 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
541 tool_calls: Vec<ToolCall>,
542 },
543 System {
544 content: String,
545 #[serde(skip_serializing_if = "Option::is_none")]
546 images: Option<Vec<String>>,
547 #[serde(skip_serializing_if = "Option::is_none")]
548 name: Option<String>,
549 },
550 #[serde(rename = "Tool")]
551 ToolResult {
552 tool_call_id: String,
553 content: OneOrMany<ToolResultContent>,
554 },
555}
556
557impl TryFrom<crate::message::Message> for Message {
563 type Error = crate::message::MessageError;
564 fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
565 use crate::message::Message as InternalMessage;
566 match internal_msg {
567 InternalMessage::User { content, .. } => {
568 let mut texts = Vec::new();
569 let mut images = Vec::new();
570 for uc in content.into_iter() {
571 match uc {
572 crate::message::UserContent::Text(t) => texts.push(t.text),
573 crate::message::UserContent::Image(img) => images.push(img.data),
574 _ => {} }
576 }
577 let content_str = texts.join(" ");
578 let images_opt = if images.is_empty() {
579 None
580 } else {
581 Some(images)
582 };
583 Ok(Message::User {
584 content: content_str,
585 images: images_opt,
586 name: None,
587 })
588 }
589 InternalMessage::Assistant { content, .. } => {
590 let mut texts = Vec::new();
591 let mut tool_calls = Vec::new();
592 for ac in content.into_iter() {
593 match ac {
594 crate::message::AssistantContent::Text(t) => texts.push(t.text),
595 crate::message::AssistantContent::ToolCall(tc) => {
596 tool_calls.push(ToolCall {
597 r#type: ToolType::Function, function: Function {
599 name: tc.function.name,
600 arguments: tc.function.arguments,
601 },
602 });
603 }
604 }
605 }
606 let content_str = texts.join(" ");
607 Ok(Message::Assistant {
608 content: content_str,
609 images: None,
610 name: None,
611 tool_calls,
612 })
613 }
614 }
615 }
616}
617
618impl From<Message> for crate::completion::Message {
621 fn from(msg: Message) -> Self {
622 match msg {
623 Message::User { content, .. } => crate::completion::Message::User {
624 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
625 text: content,
626 })),
627 },
628 Message::Assistant {
629 content,
630 tool_calls,
631 ..
632 } => {
633 let mut assistant_contents =
634 vec![crate::completion::message::AssistantContent::Text(Text {
635 text: content,
636 })];
637 for tc in tool_calls {
638 assistant_contents.push(
639 crate::completion::message::AssistantContent::tool_call(
640 tc.function.name.clone(),
641 tc.function.name,
642 tc.function.arguments,
643 ),
644 );
645 }
646 crate::completion::Message::Assistant {
647 content: OneOrMany::many(assistant_contents).unwrap(),
648 }
649 }
650 Message::System { content, .. } => crate::completion::Message::User {
652 content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
653 text: content,
654 })),
655 },
656 Message::ToolResult {
657 tool_call_id,
658 content,
659 } => crate::completion::Message::User {
660 content: OneOrMany::one(message::UserContent::tool_result(
661 tool_call_id,
662 content.map(|content| message::ToolResultContent::text(content.text)),
663 )),
664 },
665 }
666 }
667}
668
669impl Message {
670 pub fn system(content: &str) -> Self {
672 Message::System {
673 content: content.to_owned(),
674 images: None,
675 name: None,
676 }
677 }
678}
679
680#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
683pub struct ToolResultContent {
684 text: String,
685}
686
687impl FromStr for ToolResultContent {
688 type Err = Infallible;
689
690 fn from_str(s: &str) -> Result<Self, Self::Err> {
691 Ok(s.to_owned().into())
692 }
693}
694
695impl From<String> for ToolResultContent {
696 fn from(s: String) -> Self {
697 ToolResultContent { text: s }
698 }
699}
700
701#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
702pub struct SystemContent {
703 #[serde(default)]
704 r#type: SystemContentType,
705 text: String,
706}
707
708#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
709#[serde(rename_all = "lowercase")]
710pub enum SystemContentType {
711 #[default]
712 Text,
713}
714
715impl From<String> for SystemContent {
716 fn from(s: String) -> Self {
717 SystemContent {
718 r#type: SystemContentType::default(),
719 text: s,
720 }
721 }
722}
723
724impl FromStr for SystemContent {
725 type Err = std::convert::Infallible;
726 fn from_str(s: &str) -> Result<Self, Self::Err> {
727 Ok(SystemContent {
728 r#type: SystemContentType::default(),
729 text: s.to_string(),
730 })
731 }
732}
733
734#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
735pub struct AssistantContent {
736 pub text: String,
737}
738
739impl FromStr for AssistantContent {
740 type Err = std::convert::Infallible;
741 fn from_str(s: &str) -> Result<Self, Self::Err> {
742 Ok(AssistantContent { text: s.to_owned() })
743 }
744}
745
746#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
747#[serde(tag = "type", rename_all = "lowercase")]
748pub enum UserContent {
749 Text { text: String },
750 Image { image_url: ImageUrl },
751 }
753
754impl FromStr for UserContent {
755 type Err = std::convert::Infallible;
756 fn from_str(s: &str) -> Result<Self, Self::Err> {
757 Ok(UserContent::Text { text: s.to_owned() })
758 }
759}
760
761#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
762pub struct ImageUrl {
763 pub url: String,
764 #[serde(default)]
765 pub detail: ImageDetail,
766}
767
768#[cfg(test)]
773mod tests {
774 use super::*;
775 use serde_json::json;
776
777 #[tokio::test]
779 async fn test_chat_completion() {
780 let sample_chat_response = json!({
782 "model": "llama3.2",
783 "created_at": "2023-08-04T19:22:45.499127Z",
784 "message": {
785 "role": "assistant",
786 "content": "The sky is blue because of Rayleigh scattering.",
787 "images": null,
788 "tool_calls": [
789 {
790 "type": "function",
791 "function": {
792 "name": "get_current_weather",
793 "arguments": {
794 "location": "San Francisco, CA",
795 "format": "celsius"
796 }
797 }
798 }
799 ]
800 },
801 "done": true,
802 "total_duration": 8000000000u64,
803 "load_duration": 6000000u64,
804 "prompt_eval_count": 61u64,
805 "prompt_eval_duration": 400000000u64,
806 "eval_count": 468u64,
807 "eval_duration": 7700000000u64
808 });
809 let sample_text = sample_chat_response.to_string();
810
811 let chat_resp: CompletionResponse =
812 serde_json::from_str(&sample_text).expect("Invalid JSON structure");
813 let conv: completion::CompletionResponse<CompletionResponse> =
814 chat_resp.try_into().unwrap();
815 assert!(
816 !conv.choice.is_empty(),
817 "Expected non-empty choice in chat response"
818 );
819 }
820
821 #[test]
823 fn test_message_conversion() {
824 let provider_msg = Message::User {
826 content: "Test message".to_owned(),
827 images: None,
828 name: None,
829 };
830 let comp_msg: crate::completion::Message = provider_msg.into();
832 match comp_msg {
833 crate::completion::Message::User { content } => {
834 let first_content = content.first();
836 match first_content {
838 crate::completion::message::UserContent::Text(text_struct) => {
839 assert_eq!(text_struct.text, "Test message");
840 }
841 _ => panic!("Expected text content in conversion"),
842 }
843 }
844 _ => panic!("Conversion from provider Message to completion Message failed"),
845 }
846 }
847
848 #[test]
850 fn test_tool_definition_conversion() {
851 let internal_tool = crate::completion::ToolDefinition {
853 name: "get_current_weather".to_owned(),
854 description: "Get the current weather for a location".to_owned(),
855 parameters: json!({
856 "type": "object",
857 "properties": {
858 "location": {
859 "type": "string",
860 "description": "The location to get the weather for, e.g. San Francisco, CA"
861 },
862 "format": {
863 "type": "string",
864 "description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
865 "enum": ["celsius", "fahrenheit"]
866 }
867 },
868 "required": ["location", "format"]
869 }),
870 };
871 let ollama_tool: ToolDefinition = internal_tool.into();
873 assert_eq!(ollama_tool.type_field, "function");
874 assert_eq!(ollama_tool.function.name, "get_current_weather");
875 assert_eq!(
876 ollama_tool.function.description,
877 "Get the current weather for a location"
878 );
879 let params = &ollama_tool.function.parameters;
881 assert_eq!(params["properties"]["location"]["type"], "string");
882 }
883}