rig/
agent.rs

1//! This module contains the implementation of the [Agent] struct and its builder.
2//!
3//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
4//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
5//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
6//!
7//! The [Agent] struct is highly configurable, allowing the user to define anything from
8//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
9//! context documents and tools.
10//!
11//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
12//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
13//! be used for generating chat completions.
14//!
15//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
16//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
17//! before building the agent.
18//!
19//! # Example
20//! ```rust
21//! use rig::{
22//!     completion::{Chat, Completion, Prompt},
23//!     providers::openai,
24//! };
25//!
26//! let openai = openai::Client::from_env();
27//!
28//! // Configure the agent
29//! let agent = openai.agent("gpt-4o")
30//!     .preamble("System prompt")
31//!     .context("Context document 1")
32//!     .context("Context document 2")
33//!     .tool(tool1)
34//!     .tool(tool2)
35//!     .temperature(0.8)
36//!     .additional_params(json!({"foo": "bar"}))
37//!     .build();
38//!
39//! // Use the agent for completions and prompts
40//! // Generate a chat completion response from a prompt and chat history
41//! let chat_response = agent.chat("Prompt", chat_history)
42//!     .await
43//!     .expect("Failed to chat with Agent");
44//!
45//! // Generate a prompt completion response from a simple prompt
46//! let chat_response = agent.prompt("Prompt")
47//!     .await
48//!     .expect("Failed to prompt the Agent");
49//!
50//! // Generate a completion request builder from a prompt and chat history. The builder
51//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
52//! // model parameters, etc.), but these can be overwritten.
53//! let completion_req_builder = agent.completion("Prompt", chat_history)
54//!     .await
55//!     .expect("Failed to create completion request builder");
56//!
57//! let response = completion_req_builder
58//!     .temperature(0.9) // Overwrite the agent's temperature
59//!     .send()
60//!     .await
61//!     .expect("Failed to send completion request");
62//! ```
63//!
64//! RAG Agent example
65//! ```rust
66//! use rig::{
67//!     completion::Prompt,
68//!     embeddings::EmbeddingsBuilder,
69//!     providers::openai,
70//!     vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
71//! };
72//!
73//! // Initialize OpenAI client
74//! let openai = openai::Client::from_env();
75//!
76//! // Initialize OpenAI embedding model
77//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
78//!
79//! // Create vector store, compute embeddings and load them in the store
80//! let mut vector_store = InMemoryVectorStore::default();
81//!
82//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
83//!     .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
84//!     .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
85//!     .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
86//!     .build()
87//!     .await
88//!     .expect("Failed to build embeddings");
89//!
90//! vector_store.add_documents(embeddings)
91//!     .await
92//!     .expect("Failed to add documents");
93//!
94//! // Create vector store index
95//! let index = vector_store.index(embedding_model);
96//!
97//! let agent = openai.agent(openai::GPT_4O)
98//!     .preamble("
99//!         You are a dictionary assistant here to assist the user in understanding the meaning of words.
100//!         You will find additional non-standard word definitions that could be useful below.
101//!     ")
102//!     .dynamic_context(1, index)
103//!     .build();
104//!
105//! // Prompt the agent and print the response
106//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
107//!     .expect("Failed to prompt the agent");
108//! ```
109use std::collections::HashMap;
110
111use futures::{stream, StreamExt, TryStreamExt};
112
113use crate::{
114    completion::{
115        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
116        Message, Prompt, PromptError,
117    },
118    message::AssistantContent,
119    streaming::{
120        StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
121        StreamingResult,
122    },
123    tool::{Tool, ToolSet},
124    vector_store::{VectorStoreError, VectorStoreIndexDyn},
125};
126
127/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
128/// (i.e.: system prompt) and a static set of context documents and tools.
129/// All context documents and tools are always provided to the agent when prompted.
130///
131/// # Example
132/// ```
133/// use rig::{completion::Prompt, providers::openai};
134///
135/// let openai = openai::Client::from_env();
136///
137/// let comedian_agent = openai
138///     .agent("gpt-4o")
139///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
140///     .temperature(0.9)
141///     .build();
142///
143/// let response = comedian_agent.prompt("Entertain me!")
144///     .await
145///     .expect("Failed to prompt the agent");
146/// ```
147pub struct Agent<M: CompletionModel> {
148    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
149    model: M,
150    /// System prompt
151    preamble: String,
152    /// Context documents always available to the agent
153    static_context: Vec<Document>,
154    /// Tools that are always available to the agent (identified by their name)
155    static_tools: Vec<String>,
156    /// Temperature of the model
157    temperature: Option<f64>,
158    /// Maximum number of tokens for the completion
159    max_tokens: Option<u64>,
160    /// Additional parameters to be passed to the model
161    additional_params: Option<serde_json::Value>,
162    /// List of vector store, with the sample number
163    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
164    /// Dynamic tools
165    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
166    /// Actual tool implementations
167    pub tools: ToolSet,
168}
169
170impl<M: CompletionModel> Completion<M> for Agent<M> {
171    async fn completion(
172        &self,
173        prompt: impl Into<Message> + Send,
174        chat_history: Vec<Message>,
175    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
176        let prompt = prompt.into();
177        let rag_text = prompt.rag_text().clone();
178
179        let completion_request = self
180            .model
181            .completion_request(prompt)
182            .preamble(self.preamble.clone())
183            .messages(chat_history)
184            .temperature_opt(self.temperature)
185            .max_tokens_opt(self.max_tokens)
186            .additional_params_opt(self.additional_params.clone())
187            .documents(self.static_context.clone());
188
189        let agent = match &rag_text {
190            Some(text) => {
191                let dynamic_context = stream::iter(self.dynamic_context.iter())
192                    .then(|(num_sample, index)| async {
193                        Ok::<_, VectorStoreError>(
194                            index
195                                .top_n(text, *num_sample)
196                                .await?
197                                .into_iter()
198                                .map(|(_, id, doc)| {
199                                    // Pretty print the document if possible for better readability
200                                    let text = serde_json::to_string_pretty(&doc)
201                                        .unwrap_or_else(|_| doc.to_string());
202
203                                    Document {
204                                        id,
205                                        text,
206                                        additional_props: HashMap::new(),
207                                    }
208                                })
209                                .collect::<Vec<_>>(),
210                        )
211                    })
212                    .try_fold(vec![], |mut acc, docs| async {
213                        acc.extend(docs);
214                        Ok(acc)
215                    })
216                    .await
217                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
218
219                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
220                    .then(|(num_sample, index)| async {
221                        Ok::<_, VectorStoreError>(
222                            index
223                                .top_n_ids(text, *num_sample)
224                                .await?
225                                .into_iter()
226                                .map(|(_, id)| id)
227                                .collect::<Vec<_>>(),
228                        )
229                    })
230                    .try_fold(vec![], |mut acc, docs| async {
231                        for doc in docs {
232                            if let Some(tool) = self.tools.get(&doc) {
233                                acc.push(tool.definition(text.into()).await)
234                            } else {
235                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
236                            }
237                        }
238                        Ok(acc)
239                    })
240                    .await
241                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
242
243                let static_tools = stream::iter(self.static_tools.iter())
244                    .filter_map(|toolname| async move {
245                        if let Some(tool) = self.tools.get(toolname) {
246                            Some(tool.definition(text.into()).await)
247                        } else {
248                            tracing::warn!(
249                                "Tool implementation not found in toolset: {}",
250                                toolname
251                            );
252                            None
253                        }
254                    })
255                    .collect::<Vec<_>>()
256                    .await;
257
258                completion_request
259                    .documents(dynamic_context)
260                    .tools([static_tools.clone(), dynamic_tools].concat())
261            }
262            None => {
263                let static_tools = stream::iter(self.static_tools.iter())
264                    .filter_map(|toolname| async move {
265                        if let Some(tool) = self.tools.get(toolname) {
266                            // TODO: tool definitions should likely take an `Option<String>`
267                            Some(tool.definition("".into()).await)
268                        } else {
269                            tracing::warn!(
270                                "Tool implementation not found in toolset: {}",
271                                toolname
272                            );
273                            None
274                        }
275                    })
276                    .collect::<Vec<_>>()
277                    .await;
278
279                completion_request.tools(static_tools)
280            }
281        };
282
283        Ok(agent)
284    }
285}
286
287impl<M: CompletionModel> Prompt for Agent<M> {
288    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
289        self.chat(prompt, vec![]).await
290    }
291}
292
293impl<M: CompletionModel> Prompt for &Agent<M> {
294    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
295        self.chat(prompt, vec![]).await
296    }
297}
298
299impl<M: CompletionModel> Chat for Agent<M> {
300    async fn chat(
301        &self,
302        prompt: impl Into<Message> + Send,
303        chat_history: Vec<Message>,
304    ) -> Result<String, PromptError> {
305        let resp = self.completion(prompt, chat_history).await?.send().await?;
306
307        // TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
308        match resp.choice.first() {
309            AssistantContent::Text(text) => Ok(text.text.clone()),
310            AssistantContent::ToolCall(tool_call) => Ok(self
311                .tools
312                .call(
313                    &tool_call.function.name,
314                    tool_call.function.arguments.to_string(),
315                )
316                .await?),
317        }
318    }
319}
320
321/// A builder for creating an agent
322///
323/// # Example
324/// ```
325/// use rig::{providers::openai, agent::AgentBuilder};
326///
327/// let openai = openai::Client::from_env();
328///
329/// let gpt4o = openai.completion_model("gpt-4o");
330///
331/// // Configure the agent
332/// let agent = AgentBuilder::new(model)
333///     .preamble("System prompt")
334///     .context("Context document 1")
335///     .context("Context document 2")
336///     .tool(tool1)
337///     .tool(tool2)
338///     .temperature(0.8)
339///     .additional_params(json!({"foo": "bar"}))
340///     .build();
341/// ```
342pub struct AgentBuilder<M: CompletionModel> {
343    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
344    model: M,
345    /// System prompt
346    preamble: Option<String>,
347    /// Context documents always available to the agent
348    static_context: Vec<Document>,
349    /// Tools that are always available to the agent (by name)
350    static_tools: Vec<String>,
351    /// Additional parameters to be passed to the model
352    additional_params: Option<serde_json::Value>,
353    /// Maximum number of tokens for the completion
354    max_tokens: Option<u64>,
355    /// List of vector store, with the sample number
356    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
357    /// Dynamic tools
358    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
359    /// Temperature of the model
360    temperature: Option<f64>,
361    /// Actual tool implementations
362    tools: ToolSet,
363}
364
365impl<M: CompletionModel> AgentBuilder<M> {
366    pub fn new(model: M) -> Self {
367        Self {
368            model,
369            preamble: None,
370            static_context: vec![],
371            static_tools: vec![],
372            temperature: None,
373            max_tokens: None,
374            additional_params: None,
375            dynamic_context: vec![],
376            dynamic_tools: vec![],
377            tools: ToolSet::default(),
378        }
379    }
380
381    /// Set the system prompt
382    pub fn preamble(mut self, preamble: &str) -> Self {
383        self.preamble = Some(preamble.into());
384        self
385    }
386
387    /// Append to the preamble of the agent
388    pub fn append_preamble(mut self, doc: &str) -> Self {
389        self.preamble = Some(format!(
390            "{}\n{}",
391            self.preamble.unwrap_or_else(|| "".into()),
392            doc
393        ));
394        self
395    }
396
397    /// Add a static context document to the agent
398    pub fn context(mut self, doc: &str) -> Self {
399        self.static_context.push(Document {
400            id: format!("static_doc_{}", self.static_context.len()),
401            text: doc.into(),
402            additional_props: HashMap::new(),
403        });
404        self
405    }
406
407    /// Add a static tool to the agent
408    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
409        let toolname = tool.name();
410        self.tools.add_tool(tool);
411        self.static_tools.push(toolname);
412        self
413    }
414
415    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
416    /// dynamic context will be inserted in the request.
417    pub fn dynamic_context(
418        mut self,
419        sample: usize,
420        dynamic_context: impl VectorStoreIndexDyn + 'static,
421    ) -> Self {
422        self.dynamic_context
423            .push((sample, Box::new(dynamic_context)));
424        self
425    }
426
427    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
428    /// dynamic toolset will be inserted in the request.
429    pub fn dynamic_tools(
430        mut self,
431        sample: usize,
432        dynamic_tools: impl VectorStoreIndexDyn + 'static,
433        toolset: ToolSet,
434    ) -> Self {
435        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
436        self.tools.add_tools(toolset);
437        self
438    }
439
440    /// Set the temperature of the model
441    pub fn temperature(mut self, temperature: f64) -> Self {
442        self.temperature = Some(temperature);
443        self
444    }
445
446    /// Set the maximum number of tokens for the completion
447    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
448        self.max_tokens = Some(max_tokens);
449        self
450    }
451
452    /// Set additional parameters to be passed to the model
453    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
454        self.additional_params = Some(params);
455        self
456    }
457
458    /// Build the agent
459    pub fn build(self) -> Agent<M> {
460        Agent {
461            model: self.model,
462            preamble: self.preamble.unwrap_or_default(),
463            static_context: self.static_context,
464            static_tools: self.static_tools,
465            temperature: self.temperature,
466            max_tokens: self.max_tokens,
467            additional_params: self.additional_params,
468            dynamic_context: self.dynamic_context,
469            dynamic_tools: self.dynamic_tools,
470            tools: self.tools,
471        }
472    }
473}
474
475impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
476    async fn stream_completion(
477        &self,
478        prompt: &str,
479        chat_history: Vec<Message>,
480    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
481        // Reuse the existing completion implementation to build the request
482        // This ensures streaming and non-streaming use the same request building logic
483        self.completion(prompt, chat_history).await
484    }
485}
486
487impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
488    async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
489        self.stream_chat(prompt, vec![]).await
490    }
491}
492
493impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
494    async fn stream_chat(
495        &self,
496        prompt: &str,
497        chat_history: Vec<Message>,
498    ) -> Result<StreamingResult, CompletionError> {
499        self.stream_completion(prompt, chat_history)
500            .await?
501            .stream()
502            .await
503    }
504}