rig/agent.rs
1//! This module contains the implementation of the [Agent] struct and its builder.
2//!
3//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
4//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
5//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
6//!
7//! The [Agent] struct is highly configurable, allowing the user to define anything from
8//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
9//! context documents and tools.
10//!
11//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
12//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
13//! be used for generating chat completions.
14//!
15//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
16//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
17//! before building the agent.
18//!
19//! # Example
20//! ```rust
21//! use rig::{
22//! completion::{Chat, Completion, Prompt},
23//! providers::openai,
24//! };
25//!
26//! let openai = openai::Client::from_env();
27//!
28//! // Configure the agent
29//! let agent = openai.agent("gpt-4o")
30//! .preamble("System prompt")
31//! .context("Context document 1")
32//! .context("Context document 2")
33//! .tool(tool1)
34//! .tool(tool2)
35//! .temperature(0.8)
36//! .additional_params(json!({"foo": "bar"}))
37//! .build();
38//!
39//! // Use the agent for completions and prompts
40//! // Generate a chat completion response from a prompt and chat history
41//! let chat_response = agent.chat("Prompt", chat_history)
42//! .await
43//! .expect("Failed to chat with Agent");
44//!
45//! // Generate a prompt completion response from a simple prompt
46//! let chat_response = agent.prompt("Prompt")
47//! .await
48//! .expect("Failed to prompt the Agent");
49//!
50//! // Generate a completion request builder from a prompt and chat history. The builder
51//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
52//! // model parameters, etc.), but these can be overwritten.
53//! let completion_req_builder = agent.completion("Prompt", chat_history)
54//! .await
55//! .expect("Failed to create completion request builder");
56//!
57//! let response = completion_req_builder
58//! .temperature(0.9) // Overwrite the agent's temperature
59//! .send()
60//! .await
61//! .expect("Failed to send completion request");
62//! ```
63//!
64//! RAG Agent example
65//! ```rust
66//! use rig::{
67//! completion::Prompt,
68//! embeddings::EmbeddingsBuilder,
69//! providers::openai,
70//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
71//! };
72//!
73//! // Initialize OpenAI client
74//! let openai = openai::Client::from_env();
75//!
76//! // Initialize OpenAI embedding model
77//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
78//!
79//! // Create vector store, compute embeddings and load them in the store
80//! let mut vector_store = InMemoryVectorStore::default();
81//!
82//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
83//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
84//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
85//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
86//! .build()
87//! .await
88//! .expect("Failed to build embeddings");
89//!
90//! vector_store.add_documents(embeddings)
91//! .await
92//! .expect("Failed to add documents");
93//!
94//! // Create vector store index
95//! let index = vector_store.index(embedding_model);
96//!
97//! let agent = openai.agent(openai::GPT_4O)
98//! .preamble("
99//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
100//! You will find additional non-standard word definitions that could be useful below.
101//! ")
102//! .dynamic_context(1, index)
103//! .build();
104//!
105//! // Prompt the agent and print the response
106//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
107//! .expect("Failed to prompt the agent");
108//! ```
109use std::collections::HashMap;
110
111use futures::{stream, StreamExt, TryStreamExt};
112
113use crate::{
114 completion::{
115 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
116 Message, Prompt, PromptError,
117 },
118 message::AssistantContent,
119 streaming::{
120 StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
121 StreamingResult,
122 },
123 tool::{Tool, ToolSet},
124 vector_store::{VectorStoreError, VectorStoreIndexDyn},
125};
126
127/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
128/// (i.e.: system prompt) and a static set of context documents and tools.
129/// All context documents and tools are always provided to the agent when prompted.
130///
131/// # Example
132/// ```
133/// use rig::{completion::Prompt, providers::openai};
134///
135/// let openai = openai::Client::from_env();
136///
137/// let comedian_agent = openai
138/// .agent("gpt-4o")
139/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
140/// .temperature(0.9)
141/// .build();
142///
143/// let response = comedian_agent.prompt("Entertain me!")
144/// .await
145/// .expect("Failed to prompt the agent");
146/// ```
147pub struct Agent<M: CompletionModel> {
148 /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
149 model: M,
150 /// System prompt
151 preamble: String,
152 /// Context documents always available to the agent
153 static_context: Vec<Document>,
154 /// Tools that are always available to the agent (identified by their name)
155 static_tools: Vec<String>,
156 /// Temperature of the model
157 temperature: Option<f64>,
158 /// Maximum number of tokens for the completion
159 max_tokens: Option<u64>,
160 /// Additional parameters to be passed to the model
161 additional_params: Option<serde_json::Value>,
162 /// List of vector store, with the sample number
163 dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
164 /// Dynamic tools
165 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
166 /// Actual tool implementations
167 pub tools: ToolSet,
168}
169
170impl<M: CompletionModel> Completion<M> for Agent<M> {
171 async fn completion(
172 &self,
173 prompt: impl Into<Message> + Send,
174 chat_history: Vec<Message>,
175 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
176 let prompt = prompt.into();
177 let rag_text = prompt.rag_text().clone();
178
179 let completion_request = self
180 .model
181 .completion_request(prompt)
182 .preamble(self.preamble.clone())
183 .messages(chat_history)
184 .temperature_opt(self.temperature)
185 .max_tokens_opt(self.max_tokens)
186 .additional_params_opt(self.additional_params.clone())
187 .documents(self.static_context.clone());
188
189 let agent = match &rag_text {
190 Some(text) => {
191 let dynamic_context = stream::iter(self.dynamic_context.iter())
192 .then(|(num_sample, index)| async {
193 Ok::<_, VectorStoreError>(
194 index
195 .top_n(text, *num_sample)
196 .await?
197 .into_iter()
198 .map(|(_, id, doc)| {
199 // Pretty print the document if possible for better readability
200 let text = serde_json::to_string_pretty(&doc)
201 .unwrap_or_else(|_| doc.to_string());
202
203 Document {
204 id,
205 text,
206 additional_props: HashMap::new(),
207 }
208 })
209 .collect::<Vec<_>>(),
210 )
211 })
212 .try_fold(vec![], |mut acc, docs| async {
213 acc.extend(docs);
214 Ok(acc)
215 })
216 .await
217 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
218
219 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
220 .then(|(num_sample, index)| async {
221 Ok::<_, VectorStoreError>(
222 index
223 .top_n_ids(text, *num_sample)
224 .await?
225 .into_iter()
226 .map(|(_, id)| id)
227 .collect::<Vec<_>>(),
228 )
229 })
230 .try_fold(vec![], |mut acc, docs| async {
231 for doc in docs {
232 if let Some(tool) = self.tools.get(&doc) {
233 acc.push(tool.definition(text.into()).await)
234 } else {
235 tracing::warn!("Tool implementation not found in toolset: {}", doc);
236 }
237 }
238 Ok(acc)
239 })
240 .await
241 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
242
243 let static_tools = stream::iter(self.static_tools.iter())
244 .filter_map(|toolname| async move {
245 if let Some(tool) = self.tools.get(toolname) {
246 Some(tool.definition(text.into()).await)
247 } else {
248 tracing::warn!(
249 "Tool implementation not found in toolset: {}",
250 toolname
251 );
252 None
253 }
254 })
255 .collect::<Vec<_>>()
256 .await;
257
258 completion_request
259 .documents(dynamic_context)
260 .tools([static_tools.clone(), dynamic_tools].concat())
261 }
262 None => {
263 let static_tools = stream::iter(self.static_tools.iter())
264 .filter_map(|toolname| async move {
265 if let Some(tool) = self.tools.get(toolname) {
266 // TODO: tool definitions should likely take an `Option<String>`
267 Some(tool.definition("".into()).await)
268 } else {
269 tracing::warn!(
270 "Tool implementation not found in toolset: {}",
271 toolname
272 );
273 None
274 }
275 })
276 .collect::<Vec<_>>()
277 .await;
278
279 completion_request.tools(static_tools)
280 }
281 };
282
283 Ok(agent)
284 }
285}
286
287impl<M: CompletionModel> Prompt for Agent<M> {
288 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
289 self.chat(prompt, vec![]).await
290 }
291}
292
293impl<M: CompletionModel> Prompt for &Agent<M> {
294 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
295 self.chat(prompt, vec![]).await
296 }
297}
298
299impl<M: CompletionModel> Chat for Agent<M> {
300 async fn chat(
301 &self,
302 prompt: impl Into<Message> + Send,
303 chat_history: Vec<Message>,
304 ) -> Result<String, PromptError> {
305 let resp = self.completion(prompt, chat_history).await?.send().await?;
306
307 // TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
308 match resp.choice.first() {
309 AssistantContent::Text(text) => Ok(text.text.clone()),
310 AssistantContent::ToolCall(tool_call) => Ok(self
311 .tools
312 .call(
313 &tool_call.function.name,
314 tool_call.function.arguments.to_string(),
315 )
316 .await?),
317 }
318 }
319}
320
321/// A builder for creating an agent
322///
323/// # Example
324/// ```
325/// use rig::{providers::openai, agent::AgentBuilder};
326///
327/// let openai = openai::Client::from_env();
328///
329/// let gpt4o = openai.completion_model("gpt-4o");
330///
331/// // Configure the agent
332/// let agent = AgentBuilder::new(model)
333/// .preamble("System prompt")
334/// .context("Context document 1")
335/// .context("Context document 2")
336/// .tool(tool1)
337/// .tool(tool2)
338/// .temperature(0.8)
339/// .additional_params(json!({"foo": "bar"}))
340/// .build();
341/// ```
342pub struct AgentBuilder<M: CompletionModel> {
343 /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
344 model: M,
345 /// System prompt
346 preamble: Option<String>,
347 /// Context documents always available to the agent
348 static_context: Vec<Document>,
349 /// Tools that are always available to the agent (by name)
350 static_tools: Vec<String>,
351 /// Additional parameters to be passed to the model
352 additional_params: Option<serde_json::Value>,
353 /// Maximum number of tokens for the completion
354 max_tokens: Option<u64>,
355 /// List of vector store, with the sample number
356 dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
357 /// Dynamic tools
358 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
359 /// Temperature of the model
360 temperature: Option<f64>,
361 /// Actual tool implementations
362 tools: ToolSet,
363}
364
365impl<M: CompletionModel> AgentBuilder<M> {
366 pub fn new(model: M) -> Self {
367 Self {
368 model,
369 preamble: None,
370 static_context: vec![],
371 static_tools: vec![],
372 temperature: None,
373 max_tokens: None,
374 additional_params: None,
375 dynamic_context: vec![],
376 dynamic_tools: vec![],
377 tools: ToolSet::default(),
378 }
379 }
380
381 /// Set the system prompt
382 pub fn preamble(mut self, preamble: &str) -> Self {
383 self.preamble = Some(preamble.into());
384 self
385 }
386
387 /// Append to the preamble of the agent
388 pub fn append_preamble(mut self, doc: &str) -> Self {
389 self.preamble = Some(format!(
390 "{}\n{}",
391 self.preamble.unwrap_or_else(|| "".into()),
392 doc
393 ));
394 self
395 }
396
397 /// Add a static context document to the agent
398 pub fn context(mut self, doc: &str) -> Self {
399 self.static_context.push(Document {
400 id: format!("static_doc_{}", self.static_context.len()),
401 text: doc.into(),
402 additional_props: HashMap::new(),
403 });
404 self
405 }
406
407 /// Add a static tool to the agent
408 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
409 let toolname = tool.name();
410 self.tools.add_tool(tool);
411 self.static_tools.push(toolname);
412 self
413 }
414
415 /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
416 /// dynamic context will be inserted in the request.
417 pub fn dynamic_context(
418 mut self,
419 sample: usize,
420 dynamic_context: impl VectorStoreIndexDyn + 'static,
421 ) -> Self {
422 self.dynamic_context
423 .push((sample, Box::new(dynamic_context)));
424 self
425 }
426
427 /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
428 /// dynamic toolset will be inserted in the request.
429 pub fn dynamic_tools(
430 mut self,
431 sample: usize,
432 dynamic_tools: impl VectorStoreIndexDyn + 'static,
433 toolset: ToolSet,
434 ) -> Self {
435 self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
436 self.tools.add_tools(toolset);
437 self
438 }
439
440 /// Set the temperature of the model
441 pub fn temperature(mut self, temperature: f64) -> Self {
442 self.temperature = Some(temperature);
443 self
444 }
445
446 /// Set the maximum number of tokens for the completion
447 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
448 self.max_tokens = Some(max_tokens);
449 self
450 }
451
452 /// Set additional parameters to be passed to the model
453 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
454 self.additional_params = Some(params);
455 self
456 }
457
458 /// Build the agent
459 pub fn build(self) -> Agent<M> {
460 Agent {
461 model: self.model,
462 preamble: self.preamble.unwrap_or_default(),
463 static_context: self.static_context,
464 static_tools: self.static_tools,
465 temperature: self.temperature,
466 max_tokens: self.max_tokens,
467 additional_params: self.additional_params,
468 dynamic_context: self.dynamic_context,
469 dynamic_tools: self.dynamic_tools,
470 tools: self.tools,
471 }
472 }
473}
474
475impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
476 async fn stream_completion(
477 &self,
478 prompt: &str,
479 chat_history: Vec<Message>,
480 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
481 // Reuse the existing completion implementation to build the request
482 // This ensures streaming and non-streaming use the same request building logic
483 self.completion(prompt, chat_history).await
484 }
485}
486
487impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
488 async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
489 self.stream_chat(prompt, vec![]).await
490 }
491}
492
493impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
494 async fn stream_chat(
495 &self,
496 prompt: &str,
497 chat_history: Vec<Message>,
498 ) -> Result<StreamingResult, CompletionError> {
499 self.stream_completion(prompt, chat_history)
500 .await?
501 .stream()
502 .await
503 }
504}