rig/completion/request.rs
1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//! responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//! .preamble("\
39//! You are Marvin, an extremely smart but depressed robot who is \
40//! nonetheless helpful towards humanity.\
41//! ")
42//! .temperature(0.5)
43//! .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//! .await
48//! .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//! ModelChoice::Message(message) => {
53//! // Handle the completion response as a message
54//! println!("Received message: {}", message);
55//! }
56//! ModelChoice::ToolCall(tool_name, tool_params) => {
57//! // Handle the completion response as a tool call
58//! println!("Received tool call: {} {:?}", tool_name, tool_params);
59//! }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65use std::collections::HashMap;
66
67use serde::{Deserialize, Serialize};
68use thiserror::Error;
69
70use crate::streaming::{StreamingCompletionModel, StreamingResult};
71use crate::OneOrMany;
72use crate::{
73 json_utils,
74 message::{Message, UserContent},
75 tool::ToolSetError,
76};
77
78use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
79
80// Errors
81#[derive(Debug, Error)]
82pub enum CompletionError {
83 /// Http error (e.g.: connection error, timeout, etc.)
84 #[error("HttpError: {0}")]
85 HttpError(#[from] reqwest::Error),
86
87 /// Json error (e.g.: serialization, deserialization)
88 #[error("JsonError: {0}")]
89 JsonError(#[from] serde_json::Error),
90
91 /// Error building the completion request
92 #[error("RequestError: {0}")]
93 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
94
95 /// Error parsing the completion response
96 #[error("ResponseError: {0}")]
97 ResponseError(String),
98
99 /// Error returned by the completion model provider
100 #[error("ProviderError: {0}")]
101 ProviderError(String),
102}
103
104#[derive(Debug, Error)]
105pub enum PromptError {
106 #[error("CompletionError: {0}")]
107 CompletionError(#[from] CompletionError),
108
109 #[error("ToolCallError: {0}")]
110 ToolError(#[from] ToolSetError),
111
112 #[error("MaxDepthError: (reached limit: {max_depth})")]
113 MaxDepthError {
114 max_depth: usize,
115 chat_history: Vec<Message>,
116 prompt: Message,
117 },
118}
119
120#[derive(Clone, Debug, Deserialize, Serialize)]
121pub struct Document {
122 pub id: String,
123 pub text: String,
124 #[serde(flatten)]
125 pub additional_props: HashMap<String, String>,
126}
127
128impl std::fmt::Display for Document {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(
131 f,
132 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
133 self.id,
134 if self.additional_props.is_empty() {
135 self.text.clone()
136 } else {
137 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
138 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
139 let metadata = sorted_props
140 .iter()
141 .map(|(k, v)| format!("{}: {:?}", k, v))
142 .collect::<Vec<_>>()
143 .join(" ");
144 format!("<metadata {} />\n{}", metadata, self.text)
145 }
146 )
147 }
148}
149
150#[derive(Clone, Debug, Deserialize, Serialize)]
151pub struct ToolDefinition {
152 pub name: String,
153 pub description: String,
154 pub parameters: serde_json::Value,
155}
156
157// ================================================================
158// Implementations
159// ================================================================
160/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
161pub trait Prompt: Send + Sync {
162 /// Send a simple prompt to the underlying completion model.
163 ///
164 /// If the completion model's response is a message, then it is returned as a string.
165 ///
166 /// If the completion model's response is a tool call, then the tool is called and
167 /// the result is returned as a string.
168 ///
169 /// If the tool does not exist, or the tool call fails, then an error is returned.
170 fn prompt(
171 &self,
172 prompt: impl Into<Message> + Send,
173 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
174}
175
176/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
177pub trait Chat: Send + Sync {
178 /// Send a prompt with optional chat history to the underlying completion model.
179 ///
180 /// If the completion model's response is a message, then it is returned as a string.
181 ///
182 /// If the completion model's response is a tool call, then the tool is called and the result
183 /// is returned as a string.
184 ///
185 /// If the tool does not exist, or the tool call fails, then an error is returned.
186 fn chat(
187 &self,
188 prompt: impl Into<Message> + Send,
189 chat_history: Vec<Message>,
190 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
191}
192
193/// Trait defining a low-level LLM completion interface
194pub trait Completion<M: CompletionModel> {
195 /// Generates a completion request builder for the given `prompt` and `chat_history`.
196 /// This function is meant to be called by the user to further customize the
197 /// request at prompt time before sending it.
198 ///
199 /// ❗IMPORTANT: The type that implements this trait might have already
200 /// populated fields in the builder (the exact fields depend on the type).
201 /// For fields that have already been set by the model, calling the corresponding
202 /// method on the builder will overwrite the value set by the model.
203 ///
204 /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
205 /// contain the `preamble` provided when creating the agent.
206 fn completion(
207 &self,
208 prompt: impl Into<Message> + Send,
209 chat_history: Vec<Message>,
210 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
211}
212
213/// General completion response struct that contains the high-level completion choice
214/// and the raw response. The completion choice contains one or more assistant content.
215#[derive(Debug)]
216pub struct CompletionResponse<T> {
217 /// The completion choice (represented by one or more assistant message content)
218 /// returned by the completion model provider
219 pub choice: OneOrMany<AssistantContent>,
220 /// The raw response returned by the completion model provider
221 pub raw_response: T,
222}
223
224/// Trait defining a completion model that can be used to generate completion responses.
225/// This trait is meant to be implemented by the user to define a custom completion model,
226/// either from a third party provider (e.g.: OpenAI) or a local model.
227pub trait CompletionModel: Clone + Send + Sync {
228 /// The raw response type returned by the underlying completion model.
229 type Response: Send + Sync;
230
231 /// Generates a completion response for the given completion request.
232 fn completion(
233 &self,
234 request: CompletionRequest,
235 ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
236 + Send;
237
238 /// Generates a completion request builder for the given `prompt`.
239 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
240 CompletionRequestBuilder::new(self.clone(), prompt)
241 }
242}
243
244/// Struct representing a general completion request that can be sent to a completion model provider.
245pub struct CompletionRequest {
246 /// The preamble to be sent to the completion model provider
247 pub preamble: Option<String>,
248 /// The chat history to be sent to the completion model provider
249 /// The very last message will always be the prompt (hense why there is *always* one)
250 pub chat_history: OneOrMany<Message>,
251 /// The documents to be sent to the completion model provider
252 pub documents: Vec<Document>,
253 /// The tools to be sent to the completion model provider
254 pub tools: Vec<ToolDefinition>,
255 /// The temperature to be sent to the completion model provider
256 pub temperature: Option<f64>,
257 /// The max tokens to be sent to the completion model provider
258 pub max_tokens: Option<u64>,
259 /// Additional provider-specific parameters to be sent to the completion model provider
260 pub additional_params: Option<serde_json::Value>,
261}
262
263impl CompletionRequest {
264 /// Returns documents normalized into a message (if any).
265 /// Most providers do not accept documents directly as input, so it needs to convert into a
266 /// `Message` so that it can be incorperated into `chat_history` as a
267 pub fn normalized_documents(&self) -> Option<Message> {
268 if self.documents.is_empty() {
269 return None;
270 }
271
272 // Most providers will convert documents into a text unless it can handle document messages.
273 // We use `UserContent::document` for those who handle it directly!
274 let messages = self
275 .documents
276 .iter()
277 .map(|doc| {
278 UserContent::document(
279 doc.to_string(),
280 // In the future, we can customize `Document` to pass these extra types through.
281 // Most providers ditch these but they might want to use them.
282 Some(ContentFormat::String),
283 Some(DocumentMediaType::TXT),
284 )
285 })
286 .collect::<Vec<_>>();
287
288 Some(Message::User {
289 content: OneOrMany::many(messages).expect("There will be atleast one document"),
290 })
291 }
292}
293
294/// Builder struct for constructing a completion request.
295///
296/// Example usage:
297/// ```rust
298/// use rig::{
299/// providers::openai::{Client, self},
300/// completion::CompletionRequestBuilder,
301/// };
302///
303/// let openai = Client::new("your-openai-api-key");
304/// let model = openai.completion_model(openai::GPT_4O).build();
305///
306/// // Create the completion request and execute it separately
307/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
308/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
309/// .temperature(0.5)
310/// .build();
311///
312/// let response = model.completion(request)
313/// .await
314/// .expect("Failed to get completion response");
315/// ```
316///
317/// Alternatively, you can execute the completion request directly from the builder:
318/// ```rust
319/// use rig::{
320/// providers::openai::{Client, self},
321/// completion::CompletionRequestBuilder,
322/// };
323///
324/// let openai = Client::new("your-openai-api-key");
325/// let model = openai.completion_model(openai::GPT_4O).build();
326///
327/// // Create the completion request and execute it directly
328/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
329/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
330/// .temperature(0.5)
331/// .send()
332/// .await
333/// .expect("Failed to get completion response");
334/// ```
335///
336/// Note: It is usually unnecessary to create a completion request builder directly.
337/// Instead, use the [CompletionModel::completion_request] method.
338pub struct CompletionRequestBuilder<M: CompletionModel> {
339 model: M,
340 prompt: Message,
341 preamble: Option<String>,
342 chat_history: Vec<Message>,
343 documents: Vec<Document>,
344 tools: Vec<ToolDefinition>,
345 temperature: Option<f64>,
346 max_tokens: Option<u64>,
347 additional_params: Option<serde_json::Value>,
348}
349
350impl<M: CompletionModel> CompletionRequestBuilder<M> {
351 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
352 Self {
353 model,
354 prompt: prompt.into(),
355 preamble: None,
356 chat_history: Vec::new(),
357 documents: Vec::new(),
358 tools: Vec::new(),
359 temperature: None,
360 max_tokens: None,
361 additional_params: None,
362 }
363 }
364
365 /// Sets the preamble for the completion request.
366 pub fn preamble(mut self, preamble: String) -> Self {
367 self.preamble = Some(preamble);
368 self
369 }
370
371 /// Adds a message to the chat history for the completion request.
372 pub fn message(mut self, message: Message) -> Self {
373 self.chat_history.push(message);
374 self
375 }
376
377 /// Adds a list of messages to the chat history for the completion request.
378 pub fn messages(self, messages: Vec<Message>) -> Self {
379 messages
380 .into_iter()
381 .fold(self, |builder, msg| builder.message(msg))
382 }
383
384 /// Adds a document to the completion request.
385 pub fn document(mut self, document: Document) -> Self {
386 self.documents.push(document);
387 self
388 }
389
390 /// Adds a list of documents to the completion request.
391 pub fn documents(self, documents: Vec<Document>) -> Self {
392 documents
393 .into_iter()
394 .fold(self, |builder, doc| builder.document(doc))
395 }
396
397 /// Adds a tool to the completion request.
398 pub fn tool(mut self, tool: ToolDefinition) -> Self {
399 self.tools.push(tool);
400 self
401 }
402
403 /// Adds a list of tools to the completion request.
404 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
405 tools
406 .into_iter()
407 .fold(self, |builder, tool| builder.tool(tool))
408 }
409
410 /// Adds additional parameters to the completion request.
411 /// This can be used to set additional provider-specific parameters. For example,
412 /// Cohere's completion models accept a `connectors` parameter that can be used to
413 /// specify the data connectors used by Cohere when executing the completion
414 /// (see `examples/cohere_connectors.rs`).
415 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
416 match self.additional_params {
417 Some(params) => {
418 self.additional_params = Some(json_utils::merge(params, additional_params));
419 }
420 None => {
421 self.additional_params = Some(additional_params);
422 }
423 }
424 self
425 }
426
427 /// Sets the additional parameters for the completion request.
428 /// This can be used to set additional provider-specific parameters. For example,
429 /// Cohere's completion models accept a `connectors` parameter that can be used to
430 /// specify the data connectors used by Cohere when executing the completion
431 /// (see `examples/cohere_connectors.rs`).
432 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
433 self.additional_params = additional_params;
434 self
435 }
436
437 /// Sets the temperature for the completion request.
438 pub fn temperature(mut self, temperature: f64) -> Self {
439 self.temperature = Some(temperature);
440 self
441 }
442
443 /// Sets the temperature for the completion request.
444 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
445 self.temperature = temperature;
446 self
447 }
448
449 /// Sets the max tokens for the completion request.
450 /// Note: This is required if using Anthropic
451 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
452 self.max_tokens = Some(max_tokens);
453 self
454 }
455
456 /// Sets the max tokens for the completion request.
457 /// Note: This is required if using Anthropic
458 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
459 self.max_tokens = max_tokens;
460 self
461 }
462
463 /// Builds the completion request.
464 pub fn build(self) -> CompletionRequest {
465 let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
466 .expect("There will always be atleast the prompt");
467
468 CompletionRequest {
469 preamble: self.preamble,
470 chat_history,
471 documents: self.documents,
472 tools: self.tools,
473 temperature: self.temperature,
474 max_tokens: self.max_tokens,
475 additional_params: self.additional_params,
476 }
477 }
478
479 /// Sends the completion request to the completion model provider and returns the completion response.
480 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
481 let model = self.model.clone();
482 model.completion(self.build()).await
483 }
484}
485
486impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
487 /// Stream the completion request
488 pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
489 let model = self.model.clone();
490 model.stream(self.build()).await
491 }
492}
493
494#[cfg(test)]
495mod tests {
496
497 use super::*;
498
499 #[test]
500 fn test_document_display_without_metadata() {
501 let doc = Document {
502 id: "123".to_string(),
503 text: "This is a test document.".to_string(),
504 additional_props: HashMap::new(),
505 };
506
507 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
508 assert_eq!(format!("{}", doc), expected);
509 }
510
511 #[test]
512 fn test_document_display_with_metadata() {
513 let mut additional_props = HashMap::new();
514 additional_props.insert("author".to_string(), "John Doe".to_string());
515 additional_props.insert("length".to_string(), "42".to_string());
516
517 let doc = Document {
518 id: "123".to_string(),
519 text: "This is a test document.".to_string(),
520 additional_props,
521 };
522
523 let expected = concat!(
524 "<file id: 123>\n",
525 "<metadata author: \"John Doe\" length: \"42\" />\n",
526 "This is a test document.\n",
527 "</file>\n"
528 );
529 assert_eq!(format!("{}", doc), expected);
530 }
531
532 #[test]
533 fn test_normalize_documents_with_documents() {
534 let doc1 = Document {
535 id: "doc1".to_string(),
536 text: "Document 1 text.".to_string(),
537 additional_props: HashMap::new(),
538 };
539
540 let doc2 = Document {
541 id: "doc2".to_string(),
542 text: "Document 2 text.".to_string(),
543 additional_props: HashMap::new(),
544 };
545
546 let request = CompletionRequest {
547 preamble: None,
548 chat_history: OneOrMany::one("What is the capital of France?".into()),
549 documents: vec![doc1, doc2],
550 tools: Vec::new(),
551 temperature: None,
552 max_tokens: None,
553 additional_params: None,
554 };
555
556 let expected = Message::User {
557 content: OneOrMany::many(vec![
558 UserContent::document(
559 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
560 Some(ContentFormat::String),
561 Some(DocumentMediaType::TXT),
562 ),
563 UserContent::document(
564 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
565 Some(ContentFormat::String),
566 Some(DocumentMediaType::TXT),
567 ),
568 ])
569 .expect("There will be at least one document"),
570 };
571
572 assert_eq!(request.normalized_documents(), Some(expected));
573 }
574
575 #[test]
576 fn test_normalize_documents_without_documents() {
577 let request = CompletionRequest {
578 preamble: None,
579 chat_history: OneOrMany::one("What is the capital of France?".into()),
580 documents: Vec::new(),
581 tools: Vec::new(),
582 temperature: None,
583 max_tokens: None,
584 additional_params: None,
585 };
586
587 assert_eq!(request.normalized_documents(), None);
588 }
589}