bep/
agent.rs

1//! This module contains the implementation of the [Agent] struct and its builder.
2//!
3//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
4//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
5//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
6//!
7//! The [Agent] struct is highly configurable, allowing the user to define anything from
8//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
9//! context documents and tools.
10//!
11//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
12//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
13//! be used for generating chat completions.
14//!
15//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
16//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
17//! before building the agent.
18//!
19//! # Example
20//! ```rust
21//! use bep::{
22//!     completion::{Chat, Completion, Prompt},
23//!     providers::openai,
24//! };
25//!
26//! let openai = openai::Client::from_env();
27//!
28//! // Configure the agent
29//! let agent = openai.agent("gpt-4o")
30//!     .preamble("System prompt")
31//!     .context("Context document 1")
32//!     .context("Context document 2")
33//!     .tool(tool1)
34//!     .tool(tool2)
35//!     .temperature(0.8)
36//!     .additional_params(json!({"foo": "bar"}))
37//!     .build();
38//!
39//! // Use the agent for completions and prompts
40//! // Generate a chat completion response from a prompt and chat history
41//! let chat_response = agent.chat("Prompt", chat_history)
42//!     .await
43//!     .expect("Failed to chat with Agent");
44//!
45//! // Generate a prompt completion response from a simple prompt
46//! let chat_response = agent.prompt("Prompt")
47//!     .await
48//!     .expect("Failed to prompt the Agent");
49//!
50//! // Generate a completion request builder from a prompt and chat history. The builder
51//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
52//! // model parameters, etc.), but these can be overwritten.
53//! let completion_req_builder = agent.completion("Prompt", chat_history)
54//!     .await
55//!     .expect("Failed to create completion request builder");
56//!
57//! let response = completion_req_builder
58//!     .temperature(0.9) // Overwrite the agent's temperature
59//!     .send()
60//!     .await
61//!     .expect("Failed to send completion request");
62//! ```
63//!
64//! RAG Agent example
65//! ```rust
66//! use bep::{
67//!     completion::Prompt,
68//!     embeddings::EmbeddingsBuilder,
69//!     providers::openai,
70//!     vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
71//! };
72//!
73//! // Initialize OpenAI client
74//! let openai = openai::Client::from_env();
75//!
76//! // Initialize OpenAI embedding model
77//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
78//!
79//! // Create vector store, compute embeddings and load them in the store
80//! let mut vector_store = InMemoryVectorStore::default();
81//!
82//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
83//!     .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
84//!     .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
85//!     .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
86//!     .build()
87//!     .await
88//!     .expect("Failed to build embeddings");
89//!
90//! vector_store.add_documents(embeddings)
91//!     .await
92//!     .expect("Failed to add documents");
93//!
94//! // Create vector store index
95//! let index = vector_store.index(embedding_model);
96//!
97//! let agent = openai.agent(openai::GPT_4O)
98//!     .preamble("
99//!         You are a dictionary assistant here to assist the user in understanding the meaning of words.
100//!         You will find additional non-standard word definitions that could be useful below.
101//!     ")
102//!     .dynamic_context(1, index)
103//!     .build();
104//!
105//! // Prompt the agent and print the response
106//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
107//!     .expect("Failed to prompt the agent");
108//! ```
109use std::collections::HashMap;
110
111use futures::{stream, StreamExt, TryStreamExt};
112
113use crate::{
114    completion::{
115        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
116        CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
117    },
118    tool::{Tool, ToolSet},
119    vector_store::{VectorStoreError, VectorStoreIndexDyn},
120};
121
122/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble
123/// (i.e.: system prompt) and a static set of context documents and tools.
124/// All context documents and tools are always provided to the agent when prompted.
125///
126/// # Example
127/// ```
128/// use bep::{completion::Prompt, providers::openai};
129///
130/// let openai = openai::Client::from_env();
131///
132/// let comedian_agent = openai
133///     .agent("gpt-4o")
134///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
135///     .temperature(0.9)
136///     .build();
137///
138/// let response = comedian_agent.prompt("Entertain me!")
139///     .await
140///     .expect("Failed to prompt the agent");
141/// ```
142pub struct Agent<M: CompletionModel> {
143    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
144    model: M,
145    /// System prompt
146    preamble: String,
147    /// Context documents always available to the agent
148    static_context: Vec<Document>,
149    /// Tools that are always available to the agent (identified by their name)
150    static_tools: Vec<String>,
151    /// Temperature of the model
152    temperature: Option<f64>,
153    /// Maximum number of tokens for the completion
154    max_tokens: Option<u64>,
155    /// Additional parameters to be passed to the model
156    additional_params: Option<serde_json::Value>,
157    /// List of vector store, with the sample number
158    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
159    /// Dynamic tools
160    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
161    /// Actual tool implementations
162    pub tools: ToolSet,
163}
164
165impl<M: CompletionModel> Completion<M> for Agent<M> {
166    async fn completion(
167        &self,
168        prompt: &str,
169        chat_history: Vec<Message>,
170    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
171        let dynamic_context = stream::iter(self.dynamic_context.iter())
172            .then(|(num_sample, index)| async {
173                Ok::<_, VectorStoreError>(
174                    index
175                        .top_n(prompt, *num_sample)
176                        .await?
177                        .into_iter()
178                        .map(|(_, id, doc)| {
179                            // Pretty print the document if possible for better readability
180                            let text = serde_json::to_string_pretty(&doc)
181                                .unwrap_or_else(|_| doc.to_string());
182
183                            Document {
184                                id,
185                                text,
186                                additional_props: HashMap::new(),
187                            }
188                        })
189                        .collect::<Vec<_>>(),
190                )
191            })
192            .try_fold(vec![], |mut acc, docs| async {
193                acc.extend(docs);
194                Ok(acc)
195            })
196            .await
197            .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
198
199        let dynamic_tools = stream::iter(self.dynamic_tools.iter())
200            .then(|(num_sample, index)| async {
201                Ok::<_, VectorStoreError>(
202                    index
203                        .top_n_ids(prompt, *num_sample)
204                        .await?
205                        .into_iter()
206                        .map(|(_, id)| id)
207                        .collect::<Vec<_>>(),
208                )
209            })
210            .try_fold(vec![], |mut acc, docs| async {
211                for doc in docs {
212                    if let Some(tool) = self.tools.get(&doc) {
213                        acc.push(tool.definition(prompt.into()).await)
214                    } else {
215                        tracing::warn!("Tool implementation not found in toolset: {}", doc);
216                    }
217                }
218                Ok(acc)
219            })
220            .await
221            .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
222
223        let static_tools = stream::iter(self.static_tools.iter())
224            .filter_map(|toolname| async move {
225                if let Some(tool) = self.tools.get(toolname) {
226                    Some(tool.definition(prompt.into()).await)
227                } else {
228                    tracing::warn!("Tool implementation not found in toolset: {}", toolname);
229                    None
230                }
231            })
232            .collect::<Vec<_>>()
233            .await;
234
235        Ok(self
236            .model
237            .completion_request(prompt)
238            .preamble(self.preamble.clone())
239            .messages(chat_history)
240            .documents([self.static_context.clone(), dynamic_context].concat())
241            .tools([static_tools.clone(), dynamic_tools].concat())
242            .temperature_opt(self.temperature)
243            .max_tokens_opt(self.max_tokens)
244            .additional_params_opt(self.additional_params.clone()))
245    }
246}
247
248impl<M: CompletionModel> Prompt for Agent<M> {
249    async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
250        self.chat(prompt, vec![]).await
251    }
252}
253
254impl<M: CompletionModel> Prompt for &Agent<M> {
255    async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
256        self.chat(prompt, vec![]).await
257    }
258}
259
260impl<M: CompletionModel> Chat for Agent<M> {
261    async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
262        match self.completion(prompt, chat_history).await?.send().await? {
263            CompletionResponse {
264                choice: ModelChoice::Message(msg),
265                ..
266            } => Ok(msg),
267            CompletionResponse {
268                choice: ModelChoice::ToolCall(toolname, args),
269                ..
270            } => Ok(self.tools.call(&toolname, args.to_string()).await?),
271        }
272    }
273}
274
275/// A builder for creating an agent
276///
277/// # Example
278/// ```
279/// use bep::{providers::openai, agent::AgentBuilder};
280///
281/// let openai = openai::Client::from_env();
282///
283/// let gpt4o = openai.completion_model("gpt-4o");
284///
285/// // Configure the agent
286/// let agent = AgentBuilder::new(model)
287///     .preamble("System prompt")
288///     .context("Context document 1")
289///     .context("Context document 2")
290///     .tool(tool1)
291///     .tool(tool2)
292///     .temperature(0.8)
293///     .additional_params(json!({"foo": "bar"}))
294///     .build();
295/// ```
296pub struct AgentBuilder<M: CompletionModel> {
297    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
298    model: M,
299    /// System prompt
300    preamble: Option<String>,
301    /// Context documents always available to the agent
302    static_context: Vec<Document>,
303    /// Tools that are always available to the agent (by name)
304    static_tools: Vec<String>,
305    /// Additional parameters to be passed to the model
306    additional_params: Option<serde_json::Value>,
307    /// Maximum number of tokens for the completion
308    max_tokens: Option<u64>,
309    /// List of vector store, with the sample number
310    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
311    /// Dynamic tools
312    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
313    /// Temperature of the model
314    temperature: Option<f64>,
315    /// Actual tool implementations
316    tools: ToolSet,
317}
318
319impl<M: CompletionModel> AgentBuilder<M> {
320    pub fn new(model: M) -> Self {
321        Self {
322            model,
323            preamble: None,
324            static_context: vec![],
325            static_tools: vec![],
326            temperature: None,
327            max_tokens: None,
328            additional_params: None,
329            dynamic_context: vec![],
330            dynamic_tools: vec![],
331            tools: ToolSet::default(),
332        }
333    }
334
335    /// Set the system prompt
336    pub fn preamble(mut self, preamble: &str) -> Self {
337        self.preamble = Some(preamble.into());
338        self
339    }
340
341    /// Append to the preamble of the agent
342    pub fn append_preamble(mut self, doc: &str) -> Self {
343        self.preamble = Some(format!(
344            "{}\n{}",
345            self.preamble.unwrap_or_else(|| "".into()),
346            doc
347        ));
348        self
349    }
350
351    /// Add a static context document to the agent
352    pub fn context(mut self, doc: &str) -> Self {
353        self.static_context.push(Document {
354            id: format!("static_doc_{}", self.static_context.len()),
355            text: doc.into(),
356            additional_props: HashMap::new(),
357        });
358        self
359    }
360
361    /// Add a static tool to the agent
362    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
363        let toolname = tool.name();
364        self.tools.add_tool(tool);
365        self.static_tools.push(toolname);
366        self
367    }
368
369    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
370    /// dynamic context will be inserted in the request.
371    pub fn dynamic_context(
372        mut self,
373        sample: usize,
374        dynamic_context: impl VectorStoreIndexDyn + 'static,
375    ) -> Self {
376        self.dynamic_context
377            .push((sample, Box::new(dynamic_context)));
378        self
379    }
380
381    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
382    /// dynamic toolset will be inserted in the request.
383    pub fn dynamic_tools(
384        mut self,
385        sample: usize,
386        dynamic_tools: impl VectorStoreIndexDyn + 'static,
387        toolset: ToolSet,
388    ) -> Self {
389        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
390        self.tools.add_tools(toolset);
391        self
392    }
393
394    /// Set the temperature of the model
395    pub fn temperature(mut self, temperature: f64) -> Self {
396        self.temperature = Some(temperature);
397        self
398    }
399
400    /// Set the maximum number of tokens for the completion
401    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
402        self.max_tokens = Some(max_tokens);
403        self
404    }
405
406    /// Set additional parameters to be passed to the model
407    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
408        self.additional_params = Some(params);
409        self
410    }
411
412    /// Build the agent
413    pub fn build(self) -> Agent<M> {
414        Agent {
415            model: self.model,
416            preamble: self.preamble.unwrap_or_default(),
417            static_context: self.static_context,
418            static_tools: self.static_tools,
419            temperature: self.temperature,
420            max_tokens: self.max_tokens,
421            additional_params: self.additional_params,
422            dynamic_context: self.dynamic_context,
423            dynamic_tools: self.dynamic_tools,
424            tools: self.tools,
425        }
426    }
427}