rig/completion/request.rs
1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//! responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//! .preamble("\
39//! You are Marvin, an extremely smart but depressed robot who is \
40//! nonetheless helpful towards humanity.\
41//! ")
42//! .temperature(0.5)
43//! .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//! .await
48//! .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//! ModelChoice::Message(message) => {
53//! // Handle the completion response as a message
54//! println!("Received message: {}", message);
55//! }
56//! ModelChoice::ToolCall(tool_name, tool_params) => {
57//! // Handle the completion response as a tool call
58//! println!("Received tool call: {} {:?}", tool_name, tool_params);
59//! }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
67use crate::client::completion::CompletionModelHandle;
68use crate::streaming::StreamingCompletionResponse;
69use crate::{OneOrMany, streaming};
70use crate::{
71 json_utils,
72 message::{Message, UserContent},
73 tool::ToolSetError,
74};
75use futures::future::BoxFuture;
76use serde::{Deserialize, Serialize};
77use std::collections::HashMap;
78use std::sync::Arc;
79use thiserror::Error;
80
81// Errors
82#[derive(Debug, Error)]
83pub enum CompletionError {
84 /// Http error (e.g.: connection error, timeout, etc.)
85 #[error("HttpError: {0}")]
86 HttpError(#[from] reqwest::Error),
87
88 /// Json error (e.g.: serialization, deserialization)
89 #[error("JsonError: {0}")]
90 JsonError(#[from] serde_json::Error),
91
92 /// Error building the completion request
93 #[error("RequestError: {0}")]
94 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
95
96 /// Error parsing the completion response
97 #[error("ResponseError: {0}")]
98 ResponseError(String),
99
100 /// Error returned by the completion model provider
101 #[error("ProviderError: {0}")]
102 ProviderError(String),
103}
104
105/// Prompt errors
106#[derive(Debug, Error)]
107pub enum PromptError {
108 /// Something went wrong with the completion
109 #[error("CompletionError: {0}")]
110 CompletionError(#[from] CompletionError),
111
112 /// There was an error while using a tool
113 #[error("ToolCallError: {0}")]
114 ToolError(#[from] ToolSetError),
115
116 /// The LLM tried to call too many tools during a multi-turn conversation.
117 /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
118 /// or increase the amount of turns given in `.multi_turn()`.
119 #[error("MaxDepthError: (reached limit: {max_depth})")]
120 MaxDepthError {
121 max_depth: usize,
122 chat_history: Vec<Message>,
123 prompt: Message,
124 },
125}
126
127#[derive(Clone, Debug, Deserialize, Serialize)]
128pub struct Document {
129 pub id: String,
130 pub text: String,
131 #[serde(flatten)]
132 pub additional_props: HashMap<String, String>,
133}
134
135impl std::fmt::Display for Document {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write!(
138 f,
139 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
140 self.id,
141 if self.additional_props.is_empty() {
142 self.text.clone()
143 } else {
144 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
145 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
146 let metadata = sorted_props
147 .iter()
148 .map(|(k, v)| format!("{k}: {v:?}"))
149 .collect::<Vec<_>>()
150 .join(" ");
151 format!("<metadata {} />\n{}", metadata, self.text)
152 }
153 )
154 }
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize)]
158pub struct ToolDefinition {
159 pub name: String,
160 pub description: String,
161 pub parameters: serde_json::Value,
162}
163
164// ================================================================
165// Implementations
166// ================================================================
167/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
168pub trait Prompt: Send + Sync {
169 /// Send a simple prompt to the underlying completion model.
170 ///
171 /// If the completion model's response is a message, then it is returned as a string.
172 ///
173 /// If the completion model's response is a tool call, then the tool is called and
174 /// the result is returned as a string.
175 ///
176 /// If the tool does not exist, or the tool call fails, then an error is returned.
177 fn prompt(
178 &self,
179 prompt: impl Into<Message> + Send,
180 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
181}
182
183/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
184pub trait Chat: Send + Sync {
185 /// Send a prompt with optional chat history to the underlying completion model.
186 ///
187 /// If the completion model's response is a message, then it is returned as a string.
188 ///
189 /// If the completion model's response is a tool call, then the tool is called and the result
190 /// is returned as a string.
191 ///
192 /// If the tool does not exist, or the tool call fails, then an error is returned.
193 fn chat(
194 &self,
195 prompt: impl Into<Message> + Send,
196 chat_history: Vec<Message>,
197 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
198}
199
200/// Trait defining a low-level LLM completion interface
201pub trait Completion<M: CompletionModel> {
202 /// Generates a completion request builder for the given `prompt` and `chat_history`.
203 /// This function is meant to be called by the user to further customize the
204 /// request at prompt time before sending it.
205 ///
206 /// ❗IMPORTANT: The type that implements this trait might have already
207 /// populated fields in the builder (the exact fields depend on the type).
208 /// For fields that have already been set by the model, calling the corresponding
209 /// method on the builder will overwrite the value set by the model.
210 ///
211 /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
212 /// contain the `preamble` provided when creating the agent.
213 fn completion(
214 &self,
215 prompt: impl Into<Message> + Send,
216 chat_history: Vec<Message>,
217 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
218}
219
220/// General completion response struct that contains the high-level completion choice
221/// and the raw response. The completion choice contains one or more assistant content.
222#[derive(Debug)]
223pub struct CompletionResponse<T> {
224 /// The completion choice (represented by one or more assistant message content)
225 /// returned by the completion model provider
226 pub choice: OneOrMany<AssistantContent>,
227 /// The raw response returned by the completion model provider
228 pub raw_response: T,
229}
230
231/// Trait defining a completion model that can be used to generate completion responses.
232/// This trait is meant to be implemented by the user to define a custom completion model,
233/// either from a third party provider (e.g.: OpenAI) or a local model.
234pub trait CompletionModel: Clone + Send + Sync {
235 /// The raw response type returned by the underlying completion model.
236 type Response: Send + Sync;
237 /// The raw response type returned by the underlying completion model when streaming.
238 type StreamingResponse: Clone + Unpin + Send + Sync;
239
240 /// Generates a completion response for the given completion request.
241 fn completion(
242 &self,
243 request: CompletionRequest,
244 ) -> impl std::future::Future<
245 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
246 > + Send;
247
248 fn stream(
249 &self,
250 request: CompletionRequest,
251 ) -> impl std::future::Future<
252 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
253 > + Send;
254
255 /// Generates a completion request builder for the given `prompt`.
256 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
257 CompletionRequestBuilder::new(self.clone(), prompt)
258 }
259}
260pub trait CompletionModelDyn: Send + Sync {
261 fn completion(
262 &self,
263 request: CompletionRequest,
264 ) -> BoxFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
265
266 fn stream(
267 &self,
268 request: CompletionRequest,
269 ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>>;
270
271 fn completion_request(
272 &self,
273 prompt: Message,
274 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
275}
276
277impl<T, R> CompletionModelDyn for T
278where
279 T: CompletionModel<StreamingResponse = R>,
280 R: Clone + Unpin + 'static,
281{
282 fn completion(
283 &self,
284 request: CompletionRequest,
285 ) -> BoxFuture<Result<CompletionResponse<()>, CompletionError>> {
286 Box::pin(async move {
287 self.completion(request)
288 .await
289 .map(|resp| CompletionResponse {
290 choice: resp.choice,
291 raw_response: (),
292 })
293 })
294 }
295
296 fn stream(
297 &self,
298 request: CompletionRequest,
299 ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>> {
300 Box::pin(async move {
301 let resp = self.stream(request).await?;
302 let inner = resp.inner;
303
304 let stream = Box::pin(streaming::StreamingResultDyn {
305 inner: Box::pin(inner),
306 });
307
308 Ok(StreamingCompletionResponse::stream(stream))
309 })
310 }
311
312 /// Generates a completion request builder for the given `prompt`.
313 fn completion_request(
314 &self,
315 prompt: Message,
316 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
317 CompletionRequestBuilder::new(
318 CompletionModelHandle {
319 inner: Arc::new(self.clone()),
320 },
321 prompt,
322 )
323 }
324}
325
326/// Struct representing a general completion request that can be sent to a completion model provider.
327#[derive(Debug, Clone)]
328pub struct CompletionRequest {
329 /// The preamble to be sent to the completion model provider
330 pub preamble: Option<String>,
331 /// The chat history to be sent to the completion model provider
332 /// The very last message will always be the prompt (hense why there is *always* one)
333 pub chat_history: OneOrMany<Message>,
334 /// The documents to be sent to the completion model provider
335 pub documents: Vec<Document>,
336 /// The tools to be sent to the completion model provider
337 pub tools: Vec<ToolDefinition>,
338 /// The temperature to be sent to the completion model provider
339 pub temperature: Option<f64>,
340 /// The max tokens to be sent to the completion model provider
341 pub max_tokens: Option<u64>,
342 /// Additional provider-specific parameters to be sent to the completion model provider
343 pub additional_params: Option<serde_json::Value>,
344}
345
346impl CompletionRequest {
347 /// Returns documents normalized into a message (if any).
348 /// Most providers do not accept documents directly as input, so it needs to convert into a
349 /// `Message` so that it can be incorporated into `chat_history` as a
350 pub fn normalized_documents(&self) -> Option<Message> {
351 if self.documents.is_empty() {
352 return None;
353 }
354
355 // Most providers will convert documents into a text unless it can handle document messages.
356 // We use `UserContent::document` for those who handle it directly!
357 let messages = self
358 .documents
359 .iter()
360 .map(|doc| {
361 UserContent::document(
362 doc.to_string(),
363 // In the future, we can customize `Document` to pass these extra types through.
364 // Most providers ditch these but they might want to use them.
365 Some(ContentFormat::String),
366 Some(DocumentMediaType::TXT),
367 )
368 })
369 .collect::<Vec<_>>();
370
371 Some(Message::User {
372 content: OneOrMany::many(messages).expect("There will be atleast one document"),
373 })
374 }
375}
376
377/// Builder struct for constructing a completion request.
378///
379/// Example usage:
380/// ```rust
381/// use rig::{
382/// providers::openai::{Client, self},
383/// completion::CompletionRequestBuilder,
384/// };
385///
386/// let openai = Client::new("your-openai-api-key");
387/// let model = openai.completion_model(openai::GPT_4O).build();
388///
389/// // Create the completion request and execute it separately
390/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
391/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
392/// .temperature(0.5)
393/// .build();
394///
395/// let response = model.completion(request)
396/// .await
397/// .expect("Failed to get completion response");
398/// ```
399///
400/// Alternatively, you can execute the completion request directly from the builder:
401/// ```rust
402/// use rig::{
403/// providers::openai::{Client, self},
404/// completion::CompletionRequestBuilder,
405/// };
406///
407/// let openai = Client::new("your-openai-api-key");
408/// let model = openai.completion_model(openai::GPT_4O).build();
409///
410/// // Create the completion request and execute it directly
411/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
412/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
413/// .temperature(0.5)
414/// .send()
415/// .await
416/// .expect("Failed to get completion response");
417/// ```
418///
419/// Note: It is usually unnecessary to create a completion request builder directly.
420/// Instead, use the [CompletionModel::completion_request] method.
421pub struct CompletionRequestBuilder<M: CompletionModel> {
422 model: M,
423 prompt: Message,
424 preamble: Option<String>,
425 chat_history: Vec<Message>,
426 documents: Vec<Document>,
427 tools: Vec<ToolDefinition>,
428 temperature: Option<f64>,
429 max_tokens: Option<u64>,
430 additional_params: Option<serde_json::Value>,
431}
432
433impl<M: CompletionModel> CompletionRequestBuilder<M> {
434 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
435 Self {
436 model,
437 prompt: prompt.into(),
438 preamble: None,
439 chat_history: Vec::new(),
440 documents: Vec::new(),
441 tools: Vec::new(),
442 temperature: None,
443 max_tokens: None,
444 additional_params: None,
445 }
446 }
447
448 /// Sets the preamble for the completion request.
449 pub fn preamble(mut self, preamble: String) -> Self {
450 self.preamble = Some(preamble);
451 self
452 }
453
454 /// Adds a message to the chat history for the completion request.
455 pub fn message(mut self, message: Message) -> Self {
456 self.chat_history.push(message);
457 self
458 }
459
460 /// Adds a list of messages to the chat history for the completion request.
461 pub fn messages(self, messages: Vec<Message>) -> Self {
462 messages
463 .into_iter()
464 .fold(self, |builder, msg| builder.message(msg))
465 }
466
467 /// Adds a document to the completion request.
468 pub fn document(mut self, document: Document) -> Self {
469 self.documents.push(document);
470 self
471 }
472
473 /// Adds a list of documents to the completion request.
474 pub fn documents(self, documents: Vec<Document>) -> Self {
475 documents
476 .into_iter()
477 .fold(self, |builder, doc| builder.document(doc))
478 }
479
480 /// Adds a tool to the completion request.
481 pub fn tool(mut self, tool: ToolDefinition) -> Self {
482 self.tools.push(tool);
483 self
484 }
485
486 /// Adds a list of tools to the completion request.
487 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
488 tools
489 .into_iter()
490 .fold(self, |builder, tool| builder.tool(tool))
491 }
492
493 /// Adds additional parameters to the completion request.
494 /// This can be used to set additional provider-specific parameters. For example,
495 /// Cohere's completion models accept a `connectors` parameter that can be used to
496 /// specify the data connectors used by Cohere when executing the completion
497 /// (see `examples/cohere_connectors.rs`).
498 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
499 match self.additional_params {
500 Some(params) => {
501 self.additional_params = Some(json_utils::merge(params, additional_params));
502 }
503 None => {
504 self.additional_params = Some(additional_params);
505 }
506 }
507 self
508 }
509
510 /// Sets the additional parameters for the completion request.
511 /// This can be used to set additional provider-specific parameters. For example,
512 /// Cohere's completion models accept a `connectors` parameter that can be used to
513 /// specify the data connectors used by Cohere when executing the completion
514 /// (see `examples/cohere_connectors.rs`).
515 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
516 self.additional_params = additional_params;
517 self
518 }
519
520 /// Sets the temperature for the completion request.
521 pub fn temperature(mut self, temperature: f64) -> Self {
522 self.temperature = Some(temperature);
523 self
524 }
525
526 /// Sets the temperature for the completion request.
527 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
528 self.temperature = temperature;
529 self
530 }
531
532 /// Sets the max tokens for the completion request.
533 /// Note: This is required if using Anthropic
534 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
535 self.max_tokens = Some(max_tokens);
536 self
537 }
538
539 /// Sets the max tokens for the completion request.
540 /// Note: This is required if using Anthropic
541 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
542 self.max_tokens = max_tokens;
543 self
544 }
545
546 /// Builds the completion request.
547 pub fn build(self) -> CompletionRequest {
548 let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
549 .expect("There will always be atleast the prompt");
550
551 CompletionRequest {
552 preamble: self.preamble,
553 chat_history,
554 documents: self.documents,
555 tools: self.tools,
556 temperature: self.temperature,
557 max_tokens: self.max_tokens,
558 additional_params: self.additional_params,
559 }
560 }
561
562 /// Sends the completion request to the completion model provider and returns the completion response.
563 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
564 let model = self.model.clone();
565 model.completion(self.build()).await
566 }
567
568 /// Stream the completion request
569 pub async fn stream<'a>(
570 self,
571 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
572 where
573 <M as CompletionModel>::StreamingResponse: 'a,
574 Self: 'a,
575 {
576 let model = self.model.clone();
577 model.stream(self.build()).await
578 }
579}
580
581#[cfg(test)]
582mod tests {
583
584 use super::*;
585
586 #[test]
587 fn test_document_display_without_metadata() {
588 let doc = Document {
589 id: "123".to_string(),
590 text: "This is a test document.".to_string(),
591 additional_props: HashMap::new(),
592 };
593
594 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
595 assert_eq!(format!("{doc}"), expected);
596 }
597
598 #[test]
599 fn test_document_display_with_metadata() {
600 let mut additional_props = HashMap::new();
601 additional_props.insert("author".to_string(), "John Doe".to_string());
602 additional_props.insert("length".to_string(), "42".to_string());
603
604 let doc = Document {
605 id: "123".to_string(),
606 text: "This is a test document.".to_string(),
607 additional_props,
608 };
609
610 let expected = concat!(
611 "<file id: 123>\n",
612 "<metadata author: \"John Doe\" length: \"42\" />\n",
613 "This is a test document.\n",
614 "</file>\n"
615 );
616 assert_eq!(format!("{doc}"), expected);
617 }
618
619 #[test]
620 fn test_normalize_documents_with_documents() {
621 let doc1 = Document {
622 id: "doc1".to_string(),
623 text: "Document 1 text.".to_string(),
624 additional_props: HashMap::new(),
625 };
626
627 let doc2 = Document {
628 id: "doc2".to_string(),
629 text: "Document 2 text.".to_string(),
630 additional_props: HashMap::new(),
631 };
632
633 let request = CompletionRequest {
634 preamble: None,
635 chat_history: OneOrMany::one("What is the capital of France?".into()),
636 documents: vec![doc1, doc2],
637 tools: Vec::new(),
638 temperature: None,
639 max_tokens: None,
640 additional_params: None,
641 };
642
643 let expected = Message::User {
644 content: OneOrMany::many(vec![
645 UserContent::document(
646 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
647 Some(ContentFormat::String),
648 Some(DocumentMediaType::TXT),
649 ),
650 UserContent::document(
651 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
652 Some(ContentFormat::String),
653 Some(DocumentMediaType::TXT),
654 ),
655 ])
656 .expect("There will be at least one document"),
657 };
658
659 assert_eq!(request.normalized_documents(), Some(expected));
660 }
661
662 #[test]
663 fn test_normalize_documents_without_documents() {
664 let request = CompletionRequest {
665 preamble: None,
666 chat_history: OneOrMany::one("What is the capital of France?".into()),
667 documents: Vec::new(),
668 tools: Vec::new(),
669 temperature: None,
670 max_tokens: None,
671 additional_params: None,
672 };
673
674 assert_eq!(request.normalized_documents(), None);
675 }
676}