rig/agent/
builder.rs

1use std::collections::HashMap;
2
3use crate::{
4    completion::{CompletionModel, Document},
5    tool::{Tool, ToolSet},
6    vector_store::VectorStoreIndexDyn,
7};
8
9#[allow(deprecated)]
10#[cfg(feature = "mcp")]
11use crate::tool::mcp::McpTool;
12
13#[cfg(feature = "rmcp")]
14use crate::tool::rmcp::McpTool as RmcpTool;
15
16use super::Agent;
17
18/// A builder for creating an agent
19///
20/// # Example
21/// ```
22/// use rig::{providers::openai, agent::AgentBuilder};
23///
24/// let openai = openai::Client::from_env();
25///
26/// let gpt4o = openai.completion_model("gpt-4o");
27///
28/// // Configure the agent
29/// let agent = AgentBuilder::new(model)
30///     .preamble("System prompt")
31///     .context("Context document 1")
32///     .context("Context document 2")
33///     .tool(tool1)
34///     .tool(tool2)
35///     .temperature(0.8)
36///     .additional_params(json!({"foo": "bar"}))
37///     .build();
38/// ```
39pub struct AgentBuilder<M: CompletionModel> {
40    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
41    model: M,
42    /// System prompt
43    preamble: Option<String>,
44    /// Context documents always available to the agent
45    static_context: Vec<Document>,
46    /// Tools that are always available to the agent (by name)
47    static_tools: Vec<String>,
48    /// Additional parameters to be passed to the model
49    additional_params: Option<serde_json::Value>,
50    /// Maximum number of tokens for the completion
51    max_tokens: Option<u64>,
52    /// List of vector store, with the sample number
53    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
54    /// Dynamic tools
55    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
56    /// Temperature of the model
57    temperature: Option<f64>,
58    /// Actual tool implementations
59    tools: ToolSet,
60}
61
62impl<M: CompletionModel> AgentBuilder<M> {
63    pub fn new(model: M) -> Self {
64        Self {
65            model,
66            preamble: None,
67            static_context: vec![],
68            static_tools: vec![],
69            temperature: None,
70            max_tokens: None,
71            additional_params: None,
72            dynamic_context: vec![],
73            dynamic_tools: vec![],
74            tools: ToolSet::default(),
75        }
76    }
77
78    /// Set the system prompt
79    pub fn preamble(mut self, preamble: &str) -> Self {
80        self.preamble = Some(preamble.into());
81        self
82    }
83
84    /// Append to the preamble of the agent
85    pub fn append_preamble(mut self, doc: &str) -> Self {
86        self.preamble = Some(format!(
87            "{}\n{}",
88            self.preamble.unwrap_or_else(|| "".into()),
89            doc
90        ));
91        self
92    }
93
94    /// Add a static context document to the agent
95    pub fn context(mut self, doc: &str) -> Self {
96        self.static_context.push(Document {
97            id: format!("static_doc_{}", self.static_context.len()),
98            text: doc.into(),
99            additional_props: HashMap::new(),
100        });
101        self
102    }
103
104    /// Add a static tool to the agent
105    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
106        let toolname = tool.name();
107        self.tools.add_tool(tool);
108        self.static_tools.push(toolname);
109        self
110    }
111
112    // Add an MCP tool to the agent
113    #[cfg(feature = "mcp")]
114    pub fn mcp_tool<T: mcp_core::transport::Transport>(
115        mut self,
116        tool: mcp_core::types::Tool,
117        client: mcp_core::client::Client<T>,
118    ) -> Self {
119        let toolname = tool.name.clone();
120        #[allow(deprecated)]
121        self.tools.add_tool(McpTool::from_mcp_server(tool, client));
122        self.static_tools.push(toolname);
123        self
124    }
125
126    // Add an MCP tool (from `rmcp`) to the agent
127    #[cfg(feature = "rmcp")]
128    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
129        let toolname = tool.name.clone();
130        self.tools.add_tool(RmcpTool::from_mcp_server(tool, client));
131        self.static_tools.push(toolname.to_string());
132        self
133    }
134
135    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
136    /// dynamic context will be inserted in the request.
137    pub fn dynamic_context(
138        mut self,
139        sample: usize,
140        dynamic_context: impl VectorStoreIndexDyn + 'static,
141    ) -> Self {
142        self.dynamic_context
143            .push((sample, Box::new(dynamic_context)));
144        self
145    }
146
147    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
148    /// dynamic toolset will be inserted in the request.
149    pub fn dynamic_tools(
150        mut self,
151        sample: usize,
152        dynamic_tools: impl VectorStoreIndexDyn + 'static,
153        toolset: ToolSet,
154    ) -> Self {
155        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
156        self.tools.add_tools(toolset);
157        self
158    }
159
160    /// Set the temperature of the model
161    pub fn temperature(mut self, temperature: f64) -> Self {
162        self.temperature = Some(temperature);
163        self
164    }
165
166    /// Set the maximum number of tokens for the completion
167    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
168        self.max_tokens = Some(max_tokens);
169        self
170    }
171
172    /// Set additional parameters to be passed to the model
173    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
174        self.additional_params = Some(params);
175        self
176    }
177
178    /// Build the agent
179    pub fn build(self) -> Agent<M> {
180        Agent {
181            model: self.model,
182            preamble: self.preamble.unwrap_or_default(),
183            static_context: self.static_context,
184            static_tools: self.static_tools,
185            temperature: self.temperature,
186            max_tokens: self.max_tokens,
187            additional_params: self.additional_params,
188            dynamic_context: self.dynamic_context,
189            dynamic_tools: self.dynamic_tools,
190            tools: self.tools,
191        }
192    }
193}