rig/completion/request.rs
1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//! responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//! .preamble("\
39//! You are Marvin, an extremely smart but depressed robot who is \
40//! nonetheless helpful towards humanity.\
41//! ")
42//! .temperature(0.5)
43//! .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//! .await
48//! .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//! ModelChoice::Message(message) => {
53//! // Handle the completion response as a message
54//! println!("Received message: {}", message);
55//! }
56//! ModelChoice::ToolCall(tool_name, tool_params) => {
57//! // Handle the completion response as a tool call
58//! println!("Received tool call: {} {:?}", tool_name, tool_params);
59//! }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
67use crate::client::completion::CompletionModelHandle;
68use crate::streaming::StreamingCompletionResponse;
69use crate::{OneOrMany, streaming};
70use crate::{
71 json_utils,
72 message::{Message, UserContent},
73 tool::ToolSetError,
74};
75use futures::future::BoxFuture;
76use serde::de::DeserializeOwned;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use std::ops::{Add, AddAssign};
80use std::sync::Arc;
81use thiserror::Error;
82
83// Errors
84#[derive(Debug, Error)]
85pub enum CompletionError {
86 /// Http error (e.g.: connection error, timeout, etc.)
87 #[error("HttpError: {0}")]
88 HttpError(#[from] reqwest::Error),
89
90 /// Json error (e.g.: serialization, deserialization)
91 #[error("JsonError: {0}")]
92 JsonError(#[from] serde_json::Error),
93
94 /// Error building the completion request
95 #[error("RequestError: {0}")]
96 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
97
98 /// Error parsing the completion response
99 #[error("ResponseError: {0}")]
100 ResponseError(String),
101
102 /// Error returned by the completion model provider
103 #[error("ProviderError: {0}")]
104 ProviderError(String),
105}
106
107/// Prompt errors
108#[derive(Debug, Error)]
109pub enum PromptError {
110 /// Something went wrong with the completion
111 #[error("CompletionError: {0}")]
112 CompletionError(#[from] CompletionError),
113
114 /// There was an error while using a tool
115 #[error("ToolCallError: {0}")]
116 ToolError(#[from] ToolSetError),
117
118 /// The LLM tried to call too many tools during a multi-turn conversation.
119 /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
120 /// or increase the amount of turns given in `.multi_turn()`.
121 #[error("MaxDepthError: (reached limit: {max_depth})")]
122 MaxDepthError {
123 max_depth: usize,
124 chat_history: Vec<Message>,
125 prompt: Message,
126 },
127}
128
129#[derive(Clone, Debug, Deserialize, Serialize)]
130pub struct Document {
131 pub id: String,
132 pub text: String,
133 #[serde(flatten)]
134 pub additional_props: HashMap<String, String>,
135}
136
137impl std::fmt::Display for Document {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 write!(
140 f,
141 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
142 self.id,
143 if self.additional_props.is_empty() {
144 self.text.clone()
145 } else {
146 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
147 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
148 let metadata = sorted_props
149 .iter()
150 .map(|(k, v)| format!("{k}: {v:?}"))
151 .collect::<Vec<_>>()
152 .join(" ");
153 format!("<metadata {} />\n{}", metadata, self.text)
154 }
155 )
156 }
157}
158
159#[derive(Clone, Debug, Deserialize, Serialize)]
160pub struct ToolDefinition {
161 pub name: String,
162 pub description: String,
163 pub parameters: serde_json::Value,
164}
165
166// ================================================================
167// Implementations
168// ================================================================
169/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
170pub trait Prompt: Send + Sync {
171 /// Send a simple prompt to the underlying completion model.
172 ///
173 /// If the completion model's response is a message, then it is returned as a string.
174 ///
175 /// If the completion model's response is a tool call, then the tool is called and
176 /// the result is returned as a string.
177 ///
178 /// If the tool does not exist, or the tool call fails, then an error is returned.
179 fn prompt(
180 &self,
181 prompt: impl Into<Message> + Send,
182 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
183}
184
185/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
186pub trait Chat: Send + Sync {
187 /// Send a prompt with optional chat history to the underlying completion model.
188 ///
189 /// If the completion model's response is a message, then it is returned as a string.
190 ///
191 /// If the completion model's response is a tool call, then the tool is called and the result
192 /// is returned as a string.
193 ///
194 /// If the tool does not exist, or the tool call fails, then an error is returned.
195 fn chat(
196 &self,
197 prompt: impl Into<Message> + Send,
198 chat_history: Vec<Message>,
199 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
200}
201
202/// Trait defining a low-level LLM completion interface
203pub trait Completion<M: CompletionModel> {
204 /// Generates a completion request builder for the given `prompt` and `chat_history`.
205 /// This function is meant to be called by the user to further customize the
206 /// request at prompt time before sending it.
207 ///
208 /// ❗IMPORTANT: The type that implements this trait might have already
209 /// populated fields in the builder (the exact fields depend on the type).
210 /// For fields that have already been set by the model, calling the corresponding
211 /// method on the builder will overwrite the value set by the model.
212 ///
213 /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
214 /// contain the `preamble` provided when creating the agent.
215 fn completion(
216 &self,
217 prompt: impl Into<Message> + Send,
218 chat_history: Vec<Message>,
219 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
220}
221
222/// General completion response struct that contains the high-level completion choice
223/// and the raw response. The completion choice contains one or more assistant content.
224#[derive(Debug)]
225pub struct CompletionResponse<T> {
226 /// The completion choice (represented by one or more assistant message content)
227 /// returned by the completion model provider
228 pub choice: OneOrMany<AssistantContent>,
229 /// Tokens used during prompting and responding
230 pub usage: Usage,
231 /// The raw response returned by the completion model provider
232 pub raw_response: T,
233}
234
235/// Struct representing the token usage for a completion request.
236/// If tokens used are `0`, then the provider failed to supply token usage metrics.
237#[derive(Debug, PartialEq, Eq, Clone, Copy)]
238pub struct Usage {
239 pub input_tokens: u64,
240 pub output_tokens: u64,
241 // We store this separately as some providers may only report one number
242 pub total_tokens: u64,
243}
244
245impl Usage {
246 pub fn new() -> Self {
247 Self {
248 input_tokens: 0,
249 output_tokens: 0,
250 total_tokens: 0,
251 }
252 }
253}
254
255impl Default for Usage {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261impl Add for Usage {
262 type Output = Self;
263
264 fn add(self, other: Self) -> Self::Output {
265 Self {
266 input_tokens: self.input_tokens + other.input_tokens,
267 output_tokens: self.output_tokens + other.output_tokens,
268 total_tokens: self.total_tokens + other.total_tokens,
269 }
270 }
271}
272
273impl AddAssign for Usage {
274 fn add_assign(&mut self, other: Self) {
275 self.input_tokens += other.input_tokens;
276 self.output_tokens += other.output_tokens;
277 self.total_tokens += other.total_tokens;
278 }
279}
280
281/// Trait defining a completion model that can be used to generate completion responses.
282/// This trait is meant to be implemented by the user to define a custom completion model,
283/// either from a third party provider (e.g.: OpenAI) or a local model.
284pub trait CompletionModel: Clone + Send + Sync {
285 /// The raw response type returned by the underlying completion model.
286 type Response: Send + Sync + Serialize + DeserializeOwned;
287 /// The raw response type returned by the underlying completion model when streaming.
288 type StreamingResponse: Clone + Unpin + Send + Sync + Serialize + DeserializeOwned;
289
290 /// Generates a completion response for the given completion request.
291 fn completion(
292 &self,
293 request: CompletionRequest,
294 ) -> impl std::future::Future<
295 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
296 > + Send;
297
298 fn stream(
299 &self,
300 request: CompletionRequest,
301 ) -> impl std::future::Future<
302 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
303 > + Send;
304
305 /// Generates a completion request builder for the given `prompt`.
306 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
307 CompletionRequestBuilder::new(self.clone(), prompt)
308 }
309}
310pub trait CompletionModelDyn: Send + Sync {
311 fn completion(
312 &self,
313 request: CompletionRequest,
314 ) -> BoxFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
315
316 fn stream(
317 &self,
318 request: CompletionRequest,
319 ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>>;
320
321 fn completion_request(
322 &self,
323 prompt: Message,
324 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
325}
326
327impl<T, R> CompletionModelDyn for T
328where
329 T: CompletionModel<StreamingResponse = R>,
330 R: Clone + Unpin + 'static,
331{
332 fn completion(
333 &self,
334 request: CompletionRequest,
335 ) -> BoxFuture<Result<CompletionResponse<()>, CompletionError>> {
336 Box::pin(async move {
337 self.completion(request)
338 .await
339 .map(|resp| CompletionResponse {
340 choice: resp.choice,
341 usage: resp.usage,
342 raw_response: (),
343 })
344 })
345 }
346
347 fn stream(
348 &self,
349 request: CompletionRequest,
350 ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>> {
351 Box::pin(async move {
352 let resp = self.stream(request).await?;
353 let inner = resp.inner;
354
355 let stream = Box::pin(streaming::StreamingResultDyn {
356 inner: Box::pin(inner),
357 });
358
359 Ok(StreamingCompletionResponse::stream(stream))
360 })
361 }
362
363 /// Generates a completion request builder for the given `prompt`.
364 fn completion_request(
365 &self,
366 prompt: Message,
367 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
368 CompletionRequestBuilder::new(
369 CompletionModelHandle {
370 inner: Arc::new(self.clone()),
371 },
372 prompt,
373 )
374 }
375}
376
377/// Struct representing a general completion request that can be sent to a completion model provider.
378#[derive(Debug, Clone)]
379pub struct CompletionRequest {
380 /// The preamble to be sent to the completion model provider
381 pub preamble: Option<String>,
382 /// The chat history to be sent to the completion model provider.
383 /// The very last message will always be the prompt (hence why there is *always* one)
384 pub chat_history: OneOrMany<Message>,
385 /// The documents to be sent to the completion model provider
386 pub documents: Vec<Document>,
387 /// The tools to be sent to the completion model provider
388 pub tools: Vec<ToolDefinition>,
389 /// The temperature to be sent to the completion model provider
390 pub temperature: Option<f64>,
391 /// The max tokens to be sent to the completion model provider
392 pub max_tokens: Option<u64>,
393 /// Additional provider-specific parameters to be sent to the completion model provider
394 pub additional_params: Option<serde_json::Value>,
395}
396
397impl CompletionRequest {
398 /// Returns documents normalized into a message (if any).
399 /// Most providers do not accept documents directly as input, so it needs to convert into a
400 /// `Message` so that it can be incorporated into `chat_history` as a
401 pub fn normalized_documents(&self) -> Option<Message> {
402 if self.documents.is_empty() {
403 return None;
404 }
405
406 // Most providers will convert documents into a text unless it can handle document messages.
407 // We use `UserContent::document` for those who handle it directly!
408 let messages = self
409 .documents
410 .iter()
411 .map(|doc| {
412 UserContent::document(
413 doc.to_string(),
414 // In the future, we can customize `Document` to pass these extra types through.
415 // Most providers ditch these but they might want to use them.
416 Some(ContentFormat::String),
417 Some(DocumentMediaType::TXT),
418 )
419 })
420 .collect::<Vec<_>>();
421
422 Some(Message::User {
423 content: OneOrMany::many(messages).expect("There will be atleast one document"),
424 })
425 }
426}
427
428/// Builder struct for constructing a completion request.
429///
430/// Example usage:
431/// ```rust
432/// use rig::{
433/// providers::openai::{Client, self},
434/// completion::CompletionRequestBuilder,
435/// };
436///
437/// let openai = Client::new("your-openai-api-key");
438/// let model = openai.completion_model(openai::GPT_4O).build();
439///
440/// // Create the completion request and execute it separately
441/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
442/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
443/// .temperature(0.5)
444/// .build();
445///
446/// let response = model.completion(request)
447/// .await
448/// .expect("Failed to get completion response");
449/// ```
450///
451/// Alternatively, you can execute the completion request directly from the builder:
452/// ```rust
453/// use rig::{
454/// providers::openai::{Client, self},
455/// completion::CompletionRequestBuilder,
456/// };
457///
458/// let openai = Client::new("your-openai-api-key");
459/// let model = openai.completion_model(openai::GPT_4O).build();
460///
461/// // Create the completion request and execute it directly
462/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
463/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
464/// .temperature(0.5)
465/// .send()
466/// .await
467/// .expect("Failed to get completion response");
468/// ```
469///
470/// Note: It is usually unnecessary to create a completion request builder directly.
471/// Instead, use the [CompletionModel::completion_request] method.
472pub struct CompletionRequestBuilder<M: CompletionModel> {
473 model: M,
474 prompt: Message,
475 preamble: Option<String>,
476 chat_history: Vec<Message>,
477 documents: Vec<Document>,
478 tools: Vec<ToolDefinition>,
479 temperature: Option<f64>,
480 max_tokens: Option<u64>,
481 additional_params: Option<serde_json::Value>,
482}
483
484impl<M: CompletionModel> CompletionRequestBuilder<M> {
485 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
486 Self {
487 model,
488 prompt: prompt.into(),
489 preamble: None,
490 chat_history: Vec::new(),
491 documents: Vec::new(),
492 tools: Vec::new(),
493 temperature: None,
494 max_tokens: None,
495 additional_params: None,
496 }
497 }
498
499 /// Sets the preamble for the completion request.
500 pub fn preamble(mut self, preamble: String) -> Self {
501 self.preamble = Some(preamble);
502 self
503 }
504
505 /// Adds a message to the chat history for the completion request.
506 pub fn message(mut self, message: Message) -> Self {
507 self.chat_history.push(message);
508 self
509 }
510
511 /// Adds a list of messages to the chat history for the completion request.
512 pub fn messages(self, messages: Vec<Message>) -> Self {
513 messages
514 .into_iter()
515 .fold(self, |builder, msg| builder.message(msg))
516 }
517
518 /// Adds a document to the completion request.
519 pub fn document(mut self, document: Document) -> Self {
520 self.documents.push(document);
521 self
522 }
523
524 /// Adds a list of documents to the completion request.
525 pub fn documents(self, documents: Vec<Document>) -> Self {
526 documents
527 .into_iter()
528 .fold(self, |builder, doc| builder.document(doc))
529 }
530
531 /// Adds a tool to the completion request.
532 pub fn tool(mut self, tool: ToolDefinition) -> Self {
533 self.tools.push(tool);
534 self
535 }
536
537 /// Adds a list of tools to the completion request.
538 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
539 tools
540 .into_iter()
541 .fold(self, |builder, tool| builder.tool(tool))
542 }
543
544 /// Adds additional parameters to the completion request.
545 /// This can be used to set additional provider-specific parameters. For example,
546 /// Cohere's completion models accept a `connectors` parameter that can be used to
547 /// specify the data connectors used by Cohere when executing the completion
548 /// (see `examples/cohere_connectors.rs`).
549 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
550 match self.additional_params {
551 Some(params) => {
552 self.additional_params = Some(json_utils::merge(params, additional_params));
553 }
554 None => {
555 self.additional_params = Some(additional_params);
556 }
557 }
558 self
559 }
560
561 /// Sets the additional parameters for the completion request.
562 /// This can be used to set additional provider-specific parameters. For example,
563 /// Cohere's completion models accept a `connectors` parameter that can be used to
564 /// specify the data connectors used by Cohere when executing the completion
565 /// (see `examples/cohere_connectors.rs`).
566 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
567 self.additional_params = additional_params;
568 self
569 }
570
571 /// Sets the temperature for the completion request.
572 pub fn temperature(mut self, temperature: f64) -> Self {
573 self.temperature = Some(temperature);
574 self
575 }
576
577 /// Sets the temperature for the completion request.
578 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
579 self.temperature = temperature;
580 self
581 }
582
583 /// Sets the max tokens for the completion request.
584 /// Note: This is required if using Anthropic
585 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
586 self.max_tokens = Some(max_tokens);
587 self
588 }
589
590 /// Sets the max tokens for the completion request.
591 /// Note: This is required if using Anthropic
592 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
593 self.max_tokens = max_tokens;
594 self
595 }
596
597 /// Builds the completion request.
598 pub fn build(self) -> CompletionRequest {
599 let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
600 .expect("There will always be atleast the prompt");
601
602 CompletionRequest {
603 preamble: self.preamble,
604 chat_history,
605 documents: self.documents,
606 tools: self.tools,
607 temperature: self.temperature,
608 max_tokens: self.max_tokens,
609 additional_params: self.additional_params,
610 }
611 }
612
613 /// Sends the completion request to the completion model provider and returns the completion response.
614 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
615 let model = self.model.clone();
616 model.completion(self.build()).await
617 }
618
619 /// Stream the completion request
620 pub async fn stream<'a>(
621 self,
622 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
623 where
624 <M as CompletionModel>::StreamingResponse: 'a,
625 Self: 'a,
626 {
627 let model = self.model.clone();
628 model.stream(self.build()).await
629 }
630}
631
632#[cfg(test)]
633mod tests {
634
635 use super::*;
636
637 #[test]
638 fn test_document_display_without_metadata() {
639 let doc = Document {
640 id: "123".to_string(),
641 text: "This is a test document.".to_string(),
642 additional_props: HashMap::new(),
643 };
644
645 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
646 assert_eq!(format!("{doc}"), expected);
647 }
648
649 #[test]
650 fn test_document_display_with_metadata() {
651 let mut additional_props = HashMap::new();
652 additional_props.insert("author".to_string(), "John Doe".to_string());
653 additional_props.insert("length".to_string(), "42".to_string());
654
655 let doc = Document {
656 id: "123".to_string(),
657 text: "This is a test document.".to_string(),
658 additional_props,
659 };
660
661 let expected = concat!(
662 "<file id: 123>\n",
663 "<metadata author: \"John Doe\" length: \"42\" />\n",
664 "This is a test document.\n",
665 "</file>\n"
666 );
667 assert_eq!(format!("{doc}"), expected);
668 }
669
670 #[test]
671 fn test_normalize_documents_with_documents() {
672 let doc1 = Document {
673 id: "doc1".to_string(),
674 text: "Document 1 text.".to_string(),
675 additional_props: HashMap::new(),
676 };
677
678 let doc2 = Document {
679 id: "doc2".to_string(),
680 text: "Document 2 text.".to_string(),
681 additional_props: HashMap::new(),
682 };
683
684 let request = CompletionRequest {
685 preamble: None,
686 chat_history: OneOrMany::one("What is the capital of France?".into()),
687 documents: vec![doc1, doc2],
688 tools: Vec::new(),
689 temperature: None,
690 max_tokens: None,
691 additional_params: None,
692 };
693
694 let expected = Message::User {
695 content: OneOrMany::many(vec![
696 UserContent::document(
697 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
698 Some(ContentFormat::String),
699 Some(DocumentMediaType::TXT),
700 ),
701 UserContent::document(
702 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
703 Some(ContentFormat::String),
704 Some(DocumentMediaType::TXT),
705 ),
706 ])
707 .expect("There will be at least one document"),
708 };
709
710 assert_eq!(request.normalized_documents(), Some(expected));
711 }
712
713 #[test]
714 fn test_normalize_documents_without_documents() {
715 let request = CompletionRequest {
716 preamble: None,
717 chat_history: OneOrMany::one("What is the capital of France?".into()),
718 documents: Vec::new(),
719 tools: Vec::new(),
720 temperature: None,
721 max_tokens: None,
722 additional_params: None,
723 };
724
725 assert_eq!(request.normalized_documents(), None);
726 }
727}