rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
67use crate::client::completion::CompletionModelHandle;
68use crate::streaming::StreamingCompletionResponse;
69use crate::{OneOrMany, streaming};
70use crate::{
71    json_utils,
72    message::{Message, UserContent},
73    tool::ToolSetError,
74};
75use futures::future::BoxFuture;
76use serde::{Deserialize, Serialize};
77use std::collections::HashMap;
78use std::ops::{Add, AddAssign};
79use std::sync::Arc;
80use thiserror::Error;
81
82// Errors
83#[derive(Debug, Error)]
84pub enum CompletionError {
85    /// Http error (e.g.: connection error, timeout, etc.)
86    #[error("HttpError: {0}")]
87    HttpError(#[from] reqwest::Error),
88
89    /// Json error (e.g.: serialization, deserialization)
90    #[error("JsonError: {0}")]
91    JsonError(#[from] serde_json::Error),
92
93    /// Error building the completion request
94    #[error("RequestError: {0}")]
95    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
96
97    /// Error parsing the completion response
98    #[error("ResponseError: {0}")]
99    ResponseError(String),
100
101    /// Error returned by the completion model provider
102    #[error("ProviderError: {0}")]
103    ProviderError(String),
104}
105
106/// Prompt errors
107#[derive(Debug, Error)]
108pub enum PromptError {
109    /// Something went wrong with the completion
110    #[error("CompletionError: {0}")]
111    CompletionError(#[from] CompletionError),
112
113    /// There was an error while using a tool
114    #[error("ToolCallError: {0}")]
115    ToolError(#[from] ToolSetError),
116
117    /// The LLM tried to call too many tools during a multi-turn conversation.
118    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
119    /// or increase the amount of turns given in `.multi_turn()`.
120    #[error("MaxDepthError: (reached limit: {max_depth})")]
121    MaxDepthError {
122        max_depth: usize,
123        chat_history: Vec<Message>,
124        prompt: Message,
125    },
126}
127
128#[derive(Clone, Debug, Deserialize, Serialize)]
129pub struct Document {
130    pub id: String,
131    pub text: String,
132    #[serde(flatten)]
133    pub additional_props: HashMap<String, String>,
134}
135
136impl std::fmt::Display for Document {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(
139            f,
140            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
141            self.id,
142            if self.additional_props.is_empty() {
143                self.text.clone()
144            } else {
145                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
146                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
147                let metadata = sorted_props
148                    .iter()
149                    .map(|(k, v)| format!("{k}: {v:?}"))
150                    .collect::<Vec<_>>()
151                    .join(" ");
152                format!("<metadata {} />\n{}", metadata, self.text)
153            }
154        )
155    }
156}
157
158#[derive(Clone, Debug, Deserialize, Serialize)]
159pub struct ToolDefinition {
160    pub name: String,
161    pub description: String,
162    pub parameters: serde_json::Value,
163}
164
165// ================================================================
166// Implementations
167// ================================================================
168/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
169pub trait Prompt: Send + Sync {
170    /// Send a simple prompt to the underlying completion model.
171    ///
172    /// If the completion model's response is a message, then it is returned as a string.
173    ///
174    /// If the completion model's response is a tool call, then the tool is called and
175    /// the result is returned as a string.
176    ///
177    /// If the tool does not exist, or the tool call fails, then an error is returned.
178    fn prompt(
179        &self,
180        prompt: impl Into<Message> + Send,
181    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
182}
183
184/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
185pub trait Chat: Send + Sync {
186    /// Send a prompt with optional chat history to the underlying completion model.
187    ///
188    /// If the completion model's response is a message, then it is returned as a string.
189    ///
190    /// If the completion model's response is a tool call, then the tool is called and the result
191    /// is returned as a string.
192    ///
193    /// If the tool does not exist, or the tool call fails, then an error is returned.
194    fn chat(
195        &self,
196        prompt: impl Into<Message> + Send,
197        chat_history: Vec<Message>,
198    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
199}
200
201/// Trait defining a low-level LLM completion interface
202pub trait Completion<M: CompletionModel> {
203    /// Generates a completion request builder for the given `prompt` and `chat_history`.
204    /// This function is meant to be called by the user to further customize the
205    /// request at prompt time before sending it.
206    ///
207    /// ❗IMPORTANT: The type that implements this trait might have already
208    /// populated fields in the builder (the exact fields depend on the type).
209    /// For fields that have already been set by the model, calling the corresponding
210    /// method on the builder will overwrite the value set by the model.
211    ///
212    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
213    /// contain the `preamble` provided when creating the agent.
214    fn completion(
215        &self,
216        prompt: impl Into<Message> + Send,
217        chat_history: Vec<Message>,
218    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
219}
220
221/// General completion response struct that contains the high-level completion choice
222/// and the raw response. The completion choice contains one or more assistant content.
223#[derive(Debug)]
224pub struct CompletionResponse<T> {
225    /// The completion choice (represented by one or more assistant message content)
226    /// returned by the completion model provider
227    pub choice: OneOrMany<AssistantContent>,
228    /// Tokens used during prompting and responding
229    pub usage: Usage,
230    /// The raw response returned by the completion model provider
231    pub raw_response: T,
232}
233
234/// Struct representing the token usage for a completion request.
235/// If tokens used are `0`, then the provider failed to supply token usage metrics.
236#[derive(Debug, PartialEq, Eq, Clone, Copy)]
237pub struct Usage {
238    pub input_tokens: u64,
239    pub output_tokens: u64,
240    // We store this separately as some providers may only report one number
241    pub total_tokens: u64,
242}
243
244impl Usage {
245    pub fn new() -> Self {
246        Self {
247            input_tokens: 0,
248            output_tokens: 0,
249            total_tokens: 0,
250        }
251    }
252}
253
254impl Default for Usage {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl Add for Usage {
261    type Output = Self;
262
263    fn add(self, other: Self) -> Self::Output {
264        Self {
265            input_tokens: self.input_tokens + other.input_tokens,
266            output_tokens: self.output_tokens + other.output_tokens,
267            total_tokens: self.total_tokens + other.total_tokens,
268        }
269    }
270}
271
272impl AddAssign for Usage {
273    fn add_assign(&mut self, other: Self) {
274        self.input_tokens += other.input_tokens;
275        self.output_tokens += other.output_tokens;
276        self.total_tokens += other.total_tokens;
277    }
278}
279
280/// Trait defining a completion model that can be used to generate completion responses.
281/// This trait is meant to be implemented by the user to define a custom completion model,
282/// either from a third party provider (e.g.: OpenAI) or a local model.
283pub trait CompletionModel: Clone + Send + Sync {
284    /// The raw response type returned by the underlying completion model.
285    type Response: Send + Sync;
286    /// The raw response type returned by the underlying completion model when streaming.
287    type StreamingResponse: Clone + Unpin + Send + Sync;
288
289    /// Generates a completion response for the given completion request.
290    fn completion(
291        &self,
292        request: CompletionRequest,
293    ) -> impl std::future::Future<
294        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
295    > + Send;
296
297    fn stream(
298        &self,
299        request: CompletionRequest,
300    ) -> impl std::future::Future<
301        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
302    > + Send;
303
304    /// Generates a completion request builder for the given `prompt`.
305    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
306        CompletionRequestBuilder::new(self.clone(), prompt)
307    }
308}
309pub trait CompletionModelDyn: Send + Sync {
310    fn completion(
311        &self,
312        request: CompletionRequest,
313    ) -> BoxFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
314
315    fn stream(
316        &self,
317        request: CompletionRequest,
318    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>>;
319
320    fn completion_request(
321        &self,
322        prompt: Message,
323    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
324}
325
326impl<T, R> CompletionModelDyn for T
327where
328    T: CompletionModel<StreamingResponse = R>,
329    R: Clone + Unpin + 'static,
330{
331    fn completion(
332        &self,
333        request: CompletionRequest,
334    ) -> BoxFuture<Result<CompletionResponse<()>, CompletionError>> {
335        Box::pin(async move {
336            self.completion(request)
337                .await
338                .map(|resp| CompletionResponse {
339                    choice: resp.choice,
340                    usage: resp.usage,
341                    raw_response: (),
342                })
343        })
344    }
345
346    fn stream(
347        &self,
348        request: CompletionRequest,
349    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>> {
350        Box::pin(async move {
351            let resp = self.stream(request).await?;
352            let inner = resp.inner;
353
354            let stream = Box::pin(streaming::StreamingResultDyn {
355                inner: Box::pin(inner),
356            });
357
358            Ok(StreamingCompletionResponse::stream(stream))
359        })
360    }
361
362    /// Generates a completion request builder for the given `prompt`.
363    fn completion_request(
364        &self,
365        prompt: Message,
366    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
367        CompletionRequestBuilder::new(
368            CompletionModelHandle {
369                inner: Arc::new(self.clone()),
370            },
371            prompt,
372        )
373    }
374}
375
376/// Struct representing a general completion request that can be sent to a completion model provider.
377#[derive(Debug, Clone)]
378pub struct CompletionRequest {
379    /// The preamble to be sent to the completion model provider
380    pub preamble: Option<String>,
381    /// The chat history to be sent to the completion model provider
382    /// The very last message will always be the prompt (hense why there is *always* one)
383    pub chat_history: OneOrMany<Message>,
384    /// The documents to be sent to the completion model provider
385    pub documents: Vec<Document>,
386    /// The tools to be sent to the completion model provider
387    pub tools: Vec<ToolDefinition>,
388    /// The temperature to be sent to the completion model provider
389    pub temperature: Option<f64>,
390    /// The max tokens to be sent to the completion model provider
391    pub max_tokens: Option<u64>,
392    /// Additional provider-specific parameters to be sent to the completion model provider
393    pub additional_params: Option<serde_json::Value>,
394}
395
396impl CompletionRequest {
397    /// Returns documents normalized into a message (if any).
398    /// Most providers do not accept documents directly as input, so it needs to convert into a
399    ///  `Message` so that it can be incorporated into `chat_history` as a
400    pub fn normalized_documents(&self) -> Option<Message> {
401        if self.documents.is_empty() {
402            return None;
403        }
404
405        // Most providers will convert documents into a text unless it can handle document messages.
406        // We use `UserContent::document` for those who handle it directly!
407        let messages = self
408            .documents
409            .iter()
410            .map(|doc| {
411                UserContent::document(
412                    doc.to_string(),
413                    // In the future, we can customize `Document` to pass these extra types through.
414                    // Most providers ditch these but they might want to use them.
415                    Some(ContentFormat::String),
416                    Some(DocumentMediaType::TXT),
417                )
418            })
419            .collect::<Vec<_>>();
420
421        Some(Message::User {
422            content: OneOrMany::many(messages).expect("There will be atleast one document"),
423        })
424    }
425}
426
427/// Builder struct for constructing a completion request.
428///
429/// Example usage:
430/// ```rust
431/// use rig::{
432///     providers::openai::{Client, self},
433///     completion::CompletionRequestBuilder,
434/// };
435///
436/// let openai = Client::new("your-openai-api-key");
437/// let model = openai.completion_model(openai::GPT_4O).build();
438///
439/// // Create the completion request and execute it separately
440/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
441///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
442///     .temperature(0.5)
443///     .build();
444///
445/// let response = model.completion(request)
446///     .await
447///     .expect("Failed to get completion response");
448/// ```
449///
450/// Alternatively, you can execute the completion request directly from the builder:
451/// ```rust
452/// use rig::{
453///     providers::openai::{Client, self},
454///     completion::CompletionRequestBuilder,
455/// };
456///
457/// let openai = Client::new("your-openai-api-key");
458/// let model = openai.completion_model(openai::GPT_4O).build();
459///
460/// // Create the completion request and execute it directly
461/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
462///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
463///     .temperature(0.5)
464///     .send()
465///     .await
466///     .expect("Failed to get completion response");
467/// ```
468///
469/// Note: It is usually unnecessary to create a completion request builder directly.
470/// Instead, use the [CompletionModel::completion_request] method.
471pub struct CompletionRequestBuilder<M: CompletionModel> {
472    model: M,
473    prompt: Message,
474    preamble: Option<String>,
475    chat_history: Vec<Message>,
476    documents: Vec<Document>,
477    tools: Vec<ToolDefinition>,
478    temperature: Option<f64>,
479    max_tokens: Option<u64>,
480    additional_params: Option<serde_json::Value>,
481}
482
483impl<M: CompletionModel> CompletionRequestBuilder<M> {
484    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
485        Self {
486            model,
487            prompt: prompt.into(),
488            preamble: None,
489            chat_history: Vec::new(),
490            documents: Vec::new(),
491            tools: Vec::new(),
492            temperature: None,
493            max_tokens: None,
494            additional_params: None,
495        }
496    }
497
498    /// Sets the preamble for the completion request.
499    pub fn preamble(mut self, preamble: String) -> Self {
500        self.preamble = Some(preamble);
501        self
502    }
503
504    /// Adds a message to the chat history for the completion request.
505    pub fn message(mut self, message: Message) -> Self {
506        self.chat_history.push(message);
507        self
508    }
509
510    /// Adds a list of messages to the chat history for the completion request.
511    pub fn messages(self, messages: Vec<Message>) -> Self {
512        messages
513            .into_iter()
514            .fold(self, |builder, msg| builder.message(msg))
515    }
516
517    /// Adds a document to the completion request.
518    pub fn document(mut self, document: Document) -> Self {
519        self.documents.push(document);
520        self
521    }
522
523    /// Adds a list of documents to the completion request.
524    pub fn documents(self, documents: Vec<Document>) -> Self {
525        documents
526            .into_iter()
527            .fold(self, |builder, doc| builder.document(doc))
528    }
529
530    /// Adds a tool to the completion request.
531    pub fn tool(mut self, tool: ToolDefinition) -> Self {
532        self.tools.push(tool);
533        self
534    }
535
536    /// Adds a list of tools to the completion request.
537    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
538        tools
539            .into_iter()
540            .fold(self, |builder, tool| builder.tool(tool))
541    }
542
543    /// Adds additional parameters to the completion request.
544    /// This can be used to set additional provider-specific parameters. For example,
545    /// Cohere's completion models accept a `connectors` parameter that can be used to
546    /// specify the data connectors used by Cohere when executing the completion
547    /// (see `examples/cohere_connectors.rs`).
548    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
549        match self.additional_params {
550            Some(params) => {
551                self.additional_params = Some(json_utils::merge(params, additional_params));
552            }
553            None => {
554                self.additional_params = Some(additional_params);
555            }
556        }
557        self
558    }
559
560    /// Sets the additional parameters for the completion request.
561    /// This can be used to set additional provider-specific parameters. For example,
562    /// Cohere's completion models accept a `connectors` parameter that can be used to
563    /// specify the data connectors used by Cohere when executing the completion
564    /// (see `examples/cohere_connectors.rs`).
565    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
566        self.additional_params = additional_params;
567        self
568    }
569
570    /// Sets the temperature for the completion request.
571    pub fn temperature(mut self, temperature: f64) -> Self {
572        self.temperature = Some(temperature);
573        self
574    }
575
576    /// Sets the temperature for the completion request.
577    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
578        self.temperature = temperature;
579        self
580    }
581
582    /// Sets the max tokens for the completion request.
583    /// Note: This is required if using Anthropic
584    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
585        self.max_tokens = Some(max_tokens);
586        self
587    }
588
589    /// Sets the max tokens for the completion request.
590    /// Note: This is required if using Anthropic
591    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
592        self.max_tokens = max_tokens;
593        self
594    }
595
596    /// Builds the completion request.
597    pub fn build(self) -> CompletionRequest {
598        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
599            .expect("There will always be atleast the prompt");
600
601        CompletionRequest {
602            preamble: self.preamble,
603            chat_history,
604            documents: self.documents,
605            tools: self.tools,
606            temperature: self.temperature,
607            max_tokens: self.max_tokens,
608            additional_params: self.additional_params,
609        }
610    }
611
612    /// Sends the completion request to the completion model provider and returns the completion response.
613    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
614        let model = self.model.clone();
615        model.completion(self.build()).await
616    }
617
618    /// Stream the completion request
619    pub async fn stream<'a>(
620        self,
621    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
622    where
623        <M as CompletionModel>::StreamingResponse: 'a,
624        Self: 'a,
625    {
626        let model = self.model.clone();
627        model.stream(self.build()).await
628    }
629}
630
631#[cfg(test)]
632mod tests {
633
634    use super::*;
635
636    #[test]
637    fn test_document_display_without_metadata() {
638        let doc = Document {
639            id: "123".to_string(),
640            text: "This is a test document.".to_string(),
641            additional_props: HashMap::new(),
642        };
643
644        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
645        assert_eq!(format!("{doc}"), expected);
646    }
647
648    #[test]
649    fn test_document_display_with_metadata() {
650        let mut additional_props = HashMap::new();
651        additional_props.insert("author".to_string(), "John Doe".to_string());
652        additional_props.insert("length".to_string(), "42".to_string());
653
654        let doc = Document {
655            id: "123".to_string(),
656            text: "This is a test document.".to_string(),
657            additional_props,
658        };
659
660        let expected = concat!(
661            "<file id: 123>\n",
662            "<metadata author: \"John Doe\" length: \"42\" />\n",
663            "This is a test document.\n",
664            "</file>\n"
665        );
666        assert_eq!(format!("{doc}"), expected);
667    }
668
669    #[test]
670    fn test_normalize_documents_with_documents() {
671        let doc1 = Document {
672            id: "doc1".to_string(),
673            text: "Document 1 text.".to_string(),
674            additional_props: HashMap::new(),
675        };
676
677        let doc2 = Document {
678            id: "doc2".to_string(),
679            text: "Document 2 text.".to_string(),
680            additional_props: HashMap::new(),
681        };
682
683        let request = CompletionRequest {
684            preamble: None,
685            chat_history: OneOrMany::one("What is the capital of France?".into()),
686            documents: vec![doc1, doc2],
687            tools: Vec::new(),
688            temperature: None,
689            max_tokens: None,
690            additional_params: None,
691        };
692
693        let expected = Message::User {
694            content: OneOrMany::many(vec![
695                UserContent::document(
696                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
697                    Some(ContentFormat::String),
698                    Some(DocumentMediaType::TXT),
699                ),
700                UserContent::document(
701                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
702                    Some(ContentFormat::String),
703                    Some(DocumentMediaType::TXT),
704                ),
705            ])
706            .expect("There will be at least one document"),
707        };
708
709        assert_eq!(request.normalized_documents(), Some(expected));
710    }
711
712    #[test]
713    fn test_normalize_documents_without_documents() {
714        let request = CompletionRequest {
715            preamble: None,
716            chat_history: OneOrMany::one("What is the capital of France?".into()),
717            documents: Vec::new(),
718            tools: Vec::new(),
719            temperature: None,
720            max_tokens: None,
721            additional_params: None,
722        };
723
724        assert_eq!(request.normalized_documents(), None);
725    }
726}