rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65use std::collections::HashMap;
66
67use serde::{Deserialize, Serialize};
68use thiserror::Error;
69
70use crate::streaming::{StreamingCompletionModel, StreamingResult};
71use crate::OneOrMany;
72use crate::{
73    json_utils,
74    message::{Message, UserContent},
75    tool::ToolSetError,
76};
77
78use super::message::AssistantContent;
79
80// Errors
81#[derive(Debug, Error)]
82pub enum CompletionError {
83    /// Http error (e.g.: connection error, timeout, etc.)
84    #[error("HttpError: {0}")]
85    HttpError(#[from] reqwest::Error),
86
87    /// Json error (e.g.: serialization, deserialization)
88    #[error("JsonError: {0}")]
89    JsonError(#[from] serde_json::Error),
90
91    /// Error building the completion request
92    #[error("RequestError: {0}")]
93    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
94
95    /// Error parsing the completion response
96    #[error("ResponseError: {0}")]
97    ResponseError(String),
98
99    /// Error returned by the completion model provider
100    #[error("ProviderError: {0}")]
101    ProviderError(String),
102}
103
104#[derive(Debug, Error)]
105pub enum PromptError {
106    #[error("CompletionError: {0}")]
107    CompletionError(#[from] CompletionError),
108
109    #[error("ToolCallError: {0}")]
110    ToolError(#[from] ToolSetError),
111}
112
113#[derive(Clone, Debug, Deserialize, Serialize)]
114pub struct Document {
115    pub id: String,
116    pub text: String,
117    #[serde(flatten)]
118    pub additional_props: HashMap<String, String>,
119}
120
121impl std::fmt::Display for Document {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        write!(
124            f,
125            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
126            self.id,
127            if self.additional_props.is_empty() {
128                self.text.clone()
129            } else {
130                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
131                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
132                let metadata = sorted_props
133                    .iter()
134                    .map(|(k, v)| format!("{}: {:?}", k, v))
135                    .collect::<Vec<_>>()
136                    .join(" ");
137                format!("<metadata {} />\n{}", metadata, self.text)
138            }
139        )
140    }
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize)]
144pub struct ToolDefinition {
145    pub name: String,
146    pub description: String,
147    pub parameters: serde_json::Value,
148}
149
150// ================================================================
151// Implementations
152// ================================================================
153/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
154pub trait Prompt: Send + Sync {
155    /// Send a simple prompt to the underlying completion model.
156    ///
157    /// If the completion model's response is a message, then it is returned as a string.
158    ///
159    /// If the completion model's response is a tool call, then the tool is called and
160    /// the result is returned as a string.
161    ///
162    /// If the tool does not exist, or the tool call fails, then an error is returned.
163    fn prompt(
164        &self,
165        prompt: impl Into<Message> + Send,
166    ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
167}
168
169/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
170pub trait Chat: Send + Sync {
171    /// Send a prompt with optional chat history to the underlying completion model.
172    ///
173    /// If the completion model's response is a message, then it is returned as a string.
174    ///
175    /// If the completion model's response is a tool call, then the tool is called and the result
176    /// is returned as a string.
177    ///
178    /// If the tool does not exist, or the tool call fails, then an error is returned.
179    fn chat(
180        &self,
181        prompt: impl Into<Message> + Send,
182        chat_history: Vec<Message>,
183    ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
184}
185
186/// Trait defining a low-level LLM completion interface
187pub trait Completion<M: CompletionModel> {
188    /// Generates a completion request builder for the given `prompt` and `chat_history`.
189    /// This function is meant to be called by the user to further customize the
190    /// request at prompt time before sending it.
191    ///
192    /// ❗IMPORTANT: The type that implements this trait might have already
193    /// populated fields in the builder (the exact fields depend on the type).
194    /// For fields that have already been set by the model, calling the corresponding
195    /// method on the builder will overwrite the value set by the model.
196    ///
197    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
198    /// contain the `preamble` provided when creating the agent.
199    fn completion(
200        &self,
201        prompt: impl Into<Message> + Send,
202        chat_history: Vec<Message>,
203    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
204}
205
206/// General completion response struct that contains the high-level completion choice
207/// and the raw response. The completion choice contains one or more assistant content.
208#[derive(Debug)]
209pub struct CompletionResponse<T> {
210    /// The completion choice (represented by one or more assistant message content)
211    /// returned by the completion model provider
212    pub choice: OneOrMany<AssistantContent>,
213    /// The raw response returned by the completion model provider
214    pub raw_response: T,
215}
216
217/// Trait defining a completion model that can be used to generate completion responses.
218/// This trait is meant to be implemented by the user to define a custom completion model,
219/// either from a third party provider (e.g.: OpenAI) or a local model.
220pub trait CompletionModel: Clone + Send + Sync {
221    /// The raw response type returned by the underlying completion model.
222    type Response: Send + Sync;
223
224    /// Generates a completion response for the given completion request.
225    fn completion(
226        &self,
227        request: CompletionRequest,
228    ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
229           + Send;
230
231    /// Generates a completion request builder for the given `prompt`.
232    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
233        CompletionRequestBuilder::new(self.clone(), prompt)
234    }
235}
236
237/// Struct representing a general completion request that can be sent to a completion model provider.
238pub struct CompletionRequest {
239    /// The prompt to be sent to the completion model provider
240    pub prompt: Message,
241    /// The preamble to be sent to the completion model provider
242    pub preamble: Option<String>,
243    /// The chat history to be sent to the completion model provider
244    pub chat_history: Vec<Message>,
245    /// The documents to be sent to the completion model provider
246    pub documents: Vec<Document>,
247    /// The tools to be sent to the completion model provider
248    pub tools: Vec<ToolDefinition>,
249    /// The temperature to be sent to the completion model provider
250    pub temperature: Option<f64>,
251    /// The max tokens to be sent to the completion model provider
252    pub max_tokens: Option<u64>,
253    /// Additional provider-specific parameters to be sent to the completion model provider
254    pub additional_params: Option<serde_json::Value>,
255}
256
257impl CompletionRequest {
258    pub fn prompt_with_context(&self) -> Message {
259        let mut new_prompt = self.prompt.clone();
260        if let Message::User { ref mut content } = new_prompt {
261            if !self.documents.is_empty() {
262                let attachments = self
263                    .documents
264                    .iter()
265                    .map(|doc| doc.to_string())
266                    .collect::<Vec<_>>()
267                    .join("");
268                let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
269                let mut new_content = vec![UserContent::text(formatted_content)];
270                new_content.extend(content.clone());
271                *content = OneOrMany::many(new_content).expect("This has more than 1 item");
272            }
273        }
274        new_prompt
275    }
276}
277
278/// Builder struct for constructing a completion request.
279///
280/// Example usage:
281/// ```rust
282/// use rig::{
283///     providers::openai::{Client, self},
284///     completion::CompletionRequestBuilder,
285/// };
286///
287/// let openai = Client::new("your-openai-api-key");
288/// let model = openai.completion_model(openai::GPT_4O).build();
289///
290/// // Create the completion request and execute it separately
291/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
292///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
293///     .temperature(0.5)
294///     .build();
295///
296/// let response = model.completion(request)
297///     .await
298///     .expect("Failed to get completion response");
299/// ```
300///
301/// Alternatively, you can execute the completion request directly from the builder:
302/// ```rust
303/// use rig::{
304///     providers::openai::{Client, self},
305///     completion::CompletionRequestBuilder,
306/// };
307///
308/// let openai = Client::new("your-openai-api-key");
309/// let model = openai.completion_model(openai::GPT_4O).build();
310///
311/// // Create the completion request and execute it directly
312/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
313///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
314///     .temperature(0.5)
315///     .send()
316///     .await
317///     .expect("Failed to get completion response");
318/// ```
319///
320/// Note: It is usually unnecessary to create a completion request builder directly.
321/// Instead, use the [CompletionModel::completion_request] method.
322pub struct CompletionRequestBuilder<M: CompletionModel> {
323    model: M,
324    prompt: Message,
325    preamble: Option<String>,
326    chat_history: Vec<Message>,
327    documents: Vec<Document>,
328    tools: Vec<ToolDefinition>,
329    temperature: Option<f64>,
330    max_tokens: Option<u64>,
331    additional_params: Option<serde_json::Value>,
332}
333
334impl<M: CompletionModel> CompletionRequestBuilder<M> {
335    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
336        Self {
337            model,
338            prompt: prompt.into(),
339            preamble: None,
340            chat_history: Vec::new(),
341            documents: Vec::new(),
342            tools: Vec::new(),
343            temperature: None,
344            max_tokens: None,
345            additional_params: None,
346        }
347    }
348
349    /// Sets the preamble for the completion request.
350    pub fn preamble(mut self, preamble: String) -> Self {
351        self.preamble = Some(preamble);
352        self
353    }
354
355    /// Adds a message to the chat history for the completion request.
356    pub fn message(mut self, message: Message) -> Self {
357        self.chat_history.push(message);
358        self
359    }
360
361    /// Adds a list of messages to the chat history for the completion request.
362    pub fn messages(self, messages: Vec<Message>) -> Self {
363        messages
364            .into_iter()
365            .fold(self, |builder, msg| builder.message(msg))
366    }
367
368    /// Adds a document to the completion request.
369    pub fn document(mut self, document: Document) -> Self {
370        self.documents.push(document);
371        self
372    }
373
374    /// Adds a list of documents to the completion request.
375    pub fn documents(self, documents: Vec<Document>) -> Self {
376        documents
377            .into_iter()
378            .fold(self, |builder, doc| builder.document(doc))
379    }
380
381    /// Adds a tool to the completion request.
382    pub fn tool(mut self, tool: ToolDefinition) -> Self {
383        self.tools.push(tool);
384        self
385    }
386
387    /// Adds a list of tools to the completion request.
388    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
389        tools
390            .into_iter()
391            .fold(self, |builder, tool| builder.tool(tool))
392    }
393
394    /// Adds additional parameters to the completion request.
395    /// This can be used to set additional provider-specific parameters. For example,
396    /// Cohere's completion models accept a `connectors` parameter that can be used to
397    /// specify the data connectors used by Cohere when executing the completion
398    /// (see `examples/cohere_connectors.rs`).
399    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
400        match self.additional_params {
401            Some(params) => {
402                self.additional_params = Some(json_utils::merge(params, additional_params));
403            }
404            None => {
405                self.additional_params = Some(additional_params);
406            }
407        }
408        self
409    }
410
411    /// Sets the additional parameters for the completion request.
412    /// This can be used to set additional provider-specific parameters. For example,
413    /// Cohere's completion models accept a `connectors` parameter that can be used to
414    /// specify the data connectors used by Cohere when executing the completion
415    /// (see `examples/cohere_connectors.rs`).
416    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
417        self.additional_params = additional_params;
418        self
419    }
420
421    /// Sets the temperature for the completion request.
422    pub fn temperature(mut self, temperature: f64) -> Self {
423        self.temperature = Some(temperature);
424        self
425    }
426
427    /// Sets the temperature for the completion request.
428    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
429        self.temperature = temperature;
430        self
431    }
432
433    /// Sets the max tokens for the completion request.
434    /// Note: This is required if using Anthropic
435    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
436        self.max_tokens = Some(max_tokens);
437        self
438    }
439
440    /// Sets the max tokens for the completion request.
441    /// Note: This is required if using Anthropic
442    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
443        self.max_tokens = max_tokens;
444        self
445    }
446
447    /// Builds the completion request.
448    pub fn build(self) -> CompletionRequest {
449        CompletionRequest {
450            prompt: self.prompt,
451            preamble: self.preamble,
452            chat_history: self.chat_history,
453            documents: self.documents,
454            tools: self.tools,
455            temperature: self.temperature,
456            max_tokens: self.max_tokens,
457            additional_params: self.additional_params,
458        }
459    }
460
461    /// Sends the completion request to the completion model provider and returns the completion response.
462    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
463        let model = self.model.clone();
464        model.completion(self.build()).await
465    }
466}
467
468impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
469    /// Stream the completion request
470    pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
471        let model = self.model.clone();
472        model.stream(self.build()).await
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use crate::OneOrMany;
479
480    use super::*;
481
482    #[test]
483    fn test_document_display_without_metadata() {
484        let doc = Document {
485            id: "123".to_string(),
486            text: "This is a test document.".to_string(),
487            additional_props: HashMap::new(),
488        };
489
490        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
491        assert_eq!(format!("{}", doc), expected);
492    }
493
494    #[test]
495    fn test_document_display_with_metadata() {
496        let mut additional_props = HashMap::new();
497        additional_props.insert("author".to_string(), "John Doe".to_string());
498        additional_props.insert("length".to_string(), "42".to_string());
499
500        let doc = Document {
501            id: "123".to_string(),
502            text: "This is a test document.".to_string(),
503            additional_props,
504        };
505
506        let expected = concat!(
507            "<file id: 123>\n",
508            "<metadata author: \"John Doe\" length: \"42\" />\n",
509            "This is a test document.\n",
510            "</file>\n"
511        );
512        assert_eq!(format!("{}", doc), expected);
513    }
514
515    #[test]
516    fn test_prompt_with_context_with_documents() {
517        let doc1 = Document {
518            id: "doc1".to_string(),
519            text: "Document 1 text.".to_string(),
520            additional_props: HashMap::new(),
521        };
522
523        let doc2 = Document {
524            id: "doc2".to_string(),
525            text: "Document 2 text.".to_string(),
526            additional_props: HashMap::new(),
527        };
528
529        let request = CompletionRequest {
530            prompt: "What is the capital of France?".into(),
531            preamble: None,
532            chat_history: Vec::new(),
533            documents: vec![doc1, doc2],
534            tools: Vec::new(),
535            temperature: None,
536            max_tokens: None,
537            additional_params: None,
538        };
539
540        let expected = Message::User {
541            content: OneOrMany::many(vec![
542                UserContent::text(concat!(
543                    "<attachments>\n",
544                    "<file id: doc1>\nDocument 1 text.\n</file>\n",
545                    "<file id: doc2>\nDocument 2 text.\n</file>\n",
546                    "</attachments>"
547                )),
548                UserContent::text("What is the capital of France?"),
549            ])
550            .expect("This has more than 1 item"),
551        };
552
553        request.prompt_with_context();
554
555        assert_eq!(request.prompt_with_context(), expected);
556    }
557}