bep/agent.rs
1//! This module contains the implementation of the [Agent] struct and its builder.
2//!
3//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
4//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
5//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
6//!
7//! The [Agent] struct is highly configurable, allowing the user to define anything from
8//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
9//! context documents and tools.
10//!
11//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
12//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
13//! be used for generating chat completions.
14//!
15//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
16//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
17//! before building the agent.
18//!
19//! # Example
20//! ```rust
21//! use bep::{
22//! completion::{Chat, Completion, Prompt},
23//! providers::openai,
24//! };
25//!
26//! let openai = openai::Client::from_env();
27//!
28//! // Configure the agent
29//! let agent = openai.agent("gpt-4o")
30//! .preamble("System prompt")
31//! .context("Context document 1")
32//! .context("Context document 2")
33//! .tool(tool1)
34//! .tool(tool2)
35//! .temperature(0.8)
36//! .additional_params(json!({"foo": "bar"}))
37//! .build();
38//!
39//! // Use the agent for completions and prompts
40//! // Generate a chat completion response from a prompt and chat history
41//! let chat_response = agent.chat("Prompt", chat_history)
42//! .await
43//! .expect("Failed to chat with Agent");
44//!
45//! // Generate a prompt completion response from a simple prompt
46//! let chat_response = agent.prompt("Prompt")
47//! .await
48//! .expect("Failed to prompt the Agent");
49//!
50//! // Generate a completion request builder from a prompt and chat history. The builder
51//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
52//! // model parameters, etc.), but these can be overwritten.
53//! let completion_req_builder = agent.completion("Prompt", chat_history)
54//! .await
55//! .expect("Failed to create completion request builder");
56//!
57//! let response = completion_req_builder
58//! .temperature(0.9) // Overwrite the agent's temperature
59//! .send()
60//! .await
61//! .expect("Failed to send completion request");
62//! ```
63//!
64//! RAG Agent example
65//! ```rust
66//! use bep::{
67//! completion::Prompt,
68//! embeddings::EmbeddingsBuilder,
69//! providers::openai,
70//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
71//! };
72//!
73//! // Initialize OpenAI client
74//! let openai = openai::Client::from_env();
75//!
76//! // Initialize OpenAI embedding model
77//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
78//!
79//! // Create vector store, compute embeddings and load them in the store
80//! let mut vector_store = InMemoryVectorStore::default();
81//!
82//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
83//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
84//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
85//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
86//! .build()
87//! .await
88//! .expect("Failed to build embeddings");
89//!
90//! vector_store.add_documents(embeddings)
91//! .await
92//! .expect("Failed to add documents");
93//!
94//! // Create vector store index
95//! let index = vector_store.index(embedding_model);
96//!
97//! let agent = openai.agent(openai::GPT_4O)
98//! .preamble("
99//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
100//! You will find additional non-standard word definitions that could be useful below.
101//! ")
102//! .dynamic_context(1, index)
103//! .build();
104//!
105//! // Prompt the agent and print the response
106//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
107//! .expect("Failed to prompt the agent");
108//! ```
109use std::collections::HashMap;
110
111use futures::{stream, StreamExt, TryStreamExt};
112
113use crate::{
114 completion::{
115 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
116 CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
117 },
118 tool::{Tool, ToolSet},
119 vector_store::{VectorStoreError, VectorStoreIndexDyn},
120};
121
122/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble
123/// (i.e.: system prompt) and a static set of context documents and tools.
124/// All context documents and tools are always provided to the agent when prompted.
125///
126/// # Example
127/// ```
128/// use bep::{completion::Prompt, providers::openai};
129///
130/// let openai = openai::Client::from_env();
131///
132/// let comedian_agent = openai
133/// .agent("gpt-4o")
134/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
135/// .temperature(0.9)
136/// .build();
137///
138/// let response = comedian_agent.prompt("Entertain me!")
139/// .await
140/// .expect("Failed to prompt the agent");
141/// ```
142pub struct Agent<M: CompletionModel> {
143 /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
144 model: M,
145 /// System prompt
146 preamble: String,
147 /// Context documents always available to the agent
148 static_context: Vec<Document>,
149 /// Tools that are always available to the agent (identified by their name)
150 static_tools: Vec<String>,
151 /// Temperature of the model
152 temperature: Option<f64>,
153 /// Maximum number of tokens for the completion
154 max_tokens: Option<u64>,
155 /// Additional parameters to be passed to the model
156 additional_params: Option<serde_json::Value>,
157 /// List of vector store, with the sample number
158 dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
159 /// Dynamic tools
160 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
161 /// Actual tool implementations
162 pub tools: ToolSet,
163}
164
165impl<M: CompletionModel> Completion<M> for Agent<M> {
166 async fn completion(
167 &self,
168 prompt: &str,
169 chat_history: Vec<Message>,
170 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
171 let dynamic_context = stream::iter(self.dynamic_context.iter())
172 .then(|(num_sample, index)| async {
173 Ok::<_, VectorStoreError>(
174 index
175 .top_n(prompt, *num_sample)
176 .await?
177 .into_iter()
178 .map(|(_, id, doc)| {
179 // Pretty print the document if possible for better readability
180 let text = serde_json::to_string_pretty(&doc)
181 .unwrap_or_else(|_| doc.to_string());
182
183 Document {
184 id,
185 text,
186 additional_props: HashMap::new(),
187 }
188 })
189 .collect::<Vec<_>>(),
190 )
191 })
192 .try_fold(vec![], |mut acc, docs| async {
193 acc.extend(docs);
194 Ok(acc)
195 })
196 .await
197 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
198
199 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
200 .then(|(num_sample, index)| async {
201 Ok::<_, VectorStoreError>(
202 index
203 .top_n_ids(prompt, *num_sample)
204 .await?
205 .into_iter()
206 .map(|(_, id)| id)
207 .collect::<Vec<_>>(),
208 )
209 })
210 .try_fold(vec![], |mut acc, docs| async {
211 for doc in docs {
212 if let Some(tool) = self.tools.get(&doc) {
213 acc.push(tool.definition(prompt.into()).await)
214 } else {
215 tracing::warn!("Tool implementation not found in toolset: {}", doc);
216 }
217 }
218 Ok(acc)
219 })
220 .await
221 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
222
223 let static_tools = stream::iter(self.static_tools.iter())
224 .filter_map(|toolname| async move {
225 if let Some(tool) = self.tools.get(toolname) {
226 Some(tool.definition(prompt.into()).await)
227 } else {
228 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
229 None
230 }
231 })
232 .collect::<Vec<_>>()
233 .await;
234
235 Ok(self
236 .model
237 .completion_request(prompt)
238 .preamble(self.preamble.clone())
239 .messages(chat_history)
240 .documents([self.static_context.clone(), dynamic_context].concat())
241 .tools([static_tools.clone(), dynamic_tools].concat())
242 .temperature_opt(self.temperature)
243 .max_tokens_opt(self.max_tokens)
244 .additional_params_opt(self.additional_params.clone()))
245 }
246}
247
248impl<M: CompletionModel> Prompt for Agent<M> {
249 async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
250 self.chat(prompt, vec![]).await
251 }
252}
253
254impl<M: CompletionModel> Prompt for &Agent<M> {
255 async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
256 self.chat(prompt, vec![]).await
257 }
258}
259
260impl<M: CompletionModel> Chat for Agent<M> {
261 async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
262 match self.completion(prompt, chat_history).await?.send().await? {
263 CompletionResponse {
264 choice: ModelChoice::Message(msg),
265 ..
266 } => Ok(msg),
267 CompletionResponse {
268 choice: ModelChoice::ToolCall(toolname, args),
269 ..
270 } => Ok(self.tools.call(&toolname, args.to_string()).await?),
271 }
272 }
273}
274
275/// A builder for creating an agent
276///
277/// # Example
278/// ```
279/// use bep::{providers::openai, agent::AgentBuilder};
280///
281/// let openai = openai::Client::from_env();
282///
283/// let gpt4o = openai.completion_model("gpt-4o");
284///
285/// // Configure the agent
286/// let agent = AgentBuilder::new(model)
287/// .preamble("System prompt")
288/// .context("Context document 1")
289/// .context("Context document 2")
290/// .tool(tool1)
291/// .tool(tool2)
292/// .temperature(0.8)
293/// .additional_params(json!({"foo": "bar"}))
294/// .build();
295/// ```
296pub struct AgentBuilder<M: CompletionModel> {
297 /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
298 model: M,
299 /// System prompt
300 preamble: Option<String>,
301 /// Context documents always available to the agent
302 static_context: Vec<Document>,
303 /// Tools that are always available to the agent (by name)
304 static_tools: Vec<String>,
305 /// Additional parameters to be passed to the model
306 additional_params: Option<serde_json::Value>,
307 /// Maximum number of tokens for the completion
308 max_tokens: Option<u64>,
309 /// List of vector store, with the sample number
310 dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
311 /// Dynamic tools
312 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
313 /// Temperature of the model
314 temperature: Option<f64>,
315 /// Actual tool implementations
316 tools: ToolSet,
317}
318
319impl<M: CompletionModel> AgentBuilder<M> {
320 pub fn new(model: M) -> Self {
321 Self {
322 model,
323 preamble: None,
324 static_context: vec![],
325 static_tools: vec![],
326 temperature: None,
327 max_tokens: None,
328 additional_params: None,
329 dynamic_context: vec![],
330 dynamic_tools: vec![],
331 tools: ToolSet::default(),
332 }
333 }
334
335 /// Set the system prompt
336 pub fn preamble(mut self, preamble: &str) -> Self {
337 self.preamble = Some(preamble.into());
338 self
339 }
340
341 /// Append to the preamble of the agent
342 pub fn append_preamble(mut self, doc: &str) -> Self {
343 self.preamble = Some(format!(
344 "{}\n{}",
345 self.preamble.unwrap_or_else(|| "".into()),
346 doc
347 ));
348 self
349 }
350
351 /// Add a static context document to the agent
352 pub fn context(mut self, doc: &str) -> Self {
353 self.static_context.push(Document {
354 id: format!("static_doc_{}", self.static_context.len()),
355 text: doc.into(),
356 additional_props: HashMap::new(),
357 });
358 self
359 }
360
361 /// Add a static tool to the agent
362 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
363 let toolname = tool.name();
364 self.tools.add_tool(tool);
365 self.static_tools.push(toolname);
366 self
367 }
368
369 /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
370 /// dynamic context will be inserted in the request.
371 pub fn dynamic_context(
372 mut self,
373 sample: usize,
374 dynamic_context: impl VectorStoreIndexDyn + 'static,
375 ) -> Self {
376 self.dynamic_context
377 .push((sample, Box::new(dynamic_context)));
378 self
379 }
380
381 /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
382 /// dynamic toolset will be inserted in the request.
383 pub fn dynamic_tools(
384 mut self,
385 sample: usize,
386 dynamic_tools: impl VectorStoreIndexDyn + 'static,
387 toolset: ToolSet,
388 ) -> Self {
389 self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
390 self.tools.add_tools(toolset);
391 self
392 }
393
394 /// Set the temperature of the model
395 pub fn temperature(mut self, temperature: f64) -> Self {
396 self.temperature = Some(temperature);
397 self
398 }
399
400 /// Set the maximum number of tokens for the completion
401 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
402 self.max_tokens = Some(max_tokens);
403 self
404 }
405
406 /// Set additional parameters to be passed to the model
407 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
408 self.additional_params = Some(params);
409 self
410 }
411
412 /// Build the agent
413 pub fn build(self) -> Agent<M> {
414 Agent {
415 model: self.model,
416 preamble: self.preamble.unwrap_or_default(),
417 static_context: self.static_context,
418 static_tools: self.static_tools,
419 temperature: self.temperature,
420 max_tokens: self.max_tokens,
421 additional_params: self.additional_params,
422 dynamic_context: self.dynamic_context,
423 dynamic_tools: self.dynamic_tools,
424 tools: self.tools,
425 }
426 }
427}