rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
67use crate::client::completion::CompletionModelHandle;
68use crate::streaming::StreamingCompletionResponse;
69use crate::{
70    json_utils,
71    message::{Message, UserContent},
72    tool::ToolSetError,
73};
74use crate::{streaming, OneOrMany};
75use futures::future::BoxFuture;
76use serde::{Deserialize, Serialize};
77use std::collections::HashMap;
78use std::sync::Arc;
79use thiserror::Error;
80
81// Errors
82#[derive(Debug, Error)]
83pub enum CompletionError {
84    /// Http error (e.g.: connection error, timeout, etc.)
85    #[error("HttpError: {0}")]
86    HttpError(#[from] reqwest::Error),
87
88    /// Json error (e.g.: serialization, deserialization)
89    #[error("JsonError: {0}")]
90    JsonError(#[from] serde_json::Error),
91
92    /// Error building the completion request
93    #[error("RequestError: {0}")]
94    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
95
96    /// Error parsing the completion response
97    #[error("ResponseError: {0}")]
98    ResponseError(String),
99
100    /// Error returned by the completion model provider
101    #[error("ProviderError: {0}")]
102    ProviderError(String),
103}
104
105#[derive(Debug, Error)]
106pub enum PromptError {
107    #[error("CompletionError: {0}")]
108    CompletionError(#[from] CompletionError),
109
110    #[error("ToolCallError: {0}")]
111    ToolError(#[from] ToolSetError),
112
113    #[error("MaxDepthError: (reached limit: {max_depth})")]
114    MaxDepthError {
115        max_depth: usize,
116        chat_history: Vec<Message>,
117        prompt: Message,
118    },
119}
120
121#[derive(Clone, Debug, Deserialize, Serialize)]
122pub struct Document {
123    pub id: String,
124    pub text: String,
125    #[serde(flatten)]
126    pub additional_props: HashMap<String, String>,
127}
128
129impl std::fmt::Display for Document {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(
132            f,
133            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
134            self.id,
135            if self.additional_props.is_empty() {
136                self.text.clone()
137            } else {
138                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
139                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
140                let metadata = sorted_props
141                    .iter()
142                    .map(|(k, v)| format!("{k}: {v:?}"))
143                    .collect::<Vec<_>>()
144                    .join(" ");
145                format!("<metadata {} />\n{}", metadata, self.text)
146            }
147        )
148    }
149}
150
151#[derive(Clone, Debug, Deserialize, Serialize)]
152pub struct ToolDefinition {
153    pub name: String,
154    pub description: String,
155    pub parameters: serde_json::Value,
156}
157
158// ================================================================
159// Implementations
160// ================================================================
161/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
162pub trait Prompt: Send + Sync {
163    /// Send a simple prompt to the underlying completion model.
164    ///
165    /// If the completion model's response is a message, then it is returned as a string.
166    ///
167    /// If the completion model's response is a tool call, then the tool is called and
168    /// the result is returned as a string.
169    ///
170    /// If the tool does not exist, or the tool call fails, then an error is returned.
171    fn prompt(
172        &self,
173        prompt: impl Into<Message> + Send,
174    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
175}
176
177/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
178pub trait Chat: Send + Sync {
179    /// Send a prompt with optional chat history to the underlying completion model.
180    ///
181    /// If the completion model's response is a message, then it is returned as a string.
182    ///
183    /// If the completion model's response is a tool call, then the tool is called and the result
184    /// is returned as a string.
185    ///
186    /// If the tool does not exist, or the tool call fails, then an error is returned.
187    fn chat(
188        &self,
189        prompt: impl Into<Message> + Send,
190        chat_history: Vec<Message>,
191    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
192}
193
194/// Trait defining a low-level LLM completion interface
195pub trait Completion<M: CompletionModel> {
196    /// Generates a completion request builder for the given `prompt` and `chat_history`.
197    /// This function is meant to be called by the user to further customize the
198    /// request at prompt time before sending it.
199    ///
200    /// ❗IMPORTANT: The type that implements this trait might have already
201    /// populated fields in the builder (the exact fields depend on the type).
202    /// For fields that have already been set by the model, calling the corresponding
203    /// method on the builder will overwrite the value set by the model.
204    ///
205    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
206    /// contain the `preamble` provided when creating the agent.
207    fn completion(
208        &self,
209        prompt: impl Into<Message> + Send,
210        chat_history: Vec<Message>,
211    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
212}
213
214/// General completion response struct that contains the high-level completion choice
215/// and the raw response. The completion choice contains one or more assistant content.
216#[derive(Debug)]
217pub struct CompletionResponse<T> {
218    /// The completion choice (represented by one or more assistant message content)
219    /// returned by the completion model provider
220    pub choice: OneOrMany<AssistantContent>,
221    /// The raw response returned by the completion model provider
222    pub raw_response: T,
223}
224
225/// Trait defining a completion model that can be used to generate completion responses.
226/// This trait is meant to be implemented by the user to define a custom completion model,
227/// either from a third party provider (e.g.: OpenAI) or a local model.
228pub trait CompletionModel: Clone + Send + Sync {
229    /// The raw response type returned by the underlying completion model.
230    type Response: Send + Sync;
231    /// The raw response type returned by the underlying completion model when streaming.
232    type StreamingResponse: Clone + Unpin + Send + Sync;
233
234    /// Generates a completion response for the given completion request.
235    fn completion(
236        &self,
237        request: CompletionRequest,
238    ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
239           + Send;
240
241    fn stream(
242        &self,
243        request: CompletionRequest,
244    ) -> impl std::future::Future<
245        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
246    > + Send;
247
248    /// Generates a completion request builder for the given `prompt`.
249    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
250        CompletionRequestBuilder::new(self.clone(), prompt)
251    }
252}
253pub trait CompletionModelDyn: Send + Sync {
254    fn completion(
255        &self,
256        request: CompletionRequest,
257    ) -> BoxFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
258
259    fn stream(
260        &self,
261        request: CompletionRequest,
262    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>>;
263
264    fn completion_request(
265        &self,
266        prompt: Message,
267    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
268}
269
270impl<T, R> CompletionModelDyn for T
271where
272    T: CompletionModel<StreamingResponse = R>,
273    R: Clone + Unpin + 'static,
274{
275    fn completion(
276        &self,
277        request: CompletionRequest,
278    ) -> BoxFuture<Result<CompletionResponse<()>, CompletionError>> {
279        Box::pin(async move {
280            self.completion(request)
281                .await
282                .map(|resp| CompletionResponse {
283                    choice: resp.choice,
284                    raw_response: (),
285                })
286        })
287    }
288
289    fn stream(
290        &self,
291        request: CompletionRequest,
292    ) -> BoxFuture<Result<StreamingCompletionResponse<()>, CompletionError>> {
293        Box::pin(async move {
294            let resp = self.stream(request).await?;
295            let inner = resp.inner;
296
297            let stream = Box::pin(streaming::StreamingResultDyn { inner });
298
299            Ok(StreamingCompletionResponse::stream(stream))
300        })
301    }
302
303    /// Generates a completion request builder for the given `prompt`.
304    fn completion_request(
305        &self,
306        prompt: Message,
307    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
308        CompletionRequestBuilder::new(
309            CompletionModelHandle {
310                inner: Arc::new(self.clone()),
311            },
312            prompt,
313        )
314    }
315}
316
317/// Struct representing a general completion request that can be sent to a completion model provider.
318#[derive(Debug, Clone)]
319pub struct CompletionRequest {
320    /// The preamble to be sent to the completion model provider
321    pub preamble: Option<String>,
322    /// The chat history to be sent to the completion model provider
323    /// The very last message will always be the prompt (hense why there is *always* one)
324    pub chat_history: OneOrMany<Message>,
325    /// The documents to be sent to the completion model provider
326    pub documents: Vec<Document>,
327    /// The tools to be sent to the completion model provider
328    pub tools: Vec<ToolDefinition>,
329    /// The temperature to be sent to the completion model provider
330    pub temperature: Option<f64>,
331    /// The max tokens to be sent to the completion model provider
332    pub max_tokens: Option<u64>,
333    /// Additional provider-specific parameters to be sent to the completion model provider
334    pub additional_params: Option<serde_json::Value>,
335}
336
337impl CompletionRequest {
338    /// Returns documents normalized into a message (if any).
339    /// Most providers do not accept documents directly as input, so it needs to convert into a
340    ///  `Message` so that it can be incorporated into `chat_history` as a
341    pub fn normalized_documents(&self) -> Option<Message> {
342        if self.documents.is_empty() {
343            return None;
344        }
345
346        // Most providers will convert documents into a text unless it can handle document messages.
347        // We use `UserContent::document` for those who handle it directly!
348        let messages = self
349            .documents
350            .iter()
351            .map(|doc| {
352                UserContent::document(
353                    doc.to_string(),
354                    // In the future, we can customize `Document` to pass these extra types through.
355                    // Most providers ditch these but they might want to use them.
356                    Some(ContentFormat::String),
357                    Some(DocumentMediaType::TXT),
358                )
359            })
360            .collect::<Vec<_>>();
361
362        Some(Message::User {
363            content: OneOrMany::many(messages).expect("There will be atleast one document"),
364        })
365    }
366}
367
368/// Builder struct for constructing a completion request.
369///
370/// Example usage:
371/// ```rust
372/// use rig::{
373///     providers::openai::{Client, self},
374///     completion::CompletionRequestBuilder,
375/// };
376///
377/// let openai = Client::new("your-openai-api-key");
378/// let model = openai.completion_model(openai::GPT_4O).build();
379///
380/// // Create the completion request and execute it separately
381/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
382///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
383///     .temperature(0.5)
384///     .build();
385///
386/// let response = model.completion(request)
387///     .await
388///     .expect("Failed to get completion response");
389/// ```
390///
391/// Alternatively, you can execute the completion request directly from the builder:
392/// ```rust
393/// use rig::{
394///     providers::openai::{Client, self},
395///     completion::CompletionRequestBuilder,
396/// };
397///
398/// let openai = Client::new("your-openai-api-key");
399/// let model = openai.completion_model(openai::GPT_4O).build();
400///
401/// // Create the completion request and execute it directly
402/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
403///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
404///     .temperature(0.5)
405///     .send()
406///     .await
407///     .expect("Failed to get completion response");
408/// ```
409///
410/// Note: It is usually unnecessary to create a completion request builder directly.
411/// Instead, use the [CompletionModel::completion_request] method.
412pub struct CompletionRequestBuilder<M: CompletionModel> {
413    model: M,
414    prompt: Message,
415    preamble: Option<String>,
416    chat_history: Vec<Message>,
417    documents: Vec<Document>,
418    tools: Vec<ToolDefinition>,
419    temperature: Option<f64>,
420    max_tokens: Option<u64>,
421    additional_params: Option<serde_json::Value>,
422}
423
424impl<M: CompletionModel> CompletionRequestBuilder<M> {
425    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
426        Self {
427            model,
428            prompt: prompt.into(),
429            preamble: None,
430            chat_history: Vec::new(),
431            documents: Vec::new(),
432            tools: Vec::new(),
433            temperature: None,
434            max_tokens: None,
435            additional_params: None,
436        }
437    }
438
439    /// Sets the preamble for the completion request.
440    pub fn preamble(mut self, preamble: String) -> Self {
441        self.preamble = Some(preamble);
442        self
443    }
444
445    /// Adds a message to the chat history for the completion request.
446    pub fn message(mut self, message: Message) -> Self {
447        self.chat_history.push(message);
448        self
449    }
450
451    /// Adds a list of messages to the chat history for the completion request.
452    pub fn messages(self, messages: Vec<Message>) -> Self {
453        messages
454            .into_iter()
455            .fold(self, |builder, msg| builder.message(msg))
456    }
457
458    /// Adds a document to the completion request.
459    pub fn document(mut self, document: Document) -> Self {
460        self.documents.push(document);
461        self
462    }
463
464    /// Adds a list of documents to the completion request.
465    pub fn documents(self, documents: Vec<Document>) -> Self {
466        documents
467            .into_iter()
468            .fold(self, |builder, doc| builder.document(doc))
469    }
470
471    /// Adds a tool to the completion request.
472    pub fn tool(mut self, tool: ToolDefinition) -> Self {
473        self.tools.push(tool);
474        self
475    }
476
477    /// Adds a list of tools to the completion request.
478    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
479        tools
480            .into_iter()
481            .fold(self, |builder, tool| builder.tool(tool))
482    }
483
484    /// Adds additional parameters to the completion request.
485    /// This can be used to set additional provider-specific parameters. For example,
486    /// Cohere's completion models accept a `connectors` parameter that can be used to
487    /// specify the data connectors used by Cohere when executing the completion
488    /// (see `examples/cohere_connectors.rs`).
489    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
490        match self.additional_params {
491            Some(params) => {
492                self.additional_params = Some(json_utils::merge(params, additional_params));
493            }
494            None => {
495                self.additional_params = Some(additional_params);
496            }
497        }
498        self
499    }
500
501    /// Sets the additional parameters for the completion request.
502    /// This can be used to set additional provider-specific parameters. For example,
503    /// Cohere's completion models accept a `connectors` parameter that can be used to
504    /// specify the data connectors used by Cohere when executing the completion
505    /// (see `examples/cohere_connectors.rs`).
506    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
507        self.additional_params = additional_params;
508        self
509    }
510
511    /// Sets the temperature for the completion request.
512    pub fn temperature(mut self, temperature: f64) -> Self {
513        self.temperature = Some(temperature);
514        self
515    }
516
517    /// Sets the temperature for the completion request.
518    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
519        self.temperature = temperature;
520        self
521    }
522
523    /// Sets the max tokens for the completion request.
524    /// Note: This is required if using Anthropic
525    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
526        self.max_tokens = Some(max_tokens);
527        self
528    }
529
530    /// Sets the max tokens for the completion request.
531    /// Note: This is required if using Anthropic
532    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
533        self.max_tokens = max_tokens;
534        self
535    }
536
537    /// Builds the completion request.
538    pub fn build(self) -> CompletionRequest {
539        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
540            .expect("There will always be atleast the prompt");
541
542        CompletionRequest {
543            preamble: self.preamble,
544            chat_history,
545            documents: self.documents,
546            tools: self.tools,
547            temperature: self.temperature,
548            max_tokens: self.max_tokens,
549            additional_params: self.additional_params,
550        }
551    }
552
553    /// Sends the completion request to the completion model provider and returns the completion response.
554    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
555        let model = self.model.clone();
556        model.completion(self.build()).await
557    }
558
559    /// Stream the completion request
560    pub async fn stream<'a>(
561        self,
562    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
563    where
564        <M as CompletionModel>::StreamingResponse: 'a,
565        Self: 'a,
566    {
567        let model = self.model.clone();
568        model.stream(self.build()).await
569    }
570}
571
572#[cfg(test)]
573mod tests {
574
575    use super::*;
576
577    #[test]
578    fn test_document_display_without_metadata() {
579        let doc = Document {
580            id: "123".to_string(),
581            text: "This is a test document.".to_string(),
582            additional_props: HashMap::new(),
583        };
584
585        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
586        assert_eq!(format!("{}", doc), expected);
587    }
588
589    #[test]
590    fn test_document_display_with_metadata() {
591        let mut additional_props = HashMap::new();
592        additional_props.insert("author".to_string(), "John Doe".to_string());
593        additional_props.insert("length".to_string(), "42".to_string());
594
595        let doc = Document {
596            id: "123".to_string(),
597            text: "This is a test document.".to_string(),
598            additional_props,
599        };
600
601        let expected = concat!(
602            "<file id: 123>\n",
603            "<metadata author: \"John Doe\" length: \"42\" />\n",
604            "This is a test document.\n",
605            "</file>\n"
606        );
607        assert_eq!(format!("{}", doc), expected);
608    }
609
610    #[test]
611    fn test_normalize_documents_with_documents() {
612        let doc1 = Document {
613            id: "doc1".to_string(),
614            text: "Document 1 text.".to_string(),
615            additional_props: HashMap::new(),
616        };
617
618        let doc2 = Document {
619            id: "doc2".to_string(),
620            text: "Document 2 text.".to_string(),
621            additional_props: HashMap::new(),
622        };
623
624        let request = CompletionRequest {
625            preamble: None,
626            chat_history: OneOrMany::one("What is the capital of France?".into()),
627            documents: vec![doc1, doc2],
628            tools: Vec::new(),
629            temperature: None,
630            max_tokens: None,
631            additional_params: None,
632        };
633
634        let expected = Message::User {
635            content: OneOrMany::many(vec![
636                UserContent::document(
637                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
638                    Some(ContentFormat::String),
639                    Some(DocumentMediaType::TXT),
640                ),
641                UserContent::document(
642                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
643                    Some(ContentFormat::String),
644                    Some(DocumentMediaType::TXT),
645                ),
646            ])
647            .expect("There will be at least one document"),
648        };
649
650        assert_eq!(request.normalized_documents(), Some(expected));
651    }
652
653    #[test]
654    fn test_normalize_documents_without_documents() {
655        let request = CompletionRequest {
656            preamble: None,
657            chat_history: OneOrMany::one("What is the capital of France?".into()),
658            documents: Vec::new(),
659            tools: Vec::new(),
660            temperature: None,
661            max_tokens: None,
662            additional_params: None,
663        };
664
665        assert_eq!(request.normalized_documents(), None);
666    }
667}