bep/
completion.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use bep::providers::openai::{Client, self};
29//! use bep::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65use std::collections::HashMap;
66
67use serde::{Deserialize, Serialize};
68use thiserror::Error;
69
70use crate::{json_utils, tool::ToolSetError};
71
72// Errors
73#[derive(Debug, Error)]
74pub enum CompletionError {
75    /// Http error (e.g.: connection error, timeout, etc.)
76    #[error("HttpError: {0}")]
77    HttpError(#[from] reqwest::Error),
78
79    /// Json error (e.g.: serialization, deserialization)
80    #[error("JsonError: {0}")]
81    JsonError(#[from] serde_json::Error),
82
83    /// Error building the completion request
84    #[error("RequestError: {0}")]
85    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
86
87    /// Error parsing the completion response
88    #[error("ResponseError: {0}")]
89    ResponseError(String),
90
91    /// Error returned by the completion model provider
92    #[error("ProviderError: {0}")]
93    ProviderError(String),
94}
95
96#[derive(Debug, Error)]
97pub enum PromptError {
98    #[error("CompletionError: {0}")]
99    CompletionError(#[from] CompletionError),
100
101    #[error("ToolCallError: {0}")]
102    ToolError(#[from] ToolSetError),
103}
104
105// ================================================================
106// Request models
107// ================================================================
108#[derive(Clone, Debug, Deserialize, Serialize)]
109pub struct Message {
110    /// "system", "user", or "assistant"
111    pub role: String,
112    pub content: String,
113}
114
115#[derive(Clone, Debug, Deserialize, Serialize)]
116pub struct Document {
117    pub id: String,
118    pub text: String,
119    #[serde(flatten)]
120    pub additional_props: HashMap<String, String>,
121}
122
123impl std::fmt::Display for Document {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(
126            f,
127            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
128            self.id,
129            if self.additional_props.is_empty() {
130                self.text.clone()
131            } else {
132                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
133                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
134                let metadata = sorted_props
135                    .iter()
136                    .map(|(k, v)| format!("{}: {:?}", k, v))
137                    .collect::<Vec<_>>()
138                    .join(" ");
139                format!("<metadata {} />\n{}", metadata, self.text)
140            }
141        )
142    }
143}
144
145#[derive(Clone, Debug, Deserialize, Serialize)]
146pub struct ToolDefinition {
147    pub name: String,
148    pub description: String,
149    pub parameters: serde_json::Value,
150}
151
152// ================================================================
153// Implementations
154// ================================================================
155/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
156pub trait Prompt: Send + Sync {
157    /// Send a simple prompt to the underlying completion model.
158    ///
159    /// If the completion model's response is a message, then it is returned as a string.
160    ///
161    /// If the completion model's response is a tool call, then the tool is called and
162    /// the result is returned as a string.
163    ///
164    /// If the tool does not exist, or the tool call fails, then an error is returned.
165    fn prompt(
166        &self,
167        prompt: &str,
168    ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
169}
170
171/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
172pub trait Chat: Send + Sync {
173    /// Send a prompt with optional chat history to the underlying completion model.
174    ///
175    /// If the completion model's response is a message, then it is returned as a string.
176    ///
177    /// If the completion model's response is a tool call, then the tool is called and the result
178    /// is returned as a string.
179    ///
180    /// If the tool does not exist, or the tool call fails, then an error is returned.
181    fn chat(
182        &self,
183        prompt: &str,
184        chat_history: Vec<Message>,
185    ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
186}
187
188/// Trait defininig a low-level LLM completion interface
189pub trait Completion<M: CompletionModel> {
190    /// Generates a completion request builder for the given `prompt` and `chat_history`.
191    /// This function is meant to be called by the user to further customize the
192    /// request at prompt time before sending it.
193    ///
194    /// ❗IMPORTANT: The type that implements this trait might have already
195    /// populated fields in the builder (the exact fields depend on the type).
196    /// For fields that have already been set by the model, calling the corresponding
197    /// method on the builder will overwrite the value set by the model.
198    ///
199    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
200    /// contain the `preamble` provided when creating the agent.
201    fn completion(
202        &self,
203        prompt: &str,
204        chat_history: Vec<Message>,
205    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
206}
207
208/// General completion response struct that contains the high-level completion choice
209/// and the raw response.
210#[derive(Debug)]
211pub struct CompletionResponse<T> {
212    /// The completion choice returned by the completion model provider
213    pub choice: ModelChoice,
214    /// The raw response returned by the completion model provider
215    pub raw_response: T,
216}
217
218/// Enum representing the high-level completion choice returned by the completion model provider.
219#[derive(Debug)]
220pub enum ModelChoice {
221    /// Represents a completion response as a message
222    Message(String),
223    /// Represents a completion response as a tool call of the form
224    /// `ToolCall(function_name, function_params)`.
225    ToolCall(String, serde_json::Value),
226}
227
228/// Trait defining a completion model that can be used to generate completion responses.
229/// This trait is meant to be implemented by the user to define a custom completion model,
230/// either from a third party provider (e.g.: OpenAI) or a local model.
231pub trait CompletionModel: Clone + Send + Sync {
232    /// The raw response type returned by the underlying completion model.
233    type Response: Send + Sync;
234
235    /// Generates a completion response for the given completion request.
236    fn completion(
237        &self,
238        request: CompletionRequest,
239    ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
240           + Send;
241
242    /// Generates a completion request builder for the given `prompt`.
243    fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder<Self> {
244        CompletionRequestBuilder::new(self.clone(), prompt.to_string())
245    }
246}
247
248/// Struct representing a general completion request that can be sent to a completion model provider.
249pub struct CompletionRequest {
250    /// The prompt to be sent to the completion model provider
251    pub prompt: String,
252    /// The preamble to be sent to the completion model provider
253    pub preamble: Option<String>,
254    /// The chat history to be sent to the completion model provider
255    pub chat_history: Vec<Message>,
256    /// The documents to be sent to the completion model provider
257    pub documents: Vec<Document>,
258    /// The tools to be sent to the completion model provider
259    pub tools: Vec<ToolDefinition>,
260    /// The temperature to be sent to the completion model provider
261    pub temperature: Option<f64>,
262    /// The max tokens to be sent to the completion model provider
263    pub max_tokens: Option<u64>,
264    /// Additional provider-specific parameters to be sent to the completion model provider
265    pub additional_params: Option<serde_json::Value>,
266}
267
268impl CompletionRequest {
269    pub(crate) fn prompt_with_context(&self) -> String {
270        if !self.documents.is_empty() {
271            format!(
272                "<attachments>\n{}</attachments>\n\n{}",
273                self.documents
274                    .iter()
275                    .map(|doc| doc.to_string())
276                    .collect::<Vec<_>>()
277                    .join(""),
278                self.prompt
279            )
280        } else {
281            self.prompt.clone()
282        }
283    }
284}
285
286/// Builder struct for constructing a completion request.
287///
288/// Example usage:
289/// ```rust
290/// use bep::{
291///     providers::openai::{Client, self},
292///     completion::CompletionRequestBuilder,
293/// };
294///
295/// let openai = Client::new("your-openai-api-key");
296/// let model = openai.completion_model(openai::GPT_4O).build();
297///
298/// // Create the completion request and execute it separately
299/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
300///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
301///     .temperature(0.5)
302///     .build();
303///
304/// let response = model.completion(request)
305///     .await
306///     .expect("Failed to get completion response");
307/// ```
308///
309/// Alternatively, you can execute the completion request directly from the builder:
310/// ```rust
311/// use bep::{
312///     providers::openai::{Client, self},
313///     completion::CompletionRequestBuilder,
314/// };
315///
316/// let openai = Client::new("your-openai-api-key");
317/// let model = openai.completion_model(openai::GPT_4O).build();
318///
319/// // Create the completion request and execute it directly
320/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
321///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
322///     .temperature(0.5)
323///     .send()
324///     .await
325///     .expect("Failed to get completion response");
326/// ```
327///
328/// Note: It is usually unnecessary to create a completion request builder directly.
329/// Instead, use the [CompletionModel::completion_request] method.
330pub struct CompletionRequestBuilder<M: CompletionModel> {
331    model: M,
332    prompt: String,
333    preamble: Option<String>,
334    chat_history: Vec<Message>,
335    documents: Vec<Document>,
336    tools: Vec<ToolDefinition>,
337    temperature: Option<f64>,
338    max_tokens: Option<u64>,
339    additional_params: Option<serde_json::Value>,
340}
341
342impl<M: CompletionModel> CompletionRequestBuilder<M> {
343    pub fn new(model: M, prompt: String) -> Self {
344        Self {
345            model,
346            prompt,
347            preamble: None,
348            chat_history: Vec::new(),
349            documents: Vec::new(),
350            tools: Vec::new(),
351            temperature: None,
352            max_tokens: None,
353            additional_params: None,
354        }
355    }
356
357    /// Sets the preamble for the completion request.
358    pub fn preamble(mut self, preamble: String) -> Self {
359        self.preamble = Some(preamble);
360        self
361    }
362
363    /// Adds a message to the chat history for the completion request.
364    pub fn message(mut self, message: Message) -> Self {
365        self.chat_history.push(message);
366        self
367    }
368
369    /// Adds a list of messages to the chat history for the completion request.
370    pub fn messages(self, messages: Vec<Message>) -> Self {
371        messages
372            .into_iter()
373            .fold(self, |builder, msg| builder.message(msg))
374    }
375
376    /// Adds a document to the completion request.
377    pub fn document(mut self, document: Document) -> Self {
378        self.documents.push(document);
379        self
380    }
381
382    /// Adds a list of documents to the completion request.
383    pub fn documents(self, documents: Vec<Document>) -> Self {
384        documents
385            .into_iter()
386            .fold(self, |builder, doc| builder.document(doc))
387    }
388
389    /// Adds a tool to the completion request.
390    pub fn tool(mut self, tool: ToolDefinition) -> Self {
391        self.tools.push(tool);
392        self
393    }
394
395    /// Adds a list of tools to the completion request.
396    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
397        tools
398            .into_iter()
399            .fold(self, |builder, tool| builder.tool(tool))
400    }
401
402    /// Adds additional parameters to the completion request.
403    /// This can be used to set additional provider-specific parameters. For example,
404    /// Cohere's completion models accept a `connectors` parameter that can be used to
405    /// specify the data connectors used by Cohere when executing the completion
406    /// (see `examples/cohere_connectors.rs`).
407    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
408        match self.additional_params {
409            Some(params) => {
410                self.additional_params = Some(json_utils::merge(params, additional_params));
411            }
412            None => {
413                self.additional_params = Some(additional_params);
414            }
415        }
416        self
417    }
418
419    /// Sets the additional parameters for the completion request.
420    /// This can be used to set additional provider-specific parameters. For example,
421    /// Cohere's completion models accept a `connectors` parameter that can be used to
422    /// specify the data connectors used by Cohere when executing the completion
423    /// (see `examples/cohere_connectors.rs`).
424    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
425        self.additional_params = additional_params;
426        self
427    }
428
429    /// Sets the temperature for the completion request.
430    pub fn temperature(mut self, temperature: f64) -> Self {
431        self.temperature = Some(temperature);
432        self
433    }
434
435    /// Sets the temperature for the completion request.
436    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
437        self.temperature = temperature;
438        self
439    }
440
441    /// Sets the max tokens for the completion request.
442    /// Note: This is required if using Anthropic
443    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
444        self.max_tokens = Some(max_tokens);
445        self
446    }
447
448    /// Sets the max tokens for the completion request.
449    /// Note: This is required if using Anthropic
450    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
451        self.max_tokens = max_tokens;
452        self
453    }
454
455    /// Builds the completion request.
456    pub fn build(self) -> CompletionRequest {
457        CompletionRequest {
458            prompt: self.prompt,
459            preamble: self.preamble,
460            chat_history: self.chat_history,
461            documents: self.documents,
462            tools: self.tools,
463            temperature: self.temperature,
464            max_tokens: self.max_tokens,
465            additional_params: self.additional_params,
466        }
467    }
468
469    /// Sends the completion request to the completion model provider and returns the completion response.
470    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
471        let model = self.model.clone();
472        model.completion(self.build()).await
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_document_display_without_metadata() {
482        let doc = Document {
483            id: "123".to_string(),
484            text: "This is a test document.".to_string(),
485            additional_props: HashMap::new(),
486        };
487
488        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
489        assert_eq!(format!("{}", doc), expected);
490    }
491
492    #[test]
493    fn test_document_display_with_metadata() {
494        let mut additional_props = HashMap::new();
495        additional_props.insert("author".to_string(), "John Doe".to_string());
496        additional_props.insert("length".to_string(), "42".to_string());
497
498        let doc = Document {
499            id: "123".to_string(),
500            text: "This is a test document.".to_string(),
501            additional_props,
502        };
503
504        let expected = concat!(
505            "<file id: 123>\n",
506            "<metadata author: \"John Doe\" length: \"42\" />\n",
507            "This is a test document.\n",
508            "</file>\n"
509        );
510        assert_eq!(format!("{}", doc), expected);
511    }
512
513    #[test]
514    fn test_prompt_with_context_with_documents() {
515        let doc1 = Document {
516            id: "doc1".to_string(),
517            text: "Document 1 text.".to_string(),
518            additional_props: HashMap::new(),
519        };
520
521        let doc2 = Document {
522            id: "doc2".to_string(),
523            text: "Document 2 text.".to_string(),
524            additional_props: HashMap::new(),
525        };
526
527        let request = CompletionRequest {
528            prompt: "What is the capital of France?".to_string(),
529            preamble: None,
530            chat_history: Vec::new(),
531            documents: vec![doc1, doc2],
532            tools: Vec::new(),
533            temperature: None,
534            max_tokens: None,
535            additional_params: None,
536        };
537
538        let expected = concat!(
539            "<attachments>\n",
540            "<file id: doc1>\nDocument 1 text.\n</file>\n",
541            "<file id: doc2>\nDocument 2 text.\n</file>\n",
542            "</attachments>\n\n",
543            "What is the capital of France?"
544        )
545        .to_string();
546
547        assert_eq!(request.prompt_with_context(), expected);
548    }
549}