rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
67use crate::client::completion::CompletionModelHandle;
68use crate::streaming::StreamingCompletionResponse;
69use crate::{OneOrMany, streaming};
70use crate::{
71    json_utils,
72    message::{Message, UserContent},
73    tool::ToolSetError,
74};
75use futures::future::BoxFuture;
76use serde::{Deserialize, Serialize};
77use std::collections::HashMap;
78use std::sync::Arc;
79use thiserror::Error;
80
81// Errors
82#[derive(Debug, Error)]
83pub enum CompletionError {
84    /// Http error (e.g.: connection error, timeout, etc.)
85    #[error("HttpError: {0}")]
86    HttpError(#[from] reqwest::Error),
87
88    /// Json error (e.g.: serialization, deserialization)
89    #[error("JsonError: {0}")]
90    JsonError(#[from] serde_json::Error),
91
92    /// Error building the completion request
93    #[error("RequestError: {0}")]
94    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
95
96    /// Error parsing the completion response
97    #[error("ResponseError: {0}")]
98    ResponseError(String),
99
100    /// Error returned by the completion model provider
101    #[error("ProviderError: {0}")]
102    ProviderError(String),
103}
104
105/// Prompt errors
106#[derive(Debug, Error)]
107pub enum PromptError {
108    /// Something went wrong with the completion
109    #[error("CompletionError: {0}")]
110    CompletionError(#[from] CompletionError),
111
112    /// There was an error while using a tool
113    #[error("ToolCallError: {0}")]
114    ToolError(#[from] ToolSetError),
115
116    /// The LLM tried to call too many tools during a multi-turn conversation.
117    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
118    /// or increase the amount of turns given in `.multi_turn()`.
119    #[error("MaxDepthError: (reached limit: {max_depth})")]
120    MaxDepthError {
121        max_depth: usize,
122        chat_history: Vec<Message>,
123        prompt: Message,
124    },
125}
126
127#[derive(Clone, Debug, Deserialize, Serialize)]
128pub struct Document {
129    pub id: String,
130    pub text: String,
131    #[serde(flatten)]
132    pub additional_props: HashMap<String, String>,
133}
134
135impl std::fmt::Display for Document {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        write!(
138            f,
139            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
140            self.id,
141            if self.additional_props.is_empty() {
142                self.text.clone()
143            } else {
144                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
145                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
146                let metadata = sorted_props
147                    .iter()
148                    .map(|(k, v)| format!("{k}: {v:?}"))
149                    .collect::<Vec<_>>()
150                    .join(" ");
151                format!("<metadata {} />\n{}", metadata, self.text)
152            }
153        )
154    }
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize)]
158pub struct ToolDefinition {
159    pub name: String,
160    pub description: String,
161    pub parameters: serde_json::Value,
162}
163
164// ================================================================
165// Implementations
166// ================================================================
167/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
168pub trait Prompt: Send + Sync {
169    /// Send a simple prompt to the underlying completion model.
170    ///
171    /// If the completion model's response is a message, then it is returned as a string.
172    ///
173    /// If the completion model's response is a tool call, then the tool is called and
174    /// the result is returned as a string.
175    ///
176    /// If the tool does not exist, or the tool call fails, then an error is returned.
177    fn prompt(
178        &self,
179        prompt: impl Into<Message> + Send,
180    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
181}
182
183/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
184pub trait Chat: Send + Sync {
185    /// Send a prompt with optional chat history to the underlying completion model.
186    ///
187    /// If the completion model's response is a message, then it is returned as a string.
188    ///
189    /// If the completion model's response is a tool call, then the tool is called and the result
190    /// is returned as a string.
191    ///
192    /// If the tool does not exist, or the tool call fails, then an error is returned.
193    fn chat(
194        &self,
195        prompt: impl Into<Message> + Send,
196        chat_history: Vec<Message>,
197    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
198}
199
200/// Trait defining a low-level LLM completion interface
201pub trait Completion<M: CompletionModel> {
202    /// Generates a completion request builder for the given `prompt` and `chat_history`.
203    /// This function is meant to be called by the user to further customize the
204    /// request at prompt time before sending it.
205    ///
206    /// ❗IMPORTANT: The type that implements this trait might have already
207    /// populated fields in the builder (the exact fields depend on the type).
208    /// For fields that have already been set by the model, calling the corresponding
209    /// method on the builder will overwrite the value set by the model.
210    ///
211    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
212    /// contain the `preamble` provided when creating the agent.
213    fn completion(
214        &self,
215        prompt: impl Into<Message> + Send,
216        chat_history: Vec<Message>,
217    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
218}
219
220/// General completion response struct that contains the high-level completion choice
221/// and the raw response. The completion choice contains one or more assistant content.
222#[derive(Debug)]
223pub struct CompletionResponse<T> {
224    /// The completion choice (represented by one or more assistant message content)
225    /// returned by the completion model provider
226    pub choice: OneOrMany<AssistantContent>,
227    /// The raw response returned by the completion model provider
228    pub raw_response: T,
229}
230
231/// Trait defining a completion model that can be used to generate completion responses.
232/// This trait is meant to be implemented by the user to define a custom completion model,
233/// either from a third party provider (e.g.: OpenAI) or a local model.
234pub trait CompletionModel: Clone + Send + Sync {
235    /// The raw response type returned by the underlying completion model.
236    type Response: Send + Sync;
237    /// The raw response type returned by the underlying completion model when streaming.
238    type StreamingResponse: Clone + Unpin + Send + Sync;
239
240    /// Generates a completion response for the given completion request.
241    fn completion(
242        &self,
243        request: CompletionRequest,
244    ) -> impl std::future::Future<
245        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
246    > + Send;
247
248    fn stream(
249        &self,
250        request: CompletionRequest,
251    ) -> impl std::future::Future<
252        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
253    > + Send;
254
255    /// Generates a completion request builder for the given `prompt`.
256    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
257        CompletionRequestBuilder::new(self.clone(), prompt)
258    }
259}
260pub trait CompletionModelDyn: Send + Sync {
261    fn completion(
262        &self,
263        request: CompletionRequest,
264    ) -> BoxFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
265
266    fn stream(
267        &self,
268        request: CompletionRequest,
269    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>>;
270
271    fn completion_request(
272        &self,
273        prompt: Message,
274    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
275}
276
277impl<T, R> CompletionModelDyn for T
278where
279    T: CompletionModel<StreamingResponse = R>,
280    R: Clone + Unpin + 'static,
281{
282    fn completion(
283        &self,
284        request: CompletionRequest,
285    ) -> BoxFuture<Result<CompletionResponse<()>, CompletionError>> {
286        Box::pin(async move {
287            self.completion(request)
288                .await
289                .map(|resp| CompletionResponse {
290                    choice: resp.choice,
291                    raw_response: (),
292                })
293        })
294    }
295
296    fn stream(
297        &self,
298        request: CompletionRequest,
299    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>> {
300        Box::pin(async move {
301            let resp = self.stream(request).await?;
302            let inner = resp.inner;
303
304            let stream = Box::pin(streaming::StreamingResultDyn {
305                inner: Box::pin(inner),
306            });
307
308            Ok(StreamingCompletionResponse::stream(stream))
309        })
310    }
311
312    /// Generates a completion request builder for the given `prompt`.
313    fn completion_request(
314        &self,
315        prompt: Message,
316    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
317        CompletionRequestBuilder::new(
318            CompletionModelHandle {
319                inner: Arc::new(self.clone()),
320            },
321            prompt,
322        )
323    }
324}
325
326/// Struct representing a general completion request that can be sent to a completion model provider.
327#[derive(Debug, Clone)]
328pub struct CompletionRequest {
329    /// The preamble to be sent to the completion model provider
330    pub preamble: Option<String>,
331    /// The chat history to be sent to the completion model provider
332    /// The very last message will always be the prompt (hense why there is *always* one)
333    pub chat_history: OneOrMany<Message>,
334    /// The documents to be sent to the completion model provider
335    pub documents: Vec<Document>,
336    /// The tools to be sent to the completion model provider
337    pub tools: Vec<ToolDefinition>,
338    /// The temperature to be sent to the completion model provider
339    pub temperature: Option<f64>,
340    /// The max tokens to be sent to the completion model provider
341    pub max_tokens: Option<u64>,
342    /// Additional provider-specific parameters to be sent to the completion model provider
343    pub additional_params: Option<serde_json::Value>,
344}
345
346impl CompletionRequest {
347    /// Returns documents normalized into a message (if any).
348    /// Most providers do not accept documents directly as input, so it needs to convert into a
349    ///  `Message` so that it can be incorporated into `chat_history` as a
350    pub fn normalized_documents(&self) -> Option<Message> {
351        if self.documents.is_empty() {
352            return None;
353        }
354
355        // Most providers will convert documents into a text unless it can handle document messages.
356        // We use `UserContent::document` for those who handle it directly!
357        let messages = self
358            .documents
359            .iter()
360            .map(|doc| {
361                UserContent::document(
362                    doc.to_string(),
363                    // In the future, we can customize `Document` to pass these extra types through.
364                    // Most providers ditch these but they might want to use them.
365                    Some(ContentFormat::String),
366                    Some(DocumentMediaType::TXT),
367                )
368            })
369            .collect::<Vec<_>>();
370
371        Some(Message::User {
372            content: OneOrMany::many(messages).expect("There will be atleast one document"),
373        })
374    }
375}
376
377/// Builder struct for constructing a completion request.
378///
379/// Example usage:
380/// ```rust
381/// use rig::{
382///     providers::openai::{Client, self},
383///     completion::CompletionRequestBuilder,
384/// };
385///
386/// let openai = Client::new("your-openai-api-key");
387/// let model = openai.completion_model(openai::GPT_4O).build();
388///
389/// // Create the completion request and execute it separately
390/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
391///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
392///     .temperature(0.5)
393///     .build();
394///
395/// let response = model.completion(request)
396///     .await
397///     .expect("Failed to get completion response");
398/// ```
399///
400/// Alternatively, you can execute the completion request directly from the builder:
401/// ```rust
402/// use rig::{
403///     providers::openai::{Client, self},
404///     completion::CompletionRequestBuilder,
405/// };
406///
407/// let openai = Client::new("your-openai-api-key");
408/// let model = openai.completion_model(openai::GPT_4O).build();
409///
410/// // Create the completion request and execute it directly
411/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
412///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
413///     .temperature(0.5)
414///     .send()
415///     .await
416///     .expect("Failed to get completion response");
417/// ```
418///
419/// Note: It is usually unnecessary to create a completion request builder directly.
420/// Instead, use the [CompletionModel::completion_request] method.
421pub struct CompletionRequestBuilder<M: CompletionModel> {
422    model: M,
423    prompt: Message,
424    preamble: Option<String>,
425    chat_history: Vec<Message>,
426    documents: Vec<Document>,
427    tools: Vec<ToolDefinition>,
428    temperature: Option<f64>,
429    max_tokens: Option<u64>,
430    additional_params: Option<serde_json::Value>,
431}
432
433impl<M: CompletionModel> CompletionRequestBuilder<M> {
434    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
435        Self {
436            model,
437            prompt: prompt.into(),
438            preamble: None,
439            chat_history: Vec::new(),
440            documents: Vec::new(),
441            tools: Vec::new(),
442            temperature: None,
443            max_tokens: None,
444            additional_params: None,
445        }
446    }
447
448    /// Sets the preamble for the completion request.
449    pub fn preamble(mut self, preamble: String) -> Self {
450        self.preamble = Some(preamble);
451        self
452    }
453
454    /// Adds a message to the chat history for the completion request.
455    pub fn message(mut self, message: Message) -> Self {
456        self.chat_history.push(message);
457        self
458    }
459
460    /// Adds a list of messages to the chat history for the completion request.
461    pub fn messages(self, messages: Vec<Message>) -> Self {
462        messages
463            .into_iter()
464            .fold(self, |builder, msg| builder.message(msg))
465    }
466
467    /// Adds a document to the completion request.
468    pub fn document(mut self, document: Document) -> Self {
469        self.documents.push(document);
470        self
471    }
472
473    /// Adds a list of documents to the completion request.
474    pub fn documents(self, documents: Vec<Document>) -> Self {
475        documents
476            .into_iter()
477            .fold(self, |builder, doc| builder.document(doc))
478    }
479
480    /// Adds a tool to the completion request.
481    pub fn tool(mut self, tool: ToolDefinition) -> Self {
482        self.tools.push(tool);
483        self
484    }
485
486    /// Adds a list of tools to the completion request.
487    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
488        tools
489            .into_iter()
490            .fold(self, |builder, tool| builder.tool(tool))
491    }
492
493    /// Adds additional parameters to the completion request.
494    /// This can be used to set additional provider-specific parameters. For example,
495    /// Cohere's completion models accept a `connectors` parameter that can be used to
496    /// specify the data connectors used by Cohere when executing the completion
497    /// (see `examples/cohere_connectors.rs`).
498    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
499        match self.additional_params {
500            Some(params) => {
501                self.additional_params = Some(json_utils::merge(params, additional_params));
502            }
503            None => {
504                self.additional_params = Some(additional_params);
505            }
506        }
507        self
508    }
509
510    /// Sets the additional parameters for the completion request.
511    /// This can be used to set additional provider-specific parameters. For example,
512    /// Cohere's completion models accept a `connectors` parameter that can be used to
513    /// specify the data connectors used by Cohere when executing the completion
514    /// (see `examples/cohere_connectors.rs`).
515    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
516        self.additional_params = additional_params;
517        self
518    }
519
520    /// Sets the temperature for the completion request.
521    pub fn temperature(mut self, temperature: f64) -> Self {
522        self.temperature = Some(temperature);
523        self
524    }
525
526    /// Sets the temperature for the completion request.
527    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
528        self.temperature = temperature;
529        self
530    }
531
532    /// Sets the max tokens for the completion request.
533    /// Note: This is required if using Anthropic
534    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
535        self.max_tokens = Some(max_tokens);
536        self
537    }
538
539    /// Sets the max tokens for the completion request.
540    /// Note: This is required if using Anthropic
541    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
542        self.max_tokens = max_tokens;
543        self
544    }
545
546    /// Builds the completion request.
547    pub fn build(self) -> CompletionRequest {
548        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
549            .expect("There will always be atleast the prompt");
550
551        CompletionRequest {
552            preamble: self.preamble,
553            chat_history,
554            documents: self.documents,
555            tools: self.tools,
556            temperature: self.temperature,
557            max_tokens: self.max_tokens,
558            additional_params: self.additional_params,
559        }
560    }
561
562    /// Sends the completion request to the completion model provider and returns the completion response.
563    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
564        let model = self.model.clone();
565        model.completion(self.build()).await
566    }
567
568    /// Stream the completion request
569    pub async fn stream<'a>(
570        self,
571    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
572    where
573        <M as CompletionModel>::StreamingResponse: 'a,
574        Self: 'a,
575    {
576        let model = self.model.clone();
577        model.stream(self.build()).await
578    }
579}
580
581#[cfg(test)]
582mod tests {
583
584    use super::*;
585
586    #[test]
587    fn test_document_display_without_metadata() {
588        let doc = Document {
589            id: "123".to_string(),
590            text: "This is a test document.".to_string(),
591            additional_props: HashMap::new(),
592        };
593
594        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
595        assert_eq!(format!("{doc}"), expected);
596    }
597
598    #[test]
599    fn test_document_display_with_metadata() {
600        let mut additional_props = HashMap::new();
601        additional_props.insert("author".to_string(), "John Doe".to_string());
602        additional_props.insert("length".to_string(), "42".to_string());
603
604        let doc = Document {
605            id: "123".to_string(),
606            text: "This is a test document.".to_string(),
607            additional_props,
608        };
609
610        let expected = concat!(
611            "<file id: 123>\n",
612            "<metadata author: \"John Doe\" length: \"42\" />\n",
613            "This is a test document.\n",
614            "</file>\n"
615        );
616        assert_eq!(format!("{doc}"), expected);
617    }
618
619    #[test]
620    fn test_normalize_documents_with_documents() {
621        let doc1 = Document {
622            id: "doc1".to_string(),
623            text: "Document 1 text.".to_string(),
624            additional_props: HashMap::new(),
625        };
626
627        let doc2 = Document {
628            id: "doc2".to_string(),
629            text: "Document 2 text.".to_string(),
630            additional_props: HashMap::new(),
631        };
632
633        let request = CompletionRequest {
634            preamble: None,
635            chat_history: OneOrMany::one("What is the capital of France?".into()),
636            documents: vec![doc1, doc2],
637            tools: Vec::new(),
638            temperature: None,
639            max_tokens: None,
640            additional_params: None,
641        };
642
643        let expected = Message::User {
644            content: OneOrMany::many(vec![
645                UserContent::document(
646                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
647                    Some(ContentFormat::String),
648                    Some(DocumentMediaType::TXT),
649                ),
650                UserContent::document(
651                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
652                    Some(ContentFormat::String),
653                    Some(DocumentMediaType::TXT),
654                ),
655            ])
656            .expect("There will be at least one document"),
657        };
658
659        assert_eq!(request.normalized_documents(), Some(expected));
660    }
661
662    #[test]
663    fn test_normalize_documents_without_documents() {
664        let request = CompletionRequest {
665            preamble: None,
666            chat_history: OneOrMany::one("What is the capital of France?".into()),
667            documents: Vec::new(),
668            tools: Vec::new(),
669            temperature: None,
670            max_tokens: None,
671            additional_params: None,
672        };
673
674        assert_eq!(request.normalized_documents(), None);
675    }
676}