Skip to main content

rig/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4use tokio::sync::RwLock;
5
6use crate::{
7    agent::prompt_request::hooks::PromptHook,
8    completion::{CompletionModel, Document},
9    message::ToolChoice,
10    tool::{
11        Tool, ToolDyn, ToolSet,
12        server::{ToolServer, ToolServerHandle},
13    },
14    vector_store::VectorStoreIndexDyn,
15};
16
17#[cfg(feature = "rmcp")]
18#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
19use crate::tool::rmcp::McpTool as RmcpTool;
20
21use super::Agent;
22
23/// Marker type indicating no tool configuration has been set yet.
24///
25/// This is the default state for a new `AgentBuilder`. From this state,
26/// you can either:
27/// - Add tools via `.tool()`, `.tools()`, `.dynamic_tools()`, etc. (transitions to `WithBuilderTools`)
28/// - Set a pre-existing `ToolServerHandle` via `.tool_server_handle()` (transitions to `WithToolServerHandle`)
29/// - Call `.build()` to create an agent with no tools
30#[derive(Default)]
31pub struct NoToolConfig;
32
33/// Typestate indicating a pre-existing `ToolServerHandle` has been provided.
34///
35/// In this state, tool-adding methods (`.tool()`, `.tools()`, etc.) are not available.
36/// The provided handle will be used directly when building the agent.
37pub struct WithToolServerHandle {
38    handle: ToolServerHandle,
39}
40
41/// Typestate indicating tools are being configured via the builder API.
42///
43/// In this state, you can continue adding tools via `.tool()`, `.tools()`,
44/// `.dynamic_tools()`, etc. When `.build()` is called, a new `ToolServer`
45/// will be created with all the configured tools.
46pub struct WithBuilderTools {
47    static_tools: Vec<String>,
48    tools: ToolSet,
49    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
50}
51
52/// A builder for creating an agent
53///
54/// The builder uses a typestate pattern to enforce that tool configuration
55/// is done in a mutually exclusive way: either provide a pre-existing
56/// `ToolServerHandle`, or add tools via the builder API, but not both.
57///
58/// # Example
59/// ```
60/// use rig::{providers::openai, agent::AgentBuilder};
61///
62/// let openai = openai::Client::from_env();
63///
64/// let gpt4o = openai.completion_model("gpt-4o");
65///
66/// // Configure the agent
67/// let agent = AgentBuilder::new(gpt4o)
68///     .preamble("System prompt")
69///     .context("Context document 1")
70///     .context("Context document 2")
71///     .tool(tool1)
72///     .tool(tool2)
73///     .temperature(0.8)
74///     .additional_params(json!({"foo": "bar"}))
75///     .build();
76/// ```
77pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
78where
79    M: CompletionModel,
80    P: PromptHook<M>,
81{
82    /// Name of the agent used for logging and debugging
83    name: Option<String>,
84    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
85    description: Option<String>,
86    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
87    model: M,
88    /// System prompt
89    preamble: Option<String>,
90    /// Context documents always available to the agent
91    static_context: Vec<Document>,
92    /// Additional parameters to be passed to the model
93    additional_params: Option<serde_json::Value>,
94    /// Maximum number of tokens for the completion
95    max_tokens: Option<u64>,
96    /// List of vector store, with the sample number
97    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
98    /// Temperature of the model
99    temperature: Option<f64>,
100    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
101    tool_choice: Option<ToolChoice>,
102    /// Default maximum depth for multi-turn agent calls
103    default_max_turns: Option<usize>,
104    /// Tool configuration state (typestate pattern)
105    tool_state: ToolState,
106    /// Prompt hook
107    hook: Option<P>,
108    /// Optional JSON Schema for structured output
109    output_schema: Option<schemars::Schema>,
110}
111
112impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
113where
114    M: CompletionModel,
115    P: PromptHook<M>,
116{
117    /// Set the name of the agent
118    pub fn name(mut self, name: &str) -> Self {
119        self.name = Some(name.into());
120        self
121    }
122
123    /// Set the description of the agent
124    pub fn description(mut self, description: &str) -> Self {
125        self.description = Some(description.into());
126        self
127    }
128
129    /// Set the system prompt
130    pub fn preamble(mut self, preamble: &str) -> Self {
131        self.preamble = Some(preamble.into());
132        self
133    }
134
135    /// Remove the system prompt
136    pub fn without_preamble(mut self) -> Self {
137        self.preamble = None;
138        self
139    }
140
141    /// Append to the preamble of the agent
142    pub fn append_preamble(mut self, doc: &str) -> Self {
143        self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
144        self
145    }
146
147    /// Add a static context document to the agent
148    pub fn context(mut self, doc: &str) -> Self {
149        self.static_context.push(Document {
150            id: format!("static_doc_{}", self.static_context.len()),
151            text: doc.into(),
152            additional_props: HashMap::new(),
153        });
154        self
155    }
156
157    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
158    /// dynamic context will be inserted in the request.
159    pub fn dynamic_context(
160        mut self,
161        sample: usize,
162        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
163    ) -> Self {
164        self.dynamic_context
165            .push((sample, Box::new(dynamic_context)));
166        self
167    }
168
169    /// Set the tool choice for the agent
170    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
171        self.tool_choice = Some(tool_choice);
172        self
173    }
174
175    /// Set the default maximum depth that an agent will use for multi-turn.
176    pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
177        self.default_max_turns = Some(default_max_turns);
178        self
179    }
180
181    /// Set the temperature of the model
182    pub fn temperature(mut self, temperature: f64) -> Self {
183        self.temperature = Some(temperature);
184        self
185    }
186
187    /// Set the maximum number of tokens for the completion
188    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
189        self.max_tokens = Some(max_tokens);
190        self
191    }
192
193    /// Set additional parameters to be passed to the model
194    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
195        self.additional_params = Some(params);
196        self
197    }
198
199    /// Set the output schema for structured output. When set, providers that support
200    /// native structured outputs will constrain the model's response to match this schema.
201    pub fn output_schema<T>(mut self) -> Self
202    where
203        T: JsonSchema,
204    {
205        self.output_schema = Some(schema_for!(T));
206        self
207    }
208
209    /// Set the output schema for structured output. In comparison to `AgentBuilder::schema()` which requires type annotation, you can put in any schema you'd like here.
210    pub fn output_schema_raw(mut self, schema: Schema) -> Self {
211        self.output_schema = Some(schema);
212        self
213    }
214}
215
216impl<M> AgentBuilder<M, (), NoToolConfig>
217where
218    M: CompletionModel,
219{
220    /// Create a new agent builder with the given model
221    pub fn new(model: M) -> Self {
222        Self {
223            name: None,
224            description: None,
225            model,
226            preamble: None,
227            static_context: vec![],
228            temperature: None,
229            max_tokens: None,
230            additional_params: None,
231            dynamic_context: vec![],
232            tool_choice: None,
233            default_max_turns: None,
234            tool_state: NoToolConfig,
235            hook: None,
236            output_schema: None,
237        }
238    }
239}
240
241impl<M, P> AgentBuilder<M, P, NoToolConfig>
242where
243    M: CompletionModel,
244    P: PromptHook<M>,
245{
246    /// Set a pre-existing ToolServerHandle for the agent.
247    ///
248    /// After calling this method, tool-adding methods (`.tool()`, `.tools()`, etc.)
249    /// will not be available. Use this when you want to share a `ToolServer`
250    /// between multiple agents or have pre-configured tools.
251    pub fn tool_server_handle(
252        self,
253        handle: ToolServerHandle,
254    ) -> AgentBuilder<M, P, WithToolServerHandle> {
255        AgentBuilder {
256            name: self.name,
257            description: self.description,
258            model: self.model,
259            preamble: self.preamble,
260            static_context: self.static_context,
261            additional_params: self.additional_params,
262            max_tokens: self.max_tokens,
263            dynamic_context: self.dynamic_context,
264            temperature: self.temperature,
265            tool_choice: self.tool_choice,
266            default_max_turns: self.default_max_turns,
267            tool_state: WithToolServerHandle { handle },
268            hook: self.hook,
269            output_schema: self.output_schema,
270        }
271    }
272
273    /// Add a static tool to the agent.
274    ///
275    /// This transitions the builder to the `WithBuilderTools` state, where
276    /// additional tools can be added but `tool_server_handle()` is no longer available.
277    pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
278        let toolname = tool.name();
279        AgentBuilder {
280            name: self.name,
281            description: self.description,
282            model: self.model,
283            preamble: self.preamble,
284            static_context: self.static_context,
285            additional_params: self.additional_params,
286            max_tokens: self.max_tokens,
287            dynamic_context: self.dynamic_context,
288            temperature: self.temperature,
289            tool_choice: self.tool_choice,
290            default_max_turns: self.default_max_turns,
291            tool_state: WithBuilderTools {
292                static_tools: vec![toolname],
293                tools: ToolSet::from_tools(vec![tool]),
294                dynamic_tools: vec![],
295            },
296            hook: self.hook,
297            output_schema: self.output_schema,
298        }
299    }
300
301    /// Add a vector of boxed static tools to the agent.
302    ///
303    /// This is useful when you need to dynamically add static tools to the agent.
304    /// Transitions the builder to the `WithBuilderTools` state.
305    pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
306        let static_tools = tools.iter().map(|tool| tool.name()).collect();
307        let tools = ToolSet::from_tools_boxed(tools);
308
309        AgentBuilder {
310            name: self.name,
311            description: self.description,
312            model: self.model,
313            preamble: self.preamble,
314            static_context: self.static_context,
315            additional_params: self.additional_params,
316            max_tokens: self.max_tokens,
317            dynamic_context: self.dynamic_context,
318            temperature: self.temperature,
319            tool_choice: self.tool_choice,
320            default_max_turns: self.default_max_turns,
321            hook: self.hook,
322            output_schema: self.output_schema,
323            tool_state: WithBuilderTools {
324                static_tools,
325                tools,
326                dynamic_tools: vec![],
327            },
328        }
329    }
330
331    /// Add an MCP tool (from `rmcp`) to the agent.
332    ///
333    /// Transitions the builder to the `WithBuilderTools` state.
334    #[cfg(feature = "rmcp")]
335    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
336    pub fn rmcp_tool(
337        self,
338        tool: rmcp::model::Tool,
339        client: rmcp::service::ServerSink,
340    ) -> AgentBuilder<M, P, WithBuilderTools> {
341        let toolname = tool.name.clone().to_string();
342        let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
343
344        AgentBuilder {
345            name: self.name,
346            description: self.description,
347            model: self.model,
348            preamble: self.preamble,
349            static_context: self.static_context,
350            additional_params: self.additional_params,
351            max_tokens: self.max_tokens,
352            dynamic_context: self.dynamic_context,
353            temperature: self.temperature,
354            tool_choice: self.tool_choice,
355            default_max_turns: self.default_max_turns,
356            hook: self.hook,
357            output_schema: self.output_schema,
358            tool_state: WithBuilderTools {
359                static_tools: vec![toolname],
360                tools,
361                dynamic_tools: vec![],
362            },
363        }
364    }
365
366    /// Add an array of MCP tools (from `rmcp`) to the agent.
367    ///
368    /// Transitions the builder to the `WithBuilderTools` state.
369    #[cfg(feature = "rmcp")]
370    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
371    pub fn rmcp_tools(
372        self,
373        tools: Vec<rmcp::model::Tool>,
374        client: rmcp::service::ServerSink,
375    ) -> AgentBuilder<M, P, WithBuilderTools> {
376        let (static_tools, tools) = tools.into_iter().fold(
377            (Vec::new(), Vec::new()),
378            |(mut toolnames, mut toolset), tool| {
379                let tool_name = tool.name.to_string();
380                let tool = RmcpTool::from_mcp_server(tool, client.clone());
381                toolnames.push(tool_name);
382                toolset.push(tool);
383                (toolnames, toolset)
384            },
385        );
386
387        let tools = ToolSet::from_tools(tools);
388
389        AgentBuilder {
390            name: self.name,
391            description: self.description,
392            model: self.model,
393            preamble: self.preamble,
394            static_context: self.static_context,
395            additional_params: self.additional_params,
396            max_tokens: self.max_tokens,
397            dynamic_context: self.dynamic_context,
398            temperature: self.temperature,
399            tool_choice: self.tool_choice,
400            default_max_turns: self.default_max_turns,
401            hook: self.hook,
402            output_schema: self.output_schema,
403            tool_state: WithBuilderTools {
404                static_tools,
405                tools,
406                dynamic_tools: vec![],
407            },
408        }
409    }
410
411    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
412    /// dynamic toolset will be inserted in the request.
413    ///
414    /// Transitions the builder to the `WithBuilderTools` state.
415    pub fn dynamic_tools(
416        self,
417        sample: usize,
418        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
419        toolset: ToolSet,
420    ) -> AgentBuilder<M, P, WithBuilderTools> {
421        AgentBuilder {
422            name: self.name,
423            description: self.description,
424            model: self.model,
425            preamble: self.preamble,
426            static_context: self.static_context,
427            additional_params: self.additional_params,
428            max_tokens: self.max_tokens,
429            dynamic_context: self.dynamic_context,
430            temperature: self.temperature,
431            tool_choice: self.tool_choice,
432            default_max_turns: self.default_max_turns,
433            hook: self.hook,
434            output_schema: self.output_schema,
435            tool_state: WithBuilderTools {
436                static_tools: vec![],
437                tools: toolset,
438                dynamic_tools: vec![(sample, Box::new(dynamic_tools))],
439            },
440        }
441    }
442
443    /// Set the default hook for the agent.
444    ///
445    /// This hook will be used for all prompt requests unless overridden
446    /// via `.with_hook()` on the request.
447    pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, NoToolConfig>
448    where
449        P2: PromptHook<M>,
450    {
451        AgentBuilder {
452            name: self.name,
453            description: self.description,
454            model: self.model,
455            preamble: self.preamble,
456            static_context: self.static_context,
457            additional_params: self.additional_params,
458            max_tokens: self.max_tokens,
459            dynamic_context: self.dynamic_context,
460            temperature: self.temperature,
461            tool_choice: self.tool_choice,
462            default_max_turns: self.default_max_turns,
463            tool_state: self.tool_state,
464            hook: Some(hook),
465            output_schema: self.output_schema,
466        }
467    }
468
469    /// Build the agent with no tools configured.
470    ///
471    /// An empty `ToolServer` will be created for the agent.
472    pub fn build(self) -> Agent<M, P> {
473        let tool_server_handle = ToolServer::new().run();
474
475        Agent {
476            name: self.name,
477            description: self.description,
478            model: Arc::new(self.model),
479            preamble: self.preamble,
480            static_context: self.static_context,
481            temperature: self.temperature,
482            max_tokens: self.max_tokens,
483            additional_params: self.additional_params,
484            tool_choice: self.tool_choice,
485            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
486            tool_server_handle,
487            default_max_turns: self.default_max_turns,
488            hook: self.hook,
489            output_schema: self.output_schema,
490        }
491    }
492}
493
494impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
495where
496    M: CompletionModel,
497    P: PromptHook<M>,
498{
499    /// Build the agent using the pre-configured ToolServerHandle.
500    pub fn build(self) -> Agent<M, P> {
501        Agent {
502            name: self.name,
503            description: self.description,
504            model: Arc::new(self.model),
505            preamble: self.preamble,
506            static_context: self.static_context,
507            temperature: self.temperature,
508            max_tokens: self.max_tokens,
509            additional_params: self.additional_params,
510            tool_choice: self.tool_choice,
511            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
512            tool_server_handle: self.tool_state.handle,
513            default_max_turns: self.default_max_turns,
514            hook: self.hook,
515            output_schema: self.output_schema,
516        }
517    }
518}
519
520impl<M, P> AgentBuilder<M, P, WithBuilderTools>
521where
522    M: CompletionModel,
523    P: PromptHook<M>,
524{
525    /// Add another static tool to the agent.
526    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
527        let toolname = tool.name();
528        self.tool_state.tools.add_tool(tool);
529        self.tool_state.static_tools.push(toolname);
530        self
531    }
532
533    /// Add a vector of boxed static tools to the agent.
534    pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
535        let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
536        let tools = ToolSet::from_tools_boxed(tools);
537        self.tool_state.tools.add_tools(tools);
538        self.tool_state.static_tools.extend(toolnames);
539        self
540    }
541
542    /// Add an array of MCP tools (from `rmcp`) to the agent.
543    #[cfg(feature = "rmcp")]
544    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
545    pub fn rmcp_tools(
546        mut self,
547        tools: Vec<rmcp::model::Tool>,
548        client: rmcp::service::ServerSink,
549    ) -> Self {
550        for tool in tools {
551            let tool_name = tool.name.to_string();
552            let tool = RmcpTool::from_mcp_server(tool, client.clone());
553            self.tool_state.static_tools.push(tool_name);
554            self.tool_state.tools.add_tool(tool);
555        }
556
557        self
558    }
559
560    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
561    /// dynamic toolset will be inserted in the request.
562    pub fn dynamic_tools(
563        mut self,
564        sample: usize,
565        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
566        toolset: ToolSet,
567    ) -> Self {
568        self.tool_state
569            .dynamic_tools
570            .push((sample, Box::new(dynamic_tools)));
571        self.tool_state.tools.add_tools(toolset);
572        self
573    }
574
575    /// Build the agent with the configured tools.
576    ///
577    /// A new `ToolServer` will be created containing all tools added via
578    /// `.tool()`, `.tools()`, `.dynamic_tools()`, etc.
579    pub fn build(self) -> Agent<M, P> {
580        let tool_server_handle = ToolServer::new()
581            .static_tool_names(self.tool_state.static_tools)
582            .add_tools(self.tool_state.tools)
583            .add_dynamic_tools(self.tool_state.dynamic_tools)
584            .run();
585
586        Agent {
587            name: self.name,
588            description: self.description,
589            model: Arc::new(self.model),
590            preamble: self.preamble,
591            static_context: self.static_context,
592            temperature: self.temperature,
593            max_tokens: self.max_tokens,
594            additional_params: self.additional_params,
595            tool_choice: self.tool_choice,
596            dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
597            tool_server_handle,
598            default_max_turns: self.default_max_turns,
599            hook: self.hook,
600            output_schema: self.output_schema,
601        }
602    }
603}