rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
67use crate::client::completion::CompletionModelHandle;
68use crate::streaming::StreamingCompletionResponse;
69use crate::{OneOrMany, streaming};
70use crate::{
71    json_utils,
72    message::{Message, UserContent},
73    tool::ToolSetError,
74};
75use futures::future::BoxFuture;
76use serde::de::DeserializeOwned;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use std::ops::{Add, AddAssign};
80use std::sync::Arc;
81use thiserror::Error;
82
83// Errors
84#[derive(Debug, Error)]
85pub enum CompletionError {
86    /// Http error (e.g.: connection error, timeout, etc.)
87    #[error("HttpError: {0}")]
88    HttpError(#[from] reqwest::Error),
89
90    /// Json error (e.g.: serialization, deserialization)
91    #[error("JsonError: {0}")]
92    JsonError(#[from] serde_json::Error),
93
94    /// Error building the completion request
95    #[error("RequestError: {0}")]
96    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
97
98    /// Error parsing the completion response
99    #[error("ResponseError: {0}")]
100    ResponseError(String),
101
102    /// Error returned by the completion model provider
103    #[error("ProviderError: {0}")]
104    ProviderError(String),
105}
106
107/// Prompt errors
108#[derive(Debug, Error)]
109pub enum PromptError {
110    /// Something went wrong with the completion
111    #[error("CompletionError: {0}")]
112    CompletionError(#[from] CompletionError),
113
114    /// There was an error while using a tool
115    #[error("ToolCallError: {0}")]
116    ToolError(#[from] ToolSetError),
117
118    /// The LLM tried to call too many tools during a multi-turn conversation.
119    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
120    /// or increase the amount of turns given in `.multi_turn()`.
121    #[error("MaxDepthError: (reached limit: {max_depth})")]
122    MaxDepthError {
123        max_depth: usize,
124        chat_history: Vec<Message>,
125        prompt: Message,
126    },
127}
128
129#[derive(Clone, Debug, Deserialize, Serialize)]
130pub struct Document {
131    pub id: String,
132    pub text: String,
133    #[serde(flatten)]
134    pub additional_props: HashMap<String, String>,
135}
136
137impl std::fmt::Display for Document {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        write!(
140            f,
141            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
142            self.id,
143            if self.additional_props.is_empty() {
144                self.text.clone()
145            } else {
146                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
147                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
148                let metadata = sorted_props
149                    .iter()
150                    .map(|(k, v)| format!("{k}: {v:?}"))
151                    .collect::<Vec<_>>()
152                    .join(" ");
153                format!("<metadata {} />\n{}", metadata, self.text)
154            }
155        )
156    }
157}
158
159#[derive(Clone, Debug, Deserialize, Serialize)]
160pub struct ToolDefinition {
161    pub name: String,
162    pub description: String,
163    pub parameters: serde_json::Value,
164}
165
166// ================================================================
167// Implementations
168// ================================================================
169/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
170pub trait Prompt: Send + Sync {
171    /// Send a simple prompt to the underlying completion model.
172    ///
173    /// If the completion model's response is a message, then it is returned as a string.
174    ///
175    /// If the completion model's response is a tool call, then the tool is called and
176    /// the result is returned as a string.
177    ///
178    /// If the tool does not exist, or the tool call fails, then an error is returned.
179    fn prompt(
180        &self,
181        prompt: impl Into<Message> + Send,
182    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
183}
184
185/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
186pub trait Chat: Send + Sync {
187    /// Send a prompt with optional chat history to the underlying completion model.
188    ///
189    /// If the completion model's response is a message, then it is returned as a string.
190    ///
191    /// If the completion model's response is a tool call, then the tool is called and the result
192    /// is returned as a string.
193    ///
194    /// If the tool does not exist, or the tool call fails, then an error is returned.
195    fn chat(
196        &self,
197        prompt: impl Into<Message> + Send,
198        chat_history: Vec<Message>,
199    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
200}
201
202/// Trait defining a low-level LLM completion interface
203pub trait Completion<M: CompletionModel> {
204    /// Generates a completion request builder for the given `prompt` and `chat_history`.
205    /// This function is meant to be called by the user to further customize the
206    /// request at prompt time before sending it.
207    ///
208    /// ❗IMPORTANT: The type that implements this trait might have already
209    /// populated fields in the builder (the exact fields depend on the type).
210    /// For fields that have already been set by the model, calling the corresponding
211    /// method on the builder will overwrite the value set by the model.
212    ///
213    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
214    /// contain the `preamble` provided when creating the agent.
215    fn completion(
216        &self,
217        prompt: impl Into<Message> + Send,
218        chat_history: Vec<Message>,
219    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
220}
221
222/// General completion response struct that contains the high-level completion choice
223/// and the raw response. The completion choice contains one or more assistant content.
224#[derive(Debug)]
225pub struct CompletionResponse<T> {
226    /// The completion choice (represented by one or more assistant message content)
227    /// returned by the completion model provider
228    pub choice: OneOrMany<AssistantContent>,
229    /// Tokens used during prompting and responding
230    pub usage: Usage,
231    /// The raw response returned by the completion model provider
232    pub raw_response: T,
233}
234
235/// Struct representing the token usage for a completion request.
236/// If tokens used are `0`, then the provider failed to supply token usage metrics.
237#[derive(Debug, PartialEq, Eq, Clone, Copy)]
238pub struct Usage {
239    pub input_tokens: u64,
240    pub output_tokens: u64,
241    // We store this separately as some providers may only report one number
242    pub total_tokens: u64,
243}
244
245impl Usage {
246    pub fn new() -> Self {
247        Self {
248            input_tokens: 0,
249            output_tokens: 0,
250            total_tokens: 0,
251        }
252    }
253}
254
255impl Default for Usage {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261impl Add for Usage {
262    type Output = Self;
263
264    fn add(self, other: Self) -> Self::Output {
265        Self {
266            input_tokens: self.input_tokens + other.input_tokens,
267            output_tokens: self.output_tokens + other.output_tokens,
268            total_tokens: self.total_tokens + other.total_tokens,
269        }
270    }
271}
272
273impl AddAssign for Usage {
274    fn add_assign(&mut self, other: Self) {
275        self.input_tokens += other.input_tokens;
276        self.output_tokens += other.output_tokens;
277        self.total_tokens += other.total_tokens;
278    }
279}
280
281/// Trait defining a completion model that can be used to generate completion responses.
282/// This trait is meant to be implemented by the user to define a custom completion model,
283/// either from a third party provider (e.g.: OpenAI) or a local model.
284pub trait CompletionModel: Clone + Send + Sync {
285    /// The raw response type returned by the underlying completion model.
286    type Response: Send + Sync + Serialize + DeserializeOwned;
287    /// The raw response type returned by the underlying completion model when streaming.
288    type StreamingResponse: Clone + Unpin + Send + Sync + Serialize + DeserializeOwned;
289
290    /// Generates a completion response for the given completion request.
291    fn completion(
292        &self,
293        request: CompletionRequest,
294    ) -> impl std::future::Future<
295        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
296    > + Send;
297
298    fn stream(
299        &self,
300        request: CompletionRequest,
301    ) -> impl std::future::Future<
302        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
303    > + Send;
304
305    /// Generates a completion request builder for the given `prompt`.
306    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
307        CompletionRequestBuilder::new(self.clone(), prompt)
308    }
309}
310pub trait CompletionModelDyn: Send + Sync {
311    fn completion(
312        &self,
313        request: CompletionRequest,
314    ) -> BoxFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
315
316    fn stream(
317        &self,
318        request: CompletionRequest,
319    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>>;
320
321    fn completion_request(
322        &self,
323        prompt: Message,
324    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
325}
326
327impl<T, R> CompletionModelDyn for T
328where
329    T: CompletionModel<StreamingResponse = R>,
330    R: Clone + Unpin + 'static,
331{
332    fn completion(
333        &self,
334        request: CompletionRequest,
335    ) -> BoxFuture<Result<CompletionResponse<()>, CompletionError>> {
336        Box::pin(async move {
337            self.completion(request)
338                .await
339                .map(|resp| CompletionResponse {
340                    choice: resp.choice,
341                    usage: resp.usage,
342                    raw_response: (),
343                })
344        })
345    }
346
347    fn stream(
348        &self,
349        request: CompletionRequest,
350    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>> {
351        Box::pin(async move {
352            let resp = self.stream(request).await?;
353            let inner = resp.inner;
354
355            let stream = Box::pin(streaming::StreamingResultDyn {
356                inner: Box::pin(inner),
357            });
358
359            Ok(StreamingCompletionResponse::stream(stream))
360        })
361    }
362
363    /// Generates a completion request builder for the given `prompt`.
364    fn completion_request(
365        &self,
366        prompt: Message,
367    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
368        CompletionRequestBuilder::new(
369            CompletionModelHandle {
370                inner: Arc::new(self.clone()),
371            },
372            prompt,
373        )
374    }
375}
376
377/// Struct representing a general completion request that can be sent to a completion model provider.
378#[derive(Debug, Clone)]
379pub struct CompletionRequest {
380    /// The preamble to be sent to the completion model provider
381    pub preamble: Option<String>,
382    /// The chat history to be sent to the completion model provider.
383    /// The very last message will always be the prompt (hence why there is *always* one)
384    pub chat_history: OneOrMany<Message>,
385    /// The documents to be sent to the completion model provider
386    pub documents: Vec<Document>,
387    /// The tools to be sent to the completion model provider
388    pub tools: Vec<ToolDefinition>,
389    /// The temperature to be sent to the completion model provider
390    pub temperature: Option<f64>,
391    /// The max tokens to be sent to the completion model provider
392    pub max_tokens: Option<u64>,
393    /// Additional provider-specific parameters to be sent to the completion model provider
394    pub additional_params: Option<serde_json::Value>,
395}
396
397impl CompletionRequest {
398    /// Returns documents normalized into a message (if any).
399    /// Most providers do not accept documents directly as input, so it needs to convert into a
400    ///  `Message` so that it can be incorporated into `chat_history` as a
401    pub fn normalized_documents(&self) -> Option<Message> {
402        if self.documents.is_empty() {
403            return None;
404        }
405
406        // Most providers will convert documents into a text unless it can handle document messages.
407        // We use `UserContent::document` for those who handle it directly!
408        let messages = self
409            .documents
410            .iter()
411            .map(|doc| {
412                UserContent::document(
413                    doc.to_string(),
414                    // In the future, we can customize `Document` to pass these extra types through.
415                    // Most providers ditch these but they might want to use them.
416                    Some(ContentFormat::String),
417                    Some(DocumentMediaType::TXT),
418                )
419            })
420            .collect::<Vec<_>>();
421
422        Some(Message::User {
423            content: OneOrMany::many(messages).expect("There will be atleast one document"),
424        })
425    }
426}
427
428/// Builder struct for constructing a completion request.
429///
430/// Example usage:
431/// ```rust
432/// use rig::{
433///     providers::openai::{Client, self},
434///     completion::CompletionRequestBuilder,
435/// };
436///
437/// let openai = Client::new("your-openai-api-key");
438/// let model = openai.completion_model(openai::GPT_4O).build();
439///
440/// // Create the completion request and execute it separately
441/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
442///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
443///     .temperature(0.5)
444///     .build();
445///
446/// let response = model.completion(request)
447///     .await
448///     .expect("Failed to get completion response");
449/// ```
450///
451/// Alternatively, you can execute the completion request directly from the builder:
452/// ```rust
453/// use rig::{
454///     providers::openai::{Client, self},
455///     completion::CompletionRequestBuilder,
456/// };
457///
458/// let openai = Client::new("your-openai-api-key");
459/// let model = openai.completion_model(openai::GPT_4O).build();
460///
461/// // Create the completion request and execute it directly
462/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
463///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
464///     .temperature(0.5)
465///     .send()
466///     .await
467///     .expect("Failed to get completion response");
468/// ```
469///
470/// Note: It is usually unnecessary to create a completion request builder directly.
471/// Instead, use the [CompletionModel::completion_request] method.
472pub struct CompletionRequestBuilder<M: CompletionModel> {
473    model: M,
474    prompt: Message,
475    preamble: Option<String>,
476    chat_history: Vec<Message>,
477    documents: Vec<Document>,
478    tools: Vec<ToolDefinition>,
479    temperature: Option<f64>,
480    max_tokens: Option<u64>,
481    additional_params: Option<serde_json::Value>,
482}
483
484impl<M: CompletionModel> CompletionRequestBuilder<M> {
485    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
486        Self {
487            model,
488            prompt: prompt.into(),
489            preamble: None,
490            chat_history: Vec::new(),
491            documents: Vec::new(),
492            tools: Vec::new(),
493            temperature: None,
494            max_tokens: None,
495            additional_params: None,
496        }
497    }
498
499    /// Sets the preamble for the completion request.
500    pub fn preamble(mut self, preamble: String) -> Self {
501        self.preamble = Some(preamble);
502        self
503    }
504
505    /// Adds a message to the chat history for the completion request.
506    pub fn message(mut self, message: Message) -> Self {
507        self.chat_history.push(message);
508        self
509    }
510
511    /// Adds a list of messages to the chat history for the completion request.
512    pub fn messages(self, messages: Vec<Message>) -> Self {
513        messages
514            .into_iter()
515            .fold(self, |builder, msg| builder.message(msg))
516    }
517
518    /// Adds a document to the completion request.
519    pub fn document(mut self, document: Document) -> Self {
520        self.documents.push(document);
521        self
522    }
523
524    /// Adds a list of documents to the completion request.
525    pub fn documents(self, documents: Vec<Document>) -> Self {
526        documents
527            .into_iter()
528            .fold(self, |builder, doc| builder.document(doc))
529    }
530
531    /// Adds a tool to the completion request.
532    pub fn tool(mut self, tool: ToolDefinition) -> Self {
533        self.tools.push(tool);
534        self
535    }
536
537    /// Adds a list of tools to the completion request.
538    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
539        tools
540            .into_iter()
541            .fold(self, |builder, tool| builder.tool(tool))
542    }
543
544    /// Adds additional parameters to the completion request.
545    /// This can be used to set additional provider-specific parameters. For example,
546    /// Cohere's completion models accept a `connectors` parameter that can be used to
547    /// specify the data connectors used by Cohere when executing the completion
548    /// (see `examples/cohere_connectors.rs`).
549    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
550        match self.additional_params {
551            Some(params) => {
552                self.additional_params = Some(json_utils::merge(params, additional_params));
553            }
554            None => {
555                self.additional_params = Some(additional_params);
556            }
557        }
558        self
559    }
560
561    /// Sets the additional parameters for the completion request.
562    /// This can be used to set additional provider-specific parameters. For example,
563    /// Cohere's completion models accept a `connectors` parameter that can be used to
564    /// specify the data connectors used by Cohere when executing the completion
565    /// (see `examples/cohere_connectors.rs`).
566    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
567        self.additional_params = additional_params;
568        self
569    }
570
571    /// Sets the temperature for the completion request.
572    pub fn temperature(mut self, temperature: f64) -> Self {
573        self.temperature = Some(temperature);
574        self
575    }
576
577    /// Sets the temperature for the completion request.
578    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
579        self.temperature = temperature;
580        self
581    }
582
583    /// Sets the max tokens for the completion request.
584    /// Note: This is required if using Anthropic
585    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
586        self.max_tokens = Some(max_tokens);
587        self
588    }
589
590    /// Sets the max tokens for the completion request.
591    /// Note: This is required if using Anthropic
592    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
593        self.max_tokens = max_tokens;
594        self
595    }
596
597    /// Builds the completion request.
598    pub fn build(self) -> CompletionRequest {
599        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
600            .expect("There will always be atleast the prompt");
601
602        CompletionRequest {
603            preamble: self.preamble,
604            chat_history,
605            documents: self.documents,
606            tools: self.tools,
607            temperature: self.temperature,
608            max_tokens: self.max_tokens,
609            additional_params: self.additional_params,
610        }
611    }
612
613    /// Sends the completion request to the completion model provider and returns the completion response.
614    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
615        let model = self.model.clone();
616        model.completion(self.build()).await
617    }
618
619    /// Stream the completion request
620    pub async fn stream<'a>(
621        self,
622    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
623    where
624        <M as CompletionModel>::StreamingResponse: 'a,
625        Self: 'a,
626    {
627        let model = self.model.clone();
628        model.stream(self.build()).await
629    }
630}
631
632#[cfg(test)]
633mod tests {
634
635    use super::*;
636
637    #[test]
638    fn test_document_display_without_metadata() {
639        let doc = Document {
640            id: "123".to_string(),
641            text: "This is a test document.".to_string(),
642            additional_props: HashMap::new(),
643        };
644
645        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
646        assert_eq!(format!("{doc}"), expected);
647    }
648
649    #[test]
650    fn test_document_display_with_metadata() {
651        let mut additional_props = HashMap::new();
652        additional_props.insert("author".to_string(), "John Doe".to_string());
653        additional_props.insert("length".to_string(), "42".to_string());
654
655        let doc = Document {
656            id: "123".to_string(),
657            text: "This is a test document.".to_string(),
658            additional_props,
659        };
660
661        let expected = concat!(
662            "<file id: 123>\n",
663            "<metadata author: \"John Doe\" length: \"42\" />\n",
664            "This is a test document.\n",
665            "</file>\n"
666        );
667        assert_eq!(format!("{doc}"), expected);
668    }
669
670    #[test]
671    fn test_normalize_documents_with_documents() {
672        let doc1 = Document {
673            id: "doc1".to_string(),
674            text: "Document 1 text.".to_string(),
675            additional_props: HashMap::new(),
676        };
677
678        let doc2 = Document {
679            id: "doc2".to_string(),
680            text: "Document 2 text.".to_string(),
681            additional_props: HashMap::new(),
682        };
683
684        let request = CompletionRequest {
685            preamble: None,
686            chat_history: OneOrMany::one("What is the capital of France?".into()),
687            documents: vec![doc1, doc2],
688            tools: Vec::new(),
689            temperature: None,
690            max_tokens: None,
691            additional_params: None,
692        };
693
694        let expected = Message::User {
695            content: OneOrMany::many(vec![
696                UserContent::document(
697                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
698                    Some(ContentFormat::String),
699                    Some(DocumentMediaType::TXT),
700                ),
701                UserContent::document(
702                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
703                    Some(ContentFormat::String),
704                    Some(DocumentMediaType::TXT),
705                ),
706            ])
707            .expect("There will be at least one document"),
708        };
709
710        assert_eq!(request.normalized_documents(), Some(expected));
711    }
712
713    #[test]
714    fn test_normalize_documents_without_documents() {
715        let request = CompletionRequest {
716            preamble: None,
717            chat_history: OneOrMany::one("What is the capital of France?".into()),
718            documents: Vec::new(),
719            tools: Vec::new(),
720            temperature: None,
721            max_tokens: None,
722            additional_params: None,
723        };
724
725        assert_eq!(request.normalized_documents(), None);
726    }
727}