Skip to main content

rig/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4
5use crate::{
6    agent::prompt_request::hooks::PromptHook,
7    completion::{CompletionModel, Document},
8    message::ToolChoice,
9    tool::{
10        Tool, ToolDyn, ToolSet,
11        server::{ToolServer, ToolServerHandle},
12    },
13    vector_store::VectorStoreIndexDyn,
14};
15
16#[cfg(feature = "rmcp")]
17#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
18use crate::tool::rmcp::McpTool as RmcpTool;
19
20use super::Agent;
21
22/// Marker type indicating no tool configuration has been set yet.
23///
24/// This is the default state for a new `AgentBuilder`. From this state,
25/// you can either:
26/// - Add tools via `.tool()`, `.tools()`, `.dynamic_tools()`, etc. (transitions to `WithBuilderTools`)
27/// - Set a pre-existing `ToolServerHandle` via `.tool_server_handle()` (transitions to `WithToolServerHandle`)
28/// - Call `.build()` to create an agent with no tools
29#[derive(Default)]
30pub struct NoToolConfig;
31
32/// Typestate indicating a pre-existing `ToolServerHandle` has been provided.
33///
34/// In this state, tool-adding methods (`.tool()`, `.tools()`, etc.) are not available.
35/// The provided handle will be used directly when building the agent.
36pub struct WithToolServerHandle {
37    handle: ToolServerHandle,
38}
39
40/// Typestate indicating tools are being configured via the builder API.
41///
42/// In this state, you can continue adding tools via `.tool()`, `.tools()`,
43/// `.dynamic_tools()`, etc. When `.build()` is called, a new `ToolServer`
44/// will be created with all the configured tools.
45pub struct WithBuilderTools {
46    static_tools: Vec<String>,
47    tools: ToolSet,
48    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
49}
50
51/// A builder for creating an agent
52///
53/// The builder uses a typestate pattern to enforce that tool configuration
54/// is done in a mutually exclusive way: either provide a pre-existing
55/// `ToolServerHandle`, or add tools via the builder API, but not both.
56///
57/// # Example
58/// ```
59/// use rig::{client::ProviderClient, providers::openai, agent::AgentBuilder};
60///
61/// let openai = openai::Client::from_env()?;
62///
63/// let gpt4o = openai.completion_model("gpt-4o");
64///
65/// // Configure the agent
66/// let agent = AgentBuilder::new(gpt4o)
67///     .preamble("System prompt")
68///     .context("Context document 1")
69///     .context("Context document 2")
70///     .tool(tool1)
71///     .tool(tool2)
72///     .temperature(0.8)
73///     .additional_params(json!({"foo": "bar"}))
74///     .build();
75/// ```
76pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
77where
78    M: CompletionModel,
79    P: PromptHook<M>,
80{
81    /// Name of the agent used for logging and debugging
82    name: Option<String>,
83    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
84    description: Option<String>,
85    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
86    model: M,
87    /// System prompt
88    preamble: Option<String>,
89    /// Context documents always available to the agent
90    static_context: Vec<Document>,
91    /// Additional parameters to be passed to the model
92    additional_params: Option<serde_json::Value>,
93    /// Maximum number of tokens for the completion
94    max_tokens: Option<u64>,
95    /// List of vector store, with the sample number
96    dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
97    /// Temperature of the model
98    temperature: Option<f64>,
99    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
100    tool_choice: Option<ToolChoice>,
101    /// Default maximum depth for multi-turn agent calls
102    default_max_turns: Option<usize>,
103    /// Tool configuration state (typestate pattern)
104    tool_state: ToolState,
105    /// Prompt hook
106    hook: Option<P>,
107    /// Optional JSON Schema for structured output
108    output_schema: Option<schemars::Schema>,
109}
110
111impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
112where
113    M: CompletionModel,
114    P: PromptHook<M>,
115{
116    /// Set the name of the agent
117    pub fn name(mut self, name: &str) -> Self {
118        self.name = Some(name.into());
119        self
120    }
121
122    /// Set the description of the agent
123    pub fn description(mut self, description: &str) -> Self {
124        self.description = Some(description.into());
125        self
126    }
127
128    /// Set the system prompt
129    pub fn preamble(mut self, preamble: &str) -> Self {
130        self.preamble = Some(preamble.into());
131        self
132    }
133
134    /// Remove the system prompt
135    pub fn without_preamble(mut self) -> Self {
136        self.preamble = None;
137        self
138    }
139
140    /// Append to the preamble of the agent
141    pub fn append_preamble(mut self, doc: &str) -> Self {
142        self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
143        self
144    }
145
146    /// Add a static context document to the agent
147    pub fn context(mut self, doc: &str) -> Self {
148        self.static_context.push(Document {
149            id: format!("static_doc_{}", self.static_context.len()),
150            text: doc.into(),
151            additional_props: HashMap::new(),
152        });
153        self
154    }
155
156    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
157    /// dynamic context will be inserted in the request.
158    pub fn dynamic_context(
159        mut self,
160        sample: usize,
161        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
162    ) -> Self {
163        self.dynamic_context
164            .push((sample, Arc::new(dynamic_context)));
165        self
166    }
167
168    /// Set the tool choice for the agent
169    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
170        self.tool_choice = Some(tool_choice);
171        self
172    }
173
174    /// Set the default maximum depth that an agent will use for multi-turn.
175    pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
176        self.default_max_turns = Some(default_max_turns);
177        self
178    }
179
180    /// Set the temperature of the model
181    pub fn temperature(mut self, temperature: f64) -> Self {
182        self.temperature = Some(temperature);
183        self
184    }
185
186    /// Set the maximum number of tokens for the completion
187    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
188        self.max_tokens = Some(max_tokens);
189        self
190    }
191
192    /// Set additional parameters to be passed to the model
193    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
194        self.additional_params = Some(params);
195        self
196    }
197
198    /// Set the output schema for structured output. When set, providers that support
199    /// native structured outputs will constrain the model's response to match this schema.
200    pub fn output_schema<T>(mut self) -> Self
201    where
202        T: JsonSchema,
203    {
204        self.output_schema = Some(schema_for!(T));
205        self
206    }
207
208    /// Set the output schema for structured output. In comparison to `AgentBuilder::schema()` which requires type annotation, you can put in any schema you'd like here.
209    pub fn output_schema_raw(mut self, schema: Schema) -> Self {
210        self.output_schema = Some(schema);
211        self
212    }
213}
214
215impl<M> AgentBuilder<M, (), NoToolConfig>
216where
217    M: CompletionModel,
218{
219    /// Create a new agent builder with the given model
220    pub fn new(model: M) -> Self {
221        Self {
222            name: None,
223            description: None,
224            model,
225            preamble: None,
226            static_context: vec![],
227            temperature: None,
228            max_tokens: None,
229            additional_params: None,
230            dynamic_context: vec![],
231            tool_choice: None,
232            default_max_turns: None,
233            tool_state: NoToolConfig,
234            hook: None,
235            output_schema: None,
236        }
237    }
238}
239
240impl<M, P> AgentBuilder<M, P, NoToolConfig>
241where
242    M: CompletionModel,
243    P: PromptHook<M>,
244{
245    /// Set a pre-existing ToolServerHandle for the agent.
246    ///
247    /// After calling this method, tool-adding methods (`.tool()`, `.tools()`, etc.)
248    /// will not be available. Use this when you want to share a `ToolServer`
249    /// between multiple agents or have pre-configured tools.
250    pub fn tool_server_handle(
251        self,
252        handle: ToolServerHandle,
253    ) -> AgentBuilder<M, P, WithToolServerHandle> {
254        AgentBuilder {
255            name: self.name,
256            description: self.description,
257            model: self.model,
258            preamble: self.preamble,
259            static_context: self.static_context,
260            additional_params: self.additional_params,
261            max_tokens: self.max_tokens,
262            dynamic_context: self.dynamic_context,
263            temperature: self.temperature,
264            tool_choice: self.tool_choice,
265            default_max_turns: self.default_max_turns,
266            tool_state: WithToolServerHandle { handle },
267            hook: self.hook,
268            output_schema: self.output_schema,
269        }
270    }
271
272    /// Add a static tool to the agent.
273    ///
274    /// This transitions the builder to the `WithBuilderTools` state, where
275    /// additional tools can be added but `tool_server_handle()` is no longer available.
276    pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
277        let toolname = tool.name();
278        AgentBuilder {
279            name: self.name,
280            description: self.description,
281            model: self.model,
282            preamble: self.preamble,
283            static_context: self.static_context,
284            additional_params: self.additional_params,
285            max_tokens: self.max_tokens,
286            dynamic_context: self.dynamic_context,
287            temperature: self.temperature,
288            tool_choice: self.tool_choice,
289            default_max_turns: self.default_max_turns,
290            tool_state: WithBuilderTools {
291                static_tools: vec![toolname],
292                tools: ToolSet::from_tools(vec![tool]),
293                dynamic_tools: vec![],
294            },
295            hook: self.hook,
296            output_schema: self.output_schema,
297        }
298    }
299
300    /// Add a vector of boxed static tools to the agent.
301    ///
302    /// This is useful when you need to dynamically add static tools to the agent.
303    /// Transitions the builder to the `WithBuilderTools` state.
304    pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
305        let static_tools = tools.iter().map(|tool| tool.name()).collect();
306        let tools = ToolSet::from_tools_boxed(tools);
307
308        AgentBuilder {
309            name: self.name,
310            description: self.description,
311            model: self.model,
312            preamble: self.preamble,
313            static_context: self.static_context,
314            additional_params: self.additional_params,
315            max_tokens: self.max_tokens,
316            dynamic_context: self.dynamic_context,
317            temperature: self.temperature,
318            tool_choice: self.tool_choice,
319            default_max_turns: self.default_max_turns,
320            hook: self.hook,
321            output_schema: self.output_schema,
322            tool_state: WithBuilderTools {
323                static_tools,
324                tools,
325                dynamic_tools: vec![],
326            },
327        }
328    }
329
330    /// Add an MCP tool (from `rmcp`) to the agent.
331    ///
332    /// Transitions the builder to the `WithBuilderTools` state.
333    #[cfg(feature = "rmcp")]
334    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
335    pub fn rmcp_tool(
336        self,
337        tool: rmcp::model::Tool,
338        client: rmcp::service::ServerSink,
339    ) -> AgentBuilder<M, P, WithBuilderTools> {
340        let toolname = tool.name.clone().to_string();
341        let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
342
343        AgentBuilder {
344            name: self.name,
345            description: self.description,
346            model: self.model,
347            preamble: self.preamble,
348            static_context: self.static_context,
349            additional_params: self.additional_params,
350            max_tokens: self.max_tokens,
351            dynamic_context: self.dynamic_context,
352            temperature: self.temperature,
353            tool_choice: self.tool_choice,
354            default_max_turns: self.default_max_turns,
355            hook: self.hook,
356            output_schema: self.output_schema,
357            tool_state: WithBuilderTools {
358                static_tools: vec![toolname],
359                tools,
360                dynamic_tools: vec![],
361            },
362        }
363    }
364
365    /// Add an array of MCP tools (from `rmcp`) to the agent.
366    ///
367    /// Transitions the builder to the `WithBuilderTools` state.
368    #[cfg(feature = "rmcp")]
369    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
370    pub fn rmcp_tools(
371        self,
372        tools: Vec<rmcp::model::Tool>,
373        client: rmcp::service::ServerSink,
374    ) -> AgentBuilder<M, P, WithBuilderTools> {
375        let (static_tools, tools) = tools.into_iter().fold(
376            (Vec::new(), Vec::new()),
377            |(mut toolnames, mut toolset), tool| {
378                let tool_name = tool.name.to_string();
379                let tool = RmcpTool::from_mcp_server(tool, client.clone());
380                toolnames.push(tool_name);
381                toolset.push(tool);
382                (toolnames, toolset)
383            },
384        );
385
386        let tools = ToolSet::from_tools(tools);
387
388        AgentBuilder {
389            name: self.name,
390            description: self.description,
391            model: self.model,
392            preamble: self.preamble,
393            static_context: self.static_context,
394            additional_params: self.additional_params,
395            max_tokens: self.max_tokens,
396            dynamic_context: self.dynamic_context,
397            temperature: self.temperature,
398            tool_choice: self.tool_choice,
399            default_max_turns: self.default_max_turns,
400            hook: self.hook,
401            output_schema: self.output_schema,
402            tool_state: WithBuilderTools {
403                static_tools,
404                tools,
405                dynamic_tools: vec![],
406            },
407        }
408    }
409
410    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
411    /// dynamic toolset will be inserted in the request.
412    ///
413    /// Transitions the builder to the `WithBuilderTools` state.
414    pub fn dynamic_tools(
415        self,
416        sample: usize,
417        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
418        toolset: ToolSet,
419    ) -> AgentBuilder<M, P, WithBuilderTools> {
420        AgentBuilder {
421            name: self.name,
422            description: self.description,
423            model: self.model,
424            preamble: self.preamble,
425            static_context: self.static_context,
426            additional_params: self.additional_params,
427            max_tokens: self.max_tokens,
428            dynamic_context: self.dynamic_context,
429            temperature: self.temperature,
430            tool_choice: self.tool_choice,
431            default_max_turns: self.default_max_turns,
432            hook: self.hook,
433            output_schema: self.output_schema,
434            tool_state: WithBuilderTools {
435                static_tools: vec![],
436                tools: toolset,
437                dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
438            },
439        }
440    }
441
442    /// Set the default hook for the agent.
443    ///
444    /// This hook will be used for all prompt requests unless overridden
445    /// via `.with_hook()` on the request.
446    pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, NoToolConfig>
447    where
448        P2: PromptHook<M>,
449    {
450        AgentBuilder {
451            name: self.name,
452            description: self.description,
453            model: self.model,
454            preamble: self.preamble,
455            static_context: self.static_context,
456            additional_params: self.additional_params,
457            max_tokens: self.max_tokens,
458            dynamic_context: self.dynamic_context,
459            temperature: self.temperature,
460            tool_choice: self.tool_choice,
461            default_max_turns: self.default_max_turns,
462            tool_state: self.tool_state,
463            hook: Some(hook),
464            output_schema: self.output_schema,
465        }
466    }
467
468    /// Build the agent with no tools configured.
469    ///
470    /// An empty `ToolServer` will be created for the agent.
471    pub fn build(self) -> Agent<M, P> {
472        let tool_server_handle = ToolServer::new().run();
473
474        Agent {
475            name: self.name,
476            description: self.description,
477            model: Arc::new(self.model),
478            preamble: self.preamble,
479            static_context: self.static_context,
480            temperature: self.temperature,
481            max_tokens: self.max_tokens,
482            additional_params: self.additional_params,
483            tool_choice: self.tool_choice,
484            dynamic_context: Arc::new(self.dynamic_context),
485            tool_server_handle,
486            default_max_turns: self.default_max_turns,
487            hook: self.hook,
488            output_schema: self.output_schema,
489        }
490    }
491}
492
493impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
494where
495    M: CompletionModel,
496    P: PromptHook<M>,
497{
498    /// Build the agent using the pre-configured ToolServerHandle.
499    pub fn build(self) -> Agent<M, P> {
500        Agent {
501            name: self.name,
502            description: self.description,
503            model: Arc::new(self.model),
504            preamble: self.preamble,
505            static_context: self.static_context,
506            temperature: self.temperature,
507            max_tokens: self.max_tokens,
508            additional_params: self.additional_params,
509            tool_choice: self.tool_choice,
510            dynamic_context: Arc::new(self.dynamic_context),
511            tool_server_handle: self.tool_state.handle,
512            default_max_turns: self.default_max_turns,
513            hook: self.hook,
514            output_schema: self.output_schema,
515        }
516    }
517}
518
519impl<M, P> AgentBuilder<M, P, WithBuilderTools>
520where
521    M: CompletionModel,
522    P: PromptHook<M>,
523{
524    /// Add another static tool to the agent.
525    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
526        let toolname = tool.name();
527        self.tool_state.tools.add_tool(tool);
528        self.tool_state.static_tools.push(toolname);
529        self
530    }
531
532    /// Add a vector of boxed static tools to the agent.
533    pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
534        let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
535        let tools = ToolSet::from_tools_boxed(tools);
536        self.tool_state.tools.add_tools(tools);
537        self.tool_state.static_tools.extend(toolnames);
538        self
539    }
540
541    /// Add an array of MCP tools (from `rmcp`) to the agent.
542    #[cfg(feature = "rmcp")]
543    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
544    pub fn rmcp_tools(
545        mut self,
546        tools: Vec<rmcp::model::Tool>,
547        client: rmcp::service::ServerSink,
548    ) -> Self {
549        for tool in tools {
550            let tool_name = tool.name.to_string();
551            let tool = RmcpTool::from_mcp_server(tool, client.clone());
552            self.tool_state.static_tools.push(tool_name);
553            self.tool_state.tools.add_tool(tool);
554        }
555
556        self
557    }
558
559    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
560    /// dynamic toolset will be inserted in the request.
561    pub fn dynamic_tools(
562        mut self,
563        sample: usize,
564        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
565        toolset: ToolSet,
566    ) -> Self {
567        self.tool_state
568            .dynamic_tools
569            .push((sample, Arc::new(dynamic_tools)));
570        self.tool_state.tools.add_tools(toolset);
571        self
572    }
573
574    /// Build the agent with the configured tools.
575    ///
576    /// A new `ToolServer` will be created containing all tools added via
577    /// `.tool()`, `.tools()`, `.dynamic_tools()`, etc.
578    pub fn build(self) -> Agent<M, P> {
579        let tool_server_handle = ToolServer::new()
580            .static_tool_names(self.tool_state.static_tools)
581            .add_tools(self.tool_state.tools)
582            .add_dynamic_tools(self.tool_state.dynamic_tools)
583            .run();
584
585        Agent {
586            name: self.name,
587            description: self.description,
588            model: Arc::new(self.model),
589            preamble: self.preamble,
590            static_context: self.static_context,
591            temperature: self.temperature,
592            max_tokens: self.max_tokens,
593            additional_params: self.additional_params,
594            tool_choice: self.tool_choice,
595            dynamic_context: Arc::new(self.dynamic_context),
596            tool_server_handle,
597            default_max_turns: self.default_max_turns,
598            hook: self.hook,
599            output_schema: self.output_schema,
600        }
601    }
602}