Skip to main content

rig_core/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4
5use crate::{
6    agent::prompt_request::hooks::PromptHook,
7    completion::{CompletionModel, Document},
8    memory::ConversationMemory,
9    message::ToolChoice,
10    tool::{
11        Tool, ToolDyn, ToolSet,
12        server::{ToolServer, ToolServerHandle},
13    },
14    vector_store::VectorStoreIndexDyn,
15};
16
17#[cfg(feature = "rmcp")]
18#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
19use crate::tool::rmcp::McpTool as RmcpTool;
20
21use super::Agent;
22
23/// Marker type indicating no tool configuration has been set yet.
24///
25/// This is the default state for a new `AgentBuilder`. From this state,
26/// you can either:
27/// - Add tools via `.tool()`, `.tools()`, `.dynamic_tools()`, etc. (transitions to `WithBuilderTools`)
28/// - Set a pre-existing `ToolServerHandle` via `.tool_server_handle()` (transitions to `WithToolServerHandle`)
29/// - Call `.build()` to create an agent with no tools
30#[derive(Default)]
31pub struct NoToolConfig;
32
33/// Typestate indicating a pre-existing `ToolServerHandle` has been provided.
34///
35/// In this state, tool-adding methods (`.tool()`, `.tools()`, etc.) are not available.
36/// The provided handle will be used directly when building the agent.
37pub struct WithToolServerHandle {
38    handle: ToolServerHandle,
39}
40
41/// Typestate indicating tools are being configured via the builder API.
42///
43/// In this state, you can continue adding tools via `.tool()`, `.tools()`,
44/// `.dynamic_tools()`, etc. When `.build()` is called, a new `ToolServer`
45/// will be created with all the configured tools.
46pub struct WithBuilderTools {
47    static_tools: Vec<String>,
48    tools: ToolSet,
49    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
50}
51
52/// A builder for creating an agent
53///
54/// The builder uses a typestate pattern to enforce that tool configuration
55/// is done in a mutually exclusive way: either provide a pre-existing
56/// `ToolServerHandle`, or add tools via the builder API, but not both.
57///
58/// # Example
59/// ```no_run
60/// use rig_core::{agent::AgentBuilder, client::{CompletionClient, ProviderClient}, providers::openai};
61///
62/// # fn run() -> Result<(), Box<dyn std::error::Error>> {
63/// let openai = openai::Client::from_env()?;
64///
65/// let model = openai.completion_model(openai::GPT_5_2);
66///
67/// // Configure the agent
68/// let agent = AgentBuilder::new(model)
69///     .preamble("System prompt")
70///     .context("Context document 1")
71///     .context("Context document 2")
72///     .temperature(0.8)
73///     .build();
74/// # Ok(())
75/// # }
76/// ```
77pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
78where
79    M: CompletionModel,
80    P: PromptHook<M>,
81{
82    /// Name of the agent used for logging and debugging
83    name: Option<String>,
84    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
85    description: Option<String>,
86    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
87    model: M,
88    /// System prompt
89    preamble: Option<String>,
90    /// Context documents always available to the agent
91    static_context: Vec<Document>,
92    /// Additional parameters to be passed to the model
93    additional_params: Option<serde_json::Value>,
94    /// Maximum number of tokens for the completion
95    max_tokens: Option<u64>,
96    /// List of vector store, with the sample number
97    dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
98    /// Temperature of the model
99    temperature: Option<f64>,
100    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
101    tool_choice: Option<ToolChoice>,
102    /// Default maximum depth for multi-turn agent calls
103    default_max_turns: Option<usize>,
104    /// Tool configuration state (typestate pattern)
105    tool_state: ToolState,
106    /// Prompt hook
107    hook: Option<P>,
108    /// Optional JSON Schema for structured output
109    output_schema: Option<schemars::Schema>,
110    /// Optional conversation memory backend that loads/saves history per conversation id.
111    memory: Option<Arc<dyn ConversationMemory>>,
112    /// Optional default conversation id used when none is set per-request.
113    default_conversation_id: Option<String>,
114}
115
116impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
117where
118    M: CompletionModel,
119    P: PromptHook<M>,
120{
121    /// Set the name of the agent
122    pub fn name(mut self, name: &str) -> Self {
123        self.name = Some(name.into());
124        self
125    }
126
127    /// Set the description of the agent
128    pub fn description(mut self, description: &str) -> Self {
129        self.description = Some(description.into());
130        self
131    }
132
133    /// Set the system prompt
134    pub fn preamble(mut self, preamble: &str) -> Self {
135        self.preamble = Some(preamble.into());
136        self
137    }
138
139    /// Remove the system prompt
140    pub fn without_preamble(mut self) -> Self {
141        self.preamble = None;
142        self
143    }
144
145    /// Append to the preamble of the agent
146    pub fn append_preamble(mut self, doc: &str) -> Self {
147        self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
148        self
149    }
150
151    /// Add a static context document to the agent
152    pub fn context(mut self, doc: &str) -> Self {
153        self.static_context.push(Document {
154            id: format!("static_doc_{}", self.static_context.len()),
155            text: doc.into(),
156            additional_props: HashMap::new(),
157        });
158        self
159    }
160
161    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
162    /// dynamic context will be inserted in the request.
163    pub fn dynamic_context(
164        mut self,
165        sample: usize,
166        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
167    ) -> Self {
168        self.dynamic_context
169            .push((sample, Arc::new(dynamic_context)));
170        self
171    }
172
173    /// Set the tool choice for the agent
174    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
175        self.tool_choice = Some(tool_choice);
176        self
177    }
178
179    /// Set the default maximum depth that an agent will use for multi-turn.
180    pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
181        self.default_max_turns = Some(default_max_turns);
182        self
183    }
184
185    /// Set the temperature of the model
186    pub fn temperature(mut self, temperature: f64) -> Self {
187        self.temperature = Some(temperature);
188        self
189    }
190
191    /// Set the maximum number of tokens for the completion
192    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
193        self.max_tokens = Some(max_tokens);
194        self
195    }
196
197    /// Set additional parameters to be passed to the model
198    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
199        self.additional_params = Some(params);
200        self
201    }
202
203    /// Set the output schema for structured output. When set, providers that support
204    /// native structured outputs will constrain the model's response to match this schema.
205    pub fn output_schema<T>(mut self) -> Self
206    where
207        T: JsonSchema,
208    {
209        self.output_schema = Some(schema_for!(T));
210        self
211    }
212
213    /// Set the output schema for structured output. In comparison to `AgentBuilder::schema()` which requires type annotation, you can put in any schema you'd like here.
214    pub fn output_schema_raw(mut self, schema: Schema) -> Self {
215        self.output_schema = Some(schema);
216        self
217    }
218
219    /// Attach a [`ConversationMemory`] backend.
220    ///
221    /// When set, the agent will automatically load prior conversation history before
222    /// each prompt and append the new turn after a successful response. A
223    /// `conversation_id` must be supplied either via [`AgentBuilder::conversation_id`]
224    /// or per-request via [`crate::agent::prompt_request::PromptRequest::conversation`].
225    /// If neither is set, memory is silently bypassed.
226    pub fn memory<B>(mut self, memory: B) -> Self
227    where
228        B: ConversationMemory + 'static,
229    {
230        self.memory = Some(Arc::new(memory));
231        self
232    }
233
234    /// Set a default conversation id used when none is provided per-request.
235    ///
236    /// Most agents are reused across users or threads; prefer setting the id
237    /// per-request via [`crate::agent::prompt_request::PromptRequest::conversation`].
238    pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
239        self.default_conversation_id = Some(id.into());
240        self
241    }
242
243    /// Set the default hook for the agent.
244    ///
245    /// This hook will be used for all prompt requests unless overridden
246    /// via `.with_hook()` on the request.
247    pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, ToolState>
248    where
249        P2: PromptHook<M>,
250    {
251        AgentBuilder {
252            name: self.name,
253            description: self.description,
254            model: self.model,
255            preamble: self.preamble,
256            static_context: self.static_context,
257            additional_params: self.additional_params,
258            max_tokens: self.max_tokens,
259            dynamic_context: self.dynamic_context,
260            temperature: self.temperature,
261            tool_choice: self.tool_choice,
262            default_max_turns: self.default_max_turns,
263            tool_state: self.tool_state,
264            hook: Some(hook),
265            output_schema: self.output_schema,
266            memory: self.memory,
267            default_conversation_id: self.default_conversation_id,
268        }
269    }
270}
271
272impl<M> AgentBuilder<M, (), NoToolConfig>
273where
274    M: CompletionModel,
275{
276    /// Create a new agent builder with the given model
277    pub fn new(model: M) -> Self {
278        Self {
279            name: None,
280            description: None,
281            model,
282            preamble: None,
283            static_context: vec![],
284            temperature: None,
285            max_tokens: None,
286            additional_params: None,
287            dynamic_context: vec![],
288            tool_choice: None,
289            default_max_turns: None,
290            tool_state: NoToolConfig,
291            hook: None,
292            output_schema: None,
293            memory: None,
294            default_conversation_id: None,
295        }
296    }
297}
298
299impl<M, P> AgentBuilder<M, P, NoToolConfig>
300where
301    M: CompletionModel,
302    P: PromptHook<M>,
303{
304    /// Set a pre-existing ToolServerHandle for the agent.
305    ///
306    /// After calling this method, tool-adding methods (`.tool()`, `.tools()`, etc.)
307    /// will not be available. Use this when you want to share a `ToolServer`
308    /// between multiple agents or have pre-configured tools.
309    pub fn tool_server_handle(
310        self,
311        handle: ToolServerHandle,
312    ) -> AgentBuilder<M, P, WithToolServerHandle> {
313        AgentBuilder {
314            name: self.name,
315            description: self.description,
316            model: self.model,
317            preamble: self.preamble,
318            static_context: self.static_context,
319            additional_params: self.additional_params,
320            max_tokens: self.max_tokens,
321            dynamic_context: self.dynamic_context,
322            temperature: self.temperature,
323            tool_choice: self.tool_choice,
324            default_max_turns: self.default_max_turns,
325            tool_state: WithToolServerHandle { handle },
326            hook: self.hook,
327            output_schema: self.output_schema,
328            memory: self.memory,
329            default_conversation_id: self.default_conversation_id,
330        }
331    }
332
333    /// Add a static tool to the agent.
334    ///
335    /// This transitions the builder to the `WithBuilderTools` state, where
336    /// additional tools can be added but `tool_server_handle()` is no longer available.
337    pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
338        let toolname = tool.name();
339        AgentBuilder {
340            name: self.name,
341            description: self.description,
342            model: self.model,
343            preamble: self.preamble,
344            static_context: self.static_context,
345            additional_params: self.additional_params,
346            max_tokens: self.max_tokens,
347            dynamic_context: self.dynamic_context,
348            temperature: self.temperature,
349            tool_choice: self.tool_choice,
350            default_max_turns: self.default_max_turns,
351            tool_state: WithBuilderTools {
352                static_tools: vec![toolname],
353                tools: ToolSet::from_tools(vec![tool]),
354                dynamic_tools: vec![],
355            },
356            hook: self.hook,
357            output_schema: self.output_schema,
358            memory: self.memory,
359            default_conversation_id: self.default_conversation_id,
360        }
361    }
362
363    /// Add a vector of boxed static tools to the agent.
364    ///
365    /// This is useful when you need to dynamically add static tools to the agent.
366    /// Transitions the builder to the `WithBuilderTools` state.
367    pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
368        let static_tools = tools.iter().map(|tool| tool.name()).collect();
369        let tools = ToolSet::from_tools_boxed(tools);
370
371        AgentBuilder {
372            name: self.name,
373            description: self.description,
374            model: self.model,
375            preamble: self.preamble,
376            static_context: self.static_context,
377            additional_params: self.additional_params,
378            max_tokens: self.max_tokens,
379            dynamic_context: self.dynamic_context,
380            temperature: self.temperature,
381            tool_choice: self.tool_choice,
382            default_max_turns: self.default_max_turns,
383            hook: self.hook,
384            output_schema: self.output_schema,
385            memory: self.memory,
386            default_conversation_id: self.default_conversation_id,
387            tool_state: WithBuilderTools {
388                static_tools,
389                tools,
390                dynamic_tools: vec![],
391            },
392        }
393    }
394
395    /// Add an MCP tool (from `rmcp`) to the agent.
396    ///
397    /// Transitions the builder to the `WithBuilderTools` state.
398    #[cfg(feature = "rmcp")]
399    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
400    pub fn rmcp_tool(
401        self,
402        tool: rmcp::model::Tool,
403        client: rmcp::service::ServerSink,
404    ) -> AgentBuilder<M, P, WithBuilderTools> {
405        let toolname = tool.name.clone().to_string();
406        let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
407
408        AgentBuilder {
409            name: self.name,
410            description: self.description,
411            model: self.model,
412            preamble: self.preamble,
413            static_context: self.static_context,
414            additional_params: self.additional_params,
415            max_tokens: self.max_tokens,
416            dynamic_context: self.dynamic_context,
417            temperature: self.temperature,
418            tool_choice: self.tool_choice,
419            default_max_turns: self.default_max_turns,
420            hook: self.hook,
421            output_schema: self.output_schema,
422            memory: self.memory,
423            default_conversation_id: self.default_conversation_id,
424            tool_state: WithBuilderTools {
425                static_tools: vec![toolname],
426                tools,
427                dynamic_tools: vec![],
428            },
429        }
430    }
431
432    /// Add an array of MCP tools (from `rmcp`) to the agent.
433    ///
434    /// Transitions the builder to the `WithBuilderTools` state.
435    #[cfg(feature = "rmcp")]
436    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
437    pub fn rmcp_tools(
438        self,
439        tools: Vec<rmcp::model::Tool>,
440        client: rmcp::service::ServerSink,
441    ) -> AgentBuilder<M, P, WithBuilderTools> {
442        let (static_tools, tools) = tools.into_iter().fold(
443            (Vec::new(), Vec::new()),
444            |(mut toolnames, mut toolset), tool| {
445                let tool_name = tool.name.to_string();
446                let tool = RmcpTool::from_mcp_server(tool, client.clone());
447                toolnames.push(tool_name);
448                toolset.push(tool);
449                (toolnames, toolset)
450            },
451        );
452
453        let tools = ToolSet::from_tools(tools);
454
455        AgentBuilder {
456            name: self.name,
457            description: self.description,
458            model: self.model,
459            preamble: self.preamble,
460            static_context: self.static_context,
461            additional_params: self.additional_params,
462            max_tokens: self.max_tokens,
463            dynamic_context: self.dynamic_context,
464            temperature: self.temperature,
465            tool_choice: self.tool_choice,
466            default_max_turns: self.default_max_turns,
467            hook: self.hook,
468            output_schema: self.output_schema,
469            memory: self.memory,
470            default_conversation_id: self.default_conversation_id,
471            tool_state: WithBuilderTools {
472                static_tools,
473                tools,
474                dynamic_tools: vec![],
475            },
476        }
477    }
478
479    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
480    /// dynamic toolset will be inserted in the request.
481    ///
482    /// Transitions the builder to the `WithBuilderTools` state.
483    pub fn dynamic_tools(
484        self,
485        sample: usize,
486        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
487        toolset: ToolSet,
488    ) -> AgentBuilder<M, P, WithBuilderTools> {
489        AgentBuilder {
490            name: self.name,
491            description: self.description,
492            model: self.model,
493            preamble: self.preamble,
494            static_context: self.static_context,
495            additional_params: self.additional_params,
496            max_tokens: self.max_tokens,
497            dynamic_context: self.dynamic_context,
498            temperature: self.temperature,
499            tool_choice: self.tool_choice,
500            default_max_turns: self.default_max_turns,
501            hook: self.hook,
502            output_schema: self.output_schema,
503            memory: self.memory,
504            default_conversation_id: self.default_conversation_id,
505            tool_state: WithBuilderTools {
506                static_tools: vec![],
507                tools: toolset,
508                dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
509            },
510        }
511    }
512
513    /// Build the agent with no tools configured.
514    ///
515    /// An empty `ToolServer` will be created for the agent.
516    pub fn build(self) -> Agent<M, P> {
517        let tool_server_handle = ToolServer::new().run();
518
519        Agent {
520            name: self.name,
521            description: self.description,
522            model: Arc::new(self.model),
523            preamble: self.preamble,
524            static_context: self.static_context,
525            temperature: self.temperature,
526            max_tokens: self.max_tokens,
527            additional_params: self.additional_params,
528            tool_choice: self.tool_choice,
529            dynamic_context: Arc::new(self.dynamic_context),
530            tool_server_handle,
531            default_max_turns: self.default_max_turns,
532            hook: self.hook,
533            output_schema: self.output_schema,
534            memory: self.memory,
535            default_conversation_id: self.default_conversation_id,
536        }
537    }
538}
539
540impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
541where
542    M: CompletionModel,
543    P: PromptHook<M>,
544{
545    /// Build the agent using the pre-configured ToolServerHandle.
546    pub fn build(self) -> Agent<M, P> {
547        Agent {
548            name: self.name,
549            description: self.description,
550            model: Arc::new(self.model),
551            preamble: self.preamble,
552            static_context: self.static_context,
553            temperature: self.temperature,
554            max_tokens: self.max_tokens,
555            additional_params: self.additional_params,
556            tool_choice: self.tool_choice,
557            dynamic_context: Arc::new(self.dynamic_context),
558            tool_server_handle: self.tool_state.handle,
559            default_max_turns: self.default_max_turns,
560            hook: self.hook,
561            output_schema: self.output_schema,
562            memory: self.memory,
563            default_conversation_id: self.default_conversation_id,
564        }
565    }
566}
567
568impl<M, P> AgentBuilder<M, P, WithBuilderTools>
569where
570    M: CompletionModel,
571    P: PromptHook<M>,
572{
573    /// Add another static tool to the agent.
574    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
575        let toolname = tool.name();
576        self.tool_state.tools.add_tool(tool);
577        self.tool_state.static_tools.push(toolname);
578        self
579    }
580
581    /// Add a vector of boxed static tools to the agent.
582    pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
583        let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
584        let tools = ToolSet::from_tools_boxed(tools);
585        self.tool_state.tools.add_tools(tools);
586        self.tool_state.static_tools.extend(toolnames);
587        self
588    }
589
590    /// Add an array of MCP tools (from `rmcp`) to the agent.
591    #[cfg(feature = "rmcp")]
592    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
593    pub fn rmcp_tools(
594        mut self,
595        tools: Vec<rmcp::model::Tool>,
596        client: rmcp::service::ServerSink,
597    ) -> Self {
598        for tool in tools {
599            let tool_name = tool.name.to_string();
600            let tool = RmcpTool::from_mcp_server(tool, client.clone());
601            self.tool_state.static_tools.push(tool_name);
602            self.tool_state.tools.add_tool(tool);
603        }
604
605        self
606    }
607
608    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
609    /// dynamic toolset will be inserted in the request.
610    pub fn dynamic_tools(
611        mut self,
612        sample: usize,
613        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
614        toolset: ToolSet,
615    ) -> Self {
616        self.tool_state
617            .dynamic_tools
618            .push((sample, Arc::new(dynamic_tools)));
619        self.tool_state.tools.add_tools(toolset);
620        self
621    }
622
623    /// Build the agent with the configured tools.
624    ///
625    /// A new `ToolServer` will be created containing all tools added via
626    /// `.tool()`, `.tools()`, `.dynamic_tools()`, etc.
627    pub fn build(self) -> Agent<M, P> {
628        let tool_server_handle = ToolServer::new()
629            .static_tool_names(self.tool_state.static_tools)
630            .add_tools(self.tool_state.tools)
631            .add_dynamic_tools(self.tool_state.dynamic_tools)
632            .run();
633
634        Agent {
635            name: self.name,
636            description: self.description,
637            model: Arc::new(self.model),
638            preamble: self.preamble,
639            static_context: self.static_context,
640            temperature: self.temperature,
641            max_tokens: self.max_tokens,
642            additional_params: self.additional_params,
643            tool_choice: self.tool_choice,
644            dynamic_context: Arc::new(self.dynamic_context),
645            tool_server_handle,
646            default_max_turns: self.default_max_turns,
647            hook: self.hook,
648            output_schema: self.output_schema,
649            memory: self.memory,
650            default_conversation_id: self.default_conversation_id,
651        }
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use crate::test_utils::{MockAddTool, MockCompletionModel};
659
660    #[derive(Clone)]
661    struct BuilderHook;
662
663    impl PromptHook<MockCompletionModel> for BuilderHook {}
664
665    #[test]
666    fn hook_can_be_set_after_tool_configuration() {
667        let _agent = AgentBuilder::new(MockCompletionModel::text("ok"))
668            .tool(MockAddTool)
669            .hook(BuilderHook)
670            .build();
671    }
672}