rig/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use crate::{
4    completion::{CompletionModel, Document},
5    tool::{Tool, ToolSet},
6    vector_store::VectorStoreIndexDyn,
7};
8
9#[cfg(feature = "rmcp")]
10#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
11use crate::tool::rmcp::McpTool as RmcpTool;
12
13use super::Agent;
14
15/// A builder for creating an agent
16///
17/// # Example
18/// ```
19/// use rig::{providers::openai, agent::AgentBuilder};
20///
21/// let openai = openai::Client::from_env();
22///
23/// let gpt4o = openai.completion_model("gpt-4o");
24///
25/// // Configure the agent
26/// let agent = AgentBuilder::new(model)
27///     .preamble("System prompt")
28///     .context("Context document 1")
29///     .context("Context document 2")
30///     .tool(tool1)
31///     .tool(tool2)
32///     .temperature(0.8)
33///     .additional_params(json!({"foo": "bar"}))
34///     .build();
35/// ```
36pub struct AgentBuilder<M>
37where
38    M: CompletionModel,
39{
40    /// Name of the agent used for logging and debugging
41    name: Option<String>,
42    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
43    model: M,
44    /// System prompt
45    preamble: Option<String>,
46    /// Context documents always available to the agent
47    static_context: Vec<Document>,
48    /// Tools that are always available to the agent (by name)
49    static_tools: Vec<String>,
50    /// Additional parameters to be passed to the model
51    additional_params: Option<serde_json::Value>,
52    /// Maximum number of tokens for the completion
53    max_tokens: Option<u64>,
54    /// List of vector store, with the sample number
55    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
56    /// Dynamic tools
57    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
58    /// Temperature of the model
59    temperature: Option<f64>,
60    /// Actual tool implementations
61    tools: ToolSet,
62}
63
64impl<M> AgentBuilder<M>
65where
66    M: CompletionModel,
67{
68    pub fn new(model: M) -> Self {
69        Self {
70            name: None,
71            model,
72            preamble: None,
73            static_context: vec![],
74            static_tools: vec![],
75            temperature: None,
76            max_tokens: None,
77            additional_params: None,
78            dynamic_context: vec![],
79            dynamic_tools: vec![],
80            tools: ToolSet::default(),
81        }
82    }
83
84    /// Set the name of the agent
85    pub fn name(mut self, name: &str) -> Self {
86        self.name = Some(name.into());
87        self
88    }
89
90    /// Set the system prompt
91    pub fn preamble(mut self, preamble: &str) -> Self {
92        self.preamble = Some(preamble.into());
93        self
94    }
95
96    /// Remove the system prompt
97    pub fn without_preamble(mut self) -> Self {
98        self.preamble = None;
99        self
100    }
101
102    /// Append to the preamble of the agent
103    pub fn append_preamble(mut self, doc: &str) -> Self {
104        self.preamble = Some(format!(
105            "{}\n{}",
106            self.preamble.unwrap_or_else(|| "".into()),
107            doc
108        ));
109        self
110    }
111
112    /// Add a static context document to the agent
113    pub fn context(mut self, doc: &str) -> Self {
114        self.static_context.push(Document {
115            id: format!("static_doc_{}", self.static_context.len()),
116            text: doc.into(),
117            additional_props: HashMap::new(),
118        });
119        self
120    }
121
122    /// Add a static tool to the agent
123    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
124        let toolname = tool.name();
125        self.tools.add_tool(tool);
126        self.static_tools.push(toolname);
127        self
128    }
129
130    // Add an MCP tool (from `rmcp`) to the agent
131    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
132    #[cfg(feature = "rmcp")]
133    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
134        let toolname = tool.name.clone();
135        self.tools.add_tool(RmcpTool::from_mcp_server(tool, client));
136        self.static_tools.push(toolname.to_string());
137        self
138    }
139
140    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
141    /// dynamic context will be inserted in the request.
142    pub fn dynamic_context(
143        mut self,
144        sample: usize,
145        dynamic_context: impl VectorStoreIndexDyn + 'static,
146    ) -> Self {
147        self.dynamic_context
148            .push((sample, Box::new(dynamic_context)));
149        self
150    }
151
152    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
153    /// dynamic toolset will be inserted in the request.
154    pub fn dynamic_tools(
155        mut self,
156        sample: usize,
157        dynamic_tools: impl VectorStoreIndexDyn + 'static,
158        toolset: ToolSet,
159    ) -> Self {
160        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
161        self.tools.add_tools(toolset);
162        self
163    }
164
165    /// Set the temperature of the model
166    pub fn temperature(mut self, temperature: f64) -> Self {
167        self.temperature = Some(temperature);
168        self
169    }
170
171    /// Set the maximum number of tokens for the completion
172    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
173        self.max_tokens = Some(max_tokens);
174        self
175    }
176
177    /// Set additional parameters to be passed to the model
178    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
179        self.additional_params = Some(params);
180        self
181    }
182
183    /// Build the agent
184    pub fn build(self) -> Agent<M> {
185        Agent {
186            name: self.name,
187            model: Arc::new(self.model),
188            preamble: self.preamble,
189            static_context: self.static_context,
190            static_tools: self.static_tools,
191            temperature: self.temperature,
192            max_tokens: self.max_tokens,
193            additional_params: self.additional_params,
194            dynamic_context: Arc::new(self.dynamic_context),
195            dynamic_tools: Arc::new(self.dynamic_tools),
196            tools: Arc::new(self.tools),
197        }
198    }
199}