rig/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionModel, Document},
7    message::ToolChoice,
8    tool::{
9        Tool, ToolSet,
10        server::{ToolServer, ToolServerHandle},
11    },
12    vector_store::VectorStoreIndexDyn,
13};
14
15#[cfg(feature = "rmcp")]
16#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
17use crate::tool::rmcp::McpTool as RmcpTool;
18
19use super::Agent;
20
21/// A builder for creating an agent
22///
23/// # Example
24/// ```
25/// use rig::{providers::openai, agent::AgentBuilder};
26///
27/// let openai = openai::Client::from_env();
28///
29/// let gpt4o = openai.completion_model("gpt-4o");
30///
31/// // Configure the agent
32/// let agent = AgentBuilder::new(model)
33///     .preamble("System prompt")
34///     .context("Context document 1")
35///     .context("Context document 2")
36///     .tool(tool1)
37///     .tool(tool2)
38///     .temperature(0.8)
39///     .additional_params(json!({"foo": "bar"}))
40///     .build();
41/// ```
42pub struct AgentBuilder<M>
43where
44    M: CompletionModel,
45{
46    /// Name of the agent used for logging and debugging
47    name: Option<String>,
48    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
49    description: Option<String>,
50    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
51    model: M,
52    /// System prompt
53    preamble: Option<String>,
54    /// Context documents always available to the agent
55    static_context: Vec<Document>,
56    /// Additional parameters to be passed to the model
57    additional_params: Option<serde_json::Value>,
58    /// Maximum number of tokens for the completion
59    max_tokens: Option<u64>,
60    /// List of vector store, with the sample number
61    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
62    /// Temperature of the model
63    temperature: Option<f64>,
64    /// Tool server handle
65    tool_server_handle: Option<ToolServerHandle>,
66    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
67    tool_choice: Option<ToolChoice>,
68}
69
70impl<M> AgentBuilder<M>
71where
72    M: CompletionModel,
73{
74    pub fn new(model: M) -> Self {
75        Self {
76            name: None,
77            description: None,
78            model,
79            preamble: None,
80            static_context: vec![],
81            temperature: None,
82            max_tokens: None,
83            additional_params: None,
84            dynamic_context: vec![],
85            tool_server_handle: None,
86            tool_choice: None,
87        }
88    }
89
90    /// Set the name of the agent
91    pub fn name(mut self, name: &str) -> Self {
92        self.name = Some(name.into());
93        self
94    }
95
96    /// Set the description of the agent
97    pub fn description(mut self, description: &str) -> Self {
98        self.description = Some(description.into());
99        self
100    }
101
102    /// Set the system prompt
103    pub fn preamble(mut self, preamble: &str) -> Self {
104        self.preamble = Some(preamble.into());
105        self
106    }
107
108    /// Remove the system prompt
109    pub fn without_preamble(mut self) -> Self {
110        self.preamble = None;
111        self
112    }
113
114    /// Append to the preamble of the agent
115    pub fn append_preamble(mut self, doc: &str) -> Self {
116        self.preamble = Some(format!(
117            "{}\n{}",
118            self.preamble.unwrap_or_else(|| "".into()),
119            doc
120        ));
121        self
122    }
123
124    /// Add a static context document to the agent
125    pub fn context(mut self, doc: &str) -> Self {
126        self.static_context.push(Document {
127            id: format!("static_doc_{}", self.static_context.len()),
128            text: doc.into(),
129            additional_props: HashMap::new(),
130        });
131        self
132    }
133
134    /// Add a static tool to the agent
135    pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilderSimple<M> {
136        let toolname = tool.name();
137        let tools = ToolSet::from_tools(vec![tool]);
138        let static_tools = vec![toolname];
139
140        AgentBuilderSimple {
141            name: self.name,
142            description: self.description,
143            model: self.model,
144            preamble: self.preamble,
145            static_context: self.static_context,
146            static_tools,
147            additional_params: self.additional_params,
148            max_tokens: self.max_tokens,
149            dynamic_context: vec![],
150            dynamic_tools: vec![],
151            temperature: self.temperature,
152            tools,
153            tool_choice: self.tool_choice,
154        }
155    }
156
157    pub fn tool_server_handle(mut self, handle: ToolServerHandle) -> Self {
158        self.tool_server_handle = Some(handle);
159        self
160    }
161
162    /// Add an MCP tool (from `rmcp`) to the agent
163    #[cfg(feature = "rmcp")]
164    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
165    pub fn rmcp_tool(
166        self,
167        tool: rmcp::model::Tool,
168        client: rmcp::service::ServerSink,
169    ) -> AgentBuilderSimple<M> {
170        let toolname = tool.name.clone().to_string();
171        let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
172        let static_tools = vec![toolname];
173
174        AgentBuilderSimple {
175            name: self.name,
176            description: self.description,
177            model: self.model,
178            preamble: self.preamble,
179            static_context: self.static_context,
180            static_tools,
181            additional_params: self.additional_params,
182            max_tokens: self.max_tokens,
183            dynamic_context: vec![],
184            dynamic_tools: vec![],
185            temperature: self.temperature,
186            tools,
187            tool_choice: self.tool_choice,
188        }
189    }
190
191    /// Add an array of MCP tools (from `rmcp`) to the agent
192    #[cfg(feature = "rmcp")]
193    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
194    pub fn rmcp_tools(
195        self,
196        tools: Vec<rmcp::model::Tool>,
197        client: rmcp::service::ServerSink,
198    ) -> AgentBuilderSimple<M> {
199        let (static_tools, tools) = tools.into_iter().fold(
200            (Vec::new(), Vec::new()),
201            |(mut toolnames, mut toolset), tool| {
202                let tool_name = tool.name.to_string();
203                let tool = RmcpTool::from_mcp_server(tool, client.clone());
204                toolnames.push(tool_name);
205                toolset.push(tool);
206                (toolnames, toolset)
207            },
208        );
209
210        let tools = ToolSet::from_tools(tools);
211
212        AgentBuilderSimple {
213            name: self.name,
214            description: self.description,
215            model: self.model,
216            preamble: self.preamble,
217            static_context: self.static_context,
218            static_tools,
219            additional_params: self.additional_params,
220            max_tokens: self.max_tokens,
221            dynamic_context: vec![],
222            dynamic_tools: vec![],
223            temperature: self.temperature,
224            tools,
225            tool_choice: self.tool_choice,
226        }
227    }
228
229    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
230    /// dynamic context will be inserted in the request.
231    pub fn dynamic_context(
232        mut self,
233        sample: usize,
234        dynamic_context: impl VectorStoreIndexDyn + 'static,
235    ) -> Self {
236        self.dynamic_context
237            .push((sample, Box::new(dynamic_context)));
238        self
239    }
240
241    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
242        self.tool_choice = Some(tool_choice);
243        self
244    }
245
246    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
247    /// dynamic toolset will be inserted in the request.
248    pub fn dynamic_tools(
249        self,
250        sample: usize,
251        dynamic_tools: impl VectorStoreIndexDyn + 'static,
252        toolset: ToolSet,
253    ) -> AgentBuilderSimple<M> {
254        let thing: Box<dyn VectorStoreIndexDyn + 'static> = Box::new(dynamic_tools);
255        let dynamic_tools = vec![(sample, thing)];
256
257        AgentBuilderSimple {
258            name: self.name,
259            description: self.description,
260            model: self.model,
261            preamble: self.preamble,
262            static_context: self.static_context,
263            static_tools: vec![],
264            additional_params: self.additional_params,
265            max_tokens: self.max_tokens,
266            dynamic_context: vec![],
267            dynamic_tools,
268            temperature: self.temperature,
269            tools: toolset,
270            tool_choice: self.tool_choice,
271        }
272    }
273
274    /// Set the temperature of the model
275    pub fn temperature(mut self, temperature: f64) -> Self {
276        self.temperature = Some(temperature);
277        self
278    }
279
280    /// Set the maximum number of tokens for the completion
281    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
282        self.max_tokens = Some(max_tokens);
283        self
284    }
285
286    /// Set additional parameters to be passed to the model
287    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
288        self.additional_params = Some(params);
289        self
290    }
291
292    /// Build the agent
293    pub fn build(self) -> Agent<M> {
294        let tool_server_handle = if let Some(handle) = self.tool_server_handle {
295            handle
296        } else {
297            ToolServer::new().run()
298        };
299
300        Agent {
301            name: self.name,
302            description: self.description,
303            model: Arc::new(self.model),
304            preamble: self.preamble,
305            static_context: self.static_context,
306            temperature: self.temperature,
307            max_tokens: self.max_tokens,
308            additional_params: self.additional_params,
309            tool_choice: self.tool_choice,
310            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
311            tool_server_handle,
312        }
313    }
314}
315
316/// A fluent builder variation of `AgentBuilder`. Allows adding tools directly to the builder rather than using the tool server handle.
317///
318/// # Example
319/// ```
320/// use rig::{providers::openai, agent::AgentBuilder};
321///
322/// let openai = openai::Client::from_env();
323///
324/// let gpt4o = openai.completion_model("gpt-4o");
325///
326/// // Configure the agent
327/// let agent = AgentBuilder::new(model)
328///     .preamble("System prompt")
329///     .context("Context document 1")
330///     .context("Context document 2")
331///     .tool(tool1)
332///     .tool(tool2)
333///     .temperature(0.8)
334///     .additional_params(json!({"foo": "bar"}))
335///     .build();
336/// ```
337pub struct AgentBuilderSimple<M>
338where
339    M: CompletionModel,
340{
341    /// Name of the agent used for logging and debugging
342    name: Option<String>,
343    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
344    description: Option<String>,
345    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
346    model: M,
347    /// System prompt
348    preamble: Option<String>,
349    /// Context documents always available to the agent
350    static_context: Vec<Document>,
351    /// Tools that are always available to the agent (by name)
352    static_tools: Vec<String>,
353    /// Additional parameters to be passed to the model
354    additional_params: Option<serde_json::Value>,
355    /// Maximum number of tokens for the completion
356    max_tokens: Option<u64>,
357    /// List of vector store, with the sample number
358    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
359    /// Dynamic tools
360    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
361    /// Temperature of the model
362    temperature: Option<f64>,
363    /// Actual tool implementations
364    tools: ToolSet,
365    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
366    tool_choice: Option<ToolChoice>,
367}
368
369impl<M> AgentBuilderSimple<M>
370where
371    M: CompletionModel,
372{
373    pub fn new(model: M) -> Self {
374        Self {
375            name: None,
376            description: None,
377            model,
378            preamble: None,
379            static_context: vec![],
380            static_tools: vec![],
381            temperature: None,
382            max_tokens: None,
383            additional_params: None,
384            dynamic_context: vec![],
385            dynamic_tools: vec![],
386            tools: ToolSet::default(),
387            tool_choice: None,
388        }
389    }
390
391    /// Set the name of the agent
392    pub fn name(mut self, name: &str) -> Self {
393        self.name = Some(name.into());
394        self
395    }
396
397    /// Set the description of the agent
398    pub fn description(mut self, description: &str) -> Self {
399        self.description = Some(description.into());
400        self
401    }
402
403    /// Set the system prompt
404    pub fn preamble(mut self, preamble: &str) -> Self {
405        self.preamble = Some(preamble.into());
406        self
407    }
408
409    /// Remove the system prompt
410    pub fn without_preamble(mut self) -> Self {
411        self.preamble = None;
412        self
413    }
414
415    /// Append to the preamble of the agent
416    pub fn append_preamble(mut self, doc: &str) -> Self {
417        self.preamble = Some(format!(
418            "{}\n{}",
419            self.preamble.unwrap_or_else(|| "".into()),
420            doc
421        ));
422        self
423    }
424
425    /// Add a static context document to the agent
426    pub fn context(mut self, doc: &str) -> Self {
427        self.static_context.push(Document {
428            id: format!("static_doc_{}", self.static_context.len()),
429            text: doc.into(),
430            additional_props: HashMap::new(),
431        });
432        self
433    }
434
435    /// Add a static tool to the agent
436    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
437        let toolname = tool.name();
438        self.tools.add_tool(tool);
439        self.static_tools.push(toolname);
440        self
441    }
442
443    /// Add an array of MCP tools (from `rmcp`) to the agent
444    #[cfg(feature = "rmcp")]
445    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
446    pub fn rmcp_tools(
447        mut self,
448        tools: Vec<rmcp::model::Tool>,
449        client: rmcp::service::ServerSink,
450    ) -> Self {
451        for tool in tools {
452            let tool_name = tool.name.to_string();
453            let tool = RmcpTool::from_mcp_server(tool, client.clone());
454            self.static_tools.push(tool_name);
455            self.tools.add_tool(tool);
456        }
457
458        self
459    }
460
461    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
462    /// dynamic context will be inserted in the request.
463    pub fn dynamic_context(
464        mut self,
465        sample: usize,
466        dynamic_context: impl VectorStoreIndexDyn + 'static,
467    ) -> Self {
468        self.dynamic_context
469            .push((sample, Box::new(dynamic_context)));
470        self
471    }
472
473    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
474        self.tool_choice = Some(tool_choice);
475        self
476    }
477
478    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
479    /// dynamic toolset will be inserted in the request.
480    pub fn dynamic_tools(
481        mut self,
482        sample: usize,
483        dynamic_tools: impl VectorStoreIndexDyn + 'static,
484        toolset: ToolSet,
485    ) -> Self {
486        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
487        self.tools.add_tools(toolset);
488        self
489    }
490
491    /// Set the temperature of the model
492    pub fn temperature(mut self, temperature: f64) -> Self {
493        self.temperature = Some(temperature);
494        self
495    }
496
497    /// Set the maximum number of tokens for the completion
498    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
499        self.max_tokens = Some(max_tokens);
500        self
501    }
502
503    /// Set additional parameters to be passed to the model
504    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
505        self.additional_params = Some(params);
506        self
507    }
508
509    /// Build the agent
510    pub fn build(self) -> Agent<M> {
511        let tool_server_handle = ToolServer::new()
512            .static_tool_names(self.static_tools)
513            .add_tools(self.tools)
514            .add_dynamic_tools(self.dynamic_tools)
515            .run();
516
517        Agent {
518            name: self.name,
519            description: self.description,
520            model: Arc::new(self.model),
521            preamble: self.preamble,
522            static_context: self.static_context,
523            temperature: self.temperature,
524            max_tokens: self.max_tokens,
525            additional_params: self.additional_params,
526            tool_choice: self.tool_choice,
527            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
528            tool_server_handle,
529        }
530    }
531}