rig/agent/
builder.rs

1use std::collections::HashMap;
2
3use crate::{
4    completion::{CompletionModel, Document},
5    tool::{Tool, ToolSet},
6    vector_store::VectorStoreIndexDyn,
7};
8
9#[cfg(feature = "rmcp")]
10use crate::tool::rmcp::McpTool as RmcpTool;
11
12use super::Agent;
13
14/// A builder for creating an agent
15///
16/// # Example
17/// ```
18/// use rig::{providers::openai, agent::AgentBuilder};
19///
20/// let openai = openai::Client::from_env();
21///
22/// let gpt4o = openai.completion_model("gpt-4o");
23///
24/// // Configure the agent
25/// let agent = AgentBuilder::new(model)
26///     .preamble("System prompt")
27///     .context("Context document 1")
28///     .context("Context document 2")
29///     .tool(tool1)
30///     .tool(tool2)
31///     .temperature(0.8)
32///     .additional_params(json!({"foo": "bar"}))
33///     .build();
34/// ```
35pub struct AgentBuilder<M: CompletionModel> {
36    /// Name of the agent used for logging and debugging
37    name: Option<String>,
38    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
39    model: M,
40    /// System prompt
41    preamble: Option<String>,
42    /// Context documents always available to the agent
43    static_context: Vec<Document>,
44    /// Tools that are always available to the agent (by name)
45    static_tools: Vec<String>,
46    /// Additional parameters to be passed to the model
47    additional_params: Option<serde_json::Value>,
48    /// Maximum number of tokens for the completion
49    max_tokens: Option<u64>,
50    /// List of vector store, with the sample number
51    dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
52    /// Dynamic tools
53    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
54    /// Temperature of the model
55    temperature: Option<f64>,
56    /// Actual tool implementations
57    tools: ToolSet,
58}
59
60impl<M: CompletionModel> AgentBuilder<M> {
61    pub fn new(model: M) -> Self {
62        Self {
63            name: None,
64            model,
65            preamble: None,
66            static_context: vec![],
67            static_tools: vec![],
68            temperature: None,
69            max_tokens: None,
70            additional_params: None,
71            dynamic_context: vec![],
72            dynamic_tools: vec![],
73            tools: ToolSet::default(),
74        }
75    }
76
77    /// Set the name of the agent
78    pub fn name(mut self, name: &str) -> Self {
79        self.name = Some(name.into());
80        self
81    }
82
83    /// Set the system prompt
84    pub fn preamble(mut self, preamble: &str) -> Self {
85        self.preamble = Some(preamble.into());
86        self
87    }
88
89    /// Append to the preamble of the agent
90    pub fn append_preamble(mut self, doc: &str) -> Self {
91        self.preamble = Some(format!(
92            "{}\n{}",
93            self.preamble.unwrap_or_else(|| "".into()),
94            doc
95        ));
96        self
97    }
98
99    /// Add a static context document to the agent
100    pub fn context(mut self, doc: &str) -> Self {
101        self.static_context.push(Document {
102            id: format!("static_doc_{}", self.static_context.len()),
103            text: doc.into(),
104            additional_props: HashMap::new(),
105        });
106        self
107    }
108
109    /// Add a static tool to the agent
110    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
111        let toolname = tool.name();
112        self.tools.add_tool(tool);
113        self.static_tools.push(toolname);
114        self
115    }
116
117    // Add an MCP tool (from `rmcp`) to the agent
118    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
119    #[cfg(feature = "rmcp")]
120    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
121        let toolname = tool.name.clone();
122        self.tools.add_tool(RmcpTool::from_mcp_server(tool, client));
123        self.static_tools.push(toolname.to_string());
124        self
125    }
126
127    /// Add some dynamic context to the agent. On each prompt, `sample` documents from the
128    /// dynamic context will be inserted in the request.
129    pub fn dynamic_context(
130        mut self,
131        sample: usize,
132        dynamic_context: impl VectorStoreIndexDyn + 'static,
133    ) -> Self {
134        self.dynamic_context
135            .push((sample, Box::new(dynamic_context)));
136        self
137    }
138
139    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
140    /// dynamic toolset will be inserted in the request.
141    pub fn dynamic_tools(
142        mut self,
143        sample: usize,
144        dynamic_tools: impl VectorStoreIndexDyn + 'static,
145        toolset: ToolSet,
146    ) -> Self {
147        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
148        self.tools.add_tools(toolset);
149        self
150    }
151
152    /// Set the temperature of the model
153    pub fn temperature(mut self, temperature: f64) -> Self {
154        self.temperature = Some(temperature);
155        self
156    }
157
158    /// Set the maximum number of tokens for the completion
159    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
160        self.max_tokens = Some(max_tokens);
161        self
162    }
163
164    /// Set additional parameters to be passed to the model
165    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
166        self.additional_params = Some(params);
167        self
168    }
169
170    /// Build the agent
171    pub fn build(self) -> Agent<M> {
172        Agent {
173            name: self.name,
174            model: self.model,
175            preamble: self.preamble.unwrap_or_default(),
176            static_context: self.static_context,
177            static_tools: self.static_tools,
178            temperature: self.temperature,
179            max_tokens: self.max_tokens,
180            additional_params: self.additional_params,
181            dynamic_context: self.dynamic_context,
182            dynamic_tools: self.dynamic_tools,
183            tools: self.tools,
184        }
185    }
186}