1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4
5use crate::{
6 agent::prompt_request::hooks::PromptHook,
7 completion::{CompletionModel, Document},
8 memory::ConversationMemory,
9 message::ToolChoice,
10 tool::{
11 Tool, ToolDyn, ToolSet,
12 server::{ToolServer, ToolServerHandle},
13 },
14 vector_store::VectorStoreIndexDyn,
15};
16
17#[cfg(feature = "rmcp")]
18#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
19use crate::tool::rmcp::McpTool as RmcpTool;
20
21use super::Agent;
22
23#[cfg(feature = "rmcp")]
27fn build_rmcp_tools(
28 tools: Vec<rmcp::model::Tool>,
29 client: rmcp::service::ServerSink,
30 timeout: Option<std::time::Duration>,
31) -> Vec<(String, RmcpTool)> {
32 tools
33 .into_iter()
34 .map(|tool| {
35 let name = tool.name.to_string();
36 let rmcp_tool = RmcpTool::from_mcp_server(tool, client.clone()).with_timeout(timeout);
37 (name, rmcp_tool)
38 })
39 .collect()
40}
41
42#[derive(Default)]
50pub struct NoToolConfig;
51
52pub struct WithToolServerHandle {
57 handle: ToolServerHandle,
58}
59
60pub struct WithBuilderTools {
66 static_tools: Vec<String>,
67 tools: ToolSet,
68 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
69}
70
71pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
97where
98 M: CompletionModel,
99 P: PromptHook<M>,
100{
101 name: Option<String>,
103 description: Option<String>,
105 model: M,
107 preamble: Option<String>,
109 static_context: Vec<Document>,
111 additional_params: Option<serde_json::Value>,
113 max_tokens: Option<u64>,
115 dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
117 temperature: Option<f64>,
119 tool_choice: Option<ToolChoice>,
121 default_max_turns: Option<usize>,
123 tool_state: ToolState,
125 hook: Option<P>,
127 output_schema: Option<schemars::Schema>,
129 memory: Option<Arc<dyn ConversationMemory>>,
131 default_conversation_id: Option<String>,
133}
134
135impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
136where
137 M: CompletionModel,
138 P: PromptHook<M>,
139{
140 pub fn name(mut self, name: &str) -> Self {
142 self.name = Some(name.into());
143 self
144 }
145
146 pub fn description(mut self, description: &str) -> Self {
148 self.description = Some(description.into());
149 self
150 }
151
152 pub fn preamble(mut self, preamble: &str) -> Self {
154 self.preamble = Some(preamble.into());
155 self
156 }
157
158 pub fn without_preamble(mut self) -> Self {
160 self.preamble = None;
161 self
162 }
163
164 pub fn append_preamble(mut self, doc: &str) -> Self {
166 self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
167 self
168 }
169
170 pub fn context(mut self, doc: &str) -> Self {
172 self.static_context.push(Document {
173 id: format!("static_doc_{}", self.static_context.len()),
174 text: doc.into(),
175 additional_props: HashMap::new(),
176 });
177 self
178 }
179
180 pub fn dynamic_context(
183 mut self,
184 sample: usize,
185 dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
186 ) -> Self {
187 self.dynamic_context
188 .push((sample, Arc::new(dynamic_context)));
189 self
190 }
191
192 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
194 self.tool_choice = Some(tool_choice);
195 self
196 }
197
198 pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
200 self.default_max_turns = Some(default_max_turns);
201 self
202 }
203
204 pub fn temperature(mut self, temperature: f64) -> Self {
206 self.temperature = Some(temperature);
207 self
208 }
209
210 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
212 self.max_tokens = Some(max_tokens);
213 self
214 }
215
216 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
218 self.additional_params = Some(params);
219 self
220 }
221
222 pub fn output_schema<T>(mut self) -> Self
225 where
226 T: JsonSchema,
227 {
228 self.output_schema = Some(schema_for!(T));
229 self
230 }
231
232 pub fn output_schema_raw(mut self, schema: Schema) -> Self {
234 self.output_schema = Some(schema);
235 self
236 }
237
238 pub fn memory<B>(mut self, memory: B) -> Self
246 where
247 B: ConversationMemory + 'static,
248 {
249 self.memory = Some(Arc::new(memory));
250 self
251 }
252
253 pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
258 self.default_conversation_id = Some(id.into());
259 self
260 }
261
262 pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, ToolState>
267 where
268 P2: PromptHook<M>,
269 {
270 AgentBuilder {
271 name: self.name,
272 description: self.description,
273 model: self.model,
274 preamble: self.preamble,
275 static_context: self.static_context,
276 additional_params: self.additional_params,
277 max_tokens: self.max_tokens,
278 dynamic_context: self.dynamic_context,
279 temperature: self.temperature,
280 tool_choice: self.tool_choice,
281 default_max_turns: self.default_max_turns,
282 tool_state: self.tool_state,
283 hook: Some(hook),
284 output_schema: self.output_schema,
285 memory: self.memory,
286 default_conversation_id: self.default_conversation_id,
287 }
288 }
289}
290
291impl<M> AgentBuilder<M, (), NoToolConfig>
292where
293 M: CompletionModel,
294{
295 pub fn new(model: M) -> Self {
297 Self {
298 name: None,
299 description: None,
300 model,
301 preamble: None,
302 static_context: vec![],
303 temperature: None,
304 max_tokens: None,
305 additional_params: None,
306 dynamic_context: vec![],
307 tool_choice: None,
308 default_max_turns: None,
309 tool_state: NoToolConfig,
310 hook: None,
311 output_schema: None,
312 memory: None,
313 default_conversation_id: None,
314 }
315 }
316}
317
318impl<M, P> AgentBuilder<M, P, NoToolConfig>
319where
320 M: CompletionModel,
321 P: PromptHook<M>,
322{
323 pub fn tool_server_handle(
329 self,
330 handle: ToolServerHandle,
331 ) -> AgentBuilder<M, P, WithToolServerHandle> {
332 AgentBuilder {
333 name: self.name,
334 description: self.description,
335 model: self.model,
336 preamble: self.preamble,
337 static_context: self.static_context,
338 additional_params: self.additional_params,
339 max_tokens: self.max_tokens,
340 dynamic_context: self.dynamic_context,
341 temperature: self.temperature,
342 tool_choice: self.tool_choice,
343 default_max_turns: self.default_max_turns,
344 tool_state: WithToolServerHandle { handle },
345 hook: self.hook,
346 output_schema: self.output_schema,
347 memory: self.memory,
348 default_conversation_id: self.default_conversation_id,
349 }
350 }
351
352 pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
357 let toolname = tool.name();
358 AgentBuilder {
359 name: self.name,
360 description: self.description,
361 model: self.model,
362 preamble: self.preamble,
363 static_context: self.static_context,
364 additional_params: self.additional_params,
365 max_tokens: self.max_tokens,
366 dynamic_context: self.dynamic_context,
367 temperature: self.temperature,
368 tool_choice: self.tool_choice,
369 default_max_turns: self.default_max_turns,
370 tool_state: WithBuilderTools {
371 static_tools: vec![toolname],
372 tools: ToolSet::from_tools(vec![tool]),
373 dynamic_tools: vec![],
374 },
375 hook: self.hook,
376 output_schema: self.output_schema,
377 memory: self.memory,
378 default_conversation_id: self.default_conversation_id,
379 }
380 }
381
382 pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
387 let static_tools = tools.iter().map(|tool| tool.name()).collect();
388 let tools = ToolSet::from_tools_boxed(tools);
389
390 AgentBuilder {
391 name: self.name,
392 description: self.description,
393 model: self.model,
394 preamble: self.preamble,
395 static_context: self.static_context,
396 additional_params: self.additional_params,
397 max_tokens: self.max_tokens,
398 dynamic_context: self.dynamic_context,
399 temperature: self.temperature,
400 tool_choice: self.tool_choice,
401 default_max_turns: self.default_max_turns,
402 hook: self.hook,
403 output_schema: self.output_schema,
404 memory: self.memory,
405 default_conversation_id: self.default_conversation_id,
406 tool_state: WithBuilderTools {
407 static_tools,
408 tools,
409 dynamic_tools: vec![],
410 },
411 }
412 }
413
414 #[cfg(feature = "rmcp")]
421 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
422 pub fn rmcp_tool(
423 self,
424 tool: rmcp::model::Tool,
425 client: rmcp::service::ServerSink,
426 ) -> AgentBuilder<M, P, WithBuilderTools> {
427 self.rmcp_tool_with_timeout(tool, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
428 }
429
430 #[cfg(feature = "rmcp")]
437 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
438 pub fn rmcp_tool_with_timeout(
439 self,
440 tool: rmcp::model::Tool,
441 client: rmcp::service::ServerSink,
442 timeout: impl Into<Option<std::time::Duration>>,
443 ) -> AgentBuilder<M, P, WithBuilderTools> {
444 self.with_rmcp_toolset(build_rmcp_tools(vec![tool], client, timeout.into()))
445 }
446
447 #[cfg(feature = "rmcp")]
454 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
455 pub fn rmcp_tools(
456 self,
457 tools: Vec<rmcp::model::Tool>,
458 client: rmcp::service::ServerSink,
459 ) -> AgentBuilder<M, P, WithBuilderTools> {
460 self.rmcp_tools_with_timeout(tools, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
461 }
462
463 #[cfg(feature = "rmcp")]
471 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
472 pub fn rmcp_tools_with_timeout(
473 self,
474 tools: Vec<rmcp::model::Tool>,
475 client: rmcp::service::ServerSink,
476 timeout: impl Into<Option<std::time::Duration>>,
477 ) -> AgentBuilder<M, P, WithBuilderTools> {
478 self.with_rmcp_toolset(build_rmcp_tools(tools, client, timeout.into()))
479 }
480
481 #[cfg(feature = "rmcp")]
484 fn with_rmcp_toolset(
485 self,
486 built: Vec<(String, RmcpTool)>,
487 ) -> AgentBuilder<M, P, WithBuilderTools> {
488 let (static_tools, toolset): (Vec<String>, Vec<RmcpTool>) = built.into_iter().unzip();
489
490 AgentBuilder {
491 name: self.name,
492 description: self.description,
493 model: self.model,
494 preamble: self.preamble,
495 static_context: self.static_context,
496 additional_params: self.additional_params,
497 max_tokens: self.max_tokens,
498 dynamic_context: self.dynamic_context,
499 temperature: self.temperature,
500 tool_choice: self.tool_choice,
501 default_max_turns: self.default_max_turns,
502 hook: self.hook,
503 output_schema: self.output_schema,
504 memory: self.memory,
505 default_conversation_id: self.default_conversation_id,
506 tool_state: WithBuilderTools {
507 static_tools,
508 tools: ToolSet::from_tools(toolset),
509 dynamic_tools: vec![],
510 },
511 }
512 }
513
514 pub fn dynamic_tools(
519 self,
520 sample: usize,
521 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
522 toolset: ToolSet,
523 ) -> AgentBuilder<M, P, WithBuilderTools> {
524 AgentBuilder {
525 name: self.name,
526 description: self.description,
527 model: self.model,
528 preamble: self.preamble,
529 static_context: self.static_context,
530 additional_params: self.additional_params,
531 max_tokens: self.max_tokens,
532 dynamic_context: self.dynamic_context,
533 temperature: self.temperature,
534 tool_choice: self.tool_choice,
535 default_max_turns: self.default_max_turns,
536 hook: self.hook,
537 output_schema: self.output_schema,
538 memory: self.memory,
539 default_conversation_id: self.default_conversation_id,
540 tool_state: WithBuilderTools {
541 static_tools: vec![],
542 tools: toolset,
543 dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
544 },
545 }
546 }
547
548 pub fn build(self) -> Agent<M, P> {
552 let tool_server_handle = ToolServer::new().run();
553
554 Agent {
555 name: self.name,
556 description: self.description,
557 model: Arc::new(self.model),
558 preamble: self.preamble,
559 static_context: self.static_context,
560 temperature: self.temperature,
561 max_tokens: self.max_tokens,
562 additional_params: self.additional_params,
563 tool_choice: self.tool_choice,
564 dynamic_context: Arc::new(self.dynamic_context),
565 tool_server_handle,
566 default_max_turns: self.default_max_turns,
567 hook: self.hook,
568 output_schema: self.output_schema,
569 memory: self.memory,
570 default_conversation_id: self.default_conversation_id,
571 }
572 }
573}
574
575impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
576where
577 M: CompletionModel,
578 P: PromptHook<M>,
579{
580 pub fn build(self) -> Agent<M, P> {
582 Agent {
583 name: self.name,
584 description: self.description,
585 model: Arc::new(self.model),
586 preamble: self.preamble,
587 static_context: self.static_context,
588 temperature: self.temperature,
589 max_tokens: self.max_tokens,
590 additional_params: self.additional_params,
591 tool_choice: self.tool_choice,
592 dynamic_context: Arc::new(self.dynamic_context),
593 tool_server_handle: self.tool_state.handle,
594 default_max_turns: self.default_max_turns,
595 hook: self.hook,
596 output_schema: self.output_schema,
597 memory: self.memory,
598 default_conversation_id: self.default_conversation_id,
599 }
600 }
601}
602
603impl<M, P> AgentBuilder<M, P, WithBuilderTools>
604where
605 M: CompletionModel,
606 P: PromptHook<M>,
607{
608 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
610 let toolname = tool.name();
611 self.tool_state.tools.add_tool(tool);
612 self.tool_state.static_tools.push(toolname);
613 self
614 }
615
616 pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
618 let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
619 let tools = ToolSet::from_tools_boxed(tools);
620 self.tool_state.tools.add_tools(tools);
621 self.tool_state.static_tools.extend(toolnames);
622 self
623 }
624
625 #[cfg(feature = "rmcp")]
630 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
631 pub fn rmcp_tools(
632 self,
633 tools: Vec<rmcp::model::Tool>,
634 client: rmcp::service::ServerSink,
635 ) -> Self {
636 self.rmcp_tools_with_timeout(tools, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
637 }
638
639 #[cfg(feature = "rmcp")]
646 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
647 pub fn rmcp_tools_with_timeout(
648 self,
649 tools: Vec<rmcp::model::Tool>,
650 client: rmcp::service::ServerSink,
651 timeout: impl Into<Option<std::time::Duration>>,
652 ) -> Self {
653 self.add_rmcp_tools(build_rmcp_tools(tools, client, timeout.into()))
654 }
655
656 #[cfg(feature = "rmcp")]
657 fn add_rmcp_tools(mut self, built: Vec<(String, RmcpTool)>) -> Self {
658 for (name, tool) in built {
659 self.tool_state.static_tools.push(name);
660 self.tool_state.tools.add_tool(tool);
661 }
662
663 self
664 }
665
666 pub fn dynamic_tools(
669 mut self,
670 sample: usize,
671 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
672 toolset: ToolSet,
673 ) -> Self {
674 self.tool_state
675 .dynamic_tools
676 .push((sample, Arc::new(dynamic_tools)));
677 self.tool_state.tools.add_tools(toolset);
678 self
679 }
680
681 pub fn build(self) -> Agent<M, P> {
686 let tool_server_handle = ToolServer::new()
687 .static_tool_names(self.tool_state.static_tools)
688 .add_tools(self.tool_state.tools)
689 .add_dynamic_tools(self.tool_state.dynamic_tools)
690 .run();
691
692 Agent {
693 name: self.name,
694 description: self.description,
695 model: Arc::new(self.model),
696 preamble: self.preamble,
697 static_context: self.static_context,
698 temperature: self.temperature,
699 max_tokens: self.max_tokens,
700 additional_params: self.additional_params,
701 tool_choice: self.tool_choice,
702 dynamic_context: Arc::new(self.dynamic_context),
703 tool_server_handle,
704 default_max_turns: self.default_max_turns,
705 hook: self.hook,
706 output_schema: self.output_schema,
707 memory: self.memory,
708 default_conversation_id: self.default_conversation_id,
709 }
710 }
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716 use crate::test_utils::{MockAddTool, MockCompletionModel};
717
718 #[derive(Clone)]
719 struct BuilderHook;
720
721 impl PromptHook<MockCompletionModel> for BuilderHook {}
722
723 #[test]
724 fn hook_can_be_set_after_tool_configuration() {
725 let _agent = AgentBuilder::new(MockCompletionModel::text("ok"))
726 .tool(MockAddTool)
727 .hook(BuilderHook)
728 .build();
729 }
730
731 #[cfg(feature = "rmcp")]
736 #[tokio::test]
737 async fn build_rmcp_tools_threads_timeout_into_built_tools() {
738 use crate::tool::ToolDyn;
739 use crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT;
740 use rmcp::model::{
741 CallToolRequestParams, CallToolResult, ClientInfo, ErrorData, Implementation,
742 ProtocolVersion, ServerCapabilities, ServerInfo, Tool,
743 };
744 use rmcp::service::RequestContext;
745 use rmcp::{RoleServer, ServerHandler, ServiceExt};
746 use std::sync::Arc;
747 use std::time::Duration;
748
749 #[derive(Clone)]
750 struct HangingServer;
751 impl ServerHandler for HangingServer {
752 fn get_info(&self) -> ServerInfo {
753 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
754 .with_protocol_version(ProtocolVersion::LATEST)
755 .with_server_info(Implementation::new("builder-timeout-test", "0.1.0"))
756 }
757 async fn call_tool(
758 &self,
759 _request: CallToolRequestParams,
760 _context: RequestContext<RoleServer>,
761 ) -> Result<CallToolResult, ErrorData> {
762 std::future::pending::<Result<CallToolResult, ErrorData>>().await
763 }
764 }
765
766 fn tool(name: &str) -> Tool {
767 Tool::new(
768 name.to_string(),
769 String::new(),
770 Arc::new(serde_json::Map::new()),
771 )
772 }
773
774 let (c2s, sfc) = tokio::io::duplex(8192);
775 let (s2c, cfs) = tokio::io::duplex(8192);
776 let server_task = tokio::spawn(async move {
777 let running = HangingServer.serve((sfc, s2c)).await.expect("server start");
778 running.waiting().await.expect("server error");
779 });
780 let client = ClientInfo::default()
781 .serve((cfs, c2s))
782 .await
783 .expect("client connect");
784 let peer = client.peer().clone();
785
786 let built_default = build_rmcp_tools(
789 vec![tool("a")],
790 peer.clone(),
791 Some(DEFAULT_MCP_TOOL_TIMEOUT),
792 );
793 assert_eq!(built_default[0].1.timeout(), Some(DEFAULT_MCP_TOOL_TIMEOUT));
794 let built_none = build_rmcp_tools(vec![tool("b")], peer.clone(), None);
795 assert_eq!(built_none[0].1.timeout(), None);
796
797 let built = build_rmcp_tools(
799 vec![tool("hang_forever")],
800 peer,
801 Some(Duration::from_millis(200)),
802 );
803 assert_eq!(built.len(), 1);
804 assert_eq!(built[0].0, "hang_forever");
805 let timed =
806 tokio::time::timeout(Duration::from_secs(5), built[0].1.call("{}".to_string())).await;
807 let err = timed
808 .expect("built tool hung past the safety timeout")
809 .expect_err("call should time out");
810 assert!(err.to_string().contains("timed out"), "got: {err}");
811
812 drop(client);
813 server_task.abort();
814 }
815}