1use std::{collections::HashMap, sync::Arc};
2
3use schemars::{JsonSchema, Schema, schema_for};
4
5use crate::{
6 agent::prompt_request::hooks::PromptHook,
7 completion::{CompletionModel, Document},
8 memory::ConversationMemory,
9 message::ToolChoice,
10 tool::{
11 Tool, ToolDyn, ToolSet,
12 server::{ToolServer, ToolServerHandle},
13 },
14 vector_store::VectorStoreIndexDyn,
15};
16
17#[cfg(feature = "rmcp")]
18#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
19use crate::tool::rmcp::McpTool as RmcpTool;
20
21use super::Agent;
22
23#[derive(Default)]
31pub struct NoToolConfig;
32
33pub struct WithToolServerHandle {
38 handle: ToolServerHandle,
39}
40
41pub struct WithBuilderTools {
47 static_tools: Vec<String>,
48 tools: ToolSet,
49 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
50}
51
52pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
78where
79 M: CompletionModel,
80 P: PromptHook<M>,
81{
82 name: Option<String>,
84 description: Option<String>,
86 model: M,
88 preamble: Option<String>,
90 static_context: Vec<Document>,
92 additional_params: Option<serde_json::Value>,
94 max_tokens: Option<u64>,
96 dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
98 temperature: Option<f64>,
100 tool_choice: Option<ToolChoice>,
102 default_max_turns: Option<usize>,
104 tool_state: ToolState,
106 hook: Option<P>,
108 output_schema: Option<schemars::Schema>,
110 memory: Option<Arc<dyn ConversationMemory>>,
112 default_conversation_id: Option<String>,
114}
115
116impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
117where
118 M: CompletionModel,
119 P: PromptHook<M>,
120{
121 pub fn name(mut self, name: &str) -> Self {
123 self.name = Some(name.into());
124 self
125 }
126
127 pub fn description(mut self, description: &str) -> Self {
129 self.description = Some(description.into());
130 self
131 }
132
133 pub fn preamble(mut self, preamble: &str) -> Self {
135 self.preamble = Some(preamble.into());
136 self
137 }
138
139 pub fn without_preamble(mut self) -> Self {
141 self.preamble = None;
142 self
143 }
144
145 pub fn append_preamble(mut self, doc: &str) -> Self {
147 self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
148 self
149 }
150
151 pub fn context(mut self, doc: &str) -> Self {
153 self.static_context.push(Document {
154 id: format!("static_doc_{}", self.static_context.len()),
155 text: doc.into(),
156 additional_props: HashMap::new(),
157 });
158 self
159 }
160
161 pub fn dynamic_context(
164 mut self,
165 sample: usize,
166 dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
167 ) -> Self {
168 self.dynamic_context
169 .push((sample, Arc::new(dynamic_context)));
170 self
171 }
172
173 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
175 self.tool_choice = Some(tool_choice);
176 self
177 }
178
179 pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
181 self.default_max_turns = Some(default_max_turns);
182 self
183 }
184
185 pub fn temperature(mut self, temperature: f64) -> Self {
187 self.temperature = Some(temperature);
188 self
189 }
190
191 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
193 self.max_tokens = Some(max_tokens);
194 self
195 }
196
197 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
199 self.additional_params = Some(params);
200 self
201 }
202
203 pub fn output_schema<T>(mut self) -> Self
206 where
207 T: JsonSchema,
208 {
209 self.output_schema = Some(schema_for!(T));
210 self
211 }
212
213 pub fn output_schema_raw(mut self, schema: Schema) -> Self {
215 self.output_schema = Some(schema);
216 self
217 }
218
219 pub fn memory<B>(mut self, memory: B) -> Self
227 where
228 B: ConversationMemory + 'static,
229 {
230 self.memory = Some(Arc::new(memory));
231 self
232 }
233
234 pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
239 self.default_conversation_id = Some(id.into());
240 self
241 }
242}
243
244impl<M> AgentBuilder<M, (), NoToolConfig>
245where
246 M: CompletionModel,
247{
248 pub fn new(model: M) -> Self {
250 Self {
251 name: None,
252 description: None,
253 model,
254 preamble: None,
255 static_context: vec![],
256 temperature: None,
257 max_tokens: None,
258 additional_params: None,
259 dynamic_context: vec![],
260 tool_choice: None,
261 default_max_turns: None,
262 tool_state: NoToolConfig,
263 hook: None,
264 output_schema: None,
265 memory: None,
266 default_conversation_id: None,
267 }
268 }
269}
270
271impl<M, P> AgentBuilder<M, P, NoToolConfig>
272where
273 M: CompletionModel,
274 P: PromptHook<M>,
275{
276 pub fn tool_server_handle(
282 self,
283 handle: ToolServerHandle,
284 ) -> AgentBuilder<M, P, WithToolServerHandle> {
285 AgentBuilder {
286 name: self.name,
287 description: self.description,
288 model: self.model,
289 preamble: self.preamble,
290 static_context: self.static_context,
291 additional_params: self.additional_params,
292 max_tokens: self.max_tokens,
293 dynamic_context: self.dynamic_context,
294 temperature: self.temperature,
295 tool_choice: self.tool_choice,
296 default_max_turns: self.default_max_turns,
297 tool_state: WithToolServerHandle { handle },
298 hook: self.hook,
299 output_schema: self.output_schema,
300 memory: self.memory,
301 default_conversation_id: self.default_conversation_id,
302 }
303 }
304
305 pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
310 let toolname = tool.name();
311 AgentBuilder {
312 name: self.name,
313 description: self.description,
314 model: self.model,
315 preamble: self.preamble,
316 static_context: self.static_context,
317 additional_params: self.additional_params,
318 max_tokens: self.max_tokens,
319 dynamic_context: self.dynamic_context,
320 temperature: self.temperature,
321 tool_choice: self.tool_choice,
322 default_max_turns: self.default_max_turns,
323 tool_state: WithBuilderTools {
324 static_tools: vec![toolname],
325 tools: ToolSet::from_tools(vec![tool]),
326 dynamic_tools: vec![],
327 },
328 hook: self.hook,
329 output_schema: self.output_schema,
330 memory: self.memory,
331 default_conversation_id: self.default_conversation_id,
332 }
333 }
334
335 pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
340 let static_tools = tools.iter().map(|tool| tool.name()).collect();
341 let tools = ToolSet::from_tools_boxed(tools);
342
343 AgentBuilder {
344 name: self.name,
345 description: self.description,
346 model: self.model,
347 preamble: self.preamble,
348 static_context: self.static_context,
349 additional_params: self.additional_params,
350 max_tokens: self.max_tokens,
351 dynamic_context: self.dynamic_context,
352 temperature: self.temperature,
353 tool_choice: self.tool_choice,
354 default_max_turns: self.default_max_turns,
355 hook: self.hook,
356 output_schema: self.output_schema,
357 memory: self.memory,
358 default_conversation_id: self.default_conversation_id,
359 tool_state: WithBuilderTools {
360 static_tools,
361 tools,
362 dynamic_tools: vec![],
363 },
364 }
365 }
366
367 #[cfg(feature = "rmcp")]
371 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
372 pub fn rmcp_tool(
373 self,
374 tool: rmcp::model::Tool,
375 client: rmcp::service::ServerSink,
376 ) -> AgentBuilder<M, P, WithBuilderTools> {
377 let toolname = tool.name.clone().to_string();
378 let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
379
380 AgentBuilder {
381 name: self.name,
382 description: self.description,
383 model: self.model,
384 preamble: self.preamble,
385 static_context: self.static_context,
386 additional_params: self.additional_params,
387 max_tokens: self.max_tokens,
388 dynamic_context: self.dynamic_context,
389 temperature: self.temperature,
390 tool_choice: self.tool_choice,
391 default_max_turns: self.default_max_turns,
392 hook: self.hook,
393 output_schema: self.output_schema,
394 memory: self.memory,
395 default_conversation_id: self.default_conversation_id,
396 tool_state: WithBuilderTools {
397 static_tools: vec![toolname],
398 tools,
399 dynamic_tools: vec![],
400 },
401 }
402 }
403
404 #[cfg(feature = "rmcp")]
408 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
409 pub fn rmcp_tools(
410 self,
411 tools: Vec<rmcp::model::Tool>,
412 client: rmcp::service::ServerSink,
413 ) -> AgentBuilder<M, P, WithBuilderTools> {
414 let (static_tools, tools) = tools.into_iter().fold(
415 (Vec::new(), Vec::new()),
416 |(mut toolnames, mut toolset), tool| {
417 let tool_name = tool.name.to_string();
418 let tool = RmcpTool::from_mcp_server(tool, client.clone());
419 toolnames.push(tool_name);
420 toolset.push(tool);
421 (toolnames, toolset)
422 },
423 );
424
425 let tools = ToolSet::from_tools(tools);
426
427 AgentBuilder {
428 name: self.name,
429 description: self.description,
430 model: self.model,
431 preamble: self.preamble,
432 static_context: self.static_context,
433 additional_params: self.additional_params,
434 max_tokens: self.max_tokens,
435 dynamic_context: self.dynamic_context,
436 temperature: self.temperature,
437 tool_choice: self.tool_choice,
438 default_max_turns: self.default_max_turns,
439 hook: self.hook,
440 output_schema: self.output_schema,
441 memory: self.memory,
442 default_conversation_id: self.default_conversation_id,
443 tool_state: WithBuilderTools {
444 static_tools,
445 tools,
446 dynamic_tools: vec![],
447 },
448 }
449 }
450
451 pub fn dynamic_tools(
456 self,
457 sample: usize,
458 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
459 toolset: ToolSet,
460 ) -> AgentBuilder<M, P, WithBuilderTools> {
461 AgentBuilder {
462 name: self.name,
463 description: self.description,
464 model: self.model,
465 preamble: self.preamble,
466 static_context: self.static_context,
467 additional_params: self.additional_params,
468 max_tokens: self.max_tokens,
469 dynamic_context: self.dynamic_context,
470 temperature: self.temperature,
471 tool_choice: self.tool_choice,
472 default_max_turns: self.default_max_turns,
473 hook: self.hook,
474 output_schema: self.output_schema,
475 memory: self.memory,
476 default_conversation_id: self.default_conversation_id,
477 tool_state: WithBuilderTools {
478 static_tools: vec![],
479 tools: toolset,
480 dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
481 },
482 }
483 }
484
485 pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, NoToolConfig>
490 where
491 P2: PromptHook<M>,
492 {
493 AgentBuilder {
494 name: self.name,
495 description: self.description,
496 model: self.model,
497 preamble: self.preamble,
498 static_context: self.static_context,
499 additional_params: self.additional_params,
500 max_tokens: self.max_tokens,
501 dynamic_context: self.dynamic_context,
502 temperature: self.temperature,
503 tool_choice: self.tool_choice,
504 default_max_turns: self.default_max_turns,
505 tool_state: self.tool_state,
506 hook: Some(hook),
507 output_schema: self.output_schema,
508 memory: self.memory,
509 default_conversation_id: self.default_conversation_id,
510 }
511 }
512
513 pub fn build(self) -> Agent<M, P> {
517 let tool_server_handle = ToolServer::new().run();
518
519 Agent {
520 name: self.name,
521 description: self.description,
522 model: Arc::new(self.model),
523 preamble: self.preamble,
524 static_context: self.static_context,
525 temperature: self.temperature,
526 max_tokens: self.max_tokens,
527 additional_params: self.additional_params,
528 tool_choice: self.tool_choice,
529 dynamic_context: Arc::new(self.dynamic_context),
530 tool_server_handle,
531 default_max_turns: self.default_max_turns,
532 hook: self.hook,
533 output_schema: self.output_schema,
534 memory: self.memory,
535 default_conversation_id: self.default_conversation_id,
536 }
537 }
538}
539
540impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
541where
542 M: CompletionModel,
543 P: PromptHook<M>,
544{
545 pub fn build(self) -> Agent<M, P> {
547 Agent {
548 name: self.name,
549 description: self.description,
550 model: Arc::new(self.model),
551 preamble: self.preamble,
552 static_context: self.static_context,
553 temperature: self.temperature,
554 max_tokens: self.max_tokens,
555 additional_params: self.additional_params,
556 tool_choice: self.tool_choice,
557 dynamic_context: Arc::new(self.dynamic_context),
558 tool_server_handle: self.tool_state.handle,
559 default_max_turns: self.default_max_turns,
560 hook: self.hook,
561 output_schema: self.output_schema,
562 memory: self.memory,
563 default_conversation_id: self.default_conversation_id,
564 }
565 }
566}
567
568impl<M, P> AgentBuilder<M, P, WithBuilderTools>
569where
570 M: CompletionModel,
571 P: PromptHook<M>,
572{
573 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
575 let toolname = tool.name();
576 self.tool_state.tools.add_tool(tool);
577 self.tool_state.static_tools.push(toolname);
578 self
579 }
580
581 pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
583 let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
584 let tools = ToolSet::from_tools_boxed(tools);
585 self.tool_state.tools.add_tools(tools);
586 self.tool_state.static_tools.extend(toolnames);
587 self
588 }
589
590 #[cfg(feature = "rmcp")]
592 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
593 pub fn rmcp_tools(
594 mut self,
595 tools: Vec<rmcp::model::Tool>,
596 client: rmcp::service::ServerSink,
597 ) -> Self {
598 for tool in tools {
599 let tool_name = tool.name.to_string();
600 let tool = RmcpTool::from_mcp_server(tool, client.clone());
601 self.tool_state.static_tools.push(tool_name);
602 self.tool_state.tools.add_tool(tool);
603 }
604
605 self
606 }
607
608 pub fn dynamic_tools(
611 mut self,
612 sample: usize,
613 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
614 toolset: ToolSet,
615 ) -> Self {
616 self.tool_state
617 .dynamic_tools
618 .push((sample, Arc::new(dynamic_tools)));
619 self.tool_state.tools.add_tools(toolset);
620 self
621 }
622
623 pub fn build(self) -> Agent<M, P> {
628 let tool_server_handle = ToolServer::new()
629 .static_tool_names(self.tool_state.static_tools)
630 .add_tools(self.tool_state.tools)
631 .add_dynamic_tools(self.tool_state.dynamic_tools)
632 .run();
633
634 Agent {
635 name: self.name,
636 description: self.description,
637 model: Arc::new(self.model),
638 preamble: self.preamble,
639 static_context: self.static_context,
640 temperature: self.temperature,
641 max_tokens: self.max_tokens,
642 additional_params: self.additional_params,
643 tool_choice: self.tool_choice,
644 dynamic_context: Arc::new(self.dynamic_context),
645 tool_server_handle,
646 default_max_turns: self.default_max_turns,
647 hook: self.hook,
648 output_schema: self.output_schema,
649 memory: self.memory,
650 default_conversation_id: self.default_conversation_id,
651 }
652 }
653}