rig/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionModel, Document},
7    message::ToolChoice,
8    tool::{
9        Tool, ToolDyn, ToolSet,
10        server::{ToolServer, ToolServerHandle},
11    },
12    vector_store::VectorStoreIndexDyn,
13};
14
15#[cfg(feature = "rmcp")]
16#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
17use crate::tool::rmcp::McpTool as RmcpTool;
18
19use super::Agent;
20
21/// A builder for creating an agent
22///
23/// # Example
24/// ```
25/// use rig::{providers::openai, agent::AgentBuilder};
26///
27/// let openai = openai::Client::from_env();
28///
29/// let gpt4o = openai.completion_model("gpt-4o");
30///
31/// // Configure the agent
32/// let agent = AgentBuilder::new(gpt4o)
33///     .preamble("System prompt")
34///     .context("Context document 1")
35///     .context("Context document 2")
36///     .tool(tool1)
37///     .tool(tool2)
38///     .temperature(0.8)
39///     .additional_params(json!({"foo": "bar"}))
40///     .build();
41/// ```
42pub struct AgentBuilder<M>
43where
44    M: CompletionModel,
45{
46    /// Name of the agent used for logging and debugging
47    name: Option<String>,
48    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
49    description: Option<String>,
50    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
51    model: M,
52    /// System prompt
53    preamble: Option<String>,
54    /// Context documents always available to the agent
55    static_context: Vec<Document>,
56    /// Additional parameters to be passed to the model
57    additional_params: Option<serde_json::Value>,
58    /// Maximum number of tokens for the completion
59    max_tokens: Option<u64>,
60    /// List of vector store, with the sample number
61    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
62    /// Temperature of the model
63    temperature: Option<f64>,
64    /// Tool server handle
65    tool_server_handle: Option<ToolServerHandle>,
66    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
67    tool_choice: Option<ToolChoice>,
68    /// Default maximum depth for multi-turn agent calls
69    default_max_depth: Option<usize>,
70}
71
72impl<M> AgentBuilder<M>
73where
74    M: CompletionModel,
75{
76    pub fn new(model: M) -> Self {
77        Self {
78            name: None,
79            description: None,
80            model,
81            preamble: None,
82            static_context: vec![],
83            temperature: None,
84            max_tokens: None,
85            additional_params: None,
86            dynamic_context: vec![],
87            tool_server_handle: None,
88            tool_choice: None,
89            default_max_depth: None,
90        }
91    }
92
93    /// Set the name of the agent
94    pub fn name(mut self, name: &str) -> Self {
95        self.name = Some(name.into());
96        self
97    }
98
99    /// Set the description of the agent
100    pub fn description(mut self, description: &str) -> Self {
101        self.description = Some(description.into());
102        self
103    }
104
105    /// Set the system prompt
106    pub fn preamble(mut self, preamble: &str) -> Self {
107        self.preamble = Some(preamble.into());
108        self
109    }
110
111    /// Remove the system prompt
112    pub fn without_preamble(mut self) -> Self {
113        self.preamble = None;
114        self
115    }
116
117    /// Append to the preamble of the agent
118    pub fn append_preamble(mut self, doc: &str) -> Self {
119        self.preamble = Some(format!(
120            "{}\n{}",
121            self.preamble.unwrap_or_else(|| "".into()),
122            doc
123        ));
124        self
125    }
126
127    /// Add a static context document to the agent
128    pub fn context(mut self, doc: &str) -> Self {
129        self.static_context.push(Document {
130            id: format!("static_doc_{}", self.static_context.len()),
131            text: doc.into(),
132            additional_props: HashMap::new(),
133        });
134        self
135    }
136
137    /// Add a static tool to the agent
138    pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilderSimple<M> {
139        let toolname = tool.name();
140        let tools = ToolSet::from_tools(vec![tool]);
141        let static_tools = vec![toolname];
142
143        AgentBuilderSimple {
144            name: self.name,
145            description: self.description,
146            model: self.model,
147            preamble: self.preamble,
148            static_context: self.static_context,
149            static_tools,
150            additional_params: self.additional_params,
151            max_tokens: self.max_tokens,
152            dynamic_context: vec![],
153            dynamic_tools: vec![],
154            temperature: self.temperature,
155            tools,
156            tool_choice: self.tool_choice,
157            default_max_depth: self.default_max_depth,
158        }
159    }
160
161    /// Add a vector of boxed static tools to the agent
162    /// This is useful when you need to dynamically add static tools to the agent
163    pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilderSimple<M> {
164        let static_tools = tools.iter().map(|tool| tool.name()).collect();
165        let tools = ToolSet::from_tools_boxed(tools);
166
167        AgentBuilderSimple {
168            name: self.name,
169            description: self.description,
170            model: self.model,
171            preamble: self.preamble,
172            static_context: self.static_context,
173            static_tools,
174            additional_params: self.additional_params,
175            max_tokens: self.max_tokens,
176            dynamic_context: vec![],
177            dynamic_tools: vec![],
178            temperature: self.temperature,
179            tools,
180            tool_choice: self.tool_choice,
181            default_max_depth: self.default_max_depth,
182        }
183    }
184
185    pub fn tool_server_handle(mut self, handle: ToolServerHandle) -> Self {
186        self.tool_server_handle = Some(handle);
187        self
188    }
189
190    /// Add an MCP tool (from `rmcp`) to the agent
191    #[cfg(feature = "rmcp")]
192    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
193    pub fn rmcp_tool(
194        self,
195        tool: rmcp::model::Tool,
196        client: rmcp::service::ServerSink,
197    ) -> AgentBuilderSimple<M> {
198        let toolname = tool.name.clone().to_string();
199        let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
200        let static_tools = vec![toolname];
201
202        AgentBuilderSimple {
203            name: self.name,
204            description: self.description,
205            model: self.model,
206            preamble: self.preamble,
207            static_context: self.static_context,
208            static_tools,
209            additional_params: self.additional_params,
210            max_tokens: self.max_tokens,
211            dynamic_context: vec![],
212            dynamic_tools: vec![],
213            temperature: self.temperature,
214            tools,
215            tool_choice: self.tool_choice,
216            default_max_depth: self.default_max_depth,
217        }
218    }
219
220    /// Add an array of MCP tools (from `rmcp`) to the agent
221    #[cfg(feature = "rmcp")]
222    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
223    pub fn rmcp_tools(
224        self,
225        tools: Vec<rmcp::model::Tool>,
226        client: rmcp::service::ServerSink,
227    ) -> AgentBuilderSimple<M> {
228        let (static_tools, tools) = tools.into_iter().fold(
229            (Vec::new(), Vec::new()),
230            |(mut toolnames, mut toolset), tool| {
231                let tool_name = tool.name.to_string();
232                let tool = RmcpTool::from_mcp_server(tool, client.clone());
233                toolnames.push(tool_name);
234                toolset.push(tool);
235                (toolnames, toolset)
236            },
237        );
238
239        let tools = ToolSet::from_tools(tools);
240
241        AgentBuilderSimple {
242            name: self.name,
243            description: self.description,
244            model: self.model,
245            preamble: self.preamble,
246            static_context: self.static_context,
247            static_tools,
248            additional_params: self.additional_params,
249            max_tokens: self.max_tokens,
250            dynamic_context: vec![],
251            dynamic_tools: vec![],
252            temperature: self.temperature,
253            tools,
254            tool_choice: self.tool_choice,
255            default_max_depth: self.default_max_depth,
256        }
257    }
258
259    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
260    /// dynamic context will be inserted in the request.
261    pub fn dynamic_context(
262        mut self,
263        sample: usize,
264        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
265    ) -> Self {
266        self.dynamic_context
267            .push((sample, Box::new(dynamic_context)));
268        self
269    }
270
271    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
272        self.tool_choice = Some(tool_choice);
273        self
274    }
275
276    /// Set the default maximum depth that an agent will use for multi-turn.
277    pub fn default_max_depth(mut self, default_max_depth: usize) -> Self {
278        self.default_max_depth = Some(default_max_depth);
279        self
280    }
281
282    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
283    /// dynamic toolset will be inserted in the request.
284    pub fn dynamic_tools(
285        self,
286        sample: usize,
287        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
288        toolset: ToolSet,
289    ) -> AgentBuilderSimple<M> {
290        let thing: Box<dyn VectorStoreIndexDyn + Send + Sync + 'static> = Box::new(dynamic_tools);
291        let dynamic_tools = vec![(sample, thing)];
292
293        AgentBuilderSimple {
294            name: self.name,
295            description: self.description,
296            model: self.model,
297            preamble: self.preamble,
298            static_context: self.static_context,
299            static_tools: vec![],
300            additional_params: self.additional_params,
301            max_tokens: self.max_tokens,
302            dynamic_context: vec![],
303            dynamic_tools,
304            temperature: self.temperature,
305            tools: toolset,
306            tool_choice: self.tool_choice,
307            default_max_depth: self.default_max_depth,
308        }
309    }
310
311    /// Set the temperature of the model
312    pub fn temperature(mut self, temperature: f64) -> Self {
313        self.temperature = Some(temperature);
314        self
315    }
316
317    /// Set the maximum number of tokens for the completion
318    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
319        self.max_tokens = Some(max_tokens);
320        self
321    }
322
323    /// Set additional parameters to be passed to the model
324    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
325        self.additional_params = Some(params);
326        self
327    }
328
329    /// Build the agent
330    pub fn build(self) -> Agent<M> {
331        let tool_server_handle = if let Some(handle) = self.tool_server_handle {
332            handle
333        } else {
334            ToolServer::new().run()
335        };
336
337        Agent {
338            name: self.name,
339            description: self.description,
340            model: Arc::new(self.model),
341            preamble: self.preamble,
342            static_context: self.static_context,
343            temperature: self.temperature,
344            max_tokens: self.max_tokens,
345            additional_params: self.additional_params,
346            tool_choice: self.tool_choice,
347            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
348            tool_server_handle,
349            default_max_depth: self.default_max_depth,
350        }
351    }
352}
353
354/// A fluent builder variation of `AgentBuilder`. Allows adding tools directly to the builder rather than using the tool server handle.
355///
356/// # Example
357/// ```
358/// use rig::{providers::openai, agent::AgentBuilder};
359///
360/// let openai = openai::Client::from_env();
361///
362/// let gpt4o = openai.completion_model("gpt-4o");
363///
364/// // Configure the agent
365/// let agent = AgentBuilder::new(gpt4o)
366///     .preamble("System prompt")
367///     .context("Context document 1")
368///     .context("Context document 2")
369///     .tool(tool1)
370///     .tool(tool2)
371///     .temperature(0.8)
372///     .additional_params(json!({"foo": "bar"}))
373///     .build();
374/// ```
375pub struct AgentBuilderSimple<M>
376where
377    M: CompletionModel,
378{
379    /// Name of the agent used for logging and debugging
380    name: Option<String>,
381    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
382    description: Option<String>,
383    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
384    model: M,
385    /// System prompt
386    preamble: Option<String>,
387    /// Context documents always available to the agent
388    static_context: Vec<Document>,
389    /// Tools that are always available to the agent (by name)
390    static_tools: Vec<String>,
391    /// Additional parameters to be passed to the model
392    additional_params: Option<serde_json::Value>,
393    /// Maximum number of tokens for the completion
394    max_tokens: Option<u64>,
395    /// List of vector store, with the sample number
396    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
397    /// Dynamic tools
398    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
399    /// Temperature of the model
400    temperature: Option<f64>,
401    /// Actual tool implementations
402    tools: ToolSet,
403    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
404    tool_choice: Option<ToolChoice>,
405    /// Default maximum depth for multi-turn agent calls
406    default_max_depth: Option<usize>,
407}
408
409impl<M> AgentBuilderSimple<M>
410where
411    M: CompletionModel,
412{
413    pub fn new(model: M) -> Self {
414        Self {
415            name: None,
416            description: None,
417            model,
418            preamble: None,
419            static_context: vec![],
420            static_tools: vec![],
421            temperature: None,
422            max_tokens: None,
423            additional_params: None,
424            dynamic_context: vec![],
425            dynamic_tools: vec![],
426            tools: ToolSet::default(),
427            tool_choice: None,
428            default_max_depth: None,
429        }
430    }
431
432    /// Set the name of the agent
433    pub fn name(mut self, name: &str) -> Self {
434        self.name = Some(name.into());
435        self
436    }
437
438    /// Set the description of the agent
439    pub fn description(mut self, description: &str) -> Self {
440        self.description = Some(description.into());
441        self
442    }
443
444    /// Set the system prompt
445    pub fn preamble(mut self, preamble: &str) -> Self {
446        self.preamble = Some(preamble.into());
447        self
448    }
449
450    /// Remove the system prompt
451    pub fn without_preamble(mut self) -> Self {
452        self.preamble = None;
453        self
454    }
455
456    /// Append to the preamble of the agent
457    pub fn append_preamble(mut self, doc: &str) -> Self {
458        self.preamble = Some(format!(
459            "{}\n{}",
460            self.preamble.unwrap_or_else(|| "".into()),
461            doc
462        ));
463        self
464    }
465
466    /// Add a static context document to the agent
467    pub fn context(mut self, doc: &str) -> Self {
468        self.static_context.push(Document {
469            id: format!("static_doc_{}", self.static_context.len()),
470            text: doc.into(),
471            additional_props: HashMap::new(),
472        });
473        self
474    }
475
476    /// Add a static tool to the agent
477    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
478        let toolname = tool.name();
479        self.tools.add_tool(tool);
480        self.static_tools.push(toolname);
481        self
482    }
483
484    pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
485        let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
486        let tools = ToolSet::from_tools_boxed(tools);
487        self.tools.add_tools(tools);
488        self.static_tools.extend(toolnames);
489        self
490    }
491
492    /// Add an array of MCP tools (from `rmcp`) to the agent
493    #[cfg(feature = "rmcp")]
494    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
495    pub fn rmcp_tools(
496        mut self,
497        tools: Vec<rmcp::model::Tool>,
498        client: rmcp::service::ServerSink,
499    ) -> Self {
500        for tool in tools {
501            let tool_name = tool.name.to_string();
502            let tool = RmcpTool::from_mcp_server(tool, client.clone());
503            self.static_tools.push(tool_name);
504            self.tools.add_tool(tool);
505        }
506
507        self
508    }
509
510    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
511    /// dynamic context will be inserted in the request.
512    pub fn dynamic_context(
513        mut self,
514        sample: usize,
515        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
516    ) -> Self {
517        self.dynamic_context
518            .push((sample, Box::new(dynamic_context)));
519        self
520    }
521
522    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
523        self.tool_choice = Some(tool_choice);
524        self
525    }
526
527    /// Set the default maximum depth that an agent will use for multi-turn.
528    pub fn default_max_depth(mut self, default_max_depth: usize) -> Self {
529        self.default_max_depth = Some(default_max_depth);
530        self
531    }
532
533    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
534    /// dynamic toolset will be inserted in the request.
535    pub fn dynamic_tools(
536        mut self,
537        sample: usize,
538        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
539        toolset: ToolSet,
540    ) -> Self {
541        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
542        self.tools.add_tools(toolset);
543        self
544    }
545
546    /// Set the temperature of the model
547    pub fn temperature(mut self, temperature: f64) -> Self {
548        self.temperature = Some(temperature);
549        self
550    }
551
552    /// Set the maximum number of tokens for the completion
553    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
554        self.max_tokens = Some(max_tokens);
555        self
556    }
557
558    /// Set additional parameters to be passed to the model
559    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
560        self.additional_params = Some(params);
561        self
562    }
563
564    /// Build the agent
565    pub fn build(self) -> Agent<M> {
566        let tool_server_handle = ToolServer::new()
567            .static_tool_names(self.static_tools)
568            .add_tools(self.tools)
569            .add_dynamic_tools(self.dynamic_tools)
570            .run();
571
572        Agent {
573            name: self.name,
574            description: self.description,
575            model: Arc::new(self.model),
576            preamble: self.preamble,
577            static_context: self.static_context,
578            temperature: self.temperature,
579            max_tokens: self.max_tokens,
580            additional_params: self.additional_params,
581            tool_choice: self.tool_choice,
582            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
583            tool_server_handle,
584            default_max_depth: self.default_max_depth,
585        }
586    }
587}