rig/agent/
builder.rs

1use std::{collections::HashMap, sync::Arc};
2
3use crate::{
4    completion::{CompletionModel, Document},
5    tool::{Tool, ToolSet},
6    vector_store::VectorStoreIndexDyn,
7};
8
9#[cfg(feature = "rmcp")]
10#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
11use crate::tool::rmcp::McpTool as RmcpTool;
12
13use super::Agent;
14
15/// A builder for creating an agent
16///
17/// # Example
18/// ```
19/// use rig::{providers::openai, agent::AgentBuilder};
20///
21/// let openai = openai::Client::from_env();
22///
23/// let gpt4o = openai.completion_model("gpt-4o");
24///
25/// // Configure the agent
26/// let agent = AgentBuilder::new(model)
27///     .preamble("System prompt")
28///     .context("Context document 1")
29///     .context("Context document 2")
30///     .tool(tool1)
31///     .tool(tool2)
32///     .temperature(0.8)
33///     .additional_params(json!({"foo": "bar"}))
34///     .build();
35/// ```
36pub struct AgentBuilder<M>
37where
38    M: CompletionModel,
39{
40    /// Name of the agent used for logging and debugging
41    name: Option<String>,
42    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
43    model: M,
44    /// System prompt
45    preamble: Option<String>,
46    /// Context documents always available to the agent
47    static_context: Vec<Document>,
48    /// Tools that are always available to the agent (by name)
49    static_tools: Vec<String>,
50    /// Additional parameters to be passed to the model
51    additional_params: Option<serde_json::Value>,
52    /// Maximum number of tokens for the completion
53    max_tokens: Option<u64>,
54    /// List of vector store, with the sample number
55    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
56    /// Dynamic tools
57    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
58    /// Temperature of the model
59    temperature: Option<f64>,
60    /// Actual tool implementations
61    tools: ToolSet,
62}
63
64impl<M> AgentBuilder<M>
65where
66    M: CompletionModel,
67{
68    pub fn new(model: M) -> Self {
69        Self {
70            name: None,
71            model,
72            preamble: None,
73            static_context: vec![],
74            static_tools: vec![],
75            temperature: None,
76            max_tokens: None,
77            additional_params: None,
78            dynamic_context: vec![],
79            dynamic_tools: vec![],
80            tools: ToolSet::default(),
81        }
82    }
83
84    /// Set the name of the agent
85    pub fn name(mut self, name: &str) -> Self {
86        self.name = Some(name.into());
87        self
88    }
89
90    /// Set the system prompt
91    pub fn preamble(mut self, preamble: &str) -> Self {
92        self.preamble = Some(preamble.into());
93        self
94    }
95
96    /// Append to the preamble of the agent
97    pub fn append_preamble(mut self, doc: &str) -> Self {
98        self.preamble = Some(format!(
99            "{}\n{}",
100            self.preamble.unwrap_or_else(|| "".into()),
101            doc
102        ));
103        self
104    }
105
106    /// Add a static context document to the agent
107    pub fn context(mut self, doc: &str) -> Self {
108        self.static_context.push(Document {
109            id: format!("static_doc_{}", self.static_context.len()),
110            text: doc.into(),
111            additional_props: HashMap::new(),
112        });
113        self
114    }
115
116    /// Add a static tool to the agent
117    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
118        let toolname = tool.name();
119        self.tools.add_tool(tool);
120        self.static_tools.push(toolname);
121        self
122    }
123
124    // Add an MCP tool (from `rmcp`) to the agent
125    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
126    #[cfg(feature = "rmcp")]
127    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
128        let toolname = tool.name.clone();
129        self.tools.add_tool(RmcpTool::from_mcp_server(tool, client));
130        self.static_tools.push(toolname.to_string());
131        self
132    }
133
134    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
135    /// dynamic context will be inserted in the request.
136    pub fn dynamic_context(
137        mut self,
138        sample: usize,
139        dynamic_context: impl VectorStoreIndexDyn + 'static,
140    ) -> Self {
141        self.dynamic_context
142            .push((sample, Box::new(dynamic_context)));
143        self
144    }
145
146    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
147    /// dynamic toolset will be inserted in the request.
148    pub fn dynamic_tools(
149        mut self,
150        sample: usize,
151        dynamic_tools: impl VectorStoreIndexDyn + 'static,
152        toolset: ToolSet,
153    ) -> Self {
154        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
155        self.tools.add_tools(toolset);
156        self
157    }
158
159    /// Set the temperature of the model
160    pub fn temperature(mut self, temperature: f64) -> Self {
161        self.temperature = Some(temperature);
162        self
163    }
164
165    /// Set the maximum number of tokens for the completion
166    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
167        self.max_tokens = Some(max_tokens);
168        self
169    }
170
171    /// Set additional parameters to be passed to the model
172    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
173        self.additional_params = Some(params);
174        self
175    }
176
177    /// Build the agent
178    pub fn build(self) -> Agent<M> {
179        Agent {
180            name: self.name,
181            model: Arc::new(self.model),
182            preamble: self.preamble.unwrap_or_default(),
183            static_context: self.static_context,
184            static_tools: self.static_tools,
185            temperature: self.temperature,
186            max_tokens: self.max_tokens,
187            additional_params: self.additional_params,
188            dynamic_context: Arc::new(self.dynamic_context),
189            dynamic_tools: Arc::new(self.dynamic_tools),
190            tools: Arc::new(self.tools),
191        }
192    }
193}