1use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
37use crate::llm;
38use crate::types::{ToolOutcome, ToolResult, ToolTier};
39use anyhow::Result;
40use async_trait::async_trait;
41use futures::Stream;
42use serde::{Deserialize, Serialize, de::DeserializeOwned};
43use serde_json::Value;
44use std::collections::HashMap;
45use std::future::Future;
46use std::marker::PhantomData;
47use std::pin::Pin;
48use std::sync::Arc;
49use time::OffsetDateTime;
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52
53pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
76
77#[must_use]
82pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
83 serde_json::to_string(name)
84 .unwrap_or_else(|_| "\"<unknown_tool>\"".to_string())
85 .trim_matches('"')
86 .to_string()
87}
88
89pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
94 serde_json::from_str(&format!("\"{s}\""))
95}
96
97#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum PrimitiveToolName {
101 Read,
102 Write,
103 Edit,
104 MultiEdit,
105 Bash,
106 Glob,
107 Grep,
108 NotebookRead,
109 NotebookEdit,
110 TodoRead,
111 TodoWrite,
112 AskUser,
113 LinkFetch,
114 WebSearch,
115}
116
117impl ToolName for PrimitiveToolName {}
118
119#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
121#[serde(transparent)]
122pub struct DynamicToolName(String);
123
124impl DynamicToolName {
125 #[must_use]
126 pub fn new(name: impl Into<String>) -> Self {
127 Self(name.into())
128 }
129
130 #[must_use]
131 pub fn as_str(&self) -> &str {
132 &self.0
133 }
134}
135
136impl ToolName for DynamicToolName {}
137
138pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
161
162#[must_use]
169pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
170 serde_json::to_string(stage)
171 .expect("ProgressStage must serialize to string")
172 .trim_matches('"')
173 .to_string()
174}
175
176#[derive(Clone, Debug, Serialize)]
178pub enum ToolStatus<S: ProgressStage> {
179 Progress {
181 stage: S,
182 message: String,
183 data: Option<serde_json::Value>,
184 },
185
186 Completed(ToolResult),
188
189 Failed(ToolResult),
191}
192
193#[derive(Clone, Debug, Serialize, Deserialize)]
195pub enum ErasedToolStatus {
196 Progress {
198 stage: String,
199 message: String,
200 data: Option<serde_json::Value>,
201 },
202 Completed(ToolResult),
204 Failed(ToolResult),
206}
207
208#[derive(Clone, Debug, Serialize, Deserialize)]
213pub enum ListenToolUpdate {
214 Listening {
216 operation_id: String,
218 revision: u64,
220 message: String,
222 snapshot: Option<serde_json::Value>,
224 #[serde(with = "time::serde::rfc3339::option")]
226 expires_at: Option<OffsetDateTime>,
227 },
228
229 Ready {
231 operation_id: String,
233 revision: u64,
235 message: String,
237 snapshot: serde_json::Value,
239 #[serde(with = "time::serde::rfc3339::option")]
241 expires_at: Option<OffsetDateTime>,
242 },
243
244 Invalidated {
246 operation_id: String,
248 message: String,
250 recoverable: bool,
252 },
253}
254
255#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
257pub enum ListenStopReason {
258 UserRejected,
260 Blocked,
262 StreamDisconnected,
264 StreamEnded,
266}
267
268impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
269 fn from(status: ToolStatus<S>) -> Self {
270 match status {
271 ToolStatus::Progress {
272 stage,
273 message,
274 data,
275 } => Self::Progress {
276 stage: stage_to_string(&stage),
277 message,
278 data,
279 },
280 ToolStatus::Completed(r) => Self::Completed(r),
281 ToolStatus::Failed(r) => Self::Failed(r),
282 }
283 }
284}
285
286pub struct ToolContext<Ctx> {
288 pub app: Ctx,
290 pub metadata: HashMap<String, Value>,
292 event_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
294 event_seq: Option<SequenceCounter>,
296 cancel_token: Option<CancellationToken>,
298}
299
300impl<Ctx> ToolContext<Ctx> {
301 #[must_use]
302 pub fn new(app: Ctx) -> Self {
303 Self {
304 app,
305 metadata: HashMap::new(),
306 event_tx: None,
307 event_seq: None,
308 cancel_token: None,
309 }
310 }
311
312 #[must_use]
313 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
314 self.metadata.insert(key.into(), value);
315 self
316 }
317
318 #[must_use]
321 pub fn with_event_tx(
322 mut self,
323 tx: mpsc::Sender<AgentEventEnvelope>,
324 seq: SequenceCounter,
325 ) -> Self {
326 self.event_tx = Some(tx);
327 self.event_seq = Some(seq);
328 self
329 }
330
331 pub fn emit_event(&self, event: AgentEvent) {
339 if let Some((tx, seq)) = self.event_tx.as_ref().zip(self.event_seq.as_ref()) {
340 let envelope = AgentEventEnvelope::wrap(event, seq);
341 let _ = tx.try_send(envelope);
342 }
343 }
344
345 #[must_use]
350 pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEventEnvelope>> {
351 self.event_tx.clone()
352 }
353
354 #[must_use]
359 pub fn event_seq(&self) -> Option<SequenceCounter> {
360 self.event_seq.clone()
361 }
362
363 #[must_use]
365 pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
366 self.cancel_token = Some(token);
367 self
368 }
369
370 #[must_use]
375 pub fn cancel_token(&self) -> Option<CancellationToken> {
376 self.cancel_token.clone()
377 }
378}
379
380pub trait Tool<Ctx>: Send + Sync {
394 type Name: ToolName;
396
397 fn name(&self) -> Self::Name;
399
400 fn display_name(&self) -> &'static str;
404
405 fn description(&self) -> &'static str;
407
408 fn input_schema(&self) -> Value;
410
411 fn tier(&self) -> ToolTier {
413 ToolTier::Observe
414 }
415
416 fn execute(
421 &self,
422 ctx: &ToolContext<Ctx>,
423 input: Value,
424 ) -> impl Future<Output = Result<ToolResult>> + Send;
425}
426
427pub trait AsyncTool<Ctx>: Send + Sync {
476 type Name: ToolName;
478 type Stage: ProgressStage;
480
481 fn name(&self) -> Self::Name;
483
484 fn display_name(&self) -> &'static str;
486
487 fn description(&self) -> &'static str;
489
490 fn input_schema(&self) -> Value;
492
493 fn tier(&self) -> ToolTier {
495 ToolTier::Observe
496 }
497
498 fn execute(
505 &self,
506 ctx: &ToolContext<Ctx>,
507 input: Value,
508 ) -> impl Future<Output = Result<ToolOutcome>> + Send;
509
510 fn check_status(
513 &self,
514 ctx: &ToolContext<Ctx>,
515 operation_id: &str,
516 ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
517}
518
519pub trait ListenExecuteTool<Ctx>: Send + Sync {
534 type Name: ToolName;
536
537 fn name(&self) -> Self::Name;
539
540 fn display_name(&self) -> &'static str;
542
543 fn description(&self) -> &'static str;
545
546 fn input_schema(&self) -> Value;
548
549 fn tier(&self) -> ToolTier {
551 ToolTier::Confirm
552 }
553
554 fn listen(
556 &self,
557 ctx: &ToolContext<Ctx>,
558 input: Value,
559 ) -> impl Stream<Item = ListenToolUpdate> + Send;
560
561 fn execute(
566 &self,
567 ctx: &ToolContext<Ctx>,
568 operation_id: &str,
569 expected_revision: u64,
570 ) -> impl Future<Output = Result<ToolResult>> + Send;
571
572 fn cancel(
577 &self,
578 _ctx: &ToolContext<Ctx>,
579 _operation_id: &str,
580 _reason: ListenStopReason,
581 ) -> impl Future<Output = Result<()>> + Send {
582 async { Ok(()) }
583 }
584}
585
586#[async_trait]
603pub trait ErasedTool<Ctx>: Send + Sync {
604 fn name_str(&self) -> &str;
606 fn display_name(&self) -> &'static str;
608 fn description(&self) -> &'static str;
610 fn input_schema(&self) -> Value;
612 fn tier(&self) -> ToolTier;
614 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
616}
617
618struct ToolWrapper<T, Ctx>
620where
621 T: Tool<Ctx>,
622{
623 inner: T,
624 name_cache: String,
625 _marker: PhantomData<Ctx>,
626}
627
628impl<T, Ctx> ToolWrapper<T, Ctx>
629where
630 T: Tool<Ctx>,
631{
632 fn new(tool: T) -> Self {
633 let name_cache = tool_name_to_string(&tool.name());
634 Self {
635 inner: tool,
636 name_cache,
637 _marker: PhantomData,
638 }
639 }
640}
641
642#[async_trait]
643impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
644where
645 T: Tool<Ctx> + 'static,
646 Ctx: Send + Sync + 'static,
647{
648 fn name_str(&self) -> &str {
649 &self.name_cache
650 }
651
652 fn display_name(&self) -> &'static str {
653 self.inner.display_name()
654 }
655
656 fn description(&self) -> &'static str {
657 self.inner.description()
658 }
659
660 fn input_schema(&self) -> Value {
661 self.inner.input_schema()
662 }
663
664 fn tier(&self) -> ToolTier {
665 self.inner.tier()
666 }
667
668 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
669 self.inner.execute(ctx, input).await
670 }
671}
672
673#[async_trait]
682pub trait ErasedAsyncTool<Ctx>: Send + Sync {
683 fn name_str(&self) -> &str;
685 fn display_name(&self) -> &'static str;
687 fn description(&self) -> &'static str;
689 fn input_schema(&self) -> Value;
691 fn tier(&self) -> ToolTier;
693 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
695 fn check_status_stream<'a>(
697 &'a self,
698 ctx: &'a ToolContext<Ctx>,
699 operation_id: &'a str,
700 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
701}
702
703struct AsyncToolWrapper<T, Ctx>
705where
706 T: AsyncTool<Ctx>,
707{
708 inner: T,
709 name_cache: String,
710 _marker: PhantomData<Ctx>,
711}
712
713impl<T, Ctx> AsyncToolWrapper<T, Ctx>
714where
715 T: AsyncTool<Ctx>,
716{
717 fn new(tool: T) -> Self {
718 let name_cache = tool_name_to_string(&tool.name());
719 Self {
720 inner: tool,
721 name_cache,
722 _marker: PhantomData,
723 }
724 }
725}
726
727#[async_trait]
728impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
729where
730 T: AsyncTool<Ctx> + 'static,
731 Ctx: Send + Sync + 'static,
732{
733 fn name_str(&self) -> &str {
734 &self.name_cache
735 }
736
737 fn display_name(&self) -> &'static str {
738 self.inner.display_name()
739 }
740
741 fn description(&self) -> &'static str {
742 self.inner.description()
743 }
744
745 fn input_schema(&self) -> Value {
746 self.inner.input_schema()
747 }
748
749 fn tier(&self) -> ToolTier {
750 self.inner.tier()
751 }
752
753 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
754 self.inner.execute(ctx, input).await
755 }
756
757 fn check_status_stream<'a>(
758 &'a self,
759 ctx: &'a ToolContext<Ctx>,
760 operation_id: &'a str,
761 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
762 use futures::StreamExt;
763 let stream = self.inner.check_status(ctx, operation_id);
764 Box::pin(stream.map(ErasedToolStatus::from))
765 }
766}
767
768#[async_trait]
774pub trait ErasedListenTool<Ctx>: Send + Sync {
775 fn name_str(&self) -> &str;
777 fn display_name(&self) -> &'static str;
779 fn description(&self) -> &'static str;
781 fn input_schema(&self) -> Value;
783 fn tier(&self) -> ToolTier;
785 fn listen_stream<'a>(
787 &'a self,
788 ctx: &'a ToolContext<Ctx>,
789 input: Value,
790 ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
791 async fn execute(
793 &self,
794 ctx: &ToolContext<Ctx>,
795 operation_id: &str,
796 expected_revision: u64,
797 ) -> Result<ToolResult>;
798 async fn cancel(
800 &self,
801 ctx: &ToolContext<Ctx>,
802 operation_id: &str,
803 reason: ListenStopReason,
804 ) -> Result<()>;
805}
806
807struct ListenToolWrapper<T, Ctx>
809where
810 T: ListenExecuteTool<Ctx>,
811{
812 inner: T,
813 name_cache: String,
814 _marker: PhantomData<Ctx>,
815}
816
817impl<T, Ctx> ListenToolWrapper<T, Ctx>
818where
819 T: ListenExecuteTool<Ctx>,
820{
821 fn new(tool: T) -> Self {
822 let name_cache = tool_name_to_string(&tool.name());
823 Self {
824 inner: tool,
825 name_cache,
826 _marker: PhantomData,
827 }
828 }
829}
830
831#[async_trait]
832impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
833where
834 T: ListenExecuteTool<Ctx> + 'static,
835 Ctx: Send + Sync + 'static,
836{
837 fn name_str(&self) -> &str {
838 &self.name_cache
839 }
840
841 fn display_name(&self) -> &'static str {
842 self.inner.display_name()
843 }
844
845 fn description(&self) -> &'static str {
846 self.inner.description()
847 }
848
849 fn input_schema(&self) -> Value {
850 self.inner.input_schema()
851 }
852
853 fn tier(&self) -> ToolTier {
854 self.inner.tier()
855 }
856
857 fn listen_stream<'a>(
858 &'a self,
859 ctx: &'a ToolContext<Ctx>,
860 input: Value,
861 ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
862 let stream = self.inner.listen(ctx, input);
863 Box::pin(stream)
864 }
865
866 async fn execute(
867 &self,
868 ctx: &ToolContext<Ctx>,
869 operation_id: &str,
870 expected_revision: u64,
871 ) -> Result<ToolResult> {
872 self.inner
873 .execute(ctx, operation_id, expected_revision)
874 .await
875 }
876
877 async fn cancel(
878 &self,
879 ctx: &ToolContext<Ctx>,
880 operation_id: &str,
881 reason: ListenStopReason,
882 ) -> Result<()> {
883 self.inner.cancel(ctx, operation_id, reason).await
884 }
885}
886
887pub struct ToolRegistry<Ctx> {
895 tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
896 async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
897 listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
898}
899
900impl<Ctx> Clone for ToolRegistry<Ctx> {
901 fn clone(&self) -> Self {
902 Self {
903 tools: self.tools.clone(),
904 async_tools: self.async_tools.clone(),
905 listen_tools: self.listen_tools.clone(),
906 }
907 }
908}
909
910impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
911 fn default() -> Self {
912 Self::new()
913 }
914}
915
916impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
917 #[must_use]
918 pub fn new() -> Self {
919 Self {
920 tools: HashMap::new(),
921 async_tools: HashMap::new(),
922 listen_tools: HashMap::new(),
923 }
924 }
925
926 pub fn register<T>(&mut self, tool: T) -> &mut Self
931 where
932 T: Tool<Ctx> + 'static,
933 {
934 let wrapper = ToolWrapper::new(tool);
935 let name = wrapper.name_str().to_string();
936 self.tools.insert(name, Arc::new(wrapper));
937 self
938 }
939
940 pub fn register_async<T>(&mut self, tool: T) -> &mut Self
945 where
946 T: AsyncTool<Ctx> + 'static,
947 {
948 let wrapper = AsyncToolWrapper::new(tool);
949 let name = wrapper.name_str().to_string();
950 self.async_tools.insert(name, Arc::new(wrapper));
951 self
952 }
953
954 pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
959 where
960 T: ListenExecuteTool<Ctx> + 'static,
961 {
962 let wrapper = ListenToolWrapper::new(tool);
963 let name = wrapper.name_str().to_string();
964 self.listen_tools.insert(name, Arc::new(wrapper));
965 self
966 }
967
968 #[must_use]
970 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
971 self.tools.get(name)
972 }
973
974 #[must_use]
976 pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
977 self.async_tools.get(name)
978 }
979
980 #[must_use]
982 pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
983 self.listen_tools.get(name)
984 }
985
986 #[must_use]
988 pub fn is_async(&self, name: &str) -> bool {
989 self.async_tools.contains_key(name)
990 }
991
992 #[must_use]
994 pub fn is_listen(&self, name: &str) -> bool {
995 self.listen_tools.contains_key(name)
996 }
997
998 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
1000 self.tools.values()
1001 }
1002
1003 pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
1005 self.async_tools.values()
1006 }
1007
1008 pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
1010 self.listen_tools.values()
1011 }
1012
1013 #[must_use]
1015 pub fn len(&self) -> usize {
1016 self.tools.len() + self.async_tools.len() + self.listen_tools.len()
1017 }
1018
1019 #[must_use]
1021 pub fn is_empty(&self) -> bool {
1022 self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
1023 }
1024
1025 pub fn filter<F>(&mut self, predicate: F)
1037 where
1038 F: Fn(&str) -> bool,
1039 {
1040 self.tools.retain(|name, _| predicate(name));
1041 self.async_tools.retain(|name, _| predicate(name));
1042 self.listen_tools.retain(|name, _| predicate(name));
1043 }
1044
1045 #[must_use]
1047 pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
1048 let mut tools: Vec<_> = self
1049 .tools
1050 .values()
1051 .map(|tool| llm::Tool {
1052 name: tool.name_str().to_string(),
1053 description: tool.description().to_string(),
1054 input_schema: tool.input_schema(),
1055 })
1056 .collect();
1057
1058 tools.extend(self.async_tools.values().map(|tool| llm::Tool {
1059 name: tool.name_str().to_string(),
1060 description: tool.description().to_string(),
1061 input_schema: tool.input_schema(),
1062 }));
1063
1064 tools.extend(self.listen_tools.values().map(|tool| llm::Tool {
1065 name: tool.name_str().to_string(),
1066 description: tool.description().to_string(),
1067 input_schema: tool.input_schema(),
1068 }));
1069
1070 tools
1071 }
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077
1078 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
1080 #[serde(rename_all = "snake_case")]
1081 enum TestToolName {
1082 MockTool,
1083 AnotherTool,
1084 }
1085
1086 impl ToolName for TestToolName {}
1087
1088 struct MockTool;
1089
1090 impl Tool<()> for MockTool {
1091 type Name = TestToolName;
1092
1093 fn name(&self) -> TestToolName {
1094 TestToolName::MockTool
1095 }
1096
1097 fn display_name(&self) -> &'static str {
1098 "Mock Tool"
1099 }
1100
1101 fn description(&self) -> &'static str {
1102 "A mock tool for testing"
1103 }
1104
1105 fn input_schema(&self) -> Value {
1106 serde_json::json!({
1107 "type": "object",
1108 "properties": {
1109 "message": { "type": "string" }
1110 }
1111 })
1112 }
1113
1114 async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
1115 let message = input
1116 .get("message")
1117 .and_then(|v| v.as_str())
1118 .unwrap_or("no message");
1119 Ok(ToolResult::success(format!("Received: {message}")))
1120 }
1121 }
1122
1123 #[test]
1124 fn test_tool_name_serialization() {
1125 let name = TestToolName::MockTool;
1126 assert_eq!(tool_name_to_string(&name), "mock_tool");
1127
1128 let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
1129 assert_eq!(parsed, TestToolName::MockTool);
1130 }
1131
1132 #[test]
1133 fn test_dynamic_tool_name() {
1134 let name = DynamicToolName::new("my_mcp_tool");
1135 assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
1136 assert_eq!(name.as_str(), "my_mcp_tool");
1137 }
1138
1139 #[test]
1140 fn test_tool_registry() {
1141 let mut registry = ToolRegistry::new();
1142 registry.register(MockTool);
1143
1144 assert_eq!(registry.len(), 1);
1145 assert!(registry.get("mock_tool").is_some());
1146 assert!(registry.get("nonexistent").is_none());
1147 }
1148
1149 #[test]
1150 fn test_to_llm_tools() {
1151 let mut registry = ToolRegistry::new();
1152 registry.register(MockTool);
1153
1154 let llm_tools = registry.to_llm_tools();
1155 assert_eq!(llm_tools.len(), 1);
1156 assert_eq!(llm_tools[0].name, "mock_tool");
1157 }
1158
1159 struct AnotherTool;
1160
1161 impl Tool<()> for AnotherTool {
1162 type Name = TestToolName;
1163
1164 fn name(&self) -> TestToolName {
1165 TestToolName::AnotherTool
1166 }
1167
1168 fn display_name(&self) -> &'static str {
1169 "Another Tool"
1170 }
1171
1172 fn description(&self) -> &'static str {
1173 "Another tool for testing"
1174 }
1175
1176 fn input_schema(&self) -> Value {
1177 serde_json::json!({ "type": "object" })
1178 }
1179
1180 async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
1181 Ok(ToolResult::success("Done"))
1182 }
1183 }
1184
1185 #[test]
1186 fn test_filter_tools() {
1187 let mut registry = ToolRegistry::new();
1188 registry.register(MockTool);
1189 registry.register(AnotherTool);
1190
1191 assert_eq!(registry.len(), 2);
1192
1193 registry.filter(|name| name != "mock_tool");
1195
1196 assert_eq!(registry.len(), 1);
1197 assert!(registry.get("mock_tool").is_none());
1198 assert!(registry.get("another_tool").is_some());
1199 }
1200
1201 #[test]
1202 fn test_filter_tools_keep_all() {
1203 let mut registry = ToolRegistry::new();
1204 registry.register(MockTool);
1205 registry.register(AnotherTool);
1206
1207 registry.filter(|_| true);
1208
1209 assert_eq!(registry.len(), 2);
1210 }
1211
1212 #[test]
1213 fn test_filter_tools_remove_all() {
1214 let mut registry = ToolRegistry::new();
1215 registry.register(MockTool);
1216 registry.register(AnotherTool);
1217
1218 registry.filter(|_| false);
1219
1220 assert!(registry.is_empty());
1221 }
1222
1223 #[test]
1224 fn test_display_name() {
1225 let mut registry = ToolRegistry::new();
1226 registry.register(MockTool);
1227
1228 let tool = registry.get("mock_tool").unwrap();
1229 assert_eq!(tool.display_name(), "Mock Tool");
1230 }
1231
1232 struct ListenMockTool;
1233
1234 impl ListenExecuteTool<()> for ListenMockTool {
1235 type Name = TestToolName;
1236
1237 fn name(&self) -> TestToolName {
1238 TestToolName::MockTool
1239 }
1240
1241 fn display_name(&self) -> &'static str {
1242 "Listen Mock Tool"
1243 }
1244
1245 fn description(&self) -> &'static str {
1246 "A listen/execute mock tool for testing"
1247 }
1248
1249 fn input_schema(&self) -> Value {
1250 serde_json::json!({ "type": "object" })
1251 }
1252
1253 fn listen(
1254 &self,
1255 _ctx: &ToolContext<()>,
1256 _input: Value,
1257 ) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
1258 futures::stream::iter(vec![ListenToolUpdate::Ready {
1259 operation_id: "op_1".to_string(),
1260 revision: 1,
1261 message: "ready".to_string(),
1262 snapshot: serde_json::json!({"ok": true}),
1263 expires_at: None,
1264 }])
1265 }
1266
1267 async fn execute(
1268 &self,
1269 _ctx: &ToolContext<()>,
1270 _operation_id: &str,
1271 _expected_revision: u64,
1272 ) -> Result<ToolResult> {
1273 Ok(ToolResult::success("Executed"))
1274 }
1275 }
1276
1277 #[test]
1278 fn test_listen_tool_registry() {
1279 let mut registry = ToolRegistry::new();
1280 registry.register_listen(ListenMockTool);
1281
1282 assert_eq!(registry.len(), 1);
1283 assert!(registry.get_listen("mock_tool").is_some());
1284 assert!(registry.is_listen("mock_tool"));
1285 }
1286}