1use crate::events::AgentEvent;
36use crate::llm;
37use crate::types::{ToolOutcome, ToolResult, ToolTier};
38use anyhow::Result;
39use async_trait::async_trait;
40use futures::Stream;
41use serde::{Deserialize, Serialize, de::DeserializeOwned};
42use serde_json::Value;
43use std::collections::HashMap;
44use std::future::Future;
45use std::marker::PhantomData;
46use std::pin::Pin;
47use std::sync::Arc;
48use tokio::sync::mpsc;
49
50pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
73
74#[must_use]
82pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
83 serde_json::to_string(name)
84 .expect("ToolName must serialize to string")
85 .trim_matches('"')
86 .to_string()
87}
88
89pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
94 serde_json::from_str(&format!("\"{s}\""))
95}
96
97#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum PrimitiveToolName {
101 Read,
102 Write,
103 Edit,
104 MultiEdit,
105 Bash,
106 Glob,
107 Grep,
108 NotebookRead,
109 NotebookEdit,
110 TodoRead,
111 TodoWrite,
112 AskUser,
113 LinkFetch,
114 WebSearch,
115}
116
117impl ToolName for PrimitiveToolName {}
118
119#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
121#[serde(transparent)]
122pub struct DynamicToolName(String);
123
124impl DynamicToolName {
125 #[must_use]
126 pub fn new(name: impl Into<String>) -> Self {
127 Self(name.into())
128 }
129
130 #[must_use]
131 pub fn as_str(&self) -> &str {
132 &self.0
133 }
134}
135
136impl ToolName for DynamicToolName {}
137
138pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
161
162#[must_use]
169pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
170 serde_json::to_string(stage)
171 .expect("ProgressStage must serialize to string")
172 .trim_matches('"')
173 .to_string()
174}
175
176#[derive(Clone, Debug, Serialize)]
178pub enum ToolStatus<S: ProgressStage> {
179 Progress {
181 stage: S,
182 message: String,
183 data: Option<serde_json::Value>,
184 },
185
186 Completed(ToolResult),
188
189 Failed(ToolResult),
191}
192
193#[derive(Clone, Debug, Serialize, Deserialize)]
195pub enum ErasedToolStatus {
196 Progress {
198 stage: String,
199 message: String,
200 data: Option<serde_json::Value>,
201 },
202 Completed(ToolResult),
204 Failed(ToolResult),
206}
207
208impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
209 fn from(status: ToolStatus<S>) -> Self {
210 match status {
211 ToolStatus::Progress {
212 stage,
213 message,
214 data,
215 } => Self::Progress {
216 stage: stage_to_string(&stage),
217 message,
218 data,
219 },
220 ToolStatus::Completed(r) => Self::Completed(r),
221 ToolStatus::Failed(r) => Self::Failed(r),
222 }
223 }
224}
225
226pub struct ToolContext<Ctx> {
228 pub app: Ctx,
230 pub metadata: HashMap<String, Value>,
232 event_tx: Option<mpsc::Sender<AgentEvent>>,
234}
235
236impl<Ctx> ToolContext<Ctx> {
237 #[must_use]
238 pub fn new(app: Ctx) -> Self {
239 Self {
240 app,
241 metadata: HashMap::new(),
242 event_tx: None,
243 }
244 }
245
246 #[must_use]
247 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
248 self.metadata.insert(key.into(), value);
249 self
250 }
251
252 #[must_use]
254 pub fn with_event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
255 self.event_tx = Some(tx);
256 self
257 }
258
259 pub fn emit_event(&self, event: AgentEvent) {
264 if let Some(tx) = &self.event_tx {
265 let _ = tx.try_send(event);
266 }
267 }
268
269 #[must_use]
274 pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEvent>> {
275 self.event_tx.clone()
276 }
277}
278
279pub trait Tool<Ctx>: Send + Sync {
293 type Name: ToolName;
295
296 fn name(&self) -> Self::Name;
298
299 fn display_name(&self) -> &'static str;
303
304 fn description(&self) -> &'static str;
306
307 fn input_schema(&self) -> Value;
309
310 fn tier(&self) -> ToolTier {
312 ToolTier::Observe
313 }
314
315 fn execute(
320 &self,
321 ctx: &ToolContext<Ctx>,
322 input: Value,
323 ) -> impl Future<Output = Result<ToolResult>> + Send;
324}
325
326pub trait AsyncTool<Ctx>: Send + Sync {
375 type Name: ToolName;
377 type Stage: ProgressStage;
379
380 fn name(&self) -> Self::Name;
382
383 fn display_name(&self) -> &'static str;
385
386 fn description(&self) -> &'static str;
388
389 fn input_schema(&self) -> Value;
391
392 fn tier(&self) -> ToolTier {
394 ToolTier::Observe
395 }
396
397 fn execute(
404 &self,
405 ctx: &ToolContext<Ctx>,
406 input: Value,
407 ) -> impl Future<Output = Result<ToolOutcome>> + Send;
408
409 fn check_status(
412 &self,
413 ctx: &ToolContext<Ctx>,
414 operation_id: &str,
415 ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
416}
417
418#[async_trait]
435pub trait ErasedTool<Ctx>: Send + Sync {
436 fn name_str(&self) -> &str;
438 fn display_name(&self) -> &'static str;
440 fn description(&self) -> &'static str;
442 fn input_schema(&self) -> Value;
444 fn tier(&self) -> ToolTier;
446 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
448}
449
450struct ToolWrapper<T, Ctx>
452where
453 T: Tool<Ctx>,
454{
455 inner: T,
456 name_cache: String,
457 _marker: PhantomData<Ctx>,
458}
459
460impl<T, Ctx> ToolWrapper<T, Ctx>
461where
462 T: Tool<Ctx>,
463{
464 fn new(tool: T) -> Self {
465 let name_cache = tool_name_to_string(&tool.name());
466 Self {
467 inner: tool,
468 name_cache,
469 _marker: PhantomData,
470 }
471 }
472}
473
474#[async_trait]
475impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
476where
477 T: Tool<Ctx> + 'static,
478 Ctx: Send + Sync + 'static,
479{
480 fn name_str(&self) -> &str {
481 &self.name_cache
482 }
483
484 fn display_name(&self) -> &'static str {
485 self.inner.display_name()
486 }
487
488 fn description(&self) -> &'static str {
489 self.inner.description()
490 }
491
492 fn input_schema(&self) -> Value {
493 self.inner.input_schema()
494 }
495
496 fn tier(&self) -> ToolTier {
497 self.inner.tier()
498 }
499
500 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
501 self.inner.execute(ctx, input).await
502 }
503}
504
505#[async_trait]
514pub trait ErasedAsyncTool<Ctx>: Send + Sync {
515 fn name_str(&self) -> &str;
517 fn display_name(&self) -> &'static str;
519 fn description(&self) -> &'static str;
521 fn input_schema(&self) -> Value;
523 fn tier(&self) -> ToolTier;
525 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
527 fn check_status_stream<'a>(
529 &'a self,
530 ctx: &'a ToolContext<Ctx>,
531 operation_id: &'a str,
532 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
533}
534
535struct AsyncToolWrapper<T, Ctx>
537where
538 T: AsyncTool<Ctx>,
539{
540 inner: T,
541 name_cache: String,
542 _marker: PhantomData<Ctx>,
543}
544
545impl<T, Ctx> AsyncToolWrapper<T, Ctx>
546where
547 T: AsyncTool<Ctx>,
548{
549 fn new(tool: T) -> Self {
550 let name_cache = tool_name_to_string(&tool.name());
551 Self {
552 inner: tool,
553 name_cache,
554 _marker: PhantomData,
555 }
556 }
557}
558
559#[async_trait]
560impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
561where
562 T: AsyncTool<Ctx> + 'static,
563 Ctx: Send + Sync + 'static,
564{
565 fn name_str(&self) -> &str {
566 &self.name_cache
567 }
568
569 fn display_name(&self) -> &'static str {
570 self.inner.display_name()
571 }
572
573 fn description(&self) -> &'static str {
574 self.inner.description()
575 }
576
577 fn input_schema(&self) -> Value {
578 self.inner.input_schema()
579 }
580
581 fn tier(&self) -> ToolTier {
582 self.inner.tier()
583 }
584
585 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
586 self.inner.execute(ctx, input).await
587 }
588
589 fn check_status_stream<'a>(
590 &'a self,
591 ctx: &'a ToolContext<Ctx>,
592 operation_id: &'a str,
593 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
594 use futures::StreamExt;
595 let stream = self.inner.check_status(ctx, operation_id);
596 Box::pin(stream.map(ErasedToolStatus::from))
597 }
598}
599
600pub struct ToolRegistry<Ctx> {
612 tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
613 async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
614}
615
616impl<Ctx> Clone for ToolRegistry<Ctx> {
617 fn clone(&self) -> Self {
618 Self {
619 tools: self.tools.clone(),
620 async_tools: self.async_tools.clone(),
621 }
622 }
623}
624
625impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
626 fn default() -> Self {
627 Self::new()
628 }
629}
630
631impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
632 #[must_use]
633 pub fn new() -> Self {
634 Self {
635 tools: HashMap::new(),
636 async_tools: HashMap::new(),
637 }
638 }
639
640 pub fn register<T>(&mut self, tool: T) -> &mut Self
645 where
646 T: Tool<Ctx> + 'static,
647 {
648 let wrapper = ToolWrapper::new(tool);
649 let name = wrapper.name_str().to_string();
650 self.tools.insert(name, Arc::new(wrapper));
651 self
652 }
653
654 pub fn register_async<T>(&mut self, tool: T) -> &mut Self
659 where
660 T: AsyncTool<Ctx> + 'static,
661 {
662 let wrapper = AsyncToolWrapper::new(tool);
663 let name = wrapper.name_str().to_string();
664 self.async_tools.insert(name, Arc::new(wrapper));
665 self
666 }
667
668 #[must_use]
670 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
671 self.tools.get(name)
672 }
673
674 #[must_use]
676 pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
677 self.async_tools.get(name)
678 }
679
680 #[must_use]
682 pub fn is_async(&self, name: &str) -> bool {
683 self.async_tools.contains_key(name)
684 }
685
686 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
688 self.tools.values()
689 }
690
691 pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
693 self.async_tools.values()
694 }
695
696 #[must_use]
698 pub fn len(&self) -> usize {
699 self.tools.len() + self.async_tools.len()
700 }
701
702 #[must_use]
704 pub fn is_empty(&self) -> bool {
705 self.tools.is_empty() && self.async_tools.is_empty()
706 }
707
708 pub fn filter<F>(&mut self, predicate: F)
720 where
721 F: Fn(&str) -> bool,
722 {
723 self.tools.retain(|name, _| predicate(name));
724 self.async_tools.retain(|name, _| predicate(name));
725 }
726
727 #[must_use]
729 pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
730 let mut tools: Vec<_> = self
731 .tools
732 .values()
733 .map(|tool| llm::Tool {
734 name: tool.name_str().to_string(),
735 description: tool.description().to_string(),
736 input_schema: tool.input_schema(),
737 })
738 .collect();
739
740 tools.extend(self.async_tools.values().map(|tool| llm::Tool {
741 name: tool.name_str().to_string(),
742 description: tool.description().to_string(),
743 input_schema: tool.input_schema(),
744 }));
745
746 tools
747 }
748}
749
750#[cfg(test)]
751mod tests {
752 use super::*;
753
754 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
756 #[serde(rename_all = "snake_case")]
757 enum TestToolName {
758 MockTool,
759 AnotherTool,
760 }
761
762 impl ToolName for TestToolName {}
763
764 struct MockTool;
765
766 impl Tool<()> for MockTool {
767 type Name = TestToolName;
768
769 fn name(&self) -> TestToolName {
770 TestToolName::MockTool
771 }
772
773 fn display_name(&self) -> &'static str {
774 "Mock Tool"
775 }
776
777 fn description(&self) -> &'static str {
778 "A mock tool for testing"
779 }
780
781 fn input_schema(&self) -> Value {
782 serde_json::json!({
783 "type": "object",
784 "properties": {
785 "message": { "type": "string" }
786 }
787 })
788 }
789
790 async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
791 let message = input
792 .get("message")
793 .and_then(|v| v.as_str())
794 .unwrap_or("no message");
795 Ok(ToolResult::success(format!("Received: {message}")))
796 }
797 }
798
799 #[test]
800 fn test_tool_name_serialization() {
801 let name = TestToolName::MockTool;
802 assert_eq!(tool_name_to_string(&name), "mock_tool");
803
804 let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
805 assert_eq!(parsed, TestToolName::MockTool);
806 }
807
808 #[test]
809 fn test_dynamic_tool_name() {
810 let name = DynamicToolName::new("my_mcp_tool");
811 assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
812 assert_eq!(name.as_str(), "my_mcp_tool");
813 }
814
815 #[test]
816 fn test_tool_registry() {
817 let mut registry = ToolRegistry::new();
818 registry.register(MockTool);
819
820 assert_eq!(registry.len(), 1);
821 assert!(registry.get("mock_tool").is_some());
822 assert!(registry.get("nonexistent").is_none());
823 }
824
825 #[test]
826 fn test_to_llm_tools() {
827 let mut registry = ToolRegistry::new();
828 registry.register(MockTool);
829
830 let llm_tools = registry.to_llm_tools();
831 assert_eq!(llm_tools.len(), 1);
832 assert_eq!(llm_tools[0].name, "mock_tool");
833 }
834
835 struct AnotherTool;
836
837 impl Tool<()> for AnotherTool {
838 type Name = TestToolName;
839
840 fn name(&self) -> TestToolName {
841 TestToolName::AnotherTool
842 }
843
844 fn display_name(&self) -> &'static str {
845 "Another Tool"
846 }
847
848 fn description(&self) -> &'static str {
849 "Another tool for testing"
850 }
851
852 fn input_schema(&self) -> Value {
853 serde_json::json!({ "type": "object" })
854 }
855
856 async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
857 Ok(ToolResult::success("Done"))
858 }
859 }
860
861 #[test]
862 fn test_filter_tools() {
863 let mut registry = ToolRegistry::new();
864 registry.register(MockTool);
865 registry.register(AnotherTool);
866
867 assert_eq!(registry.len(), 2);
868
869 registry.filter(|name| name != "mock_tool");
871
872 assert_eq!(registry.len(), 1);
873 assert!(registry.get("mock_tool").is_none());
874 assert!(registry.get("another_tool").is_some());
875 }
876
877 #[test]
878 fn test_filter_tools_keep_all() {
879 let mut registry = ToolRegistry::new();
880 registry.register(MockTool);
881 registry.register(AnotherTool);
882
883 registry.filter(|_| true);
884
885 assert_eq!(registry.len(), 2);
886 }
887
888 #[test]
889 fn test_filter_tools_remove_all() {
890 let mut registry = ToolRegistry::new();
891 registry.register(MockTool);
892 registry.register(AnotherTool);
893
894 registry.filter(|_| false);
895
896 assert!(registry.is_empty());
897 }
898
899 #[test]
900 fn test_display_name() {
901 let mut registry = ToolRegistry::new();
902 registry.register(MockTool);
903
904 let tool = registry.get("mock_tool").unwrap();
905 assert_eq!(tool.display_name(), "Mock Tool");
906 }
907}