1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4use tokio::sync::RwLock;
5
6use crate::{
7 agent::prompt_request::hooks::PromptHook,
8 completion::{CompletionModel, Document},
9 message::ToolChoice,
10 tool::{
11 Tool, ToolDyn, ToolSet,
12 server::{ToolServer, ToolServerHandle},
13 },
14 vector_store::VectorStoreIndexDyn,
15};
16
17#[cfg(feature = "rmcp")]
18#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
19use crate::tool::rmcp::McpTool as RmcpTool;
20
21use super::Agent;
22
23#[derive(Default)]
31pub struct NoToolConfig;
32
33pub struct WithToolServerHandle {
38 handle: ToolServerHandle,
39}
40
41pub struct WithBuilderTools {
47 static_tools: Vec<String>,
48 tools: ToolSet,
49 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
50}
51
52pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
78where
79 M: CompletionModel,
80 P: PromptHook<M>,
81{
82 name: Option<String>,
84 description: Option<String>,
86 model: M,
88 preamble: Option<String>,
90 static_context: Vec<Document>,
92 additional_params: Option<serde_json::Value>,
94 max_tokens: Option<u64>,
96 dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
98 temperature: Option<f64>,
100 tool_choice: Option<ToolChoice>,
102 default_max_turns: Option<usize>,
104 tool_state: ToolState,
106 hook: Option<P>,
108 output_schema: Option<schemars::Schema>,
110}
111
112impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
113where
114 M: CompletionModel,
115 P: PromptHook<M>,
116{
117 pub fn name(mut self, name: &str) -> Self {
119 self.name = Some(name.into());
120 self
121 }
122
123 pub fn description(mut self, description: &str) -> Self {
125 self.description = Some(description.into());
126 self
127 }
128
129 pub fn preamble(mut self, preamble: &str) -> Self {
131 self.preamble = Some(preamble.into());
132 self
133 }
134
135 pub fn without_preamble(mut self) -> Self {
137 self.preamble = None;
138 self
139 }
140
141 pub fn append_preamble(mut self, doc: &str) -> Self {
143 self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
144 self
145 }
146
147 pub fn context(mut self, doc: &str) -> Self {
149 self.static_context.push(Document {
150 id: format!("static_doc_{}", self.static_context.len()),
151 text: doc.into(),
152 additional_props: HashMap::new(),
153 });
154 self
155 }
156
157 pub fn dynamic_context(
160 mut self,
161 sample: usize,
162 dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
163 ) -> Self {
164 self.dynamic_context
165 .push((sample, Box::new(dynamic_context)));
166 self
167 }
168
169 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
171 self.tool_choice = Some(tool_choice);
172 self
173 }
174
175 pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
177 self.default_max_turns = Some(default_max_turns);
178 self
179 }
180
181 pub fn temperature(mut self, temperature: f64) -> Self {
183 self.temperature = Some(temperature);
184 self
185 }
186
187 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
189 self.max_tokens = Some(max_tokens);
190 self
191 }
192
193 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
195 self.additional_params = Some(params);
196 self
197 }
198
199 pub fn output_schema<T>(mut self) -> Self
202 where
203 T: JsonSchema,
204 {
205 self.output_schema = Some(schema_for!(T));
206 self
207 }
208
209 pub fn output_schema_raw(mut self, schema: Schema) -> Self {
211 self.output_schema = Some(schema);
212 self
213 }
214}
215
216impl<M> AgentBuilder<M, (), NoToolConfig>
217where
218 M: CompletionModel,
219{
220 pub fn new(model: M) -> Self {
222 Self {
223 name: None,
224 description: None,
225 model,
226 preamble: None,
227 static_context: vec![],
228 temperature: None,
229 max_tokens: None,
230 additional_params: None,
231 dynamic_context: vec![],
232 tool_choice: None,
233 default_max_turns: None,
234 tool_state: NoToolConfig,
235 hook: None,
236 output_schema: None,
237 }
238 }
239}
240
241impl<M, P> AgentBuilder<M, P, NoToolConfig>
242where
243 M: CompletionModel,
244 P: PromptHook<M>,
245{
246 pub fn tool_server_handle(
252 self,
253 handle: ToolServerHandle,
254 ) -> AgentBuilder<M, P, WithToolServerHandle> {
255 AgentBuilder {
256 name: self.name,
257 description: self.description,
258 model: self.model,
259 preamble: self.preamble,
260 static_context: self.static_context,
261 additional_params: self.additional_params,
262 max_tokens: self.max_tokens,
263 dynamic_context: self.dynamic_context,
264 temperature: self.temperature,
265 tool_choice: self.tool_choice,
266 default_max_turns: self.default_max_turns,
267 tool_state: WithToolServerHandle { handle },
268 hook: self.hook,
269 output_schema: self.output_schema,
270 }
271 }
272
273 pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
278 let toolname = tool.name();
279 AgentBuilder {
280 name: self.name,
281 description: self.description,
282 model: self.model,
283 preamble: self.preamble,
284 static_context: self.static_context,
285 additional_params: self.additional_params,
286 max_tokens: self.max_tokens,
287 dynamic_context: self.dynamic_context,
288 temperature: self.temperature,
289 tool_choice: self.tool_choice,
290 default_max_turns: self.default_max_turns,
291 tool_state: WithBuilderTools {
292 static_tools: vec![toolname],
293 tools: ToolSet::from_tools(vec![tool]),
294 dynamic_tools: vec![],
295 },
296 hook: self.hook,
297 output_schema: self.output_schema,
298 }
299 }
300
301 pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
306 let static_tools = tools.iter().map(|tool| tool.name()).collect();
307 let tools = ToolSet::from_tools_boxed(tools);
308
309 AgentBuilder {
310 name: self.name,
311 description: self.description,
312 model: self.model,
313 preamble: self.preamble,
314 static_context: self.static_context,
315 additional_params: self.additional_params,
316 max_tokens: self.max_tokens,
317 dynamic_context: self.dynamic_context,
318 temperature: self.temperature,
319 tool_choice: self.tool_choice,
320 default_max_turns: self.default_max_turns,
321 hook: self.hook,
322 output_schema: self.output_schema,
323 tool_state: WithBuilderTools {
324 static_tools,
325 tools,
326 dynamic_tools: vec![],
327 },
328 }
329 }
330
331 #[cfg(feature = "rmcp")]
335 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
336 pub fn rmcp_tool(
337 self,
338 tool: rmcp::model::Tool,
339 client: rmcp::service::ServerSink,
340 ) -> AgentBuilder<M, P, WithBuilderTools> {
341 let toolname = tool.name.clone().to_string();
342 let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
343
344 AgentBuilder {
345 name: self.name,
346 description: self.description,
347 model: self.model,
348 preamble: self.preamble,
349 static_context: self.static_context,
350 additional_params: self.additional_params,
351 max_tokens: self.max_tokens,
352 dynamic_context: self.dynamic_context,
353 temperature: self.temperature,
354 tool_choice: self.tool_choice,
355 default_max_turns: self.default_max_turns,
356 hook: self.hook,
357 output_schema: self.output_schema,
358 tool_state: WithBuilderTools {
359 static_tools: vec![toolname],
360 tools,
361 dynamic_tools: vec![],
362 },
363 }
364 }
365
366 #[cfg(feature = "rmcp")]
370 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
371 pub fn rmcp_tools(
372 self,
373 tools: Vec<rmcp::model::Tool>,
374 client: rmcp::service::ServerSink,
375 ) -> AgentBuilder<M, P, WithBuilderTools> {
376 let (static_tools, tools) = tools.into_iter().fold(
377 (Vec::new(), Vec::new()),
378 |(mut toolnames, mut toolset), tool| {
379 let tool_name = tool.name.to_string();
380 let tool = RmcpTool::from_mcp_server(tool, client.clone());
381 toolnames.push(tool_name);
382 toolset.push(tool);
383 (toolnames, toolset)
384 },
385 );
386
387 let tools = ToolSet::from_tools(tools);
388
389 AgentBuilder {
390 name: self.name,
391 description: self.description,
392 model: self.model,
393 preamble: self.preamble,
394 static_context: self.static_context,
395 additional_params: self.additional_params,
396 max_tokens: self.max_tokens,
397 dynamic_context: self.dynamic_context,
398 temperature: self.temperature,
399 tool_choice: self.tool_choice,
400 default_max_turns: self.default_max_turns,
401 hook: self.hook,
402 output_schema: self.output_schema,
403 tool_state: WithBuilderTools {
404 static_tools,
405 tools,
406 dynamic_tools: vec![],
407 },
408 }
409 }
410
411 pub fn dynamic_tools(
416 self,
417 sample: usize,
418 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
419 toolset: ToolSet,
420 ) -> AgentBuilder<M, P, WithBuilderTools> {
421 AgentBuilder {
422 name: self.name,
423 description: self.description,
424 model: self.model,
425 preamble: self.preamble,
426 static_context: self.static_context,
427 additional_params: self.additional_params,
428 max_tokens: self.max_tokens,
429 dynamic_context: self.dynamic_context,
430 temperature: self.temperature,
431 tool_choice: self.tool_choice,
432 default_max_turns: self.default_max_turns,
433 hook: self.hook,
434 output_schema: self.output_schema,
435 tool_state: WithBuilderTools {
436 static_tools: vec![],
437 tools: toolset,
438 dynamic_tools: vec![(sample, Box::new(dynamic_tools))],
439 },
440 }
441 }
442
443 pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, NoToolConfig>
448 where
449 P2: PromptHook<M>,
450 {
451 AgentBuilder {
452 name: self.name,
453 description: self.description,
454 model: self.model,
455 preamble: self.preamble,
456 static_context: self.static_context,
457 additional_params: self.additional_params,
458 max_tokens: self.max_tokens,
459 dynamic_context: self.dynamic_context,
460 temperature: self.temperature,
461 tool_choice: self.tool_choice,
462 default_max_turns: self.default_max_turns,
463 tool_state: self.tool_state,
464 hook: Some(hook),
465 output_schema: self.output_schema,
466 }
467 }
468
469 pub fn build(self) -> Agent<M, P> {
473 let tool_server_handle = ToolServer::new().run();
474
475 Agent {
476 name: self.name,
477 description: self.description,
478 model: Arc::new(self.model),
479 preamble: self.preamble,
480 static_context: self.static_context,
481 temperature: self.temperature,
482 max_tokens: self.max_tokens,
483 additional_params: self.additional_params,
484 tool_choice: self.tool_choice,
485 dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
486 tool_server_handle,
487 default_max_turns: self.default_max_turns,
488 hook: self.hook,
489 output_schema: self.output_schema,
490 }
491 }
492}
493
494impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
495where
496 M: CompletionModel,
497 P: PromptHook<M>,
498{
499 pub fn build(self) -> Agent<M, P> {
501 Agent {
502 name: self.name,
503 description: self.description,
504 model: Arc::new(self.model),
505 preamble: self.preamble,
506 static_context: self.static_context,
507 temperature: self.temperature,
508 max_tokens: self.max_tokens,
509 additional_params: self.additional_params,
510 tool_choice: self.tool_choice,
511 dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
512 tool_server_handle: self.tool_state.handle,
513 default_max_turns: self.default_max_turns,
514 hook: self.hook,
515 output_schema: self.output_schema,
516 }
517 }
518}
519
520impl<M, P> AgentBuilder<M, P, WithBuilderTools>
521where
522 M: CompletionModel,
523 P: PromptHook<M>,
524{
525 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
527 let toolname = tool.name();
528 self.tool_state.tools.add_tool(tool);
529 self.tool_state.static_tools.push(toolname);
530 self
531 }
532
533 pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
535 let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
536 let tools = ToolSet::from_tools_boxed(tools);
537 self.tool_state.tools.add_tools(tools);
538 self.tool_state.static_tools.extend(toolnames);
539 self
540 }
541
542 #[cfg(feature = "rmcp")]
544 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
545 pub fn rmcp_tools(
546 mut self,
547 tools: Vec<rmcp::model::Tool>,
548 client: rmcp::service::ServerSink,
549 ) -> Self {
550 for tool in tools {
551 let tool_name = tool.name.to_string();
552 let tool = RmcpTool::from_mcp_server(tool, client.clone());
553 self.tool_state.static_tools.push(tool_name);
554 self.tool_state.tools.add_tool(tool);
555 }
556
557 self
558 }
559
560 pub fn dynamic_tools(
563 mut self,
564 sample: usize,
565 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
566 toolset: ToolSet,
567 ) -> Self {
568 self.tool_state
569 .dynamic_tools
570 .push((sample, Box::new(dynamic_tools)));
571 self.tool_state.tools.add_tools(toolset);
572 self
573 }
574
575 pub fn build(self) -> Agent<M, P> {
580 let tool_server_handle = ToolServer::new()
581 .static_tool_names(self.tool_state.static_tools)
582 .add_tools(self.tool_state.tools)
583 .add_dynamic_tools(self.tool_state.dynamic_tools)
584 .run();
585
586 Agent {
587 name: self.name,
588 description: self.description,
589 model: Arc::new(self.model),
590 preamble: self.preamble,
591 static_context: self.static_context,
592 temperature: self.temperature,
593 max_tokens: self.max_tokens,
594 additional_params: self.additional_params,
595 tool_choice: self.tool_choice,
596 dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
597 tool_server_handle,
598 default_max_turns: self.default_max_turns,
599 hook: self.hook,
600 output_schema: self.output_schema,
601 }
602 }
603}