1use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
37use crate::llm;
38use crate::types::{ToolOutcome, ToolResult, ToolTier};
39use anyhow::Result;
40use async_trait::async_trait;
41use futures::Stream;
42use serde::{Deserialize, Serialize, de::DeserializeOwned};
43use serde_json::Value;
44use std::collections::HashMap;
45use std::future::Future;
46use std::marker::PhantomData;
47use std::pin::Pin;
48use std::sync::Arc;
49use time::OffsetDateTime;
50use tokio::sync::mpsc;
51
52pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
75
76#[must_use]
84pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
85 serde_json::to_string(name)
86 .expect("ToolName must serialize to string")
87 .trim_matches('"')
88 .to_string()
89}
90
91pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
96 serde_json::from_str(&format!("\"{s}\""))
97}
98
99#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101#[serde(rename_all = "snake_case")]
102pub enum PrimitiveToolName {
103 Read,
104 Write,
105 Edit,
106 MultiEdit,
107 Bash,
108 Glob,
109 Grep,
110 NotebookRead,
111 NotebookEdit,
112 TodoRead,
113 TodoWrite,
114 AskUser,
115 LinkFetch,
116 WebSearch,
117}
118
119impl ToolName for PrimitiveToolName {}
120
121#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
123#[serde(transparent)]
124pub struct DynamicToolName(String);
125
126impl DynamicToolName {
127 #[must_use]
128 pub fn new(name: impl Into<String>) -> Self {
129 Self(name.into())
130 }
131
132 #[must_use]
133 pub fn as_str(&self) -> &str {
134 &self.0
135 }
136}
137
138impl ToolName for DynamicToolName {}
139
140pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
163
164#[must_use]
171pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
172 serde_json::to_string(stage)
173 .expect("ProgressStage must serialize to string")
174 .trim_matches('"')
175 .to_string()
176}
177
178#[derive(Clone, Debug, Serialize)]
180pub enum ToolStatus<S: ProgressStage> {
181 Progress {
183 stage: S,
184 message: String,
185 data: Option<serde_json::Value>,
186 },
187
188 Completed(ToolResult),
190
191 Failed(ToolResult),
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize)]
197pub enum ErasedToolStatus {
198 Progress {
200 stage: String,
201 message: String,
202 data: Option<serde_json::Value>,
203 },
204 Completed(ToolResult),
206 Failed(ToolResult),
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize)]
215pub enum ListenToolUpdate {
216 Listening {
218 operation_id: String,
220 revision: u64,
222 message: String,
224 snapshot: Option<serde_json::Value>,
226 #[serde(with = "time::serde::rfc3339::option")]
228 expires_at: Option<OffsetDateTime>,
229 },
230
231 Ready {
233 operation_id: String,
235 revision: u64,
237 message: String,
239 snapshot: serde_json::Value,
241 #[serde(with = "time::serde::rfc3339::option")]
243 expires_at: Option<OffsetDateTime>,
244 },
245
246 Invalidated {
248 operation_id: String,
250 message: String,
252 recoverable: bool,
254 },
255}
256
257#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
259pub enum ListenStopReason {
260 UserRejected,
262 Blocked,
264 StreamDisconnected,
266 StreamEnded,
268}
269
270impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
271 fn from(status: ToolStatus<S>) -> Self {
272 match status {
273 ToolStatus::Progress {
274 stage,
275 message,
276 data,
277 } => Self::Progress {
278 stage: stage_to_string(&stage),
279 message,
280 data,
281 },
282 ToolStatus::Completed(r) => Self::Completed(r),
283 ToolStatus::Failed(r) => Self::Failed(r),
284 }
285 }
286}
287
288pub struct ToolContext<Ctx> {
290 pub app: Ctx,
292 pub metadata: HashMap<String, Value>,
294 event_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
296 event_seq: Option<SequenceCounter>,
298}
299
300impl<Ctx> ToolContext<Ctx> {
301 #[must_use]
302 pub fn new(app: Ctx) -> Self {
303 Self {
304 app,
305 metadata: HashMap::new(),
306 event_tx: None,
307 event_seq: None,
308 }
309 }
310
311 #[must_use]
312 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
313 self.metadata.insert(key.into(), value);
314 self
315 }
316
317 #[must_use]
320 pub fn with_event_tx(
321 mut self,
322 tx: mpsc::Sender<AgentEventEnvelope>,
323 seq: SequenceCounter,
324 ) -> Self {
325 self.event_tx = Some(tx);
326 self.event_seq = Some(seq);
327 self
328 }
329
330 pub fn emit_event(&self, event: AgentEvent) {
338 if let Some((tx, seq)) = self.event_tx.as_ref().zip(self.event_seq.as_ref()) {
339 let envelope = AgentEventEnvelope::wrap(event, seq);
340 let _ = tx.try_send(envelope);
341 }
342 }
343
344 #[must_use]
349 pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEventEnvelope>> {
350 self.event_tx.clone()
351 }
352
353 #[must_use]
358 pub fn event_seq(&self) -> Option<SequenceCounter> {
359 self.event_seq.clone()
360 }
361}
362
363pub trait Tool<Ctx>: Send + Sync {
377 type Name: ToolName;
379
380 fn name(&self) -> Self::Name;
382
383 fn display_name(&self) -> &'static str;
387
388 fn description(&self) -> &'static str;
390
391 fn input_schema(&self) -> Value;
393
394 fn tier(&self) -> ToolTier {
396 ToolTier::Observe
397 }
398
399 fn execute(
404 &self,
405 ctx: &ToolContext<Ctx>,
406 input: Value,
407 ) -> impl Future<Output = Result<ToolResult>> + Send;
408}
409
410pub trait AsyncTool<Ctx>: Send + Sync {
459 type Name: ToolName;
461 type Stage: ProgressStage;
463
464 fn name(&self) -> Self::Name;
466
467 fn display_name(&self) -> &'static str;
469
470 fn description(&self) -> &'static str;
472
473 fn input_schema(&self) -> Value;
475
476 fn tier(&self) -> ToolTier {
478 ToolTier::Observe
479 }
480
481 fn execute(
488 &self,
489 ctx: &ToolContext<Ctx>,
490 input: Value,
491 ) -> impl Future<Output = Result<ToolOutcome>> + Send;
492
493 fn check_status(
496 &self,
497 ctx: &ToolContext<Ctx>,
498 operation_id: &str,
499 ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
500}
501
502pub trait ListenExecuteTool<Ctx>: Send + Sync {
517 type Name: ToolName;
519
520 fn name(&self) -> Self::Name;
522
523 fn display_name(&self) -> &'static str;
525
526 fn description(&self) -> &'static str;
528
529 fn input_schema(&self) -> Value;
531
532 fn tier(&self) -> ToolTier {
534 ToolTier::Confirm
535 }
536
537 fn listen(
539 &self,
540 ctx: &ToolContext<Ctx>,
541 input: Value,
542 ) -> impl Stream<Item = ListenToolUpdate> + Send;
543
544 fn execute(
549 &self,
550 ctx: &ToolContext<Ctx>,
551 operation_id: &str,
552 expected_revision: u64,
553 ) -> impl Future<Output = Result<ToolResult>> + Send;
554
555 fn cancel(
560 &self,
561 _ctx: &ToolContext<Ctx>,
562 _operation_id: &str,
563 _reason: ListenStopReason,
564 ) -> impl Future<Output = Result<()>> + Send {
565 async { Ok(()) }
566 }
567}
568
569#[async_trait]
586pub trait ErasedTool<Ctx>: Send + Sync {
587 fn name_str(&self) -> &str;
589 fn display_name(&self) -> &'static str;
591 fn description(&self) -> &'static str;
593 fn input_schema(&self) -> Value;
595 fn tier(&self) -> ToolTier;
597 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
599}
600
601struct ToolWrapper<T, Ctx>
603where
604 T: Tool<Ctx>,
605{
606 inner: T,
607 name_cache: String,
608 _marker: PhantomData<Ctx>,
609}
610
611impl<T, Ctx> ToolWrapper<T, Ctx>
612where
613 T: Tool<Ctx>,
614{
615 fn new(tool: T) -> Self {
616 let name_cache = tool_name_to_string(&tool.name());
617 Self {
618 inner: tool,
619 name_cache,
620 _marker: PhantomData,
621 }
622 }
623}
624
625#[async_trait]
626impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
627where
628 T: Tool<Ctx> + 'static,
629 Ctx: Send + Sync + 'static,
630{
631 fn name_str(&self) -> &str {
632 &self.name_cache
633 }
634
635 fn display_name(&self) -> &'static str {
636 self.inner.display_name()
637 }
638
639 fn description(&self) -> &'static str {
640 self.inner.description()
641 }
642
643 fn input_schema(&self) -> Value {
644 self.inner.input_schema()
645 }
646
647 fn tier(&self) -> ToolTier {
648 self.inner.tier()
649 }
650
651 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
652 self.inner.execute(ctx, input).await
653 }
654}
655
656#[async_trait]
665pub trait ErasedAsyncTool<Ctx>: Send + Sync {
666 fn name_str(&self) -> &str;
668 fn display_name(&self) -> &'static str;
670 fn description(&self) -> &'static str;
672 fn input_schema(&self) -> Value;
674 fn tier(&self) -> ToolTier;
676 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
678 fn check_status_stream<'a>(
680 &'a self,
681 ctx: &'a ToolContext<Ctx>,
682 operation_id: &'a str,
683 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
684}
685
686struct AsyncToolWrapper<T, Ctx>
688where
689 T: AsyncTool<Ctx>,
690{
691 inner: T,
692 name_cache: String,
693 _marker: PhantomData<Ctx>,
694}
695
696impl<T, Ctx> AsyncToolWrapper<T, Ctx>
697where
698 T: AsyncTool<Ctx>,
699{
700 fn new(tool: T) -> Self {
701 let name_cache = tool_name_to_string(&tool.name());
702 Self {
703 inner: tool,
704 name_cache,
705 _marker: PhantomData,
706 }
707 }
708}
709
710#[async_trait]
711impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
712where
713 T: AsyncTool<Ctx> + 'static,
714 Ctx: Send + Sync + 'static,
715{
716 fn name_str(&self) -> &str {
717 &self.name_cache
718 }
719
720 fn display_name(&self) -> &'static str {
721 self.inner.display_name()
722 }
723
724 fn description(&self) -> &'static str {
725 self.inner.description()
726 }
727
728 fn input_schema(&self) -> Value {
729 self.inner.input_schema()
730 }
731
732 fn tier(&self) -> ToolTier {
733 self.inner.tier()
734 }
735
736 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
737 self.inner.execute(ctx, input).await
738 }
739
740 fn check_status_stream<'a>(
741 &'a self,
742 ctx: &'a ToolContext<Ctx>,
743 operation_id: &'a str,
744 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
745 use futures::StreamExt;
746 let stream = self.inner.check_status(ctx, operation_id);
747 Box::pin(stream.map(ErasedToolStatus::from))
748 }
749}
750
751#[async_trait]
757pub trait ErasedListenTool<Ctx>: Send + Sync {
758 fn name_str(&self) -> &str;
760 fn display_name(&self) -> &'static str;
762 fn description(&self) -> &'static str;
764 fn input_schema(&self) -> Value;
766 fn tier(&self) -> ToolTier;
768 fn listen_stream<'a>(
770 &'a self,
771 ctx: &'a ToolContext<Ctx>,
772 input: Value,
773 ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
774 async fn execute(
776 &self,
777 ctx: &ToolContext<Ctx>,
778 operation_id: &str,
779 expected_revision: u64,
780 ) -> Result<ToolResult>;
781 async fn cancel(
783 &self,
784 ctx: &ToolContext<Ctx>,
785 operation_id: &str,
786 reason: ListenStopReason,
787 ) -> Result<()>;
788}
789
790struct ListenToolWrapper<T, Ctx>
792where
793 T: ListenExecuteTool<Ctx>,
794{
795 inner: T,
796 name_cache: String,
797 _marker: PhantomData<Ctx>,
798}
799
800impl<T, Ctx> ListenToolWrapper<T, Ctx>
801where
802 T: ListenExecuteTool<Ctx>,
803{
804 fn new(tool: T) -> Self {
805 let name_cache = tool_name_to_string(&tool.name());
806 Self {
807 inner: tool,
808 name_cache,
809 _marker: PhantomData,
810 }
811 }
812}
813
814#[async_trait]
815impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
816where
817 T: ListenExecuteTool<Ctx> + 'static,
818 Ctx: Send + Sync + 'static,
819{
820 fn name_str(&self) -> &str {
821 &self.name_cache
822 }
823
824 fn display_name(&self) -> &'static str {
825 self.inner.display_name()
826 }
827
828 fn description(&self) -> &'static str {
829 self.inner.description()
830 }
831
832 fn input_schema(&self) -> Value {
833 self.inner.input_schema()
834 }
835
836 fn tier(&self) -> ToolTier {
837 self.inner.tier()
838 }
839
840 fn listen_stream<'a>(
841 &'a self,
842 ctx: &'a ToolContext<Ctx>,
843 input: Value,
844 ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
845 let stream = self.inner.listen(ctx, input);
846 Box::pin(stream)
847 }
848
849 async fn execute(
850 &self,
851 ctx: &ToolContext<Ctx>,
852 operation_id: &str,
853 expected_revision: u64,
854 ) -> Result<ToolResult> {
855 self.inner
856 .execute(ctx, operation_id, expected_revision)
857 .await
858 }
859
860 async fn cancel(
861 &self,
862 ctx: &ToolContext<Ctx>,
863 operation_id: &str,
864 reason: ListenStopReason,
865 ) -> Result<()> {
866 self.inner.cancel(ctx, operation_id, reason).await
867 }
868}
869
870pub struct ToolRegistry<Ctx> {
878 tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
879 async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
880 listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
881}
882
883impl<Ctx> Clone for ToolRegistry<Ctx> {
884 fn clone(&self) -> Self {
885 Self {
886 tools: self.tools.clone(),
887 async_tools: self.async_tools.clone(),
888 listen_tools: self.listen_tools.clone(),
889 }
890 }
891}
892
893impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
894 fn default() -> Self {
895 Self::new()
896 }
897}
898
899impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
900 #[must_use]
901 pub fn new() -> Self {
902 Self {
903 tools: HashMap::new(),
904 async_tools: HashMap::new(),
905 listen_tools: HashMap::new(),
906 }
907 }
908
909 pub fn register<T>(&mut self, tool: T) -> &mut Self
914 where
915 T: Tool<Ctx> + 'static,
916 {
917 let wrapper = ToolWrapper::new(tool);
918 let name = wrapper.name_str().to_string();
919 self.tools.insert(name, Arc::new(wrapper));
920 self
921 }
922
923 pub fn register_async<T>(&mut self, tool: T) -> &mut Self
928 where
929 T: AsyncTool<Ctx> + 'static,
930 {
931 let wrapper = AsyncToolWrapper::new(tool);
932 let name = wrapper.name_str().to_string();
933 self.async_tools.insert(name, Arc::new(wrapper));
934 self
935 }
936
937 pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
942 where
943 T: ListenExecuteTool<Ctx> + 'static,
944 {
945 let wrapper = ListenToolWrapper::new(tool);
946 let name = wrapper.name_str().to_string();
947 self.listen_tools.insert(name, Arc::new(wrapper));
948 self
949 }
950
951 #[must_use]
953 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
954 self.tools.get(name)
955 }
956
957 #[must_use]
959 pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
960 self.async_tools.get(name)
961 }
962
963 #[must_use]
965 pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
966 self.listen_tools.get(name)
967 }
968
969 #[must_use]
971 pub fn is_async(&self, name: &str) -> bool {
972 self.async_tools.contains_key(name)
973 }
974
975 #[must_use]
977 pub fn is_listen(&self, name: &str) -> bool {
978 self.listen_tools.contains_key(name)
979 }
980
981 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
983 self.tools.values()
984 }
985
986 pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
988 self.async_tools.values()
989 }
990
991 pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
993 self.listen_tools.values()
994 }
995
996 #[must_use]
998 pub fn len(&self) -> usize {
999 self.tools.len() + self.async_tools.len() + self.listen_tools.len()
1000 }
1001
1002 #[must_use]
1004 pub fn is_empty(&self) -> bool {
1005 self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
1006 }
1007
1008 pub fn filter<F>(&mut self, predicate: F)
1020 where
1021 F: Fn(&str) -> bool,
1022 {
1023 self.tools.retain(|name, _| predicate(name));
1024 self.async_tools.retain(|name, _| predicate(name));
1025 self.listen_tools.retain(|name, _| predicate(name));
1026 }
1027
1028 #[must_use]
1030 pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
1031 let mut tools: Vec<_> = self
1032 .tools
1033 .values()
1034 .map(|tool| llm::Tool {
1035 name: tool.name_str().to_string(),
1036 description: tool.description().to_string(),
1037 input_schema: tool.input_schema(),
1038 })
1039 .collect();
1040
1041 tools.extend(self.async_tools.values().map(|tool| llm::Tool {
1042 name: tool.name_str().to_string(),
1043 description: tool.description().to_string(),
1044 input_schema: tool.input_schema(),
1045 }));
1046
1047 tools.extend(self.listen_tools.values().map(|tool| llm::Tool {
1048 name: tool.name_str().to_string(),
1049 description: tool.description().to_string(),
1050 input_schema: tool.input_schema(),
1051 }));
1052
1053 tools
1054 }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060
1061 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
1063 #[serde(rename_all = "snake_case")]
1064 enum TestToolName {
1065 MockTool,
1066 AnotherTool,
1067 }
1068
1069 impl ToolName for TestToolName {}
1070
1071 struct MockTool;
1072
1073 impl Tool<()> for MockTool {
1074 type Name = TestToolName;
1075
1076 fn name(&self) -> TestToolName {
1077 TestToolName::MockTool
1078 }
1079
1080 fn display_name(&self) -> &'static str {
1081 "Mock Tool"
1082 }
1083
1084 fn description(&self) -> &'static str {
1085 "A mock tool for testing"
1086 }
1087
1088 fn input_schema(&self) -> Value {
1089 serde_json::json!({
1090 "type": "object",
1091 "properties": {
1092 "message": { "type": "string" }
1093 }
1094 })
1095 }
1096
1097 async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
1098 let message = input
1099 .get("message")
1100 .and_then(|v| v.as_str())
1101 .unwrap_or("no message");
1102 Ok(ToolResult::success(format!("Received: {message}")))
1103 }
1104 }
1105
1106 #[test]
1107 fn test_tool_name_serialization() {
1108 let name = TestToolName::MockTool;
1109 assert_eq!(tool_name_to_string(&name), "mock_tool");
1110
1111 let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
1112 assert_eq!(parsed, TestToolName::MockTool);
1113 }
1114
1115 #[test]
1116 fn test_dynamic_tool_name() {
1117 let name = DynamicToolName::new("my_mcp_tool");
1118 assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
1119 assert_eq!(name.as_str(), "my_mcp_tool");
1120 }
1121
1122 #[test]
1123 fn test_tool_registry() {
1124 let mut registry = ToolRegistry::new();
1125 registry.register(MockTool);
1126
1127 assert_eq!(registry.len(), 1);
1128 assert!(registry.get("mock_tool").is_some());
1129 assert!(registry.get("nonexistent").is_none());
1130 }
1131
1132 #[test]
1133 fn test_to_llm_tools() {
1134 let mut registry = ToolRegistry::new();
1135 registry.register(MockTool);
1136
1137 let llm_tools = registry.to_llm_tools();
1138 assert_eq!(llm_tools.len(), 1);
1139 assert_eq!(llm_tools[0].name, "mock_tool");
1140 }
1141
1142 struct AnotherTool;
1143
1144 impl Tool<()> for AnotherTool {
1145 type Name = TestToolName;
1146
1147 fn name(&self) -> TestToolName {
1148 TestToolName::AnotherTool
1149 }
1150
1151 fn display_name(&self) -> &'static str {
1152 "Another Tool"
1153 }
1154
1155 fn description(&self) -> &'static str {
1156 "Another tool for testing"
1157 }
1158
1159 fn input_schema(&self) -> Value {
1160 serde_json::json!({ "type": "object" })
1161 }
1162
1163 async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
1164 Ok(ToolResult::success("Done"))
1165 }
1166 }
1167
1168 #[test]
1169 fn test_filter_tools() {
1170 let mut registry = ToolRegistry::new();
1171 registry.register(MockTool);
1172 registry.register(AnotherTool);
1173
1174 assert_eq!(registry.len(), 2);
1175
1176 registry.filter(|name| name != "mock_tool");
1178
1179 assert_eq!(registry.len(), 1);
1180 assert!(registry.get("mock_tool").is_none());
1181 assert!(registry.get("another_tool").is_some());
1182 }
1183
1184 #[test]
1185 fn test_filter_tools_keep_all() {
1186 let mut registry = ToolRegistry::new();
1187 registry.register(MockTool);
1188 registry.register(AnotherTool);
1189
1190 registry.filter(|_| true);
1191
1192 assert_eq!(registry.len(), 2);
1193 }
1194
1195 #[test]
1196 fn test_filter_tools_remove_all() {
1197 let mut registry = ToolRegistry::new();
1198 registry.register(MockTool);
1199 registry.register(AnotherTool);
1200
1201 registry.filter(|_| false);
1202
1203 assert!(registry.is_empty());
1204 }
1205
1206 #[test]
1207 fn test_display_name() {
1208 let mut registry = ToolRegistry::new();
1209 registry.register(MockTool);
1210
1211 let tool = registry.get("mock_tool").unwrap();
1212 assert_eq!(tool.display_name(), "Mock Tool");
1213 }
1214
1215 struct ListenMockTool;
1216
1217 impl ListenExecuteTool<()> for ListenMockTool {
1218 type Name = TestToolName;
1219
1220 fn name(&self) -> TestToolName {
1221 TestToolName::MockTool
1222 }
1223
1224 fn display_name(&self) -> &'static str {
1225 "Listen Mock Tool"
1226 }
1227
1228 fn description(&self) -> &'static str {
1229 "A listen/execute mock tool for testing"
1230 }
1231
1232 fn input_schema(&self) -> Value {
1233 serde_json::json!({ "type": "object" })
1234 }
1235
1236 fn listen(
1237 &self,
1238 _ctx: &ToolContext<()>,
1239 _input: Value,
1240 ) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
1241 futures::stream::iter(vec![ListenToolUpdate::Ready {
1242 operation_id: "op_1".to_string(),
1243 revision: 1,
1244 message: "ready".to_string(),
1245 snapshot: serde_json::json!({"ok": true}),
1246 expires_at: None,
1247 }])
1248 }
1249
1250 async fn execute(
1251 &self,
1252 _ctx: &ToolContext<()>,
1253 _operation_id: &str,
1254 _expected_revision: u64,
1255 ) -> Result<ToolResult> {
1256 Ok(ToolResult::success("Executed"))
1257 }
1258 }
1259
1260 #[test]
1261 fn test_listen_tool_registry() {
1262 let mut registry = ToolRegistry::new();
1263 registry.register_listen(ListenMockTool);
1264
1265 assert_eq!(registry.len(), 1);
1266 assert!(registry.get_listen("mock_tool").is_some());
1267 assert!(registry.is_listen("mock_tool"));
1268 }
1269}