rig/completion/request.rs
1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//! responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//! .preamble("\
39//! You are Marvin, an extremely smart but depressed robot who is \
40//! nonetheless helpful towards humanity.\
41//! ")
42//! .temperature(0.5)
43//! .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//! .await
48//! .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//! ModelChoice::Message(message) => {
53//! // Handle the completion response as a message
54//! println!("Received message: {}", message);
55//! }
56//! ModelChoice::ToolCall(tool_name, tool_params) => {
57//! // Handle the completion response as a tool call
58//! println!("Received tool call: {} {:?}", tool_name, tool_params);
59//! }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65use std::collections::HashMap;
66
67use serde::{Deserialize, Serialize};
68use thiserror::Error;
69
70use crate::streaming::{StreamingCompletionModel, StreamingResult};
71use crate::OneOrMany;
72use crate::{
73 json_utils,
74 message::{Message, UserContent},
75 tool::ToolSetError,
76};
77
78use super::message::AssistantContent;
79
80// Errors
81#[derive(Debug, Error)]
82pub enum CompletionError {
83 /// Http error (e.g.: connection error, timeout, etc.)
84 #[error("HttpError: {0}")]
85 HttpError(#[from] reqwest::Error),
86
87 /// Json error (e.g.: serialization, deserialization)
88 #[error("JsonError: {0}")]
89 JsonError(#[from] serde_json::Error),
90
91 /// Error building the completion request
92 #[error("RequestError: {0}")]
93 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
94
95 /// Error parsing the completion response
96 #[error("ResponseError: {0}")]
97 ResponseError(String),
98
99 /// Error returned by the completion model provider
100 #[error("ProviderError: {0}")]
101 ProviderError(String),
102}
103
104#[derive(Debug, Error)]
105pub enum PromptError {
106 #[error("CompletionError: {0}")]
107 CompletionError(#[from] CompletionError),
108
109 #[error("ToolCallError: {0}")]
110 ToolError(#[from] ToolSetError),
111}
112
113#[derive(Clone, Debug, Deserialize, Serialize)]
114pub struct Document {
115 pub id: String,
116 pub text: String,
117 #[serde(flatten)]
118 pub additional_props: HashMap<String, String>,
119}
120
121impl std::fmt::Display for Document {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 write!(
124 f,
125 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
126 self.id,
127 if self.additional_props.is_empty() {
128 self.text.clone()
129 } else {
130 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
131 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
132 let metadata = sorted_props
133 .iter()
134 .map(|(k, v)| format!("{}: {:?}", k, v))
135 .collect::<Vec<_>>()
136 .join(" ");
137 format!("<metadata {} />\n{}", metadata, self.text)
138 }
139 )
140 }
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize)]
144pub struct ToolDefinition {
145 pub name: String,
146 pub description: String,
147 pub parameters: serde_json::Value,
148}
149
150// ================================================================
151// Implementations
152// ================================================================
153/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
154pub trait Prompt: Send + Sync {
155 /// Send a simple prompt to the underlying completion model.
156 ///
157 /// If the completion model's response is a message, then it is returned as a string.
158 ///
159 /// If the completion model's response is a tool call, then the tool is called and
160 /// the result is returned as a string.
161 ///
162 /// If the tool does not exist, or the tool call fails, then an error is returned.
163 fn prompt(
164 &self,
165 prompt: impl Into<Message> + Send,
166 ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
167}
168
169/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
170pub trait Chat: Send + Sync {
171 /// Send a prompt with optional chat history to the underlying completion model.
172 ///
173 /// If the completion model's response is a message, then it is returned as a string.
174 ///
175 /// If the completion model's response is a tool call, then the tool is called and the result
176 /// is returned as a string.
177 ///
178 /// If the tool does not exist, or the tool call fails, then an error is returned.
179 fn chat(
180 &self,
181 prompt: impl Into<Message> + Send,
182 chat_history: Vec<Message>,
183 ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
184}
185
186/// Trait defining a low-level LLM completion interface
187pub trait Completion<M: CompletionModel> {
188 /// Generates a completion request builder for the given `prompt` and `chat_history`.
189 /// This function is meant to be called by the user to further customize the
190 /// request at prompt time before sending it.
191 ///
192 /// ❗IMPORTANT: The type that implements this trait might have already
193 /// populated fields in the builder (the exact fields depend on the type).
194 /// For fields that have already been set by the model, calling the corresponding
195 /// method on the builder will overwrite the value set by the model.
196 ///
197 /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
198 /// contain the `preamble` provided when creating the agent.
199 fn completion(
200 &self,
201 prompt: impl Into<Message> + Send,
202 chat_history: Vec<Message>,
203 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
204}
205
206/// General completion response struct that contains the high-level completion choice
207/// and the raw response. The completion choice contains one or more assistant content.
208#[derive(Debug)]
209pub struct CompletionResponse<T> {
210 /// The completion choice (represented by one or more assistant message content)
211 /// returned by the completion model provider
212 pub choice: OneOrMany<AssistantContent>,
213 /// The raw response returned by the completion model provider
214 pub raw_response: T,
215}
216
217/// Trait defining a completion model that can be used to generate completion responses.
218/// This trait is meant to be implemented by the user to define a custom completion model,
219/// either from a third party provider (e.g.: OpenAI) or a local model.
220pub trait CompletionModel: Clone + Send + Sync {
221 /// The raw response type returned by the underlying completion model.
222 type Response: Send + Sync;
223
224 /// Generates a completion response for the given completion request.
225 fn completion(
226 &self,
227 request: CompletionRequest,
228 ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
229 + Send;
230
231 /// Generates a completion request builder for the given `prompt`.
232 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
233 CompletionRequestBuilder::new(self.clone(), prompt)
234 }
235}
236
237/// Struct representing a general completion request that can be sent to a completion model provider.
238pub struct CompletionRequest {
239 /// The prompt to be sent to the completion model provider
240 pub prompt: Message,
241 /// The preamble to be sent to the completion model provider
242 pub preamble: Option<String>,
243 /// The chat history to be sent to the completion model provider
244 pub chat_history: Vec<Message>,
245 /// The documents to be sent to the completion model provider
246 pub documents: Vec<Document>,
247 /// The tools to be sent to the completion model provider
248 pub tools: Vec<ToolDefinition>,
249 /// The temperature to be sent to the completion model provider
250 pub temperature: Option<f64>,
251 /// The max tokens to be sent to the completion model provider
252 pub max_tokens: Option<u64>,
253 /// Additional provider-specific parameters to be sent to the completion model provider
254 pub additional_params: Option<serde_json::Value>,
255}
256
257impl CompletionRequest {
258 pub fn prompt_with_context(&self) -> Message {
259 let mut new_prompt = self.prompt.clone();
260 if let Message::User { ref mut content } = new_prompt {
261 if !self.documents.is_empty() {
262 let attachments = self
263 .documents
264 .iter()
265 .map(|doc| doc.to_string())
266 .collect::<Vec<_>>()
267 .join("");
268 let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
269 let mut new_content = vec![UserContent::text(formatted_content)];
270 new_content.extend(content.clone());
271 *content = OneOrMany::many(new_content).expect("This has more than 1 item");
272 }
273 }
274 new_prompt
275 }
276}
277
278/// Builder struct for constructing a completion request.
279///
280/// Example usage:
281/// ```rust
282/// use rig::{
283/// providers::openai::{Client, self},
284/// completion::CompletionRequestBuilder,
285/// };
286///
287/// let openai = Client::new("your-openai-api-key");
288/// let model = openai.completion_model(openai::GPT_4O).build();
289///
290/// // Create the completion request and execute it separately
291/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
292/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
293/// .temperature(0.5)
294/// .build();
295///
296/// let response = model.completion(request)
297/// .await
298/// .expect("Failed to get completion response");
299/// ```
300///
301/// Alternatively, you can execute the completion request directly from the builder:
302/// ```rust
303/// use rig::{
304/// providers::openai::{Client, self},
305/// completion::CompletionRequestBuilder,
306/// };
307///
308/// let openai = Client::new("your-openai-api-key");
309/// let model = openai.completion_model(openai::GPT_4O).build();
310///
311/// // Create the completion request and execute it directly
312/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
313/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
314/// .temperature(0.5)
315/// .send()
316/// .await
317/// .expect("Failed to get completion response");
318/// ```
319///
320/// Note: It is usually unnecessary to create a completion request builder directly.
321/// Instead, use the [CompletionModel::completion_request] method.
322pub struct CompletionRequestBuilder<M: CompletionModel> {
323 model: M,
324 prompt: Message,
325 preamble: Option<String>,
326 chat_history: Vec<Message>,
327 documents: Vec<Document>,
328 tools: Vec<ToolDefinition>,
329 temperature: Option<f64>,
330 max_tokens: Option<u64>,
331 additional_params: Option<serde_json::Value>,
332}
333
334impl<M: CompletionModel> CompletionRequestBuilder<M> {
335 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
336 Self {
337 model,
338 prompt: prompt.into(),
339 preamble: None,
340 chat_history: Vec::new(),
341 documents: Vec::new(),
342 tools: Vec::new(),
343 temperature: None,
344 max_tokens: None,
345 additional_params: None,
346 }
347 }
348
349 /// Sets the preamble for the completion request.
350 pub fn preamble(mut self, preamble: String) -> Self {
351 self.preamble = Some(preamble);
352 self
353 }
354
355 /// Adds a message to the chat history for the completion request.
356 pub fn message(mut self, message: Message) -> Self {
357 self.chat_history.push(message);
358 self
359 }
360
361 /// Adds a list of messages to the chat history for the completion request.
362 pub fn messages(self, messages: Vec<Message>) -> Self {
363 messages
364 .into_iter()
365 .fold(self, |builder, msg| builder.message(msg))
366 }
367
368 /// Adds a document to the completion request.
369 pub fn document(mut self, document: Document) -> Self {
370 self.documents.push(document);
371 self
372 }
373
374 /// Adds a list of documents to the completion request.
375 pub fn documents(self, documents: Vec<Document>) -> Self {
376 documents
377 .into_iter()
378 .fold(self, |builder, doc| builder.document(doc))
379 }
380
381 /// Adds a tool to the completion request.
382 pub fn tool(mut self, tool: ToolDefinition) -> Self {
383 self.tools.push(tool);
384 self
385 }
386
387 /// Adds a list of tools to the completion request.
388 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
389 tools
390 .into_iter()
391 .fold(self, |builder, tool| builder.tool(tool))
392 }
393
394 /// Adds additional parameters to the completion request.
395 /// This can be used to set additional provider-specific parameters. For example,
396 /// Cohere's completion models accept a `connectors` parameter that can be used to
397 /// specify the data connectors used by Cohere when executing the completion
398 /// (see `examples/cohere_connectors.rs`).
399 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
400 match self.additional_params {
401 Some(params) => {
402 self.additional_params = Some(json_utils::merge(params, additional_params));
403 }
404 None => {
405 self.additional_params = Some(additional_params);
406 }
407 }
408 self
409 }
410
411 /// Sets the additional parameters for the completion request.
412 /// This can be used to set additional provider-specific parameters. For example,
413 /// Cohere's completion models accept a `connectors` parameter that can be used to
414 /// specify the data connectors used by Cohere when executing the completion
415 /// (see `examples/cohere_connectors.rs`).
416 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
417 self.additional_params = additional_params;
418 self
419 }
420
421 /// Sets the temperature for the completion request.
422 pub fn temperature(mut self, temperature: f64) -> Self {
423 self.temperature = Some(temperature);
424 self
425 }
426
427 /// Sets the temperature for the completion request.
428 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
429 self.temperature = temperature;
430 self
431 }
432
433 /// Sets the max tokens for the completion request.
434 /// Note: This is required if using Anthropic
435 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
436 self.max_tokens = Some(max_tokens);
437 self
438 }
439
440 /// Sets the max tokens for the completion request.
441 /// Note: This is required if using Anthropic
442 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
443 self.max_tokens = max_tokens;
444 self
445 }
446
447 /// Builds the completion request.
448 pub fn build(self) -> CompletionRequest {
449 CompletionRequest {
450 prompt: self.prompt,
451 preamble: self.preamble,
452 chat_history: self.chat_history,
453 documents: self.documents,
454 tools: self.tools,
455 temperature: self.temperature,
456 max_tokens: self.max_tokens,
457 additional_params: self.additional_params,
458 }
459 }
460
461 /// Sends the completion request to the completion model provider and returns the completion response.
462 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
463 let model = self.model.clone();
464 model.completion(self.build()).await
465 }
466}
467
468impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
469 /// Stream the completion request
470 pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
471 let model = self.model.clone();
472 model.stream(self.build()).await
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use crate::OneOrMany;
479
480 use super::*;
481
482 #[test]
483 fn test_document_display_without_metadata() {
484 let doc = Document {
485 id: "123".to_string(),
486 text: "This is a test document.".to_string(),
487 additional_props: HashMap::new(),
488 };
489
490 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
491 assert_eq!(format!("{}", doc), expected);
492 }
493
494 #[test]
495 fn test_document_display_with_metadata() {
496 let mut additional_props = HashMap::new();
497 additional_props.insert("author".to_string(), "John Doe".to_string());
498 additional_props.insert("length".to_string(), "42".to_string());
499
500 let doc = Document {
501 id: "123".to_string(),
502 text: "This is a test document.".to_string(),
503 additional_props,
504 };
505
506 let expected = concat!(
507 "<file id: 123>\n",
508 "<metadata author: \"John Doe\" length: \"42\" />\n",
509 "This is a test document.\n",
510 "</file>\n"
511 );
512 assert_eq!(format!("{}", doc), expected);
513 }
514
515 #[test]
516 fn test_prompt_with_context_with_documents() {
517 let doc1 = Document {
518 id: "doc1".to_string(),
519 text: "Document 1 text.".to_string(),
520 additional_props: HashMap::new(),
521 };
522
523 let doc2 = Document {
524 id: "doc2".to_string(),
525 text: "Document 2 text.".to_string(),
526 additional_props: HashMap::new(),
527 };
528
529 let request = CompletionRequest {
530 prompt: "What is the capital of France?".into(),
531 preamble: None,
532 chat_history: Vec::new(),
533 documents: vec![doc1, doc2],
534 tools: Vec::new(),
535 temperature: None,
536 max_tokens: None,
537 additional_params: None,
538 };
539
540 let expected = Message::User {
541 content: OneOrMany::many(vec![
542 UserContent::text(concat!(
543 "<attachments>\n",
544 "<file id: doc1>\nDocument 1 text.\n</file>\n",
545 "<file id: doc2>\nDocument 2 text.\n</file>\n",
546 "</attachments>"
547 )),
548 UserContent::text("What is the capital of France?"),
549 ])
550 .expect("This has more than 1 item"),
551 };
552
553 request.prompt_with_context();
554
555 assert_eq!(request.prompt_with_context(), expected);
556 }
557}