1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4
5use crate::{
6 agent::prompt_request::hooks::PromptHook,
7 completion::{CompletionModel, Document},
8 memory::ConversationMemory,
9 message::ToolChoice,
10 tool::{
11 Tool, ToolDyn, ToolSet,
12 server::{ToolServer, ToolServerHandle},
13 },
14 vector_store::VectorStoreIndexDyn,
15};
16
17#[cfg(feature = "rmcp")]
18#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
19use crate::tool::rmcp::McpTool as RmcpTool;
20
21use super::Agent;
22
23#[derive(Default)]
31pub struct NoToolConfig;
32
33pub struct WithToolServerHandle {
38 handle: ToolServerHandle,
39}
40
41pub struct WithBuilderTools {
47 static_tools: Vec<String>,
48 tools: ToolSet,
49 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
50}
51
52pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
78where
79 M: CompletionModel,
80 P: PromptHook<M>,
81{
82 name: Option<String>,
84 description: Option<String>,
86 model: M,
88 preamble: Option<String>,
90 static_context: Vec<Document>,
92 additional_params: Option<serde_json::Value>,
94 max_tokens: Option<u64>,
96 dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
98 temperature: Option<f64>,
100 tool_choice: Option<ToolChoice>,
102 default_max_turns: Option<usize>,
104 tool_state: ToolState,
106 hook: Option<P>,
108 output_schema: Option<schemars::Schema>,
110 memory: Option<Arc<dyn ConversationMemory>>,
112 default_conversation_id: Option<String>,
114}
115
116impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
117where
118 M: CompletionModel,
119 P: PromptHook<M>,
120{
121 pub fn name(mut self, name: &str) -> Self {
123 self.name = Some(name.into());
124 self
125 }
126
127 pub fn description(mut self, description: &str) -> Self {
129 self.description = Some(description.into());
130 self
131 }
132
133 pub fn preamble(mut self, preamble: &str) -> Self {
135 self.preamble = Some(preamble.into());
136 self
137 }
138
139 pub fn without_preamble(mut self) -> Self {
141 self.preamble = None;
142 self
143 }
144
145 pub fn append_preamble(mut self, doc: &str) -> Self {
147 self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
148 self
149 }
150
151 pub fn context(mut self, doc: &str) -> Self {
153 self.static_context.push(Document {
154 id: format!("static_doc_{}", self.static_context.len()),
155 text: doc.into(),
156 additional_props: HashMap::new(),
157 });
158 self
159 }
160
161 pub fn dynamic_context(
164 mut self,
165 sample: usize,
166 dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
167 ) -> Self {
168 self.dynamic_context
169 .push((sample, Arc::new(dynamic_context)));
170 self
171 }
172
173 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
175 self.tool_choice = Some(tool_choice);
176 self
177 }
178
179 pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
181 self.default_max_turns = Some(default_max_turns);
182 self
183 }
184
185 pub fn temperature(mut self, temperature: f64) -> Self {
187 self.temperature = Some(temperature);
188 self
189 }
190
191 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
193 self.max_tokens = Some(max_tokens);
194 self
195 }
196
197 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
199 self.additional_params = Some(params);
200 self
201 }
202
203 pub fn output_schema<T>(mut self) -> Self
206 where
207 T: JsonSchema,
208 {
209 self.output_schema = Some(schema_for!(T));
210 self
211 }
212
213 pub fn output_schema_raw(mut self, schema: Schema) -> Self {
215 self.output_schema = Some(schema);
216 self
217 }
218
219 pub fn memory<B>(mut self, memory: B) -> Self
227 where
228 B: ConversationMemory + 'static,
229 {
230 self.memory = Some(Arc::new(memory));
231 self
232 }
233
234 pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
239 self.default_conversation_id = Some(id.into());
240 self
241 }
242
243 pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, ToolState>
248 where
249 P2: PromptHook<M>,
250 {
251 AgentBuilder {
252 name: self.name,
253 description: self.description,
254 model: self.model,
255 preamble: self.preamble,
256 static_context: self.static_context,
257 additional_params: self.additional_params,
258 max_tokens: self.max_tokens,
259 dynamic_context: self.dynamic_context,
260 temperature: self.temperature,
261 tool_choice: self.tool_choice,
262 default_max_turns: self.default_max_turns,
263 tool_state: self.tool_state,
264 hook: Some(hook),
265 output_schema: self.output_schema,
266 memory: self.memory,
267 default_conversation_id: self.default_conversation_id,
268 }
269 }
270}
271
272impl<M> AgentBuilder<M, (), NoToolConfig>
273where
274 M: CompletionModel,
275{
276 pub fn new(model: M) -> Self {
278 Self {
279 name: None,
280 description: None,
281 model,
282 preamble: None,
283 static_context: vec![],
284 temperature: None,
285 max_tokens: None,
286 additional_params: None,
287 dynamic_context: vec![],
288 tool_choice: None,
289 default_max_turns: None,
290 tool_state: NoToolConfig,
291 hook: None,
292 output_schema: None,
293 memory: None,
294 default_conversation_id: None,
295 }
296 }
297}
298
299impl<M, P> AgentBuilder<M, P, NoToolConfig>
300where
301 M: CompletionModel,
302 P: PromptHook<M>,
303{
304 pub fn tool_server_handle(
310 self,
311 handle: ToolServerHandle,
312 ) -> AgentBuilder<M, P, WithToolServerHandle> {
313 AgentBuilder {
314 name: self.name,
315 description: self.description,
316 model: self.model,
317 preamble: self.preamble,
318 static_context: self.static_context,
319 additional_params: self.additional_params,
320 max_tokens: self.max_tokens,
321 dynamic_context: self.dynamic_context,
322 temperature: self.temperature,
323 tool_choice: self.tool_choice,
324 default_max_turns: self.default_max_turns,
325 tool_state: WithToolServerHandle { handle },
326 hook: self.hook,
327 output_schema: self.output_schema,
328 memory: self.memory,
329 default_conversation_id: self.default_conversation_id,
330 }
331 }
332
333 pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
338 let toolname = tool.name();
339 AgentBuilder {
340 name: self.name,
341 description: self.description,
342 model: self.model,
343 preamble: self.preamble,
344 static_context: self.static_context,
345 additional_params: self.additional_params,
346 max_tokens: self.max_tokens,
347 dynamic_context: self.dynamic_context,
348 temperature: self.temperature,
349 tool_choice: self.tool_choice,
350 default_max_turns: self.default_max_turns,
351 tool_state: WithBuilderTools {
352 static_tools: vec![toolname],
353 tools: ToolSet::from_tools(vec![tool]),
354 dynamic_tools: vec![],
355 },
356 hook: self.hook,
357 output_schema: self.output_schema,
358 memory: self.memory,
359 default_conversation_id: self.default_conversation_id,
360 }
361 }
362
363 pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
368 let static_tools = tools.iter().map(|tool| tool.name()).collect();
369 let tools = ToolSet::from_tools_boxed(tools);
370
371 AgentBuilder {
372 name: self.name,
373 description: self.description,
374 model: self.model,
375 preamble: self.preamble,
376 static_context: self.static_context,
377 additional_params: self.additional_params,
378 max_tokens: self.max_tokens,
379 dynamic_context: self.dynamic_context,
380 temperature: self.temperature,
381 tool_choice: self.tool_choice,
382 default_max_turns: self.default_max_turns,
383 hook: self.hook,
384 output_schema: self.output_schema,
385 memory: self.memory,
386 default_conversation_id: self.default_conversation_id,
387 tool_state: WithBuilderTools {
388 static_tools,
389 tools,
390 dynamic_tools: vec![],
391 },
392 }
393 }
394
395 #[cfg(feature = "rmcp")]
399 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
400 pub fn rmcp_tool(
401 self,
402 tool: rmcp::model::Tool,
403 client: rmcp::service::ServerSink,
404 ) -> AgentBuilder<M, P, WithBuilderTools> {
405 let toolname = tool.name.clone().to_string();
406 let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
407
408 AgentBuilder {
409 name: self.name,
410 description: self.description,
411 model: self.model,
412 preamble: self.preamble,
413 static_context: self.static_context,
414 additional_params: self.additional_params,
415 max_tokens: self.max_tokens,
416 dynamic_context: self.dynamic_context,
417 temperature: self.temperature,
418 tool_choice: self.tool_choice,
419 default_max_turns: self.default_max_turns,
420 hook: self.hook,
421 output_schema: self.output_schema,
422 memory: self.memory,
423 default_conversation_id: self.default_conversation_id,
424 tool_state: WithBuilderTools {
425 static_tools: vec![toolname],
426 tools,
427 dynamic_tools: vec![],
428 },
429 }
430 }
431
432 #[cfg(feature = "rmcp")]
436 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
437 pub fn rmcp_tools(
438 self,
439 tools: Vec<rmcp::model::Tool>,
440 client: rmcp::service::ServerSink,
441 ) -> AgentBuilder<M, P, WithBuilderTools> {
442 let (static_tools, tools) = tools.into_iter().fold(
443 (Vec::new(), Vec::new()),
444 |(mut toolnames, mut toolset), tool| {
445 let tool_name = tool.name.to_string();
446 let tool = RmcpTool::from_mcp_server(tool, client.clone());
447 toolnames.push(tool_name);
448 toolset.push(tool);
449 (toolnames, toolset)
450 },
451 );
452
453 let tools = ToolSet::from_tools(tools);
454
455 AgentBuilder {
456 name: self.name,
457 description: self.description,
458 model: self.model,
459 preamble: self.preamble,
460 static_context: self.static_context,
461 additional_params: self.additional_params,
462 max_tokens: self.max_tokens,
463 dynamic_context: self.dynamic_context,
464 temperature: self.temperature,
465 tool_choice: self.tool_choice,
466 default_max_turns: self.default_max_turns,
467 hook: self.hook,
468 output_schema: self.output_schema,
469 memory: self.memory,
470 default_conversation_id: self.default_conversation_id,
471 tool_state: WithBuilderTools {
472 static_tools,
473 tools,
474 dynamic_tools: vec![],
475 },
476 }
477 }
478
479 pub fn dynamic_tools(
484 self,
485 sample: usize,
486 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
487 toolset: ToolSet,
488 ) -> AgentBuilder<M, P, WithBuilderTools> {
489 AgentBuilder {
490 name: self.name,
491 description: self.description,
492 model: self.model,
493 preamble: self.preamble,
494 static_context: self.static_context,
495 additional_params: self.additional_params,
496 max_tokens: self.max_tokens,
497 dynamic_context: self.dynamic_context,
498 temperature: self.temperature,
499 tool_choice: self.tool_choice,
500 default_max_turns: self.default_max_turns,
501 hook: self.hook,
502 output_schema: self.output_schema,
503 memory: self.memory,
504 default_conversation_id: self.default_conversation_id,
505 tool_state: WithBuilderTools {
506 static_tools: vec![],
507 tools: toolset,
508 dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
509 },
510 }
511 }
512
513 pub fn build(self) -> Agent<M, P> {
517 let tool_server_handle = ToolServer::new().run();
518
519 Agent {
520 name: self.name,
521 description: self.description,
522 model: Arc::new(self.model),
523 preamble: self.preamble,
524 static_context: self.static_context,
525 temperature: self.temperature,
526 max_tokens: self.max_tokens,
527 additional_params: self.additional_params,
528 tool_choice: self.tool_choice,
529 dynamic_context: Arc::new(self.dynamic_context),
530 tool_server_handle,
531 default_max_turns: self.default_max_turns,
532 hook: self.hook,
533 output_schema: self.output_schema,
534 memory: self.memory,
535 default_conversation_id: self.default_conversation_id,
536 }
537 }
538}
539
540impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
541where
542 M: CompletionModel,
543 P: PromptHook<M>,
544{
545 pub fn build(self) -> Agent<M, P> {
547 Agent {
548 name: self.name,
549 description: self.description,
550 model: Arc::new(self.model),
551 preamble: self.preamble,
552 static_context: self.static_context,
553 temperature: self.temperature,
554 max_tokens: self.max_tokens,
555 additional_params: self.additional_params,
556 tool_choice: self.tool_choice,
557 dynamic_context: Arc::new(self.dynamic_context),
558 tool_server_handle: self.tool_state.handle,
559 default_max_turns: self.default_max_turns,
560 hook: self.hook,
561 output_schema: self.output_schema,
562 memory: self.memory,
563 default_conversation_id: self.default_conversation_id,
564 }
565 }
566}
567
568impl<M, P> AgentBuilder<M, P, WithBuilderTools>
569where
570 M: CompletionModel,
571 P: PromptHook<M>,
572{
573 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
575 let toolname = tool.name();
576 self.tool_state.tools.add_tool(tool);
577 self.tool_state.static_tools.push(toolname);
578 self
579 }
580
581 pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
583 let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
584 let tools = ToolSet::from_tools_boxed(tools);
585 self.tool_state.tools.add_tools(tools);
586 self.tool_state.static_tools.extend(toolnames);
587 self
588 }
589
590 #[cfg(feature = "rmcp")]
592 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
593 pub fn rmcp_tools(
594 mut self,
595 tools: Vec<rmcp::model::Tool>,
596 client: rmcp::service::ServerSink,
597 ) -> Self {
598 for tool in tools {
599 let tool_name = tool.name.to_string();
600 let tool = RmcpTool::from_mcp_server(tool, client.clone());
601 self.tool_state.static_tools.push(tool_name);
602 self.tool_state.tools.add_tool(tool);
603 }
604
605 self
606 }
607
608 pub fn dynamic_tools(
611 mut self,
612 sample: usize,
613 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
614 toolset: ToolSet,
615 ) -> Self {
616 self.tool_state
617 .dynamic_tools
618 .push((sample, Arc::new(dynamic_tools)));
619 self.tool_state.tools.add_tools(toolset);
620 self
621 }
622
623 pub fn build(self) -> Agent<M, P> {
628 let tool_server_handle = ToolServer::new()
629 .static_tool_names(self.tool_state.static_tools)
630 .add_tools(self.tool_state.tools)
631 .add_dynamic_tools(self.tool_state.dynamic_tools)
632 .run();
633
634 Agent {
635 name: self.name,
636 description: self.description,
637 model: Arc::new(self.model),
638 preamble: self.preamble,
639 static_context: self.static_context,
640 temperature: self.temperature,
641 max_tokens: self.max_tokens,
642 additional_params: self.additional_params,
643 tool_choice: self.tool_choice,
644 dynamic_context: Arc::new(self.dynamic_context),
645 tool_server_handle,
646 default_max_turns: self.default_max_turns,
647 hook: self.hook,
648 output_schema: self.output_schema,
649 memory: self.memory,
650 default_conversation_id: self.default_conversation_id,
651 }
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658 use crate::test_utils::{MockAddTool, MockCompletionModel};
659
660 #[derive(Clone)]
661 struct BuilderHook;
662
663 impl PromptHook<MockCompletionModel> for BuilderHook {}
664
665 #[test]
666 fn hook_can_be_set_after_tool_configuration() {
667 let _agent = AgentBuilder::new(MockCompletionModel::text("ok"))
668 .tool(MockAddTool)
669 .hook(BuilderHook)
670 .build();
671 }
672}