bep/completion.rs
1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//! responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use bep::providers::openai::{Client, self};
29//! use bep::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//! .preamble("\
39//! You are Marvin, an extremely smart but depressed robot who is \
40//! nonetheless helpful towards humanity.\
41//! ")
42//! .temperature(0.5)
43//! .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//! .await
48//! .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//! ModelChoice::Message(message) => {
53//! // Handle the completion response as a message
54//! println!("Received message: {}", message);
55//! }
56//! ModelChoice::ToolCall(tool_name, tool_params) => {
57//! // Handle the completion response as a tool call
58//! println!("Received tool call: {} {:?}", tool_name, tool_params);
59//! }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65use std::collections::HashMap;
66
67use serde::{Deserialize, Serialize};
68use thiserror::Error;
69
70use crate::{json_utils, tool::ToolSetError};
71
72// Errors
73#[derive(Debug, Error)]
74pub enum CompletionError {
75 /// Http error (e.g.: connection error, timeout, etc.)
76 #[error("HttpError: {0}")]
77 HttpError(#[from] reqwest::Error),
78
79 /// Json error (e.g.: serialization, deserialization)
80 #[error("JsonError: {0}")]
81 JsonError(#[from] serde_json::Error),
82
83 /// Error building the completion request
84 #[error("RequestError: {0}")]
85 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
86
87 /// Error parsing the completion response
88 #[error("ResponseError: {0}")]
89 ResponseError(String),
90
91 /// Error returned by the completion model provider
92 #[error("ProviderError: {0}")]
93 ProviderError(String),
94}
95
96#[derive(Debug, Error)]
97pub enum PromptError {
98 #[error("CompletionError: {0}")]
99 CompletionError(#[from] CompletionError),
100
101 #[error("ToolCallError: {0}")]
102 ToolError(#[from] ToolSetError),
103}
104
105// ================================================================
106// Request models
107// ================================================================
108#[derive(Clone, Debug, Deserialize, Serialize)]
109pub struct Message {
110 /// "system", "user", or "assistant"
111 pub role: String,
112 pub content: String,
113}
114
115#[derive(Clone, Debug, Deserialize, Serialize)]
116pub struct Document {
117 pub id: String,
118 pub text: String,
119 #[serde(flatten)]
120 pub additional_props: HashMap<String, String>,
121}
122
123impl std::fmt::Display for Document {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 write!(
126 f,
127 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
128 self.id,
129 if self.additional_props.is_empty() {
130 self.text.clone()
131 } else {
132 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
133 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
134 let metadata = sorted_props
135 .iter()
136 .map(|(k, v)| format!("{}: {:?}", k, v))
137 .collect::<Vec<_>>()
138 .join(" ");
139 format!("<metadata {} />\n{}", metadata, self.text)
140 }
141 )
142 }
143}
144
145#[derive(Clone, Debug, Deserialize, Serialize)]
146pub struct ToolDefinition {
147 pub name: String,
148 pub description: String,
149 pub parameters: serde_json::Value,
150}
151
152// ================================================================
153// Implementations
154// ================================================================
155/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
156pub trait Prompt: Send + Sync {
157 /// Send a simple prompt to the underlying completion model.
158 ///
159 /// If the completion model's response is a message, then it is returned as a string.
160 ///
161 /// If the completion model's response is a tool call, then the tool is called and
162 /// the result is returned as a string.
163 ///
164 /// If the tool does not exist, or the tool call fails, then an error is returned.
165 fn prompt(
166 &self,
167 prompt: &str,
168 ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
169}
170
171/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
172pub trait Chat: Send + Sync {
173 /// Send a prompt with optional chat history to the underlying completion model.
174 ///
175 /// If the completion model's response is a message, then it is returned as a string.
176 ///
177 /// If the completion model's response is a tool call, then the tool is called and the result
178 /// is returned as a string.
179 ///
180 /// If the tool does not exist, or the tool call fails, then an error is returned.
181 fn chat(
182 &self,
183 prompt: &str,
184 chat_history: Vec<Message>,
185 ) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
186}
187
188/// Trait defininig a low-level LLM completion interface
189pub trait Completion<M: CompletionModel> {
190 /// Generates a completion request builder for the given `prompt` and `chat_history`.
191 /// This function is meant to be called by the user to further customize the
192 /// request at prompt time before sending it.
193 ///
194 /// ❗IMPORTANT: The type that implements this trait might have already
195 /// populated fields in the builder (the exact fields depend on the type).
196 /// For fields that have already been set by the model, calling the corresponding
197 /// method on the builder will overwrite the value set by the model.
198 ///
199 /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
200 /// contain the `preamble` provided when creating the agent.
201 fn completion(
202 &self,
203 prompt: &str,
204 chat_history: Vec<Message>,
205 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
206}
207
208/// General completion response struct that contains the high-level completion choice
209/// and the raw response.
210#[derive(Debug)]
211pub struct CompletionResponse<T> {
212 /// The completion choice returned by the completion model provider
213 pub choice: ModelChoice,
214 /// The raw response returned by the completion model provider
215 pub raw_response: T,
216}
217
218/// Enum representing the high-level completion choice returned by the completion model provider.
219#[derive(Debug)]
220pub enum ModelChoice {
221 /// Represents a completion response as a message
222 Message(String),
223 /// Represents a completion response as a tool call of the form
224 /// `ToolCall(function_name, function_params)`.
225 ToolCall(String, serde_json::Value),
226}
227
228/// Trait defining a completion model that can be used to generate completion responses.
229/// This trait is meant to be implemented by the user to define a custom completion model,
230/// either from a third party provider (e.g.: OpenAI) or a local model.
231pub trait CompletionModel: Clone + Send + Sync {
232 /// The raw response type returned by the underlying completion model.
233 type Response: Send + Sync;
234
235 /// Generates a completion response for the given completion request.
236 fn completion(
237 &self,
238 request: CompletionRequest,
239 ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
240 + Send;
241
242 /// Generates a completion request builder for the given `prompt`.
243 fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder<Self> {
244 CompletionRequestBuilder::new(self.clone(), prompt.to_string())
245 }
246}
247
248/// Struct representing a general completion request that can be sent to a completion model provider.
249pub struct CompletionRequest {
250 /// The prompt to be sent to the completion model provider
251 pub prompt: String,
252 /// The preamble to be sent to the completion model provider
253 pub preamble: Option<String>,
254 /// The chat history to be sent to the completion model provider
255 pub chat_history: Vec<Message>,
256 /// The documents to be sent to the completion model provider
257 pub documents: Vec<Document>,
258 /// The tools to be sent to the completion model provider
259 pub tools: Vec<ToolDefinition>,
260 /// The temperature to be sent to the completion model provider
261 pub temperature: Option<f64>,
262 /// The max tokens to be sent to the completion model provider
263 pub max_tokens: Option<u64>,
264 /// Additional provider-specific parameters to be sent to the completion model provider
265 pub additional_params: Option<serde_json::Value>,
266}
267
268impl CompletionRequest {
269 pub(crate) fn prompt_with_context(&self) -> String {
270 if !self.documents.is_empty() {
271 format!(
272 "<attachments>\n{}</attachments>\n\n{}",
273 self.documents
274 .iter()
275 .map(|doc| doc.to_string())
276 .collect::<Vec<_>>()
277 .join(""),
278 self.prompt
279 )
280 } else {
281 self.prompt.clone()
282 }
283 }
284}
285
286/// Builder struct for constructing a completion request.
287///
288/// Example usage:
289/// ```rust
290/// use bep::{
291/// providers::openai::{Client, self},
292/// completion::CompletionRequestBuilder,
293/// };
294///
295/// let openai = Client::new("your-openai-api-key");
296/// let model = openai.completion_model(openai::GPT_4O).build();
297///
298/// // Create the completion request and execute it separately
299/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
300/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
301/// .temperature(0.5)
302/// .build();
303///
304/// let response = model.completion(request)
305/// .await
306/// .expect("Failed to get completion response");
307/// ```
308///
309/// Alternatively, you can execute the completion request directly from the builder:
310/// ```rust
311/// use bep::{
312/// providers::openai::{Client, self},
313/// completion::CompletionRequestBuilder,
314/// };
315///
316/// let openai = Client::new("your-openai-api-key");
317/// let model = openai.completion_model(openai::GPT_4O).build();
318///
319/// // Create the completion request and execute it directly
320/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
321/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
322/// .temperature(0.5)
323/// .send()
324/// .await
325/// .expect("Failed to get completion response");
326/// ```
327///
328/// Note: It is usually unnecessary to create a completion request builder directly.
329/// Instead, use the [CompletionModel::completion_request] method.
330pub struct CompletionRequestBuilder<M: CompletionModel> {
331 model: M,
332 prompt: String,
333 preamble: Option<String>,
334 chat_history: Vec<Message>,
335 documents: Vec<Document>,
336 tools: Vec<ToolDefinition>,
337 temperature: Option<f64>,
338 max_tokens: Option<u64>,
339 additional_params: Option<serde_json::Value>,
340}
341
342impl<M: CompletionModel> CompletionRequestBuilder<M> {
343 pub fn new(model: M, prompt: String) -> Self {
344 Self {
345 model,
346 prompt,
347 preamble: None,
348 chat_history: Vec::new(),
349 documents: Vec::new(),
350 tools: Vec::new(),
351 temperature: None,
352 max_tokens: None,
353 additional_params: None,
354 }
355 }
356
357 /// Sets the preamble for the completion request.
358 pub fn preamble(mut self, preamble: String) -> Self {
359 self.preamble = Some(preamble);
360 self
361 }
362
363 /// Adds a message to the chat history for the completion request.
364 pub fn message(mut self, message: Message) -> Self {
365 self.chat_history.push(message);
366 self
367 }
368
369 /// Adds a list of messages to the chat history for the completion request.
370 pub fn messages(self, messages: Vec<Message>) -> Self {
371 messages
372 .into_iter()
373 .fold(self, |builder, msg| builder.message(msg))
374 }
375
376 /// Adds a document to the completion request.
377 pub fn document(mut self, document: Document) -> Self {
378 self.documents.push(document);
379 self
380 }
381
382 /// Adds a list of documents to the completion request.
383 pub fn documents(self, documents: Vec<Document>) -> Self {
384 documents
385 .into_iter()
386 .fold(self, |builder, doc| builder.document(doc))
387 }
388
389 /// Adds a tool to the completion request.
390 pub fn tool(mut self, tool: ToolDefinition) -> Self {
391 self.tools.push(tool);
392 self
393 }
394
395 /// Adds a list of tools to the completion request.
396 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
397 tools
398 .into_iter()
399 .fold(self, |builder, tool| builder.tool(tool))
400 }
401
402 /// Adds additional parameters to the completion request.
403 /// This can be used to set additional provider-specific parameters. For example,
404 /// Cohere's completion models accept a `connectors` parameter that can be used to
405 /// specify the data connectors used by Cohere when executing the completion
406 /// (see `examples/cohere_connectors.rs`).
407 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
408 match self.additional_params {
409 Some(params) => {
410 self.additional_params = Some(json_utils::merge(params, additional_params));
411 }
412 None => {
413 self.additional_params = Some(additional_params);
414 }
415 }
416 self
417 }
418
419 /// Sets the additional parameters for the completion request.
420 /// This can be used to set additional provider-specific parameters. For example,
421 /// Cohere's completion models accept a `connectors` parameter that can be used to
422 /// specify the data connectors used by Cohere when executing the completion
423 /// (see `examples/cohere_connectors.rs`).
424 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
425 self.additional_params = additional_params;
426 self
427 }
428
429 /// Sets the temperature for the completion request.
430 pub fn temperature(mut self, temperature: f64) -> Self {
431 self.temperature = Some(temperature);
432 self
433 }
434
435 /// Sets the temperature for the completion request.
436 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
437 self.temperature = temperature;
438 self
439 }
440
441 /// Sets the max tokens for the completion request.
442 /// Note: This is required if using Anthropic
443 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
444 self.max_tokens = Some(max_tokens);
445 self
446 }
447
448 /// Sets the max tokens for the completion request.
449 /// Note: This is required if using Anthropic
450 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
451 self.max_tokens = max_tokens;
452 self
453 }
454
455 /// Builds the completion request.
456 pub fn build(self) -> CompletionRequest {
457 CompletionRequest {
458 prompt: self.prompt,
459 preamble: self.preamble,
460 chat_history: self.chat_history,
461 documents: self.documents,
462 tools: self.tools,
463 temperature: self.temperature,
464 max_tokens: self.max_tokens,
465 additional_params: self.additional_params,
466 }
467 }
468
469 /// Sends the completion request to the completion model provider and returns the completion response.
470 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
471 let model = self.model.clone();
472 model.completion(self.build()).await
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_document_display_without_metadata() {
482 let doc = Document {
483 id: "123".to_string(),
484 text: "This is a test document.".to_string(),
485 additional_props: HashMap::new(),
486 };
487
488 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
489 assert_eq!(format!("{}", doc), expected);
490 }
491
492 #[test]
493 fn test_document_display_with_metadata() {
494 let mut additional_props = HashMap::new();
495 additional_props.insert("author".to_string(), "John Doe".to_string());
496 additional_props.insert("length".to_string(), "42".to_string());
497
498 let doc = Document {
499 id: "123".to_string(),
500 text: "This is a test document.".to_string(),
501 additional_props,
502 };
503
504 let expected = concat!(
505 "<file id: 123>\n",
506 "<metadata author: \"John Doe\" length: \"42\" />\n",
507 "This is a test document.\n",
508 "</file>\n"
509 );
510 assert_eq!(format!("{}", doc), expected);
511 }
512
513 #[test]
514 fn test_prompt_with_context_with_documents() {
515 let doc1 = Document {
516 id: "doc1".to_string(),
517 text: "Document 1 text.".to_string(),
518 additional_props: HashMap::new(),
519 };
520
521 let doc2 = Document {
522 id: "doc2".to_string(),
523 text: "Document 2 text.".to_string(),
524 additional_props: HashMap::new(),
525 };
526
527 let request = CompletionRequest {
528 prompt: "What is the capital of France?".to_string(),
529 preamble: None,
530 chat_history: Vec::new(),
531 documents: vec![doc1, doc2],
532 tools: Vec::new(),
533 temperature: None,
534 max_tokens: None,
535 additional_params: None,
536 };
537
538 let expected = concat!(
539 "<attachments>\n",
540 "<file id: doc1>\nDocument 1 text.\n</file>\n",
541 "<file id: doc2>\nDocument 2 text.\n</file>\n",
542 "</attachments>\n\n",
543 "What is the capital of France?"
544 )
545 .to_string();
546
547 assert_eq!(request.prompt_with_context(), expected);
548 }
549}