rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65use std::collections::HashMap;
66
67use serde::{Deserialize, Serialize};
68use thiserror::Error;
69
70use crate::streaming::{StreamingCompletionModel, StreamingResult};
71use crate::OneOrMany;
72use crate::{
73    json_utils,
74    message::{Message, UserContent},
75    tool::ToolSetError,
76};
77
78use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
79
80// Errors
81#[derive(Debug, Error)]
82pub enum CompletionError {
83    /// Http error (e.g.: connection error, timeout, etc.)
84    #[error("HttpError: {0}")]
85    HttpError(#[from] reqwest::Error),
86
87    /// Json error (e.g.: serialization, deserialization)
88    #[error("JsonError: {0}")]
89    JsonError(#[from] serde_json::Error),
90
91    /// Error building the completion request
92    #[error("RequestError: {0}")]
93    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
94
95    /// Error parsing the completion response
96    #[error("ResponseError: {0}")]
97    ResponseError(String),
98
99    /// Error returned by the completion model provider
100    #[error("ProviderError: {0}")]
101    ProviderError(String),
102}
103
104#[derive(Debug, Error)]
105pub enum PromptError {
106    #[error("CompletionError: {0}")]
107    CompletionError(#[from] CompletionError),
108
109    #[error("ToolCallError: {0}")]
110    ToolError(#[from] ToolSetError),
111
112    #[error("MaxDepthError: (reached limit: {max_depth})")]
113    MaxDepthError {
114        max_depth: usize,
115        chat_history: Vec<Message>,
116        prompt: Message,
117    },
118}
119
120#[derive(Clone, Debug, Deserialize, Serialize)]
121pub struct Document {
122    pub id: String,
123    pub text: String,
124    #[serde(flatten)]
125    pub additional_props: HashMap<String, String>,
126}
127
128impl std::fmt::Display for Document {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        write!(
131            f,
132            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
133            self.id,
134            if self.additional_props.is_empty() {
135                self.text.clone()
136            } else {
137                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
138                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
139                let metadata = sorted_props
140                    .iter()
141                    .map(|(k, v)| format!("{}: {:?}", k, v))
142                    .collect::<Vec<_>>()
143                    .join(" ");
144                format!("<metadata {} />\n{}", metadata, self.text)
145            }
146        )
147    }
148}
149
150#[derive(Clone, Debug, Deserialize, Serialize)]
151pub struct ToolDefinition {
152    pub name: String,
153    pub description: String,
154    pub parameters: serde_json::Value,
155}
156
157// ================================================================
158// Implementations
159// ================================================================
160/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
161pub trait Prompt: Send + Sync {
162    /// Send a simple prompt to the underlying completion model.
163    ///
164    /// If the completion model's response is a message, then it is returned as a string.
165    ///
166    /// If the completion model's response is a tool call, then the tool is called and
167    /// the result is returned as a string.
168    ///
169    /// If the tool does not exist, or the tool call fails, then an error is returned.
170    fn prompt(
171        &self,
172        prompt: impl Into<Message> + Send,
173    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
174}
175
176/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
177pub trait Chat: Send + Sync {
178    /// Send a prompt with optional chat history to the underlying completion model.
179    ///
180    /// If the completion model's response is a message, then it is returned as a string.
181    ///
182    /// If the completion model's response is a tool call, then the tool is called and the result
183    /// is returned as a string.
184    ///
185    /// If the tool does not exist, or the tool call fails, then an error is returned.
186    fn chat(
187        &self,
188        prompt: impl Into<Message> + Send,
189        chat_history: Vec<Message>,
190    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
191}
192
193/// Trait defining a low-level LLM completion interface
194pub trait Completion<M: CompletionModel> {
195    /// Generates a completion request builder for the given `prompt` and `chat_history`.
196    /// This function is meant to be called by the user to further customize the
197    /// request at prompt time before sending it.
198    ///
199    /// ❗IMPORTANT: The type that implements this trait might have already
200    /// populated fields in the builder (the exact fields depend on the type).
201    /// For fields that have already been set by the model, calling the corresponding
202    /// method on the builder will overwrite the value set by the model.
203    ///
204    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
205    /// contain the `preamble` provided when creating the agent.
206    fn completion(
207        &self,
208        prompt: impl Into<Message> + Send,
209        chat_history: Vec<Message>,
210    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
211}
212
213/// General completion response struct that contains the high-level completion choice
214/// and the raw response. The completion choice contains one or more assistant content.
215#[derive(Debug)]
216pub struct CompletionResponse<T> {
217    /// The completion choice (represented by one or more assistant message content)
218    /// returned by the completion model provider
219    pub choice: OneOrMany<AssistantContent>,
220    /// The raw response returned by the completion model provider
221    pub raw_response: T,
222}
223
224/// Trait defining a completion model that can be used to generate completion responses.
225/// This trait is meant to be implemented by the user to define a custom completion model,
226/// either from a third party provider (e.g.: OpenAI) or a local model.
227pub trait CompletionModel: Clone + Send + Sync {
228    /// The raw response type returned by the underlying completion model.
229    type Response: Send + Sync;
230
231    /// Generates a completion response for the given completion request.
232    fn completion(
233        &self,
234        request: CompletionRequest,
235    ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
236           + Send;
237
238    /// Generates a completion request builder for the given `prompt`.
239    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
240        CompletionRequestBuilder::new(self.clone(), prompt)
241    }
242}
243
244/// Struct representing a general completion request that can be sent to a completion model provider.
245pub struct CompletionRequest {
246    /// The preamble to be sent to the completion model provider
247    pub preamble: Option<String>,
248    /// The chat history to be sent to the completion model provider
249    /// The very last message will always be the prompt (hense why there is *always* one)
250    pub chat_history: OneOrMany<Message>,
251    /// The documents to be sent to the completion model provider
252    pub documents: Vec<Document>,
253    /// The tools to be sent to the completion model provider
254    pub tools: Vec<ToolDefinition>,
255    /// The temperature to be sent to the completion model provider
256    pub temperature: Option<f64>,
257    /// The max tokens to be sent to the completion model provider
258    pub max_tokens: Option<u64>,
259    /// Additional provider-specific parameters to be sent to the completion model provider
260    pub additional_params: Option<serde_json::Value>,
261}
262
263impl CompletionRequest {
264    /// Returns documents normalized into a message (if any).
265    /// Most providers do not accept documents directly as input, so it needs to convert into a
266    ///  `Message` so that it can be incorperated into `chat_history` as a
267    pub fn normalized_documents(&self) -> Option<Message> {
268        if self.documents.is_empty() {
269            return None;
270        }
271
272        // Most providers will convert documents into a text unless it can handle document messages.
273        // We use `UserContent::document` for those who handle it directly!
274        let messages = self
275            .documents
276            .iter()
277            .map(|doc| {
278                UserContent::document(
279                    doc.to_string(),
280                    // In the future, we can customize `Document` to pass these extra types through.
281                    // Most providers ditch these but they might want to use them.
282                    Some(ContentFormat::String),
283                    Some(DocumentMediaType::TXT),
284                )
285            })
286            .collect::<Vec<_>>();
287
288        Some(Message::User {
289            content: OneOrMany::many(messages).expect("There will be atleast one document"),
290        })
291    }
292}
293
294/// Builder struct for constructing a completion request.
295///
296/// Example usage:
297/// ```rust
298/// use rig::{
299///     providers::openai::{Client, self},
300///     completion::CompletionRequestBuilder,
301/// };
302///
303/// let openai = Client::new("your-openai-api-key");
304/// let model = openai.completion_model(openai::GPT_4O).build();
305///
306/// // Create the completion request and execute it separately
307/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
308///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
309///     .temperature(0.5)
310///     .build();
311///
312/// let response = model.completion(request)
313///     .await
314///     .expect("Failed to get completion response");
315/// ```
316///
317/// Alternatively, you can execute the completion request directly from the builder:
318/// ```rust
319/// use rig::{
320///     providers::openai::{Client, self},
321///     completion::CompletionRequestBuilder,
322/// };
323///
324/// let openai = Client::new("your-openai-api-key");
325/// let model = openai.completion_model(openai::GPT_4O).build();
326///
327/// // Create the completion request and execute it directly
328/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
329///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
330///     .temperature(0.5)
331///     .send()
332///     .await
333///     .expect("Failed to get completion response");
334/// ```
335///
336/// Note: It is usually unnecessary to create a completion request builder directly.
337/// Instead, use the [CompletionModel::completion_request] method.
338pub struct CompletionRequestBuilder<M: CompletionModel> {
339    model: M,
340    prompt: Message,
341    preamble: Option<String>,
342    chat_history: Vec<Message>,
343    documents: Vec<Document>,
344    tools: Vec<ToolDefinition>,
345    temperature: Option<f64>,
346    max_tokens: Option<u64>,
347    additional_params: Option<serde_json::Value>,
348}
349
350impl<M: CompletionModel> CompletionRequestBuilder<M> {
351    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
352        Self {
353            model,
354            prompt: prompt.into(),
355            preamble: None,
356            chat_history: Vec::new(),
357            documents: Vec::new(),
358            tools: Vec::new(),
359            temperature: None,
360            max_tokens: None,
361            additional_params: None,
362        }
363    }
364
365    /// Sets the preamble for the completion request.
366    pub fn preamble(mut self, preamble: String) -> Self {
367        self.preamble = Some(preamble);
368        self
369    }
370
371    /// Adds a message to the chat history for the completion request.
372    pub fn message(mut self, message: Message) -> Self {
373        self.chat_history.push(message);
374        self
375    }
376
377    /// Adds a list of messages to the chat history for the completion request.
378    pub fn messages(self, messages: Vec<Message>) -> Self {
379        messages
380            .into_iter()
381            .fold(self, |builder, msg| builder.message(msg))
382    }
383
384    /// Adds a document to the completion request.
385    pub fn document(mut self, document: Document) -> Self {
386        self.documents.push(document);
387        self
388    }
389
390    /// Adds a list of documents to the completion request.
391    pub fn documents(self, documents: Vec<Document>) -> Self {
392        documents
393            .into_iter()
394            .fold(self, |builder, doc| builder.document(doc))
395    }
396
397    /// Adds a tool to the completion request.
398    pub fn tool(mut self, tool: ToolDefinition) -> Self {
399        self.tools.push(tool);
400        self
401    }
402
403    /// Adds a list of tools to the completion request.
404    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
405        tools
406            .into_iter()
407            .fold(self, |builder, tool| builder.tool(tool))
408    }
409
410    /// Adds additional parameters to the completion request.
411    /// This can be used to set additional provider-specific parameters. For example,
412    /// Cohere's completion models accept a `connectors` parameter that can be used to
413    /// specify the data connectors used by Cohere when executing the completion
414    /// (see `examples/cohere_connectors.rs`).
415    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
416        match self.additional_params {
417            Some(params) => {
418                self.additional_params = Some(json_utils::merge(params, additional_params));
419            }
420            None => {
421                self.additional_params = Some(additional_params);
422            }
423        }
424        self
425    }
426
427    /// Sets the additional parameters for the completion request.
428    /// This can be used to set additional provider-specific parameters. For example,
429    /// Cohere's completion models accept a `connectors` parameter that can be used to
430    /// specify the data connectors used by Cohere when executing the completion
431    /// (see `examples/cohere_connectors.rs`).
432    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
433        self.additional_params = additional_params;
434        self
435    }
436
437    /// Sets the temperature for the completion request.
438    pub fn temperature(mut self, temperature: f64) -> Self {
439        self.temperature = Some(temperature);
440        self
441    }
442
443    /// Sets the temperature for the completion request.
444    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
445        self.temperature = temperature;
446        self
447    }
448
449    /// Sets the max tokens for the completion request.
450    /// Note: This is required if using Anthropic
451    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
452        self.max_tokens = Some(max_tokens);
453        self
454    }
455
456    /// Sets the max tokens for the completion request.
457    /// Note: This is required if using Anthropic
458    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
459        self.max_tokens = max_tokens;
460        self
461    }
462
463    /// Builds the completion request.
464    pub fn build(self) -> CompletionRequest {
465        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
466            .expect("There will always be atleast the prompt");
467
468        CompletionRequest {
469            preamble: self.preamble,
470            chat_history,
471            documents: self.documents,
472            tools: self.tools,
473            temperature: self.temperature,
474            max_tokens: self.max_tokens,
475            additional_params: self.additional_params,
476        }
477    }
478
479    /// Sends the completion request to the completion model provider and returns the completion response.
480    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
481        let model = self.model.clone();
482        model.completion(self.build()).await
483    }
484}
485
486impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
487    /// Stream the completion request
488    pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
489        let model = self.model.clone();
490        model.stream(self.build()).await
491    }
492}
493
494#[cfg(test)]
495mod tests {
496
497    use super::*;
498
499    #[test]
500    fn test_document_display_without_metadata() {
501        let doc = Document {
502            id: "123".to_string(),
503            text: "This is a test document.".to_string(),
504            additional_props: HashMap::new(),
505        };
506
507        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
508        assert_eq!(format!("{}", doc), expected);
509    }
510
511    #[test]
512    fn test_document_display_with_metadata() {
513        let mut additional_props = HashMap::new();
514        additional_props.insert("author".to_string(), "John Doe".to_string());
515        additional_props.insert("length".to_string(), "42".to_string());
516
517        let doc = Document {
518            id: "123".to_string(),
519            text: "This is a test document.".to_string(),
520            additional_props,
521        };
522
523        let expected = concat!(
524            "<file id: 123>\n",
525            "<metadata author: \"John Doe\" length: \"42\" />\n",
526            "This is a test document.\n",
527            "</file>\n"
528        );
529        assert_eq!(format!("{}", doc), expected);
530    }
531
532    #[test]
533    fn test_normalize_documents_with_documents() {
534        let doc1 = Document {
535            id: "doc1".to_string(),
536            text: "Document 1 text.".to_string(),
537            additional_props: HashMap::new(),
538        };
539
540        let doc2 = Document {
541            id: "doc2".to_string(),
542            text: "Document 2 text.".to_string(),
543            additional_props: HashMap::new(),
544        };
545
546        let request = CompletionRequest {
547            preamble: None,
548            chat_history: OneOrMany::one("What is the capital of France?".into()),
549            documents: vec![doc1, doc2],
550            tools: Vec::new(),
551            temperature: None,
552            max_tokens: None,
553            additional_params: None,
554        };
555
556        let expected = Message::User {
557            content: OneOrMany::many(vec![
558                UserContent::document(
559                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
560                    Some(ContentFormat::String),
561                    Some(DocumentMediaType::TXT),
562                ),
563                UserContent::document(
564                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
565                    Some(ContentFormat::String),
566                    Some(DocumentMediaType::TXT),
567                ),
568            ])
569            .expect("There will be at least one document"),
570        };
571
572        assert_eq!(request.normalized_documents(), Some(expected));
573    }
574
575    #[test]
576    fn test_normalize_documents_without_documents() {
577        let request = CompletionRequest {
578            preamble: None,
579            chat_history: OneOrMany::one("What is the capital of France?".into()),
580            documents: Vec::new(),
581            tools: Vec::new(),
582            temperature: None,
583            max_tokens: None,
584            additional_params: None,
585        };
586
587        assert_eq!(request.normalized_documents(), None);
588    }
589}