Skip to main content

ash_rpc/
registry.rs

1//! Method registry for organizing and dispatching JSON-RPC methods.
2//!
3//! ## Usage
4//!
5//! ### Basic Usage (Runtime Dispatch)
6//! Create method implementations using the `JsonRPCMethod` trait:
7//!
8//! ```rust
9//! use ash_rpc::*;
10//!
11//! struct PingMethod;
12//!
13//! #[async_trait::async_trait]
14//! impl JsonRPCMethod for PingMethod {
15//!     fn method_name(&self) -> &'static str { "ping" }
16//!     
17//!     async fn call(
18//!         &self,
19//!         _params: Option<serde_json::Value>,
20//!         id: Option<RequestId>,
21//!     ) -> Response {
22//!         rpc_success!("pong", id)
23//!     }
24//! }
25//!
26//! let registry = MethodRegistry::new(register_methods![PingMethod]);
27//! ```
28//!
29//! ### Optimized Usage (Compile-time Dispatch)
30//! For better performance, use the dispatch_call! macro:
31//!
32//! ```text
33//! // In your handler function:
34//! async fn handle_call(method_name: &str, params: Option<serde_json::Value>, id: Option<RequestId>) -> Response {
35//!     dispatch_call!(method_name, params, id => PingMethod, EchoMethod, CalculatorMethod)
36//! }
37//! ```
38
39use crate::builders::*;
40use crate::traits::*;
41use crate::types::*;
42use std::sync::Arc;
43
44/// Method registry with optional authentication
45pub struct MethodRegistry {
46    methods: Vec<Box<dyn JsonRPCMethod>>,
47    auth_policy: Option<Arc<dyn crate::auth::AuthPolicy>>,
48}
49
50/// Macro to generate method dispatch match arms for registered JsonRPCMethod implementations
51#[macro_export]
52macro_rules! register_methods {
53    ($($method:expr),* $(,)?) => {
54        vec![
55            $(
56                Box::new($method) as Box<dyn JsonRPCMethod>
57            ),*
58        ]
59    };
60}
61
62/// Macro to generate a dispatch function with compile-time method matching
63/// This replaces runtime iteration with a compile-time generated match statement
64#[macro_export]
65macro_rules! dispatch_call {
66    ($method_name:expr, $params:expr, $id:expr => $($method:expr),* $(,)?) => {
67        {
68            // Create temporary instances for method name comparison
69            $(
70                let temp_method = $method;
71                if $method_name == temp_method.method_name() {
72                    return temp_method.call($params, $id).await;
73                }
74            )*
75
76            // Method not found
77            ResponseBuilder::new()
78                .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
79                .id($id)
80                .build()
81        }
82    };
83}
84
85impl MethodRegistry {
86    /// Create a new method registry with the given method implementations
87    pub fn new(methods: Vec<Box<dyn JsonRPCMethod>>) -> Self {
88        tracing::debug!(method_count = methods.len(), "registry created");
89        Self {
90            methods,
91            auth_policy: None,
92        }
93    }
94
95    /// Create an empty registry
96    pub fn empty() -> Self {
97        Self {
98            methods: Vec::new(),
99            auth_policy: None,
100        }
101    }
102
103    /// Set an authentication/authorization policy
104    ///
105    /// When set, `can_access` will be checked before executing methods.
106    /// The user implements ALL auth logic in the trait.
107    ///
108    /// # Example
109    /// ```text
110    /// let registry = MethodRegistry::new(methods)
111    ///     .with_auth(MyAuthPolicy::new());
112    /// ```
113    pub fn with_auth<A: crate::auth::AuthPolicy + 'static>(mut self, policy: A) -> Self {
114        self.auth_policy = Some(Arc::new(policy));
115        self
116    }
117
118    /// Add a method implementation to the registry
119    pub fn add_method(mut self, method: Box<dyn JsonRPCMethod>) -> Self {
120        tracing::trace!("adding method to registry");
121        self.methods.push(method);
122        self
123    }
124
125    /// Call a registered method asynchronously using compile-time dispatch
126    /// Note: This method should typically be replaced by using the dispatch_methods! macro directly
127    /// for better compile-time optimization
128    pub async fn call(
129        &self,
130        method_name: &str,
131        params: Option<serde_json::Value>,
132        id: Option<RequestId>,
133    ) -> Response {
134        self.call_with_context(
135            method_name,
136            params,
137            id,
138            &crate::auth::ConnectionContext::default(),
139        )
140        .await
141    }
142
143    /// Call a registered method with authentication context
144    ///
145    /// Use this when you have connection context from your transport layer.
146    pub async fn call_with_context(
147        &self,
148        method_name: &str,
149        params: Option<serde_json::Value>,
150        id: Option<RequestId>,
151        ctx: &crate::auth::ConnectionContext,
152    ) -> Response {
153        // Check authentication if policy is set
154        if let Some(auth) = &self.auth_policy
155            && !auth.can_access(method_name, params.as_ref(), ctx)
156        {
157            tracing::warn!(
158                method = %method_name,
159                remote_addr = ?ctx.remote_addr,
160                "access denied by auth policy"
161            );
162            return auth.unauthorized_error(method_name);
163        }
164
165        // Fallback to runtime dispatch if compile-time dispatch is not used
166        for method in &self.methods {
167            if method.method_name() == method_name {
168                tracing::debug!(method = %method_name, "calling method");
169                return method.call(params, id).await;
170            }
171        }
172
173        tracing::warn!(method = %method_name, "method not found");
174        ResponseBuilder::new()
175            .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
176            .id(id)
177            .build()
178    }
179
180    /// Check if a method is registered
181    pub fn has_method(&self, method_name: &str) -> bool {
182        self.methods.iter().any(|m| m.method_name() == method_name)
183    }
184
185    /// Get list of all registered methods
186    pub fn get_methods(&self) -> Vec<String> {
187        self.methods
188            .iter()
189            .map(|m| m.method_name().to_string())
190            .collect()
191    }
192
193    /// Get the number of registered methods
194    pub fn method_count(&self) -> usize {
195        self.methods.len()
196    }
197
198    /// Generate OpenAPI specification for all registered methods
199    pub fn generate_openapi_spec(&self, title: &str, version: &str) -> OpenApiSpec {
200        tracing::debug!(method_count = self.methods.len(), "generating openapi spec");
201        let mut spec = OpenApiSpec::new(title, version);
202
203        for method in &self.methods {
204            let method_spec = method.openapi_components();
205            spec.add_method(method_spec);
206        }
207
208        spec
209    }
210
211    /// Generate OpenAPI specification with custom info and servers
212    pub fn generate_openapi_spec_with_info(
213        &self,
214        title: &str,
215        version: &str,
216        description: Option<&str>,
217        servers: Vec<OpenApiServer>,
218    ) -> OpenApiSpec {
219        let mut spec = self.generate_openapi_spec(title, version);
220
221        if let Some(desc) = description {
222            spec.info.description = Some(desc.to_string());
223        }
224
225        for server in servers {
226            spec.add_server(server);
227        }
228
229        spec
230    }
231
232    /// Export OpenAPI spec as JSON string
233    pub fn export_openapi_json(
234        &self,
235        title: &str,
236        version: &str,
237    ) -> Result<String, serde_json::Error> {
238        let spec = self.generate_openapi_spec(title, version);
239        serde_json::to_string_pretty(&spec)
240    }
241}
242
243impl Default for MethodRegistry {
244    fn default() -> Self {
245        Self::empty()
246    }
247}
248
249#[async_trait::async_trait]
250impl MessageProcessor for MethodRegistry {
251    async fn process_message(&self, message: Message) -> Option<Response> {
252        match message {
253            Message::Request(request) => {
254                tracing::trace!(method = %request.method, correlation_id = ?request.correlation_id, "processing request");
255                let response = self.call(&request.method, request.params, request.id).await;
256                Some(response)
257            }
258            Message::Notification(notification) => {
259                tracing::trace!(method = %notification.method, "processing notification");
260                let _ = self
261                    .call(&notification.method, notification.params, None)
262                    .await;
263                None
264            }
265            Message::Response(_) => None,
266        }
267    }
268
269    async fn process_batch(&self, messages: Vec<Message>) -> Vec<Response> {
270        let capabilities = self.get_capabilities();
271
272        // Validate batch size
273        if let Some(max_size) = capabilities.max_batch_size
274            && messages.len() > max_size
275        {
276            tracing::warn!(
277                batch_size = messages.len(),
278                max_batch_size = max_size,
279                "batch size limit exceeded"
280            );
281            return vec![crate::Response::error(
282                crate::ErrorBuilder::new(
283                    crate::error_codes::INVALID_REQUEST,
284                    format!("Batch size {} exceeds maximum {}", messages.len(), max_size),
285                )
286                .build(),
287                None,
288            )];
289        }
290
291        tracing::debug!(batch_size = messages.len(), "processing batch");
292        let mut results = Vec::new();
293        for msg in messages {
294            if let Some(response) = self.process_message(msg).await {
295                results.push(response);
296            }
297        }
298        results
299    }
300
301    fn get_capabilities(&self) -> ProcessorCapabilities {
302        ProcessorCapabilities {
303            supports_batch: true,
304            supports_notifications: true,
305            max_batch_size: Some(100),
306            max_request_size: Some(1024 * 1024), // 1 MB
307            request_timeout_secs: Some(30),
308            supported_versions: vec!["2.0".to_string()],
309        }
310    }
311}
312
313#[async_trait::async_trait]
314impl Handler for MethodRegistry {
315    async fn handle_request(&self, request: Request) -> Response {
316        self.call(&request.method, request.params, request.id).await
317    }
318
319    async fn handle_notification(&self, notification: Notification) {
320        let _ = self
321            .call(&notification.method, notification.params, None)
322            .await;
323    }
324
325    fn supports_method(&self, method: &str) -> bool {
326        self.has_method(method)
327    }
328
329    fn get_supported_methods(&self) -> Vec<String> {
330        self.get_methods()
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use serde_json::json;
338
339    // Test method implementation
340    struct TestMethod {
341        name: &'static str,
342    }
343
344    #[async_trait::async_trait]
345    impl JsonRPCMethod for TestMethod {
346        fn method_name(&self) -> &'static str {
347            self.name
348        }
349
350        async fn call(
351            &self,
352            _params: Option<serde_json::Value>,
353            id: Option<RequestId>,
354        ) -> Response {
355            ResponseBuilder::new()
356                .success(json!({"method": self.name}))
357                .id(id)
358                .build()
359        }
360    }
361
362    // Simple auth policy for testing
363    struct TestAuthPolicy {
364        allowed_methods: Vec<String>,
365    }
366
367    impl crate::auth::AuthPolicy for TestAuthPolicy {
368        fn can_access(
369            &self,
370            method: &str,
371            _params: Option<&serde_json::Value>,
372            _ctx: &crate::auth::ConnectionContext,
373        ) -> bool {
374            self.allowed_methods.contains(&method.to_string())
375        }
376
377        fn unauthorized_error(&self, method: &str) -> Response {
378            ResponseBuilder::new()
379                .error(
380                    ErrorBuilder::new(
381                        crate::error_codes::INTERNAL_ERROR,
382                        format!("Access denied for method '{}'", method),
383                    )
384                    .build(),
385                )
386                .build()
387        }
388    }
389
390    #[tokio::test]
391    async fn test_registry_without_auth() {
392        let registry = MethodRegistry::new(vec![Box::new(TestMethod {
393            name: "test_method",
394        })]);
395
396        let response = registry.call("test_method", None, Some(json!(1))).await;
397        assert!(response.result.is_some());
398        assert!(response.error.is_none());
399    }
400
401    #[tokio::test]
402    async fn test_registry_with_auth_allowed() {
403        let auth = TestAuthPolicy {
404            allowed_methods: vec!["allowed_method".to_string()],
405        };
406
407        let registry = MethodRegistry::new(vec![Box::new(TestMethod {
408            name: "allowed_method",
409        })])
410        .with_auth(auth);
411
412        let response = registry.call("allowed_method", None, Some(json!(1))).await;
413        assert!(response.result.is_some());
414        assert!(response.error.is_none());
415    }
416
417    #[tokio::test]
418    async fn test_registry_with_auth_denied() {
419        let auth = TestAuthPolicy {
420            allowed_methods: vec!["allowed_method".to_string()],
421        };
422
423        let registry = MethodRegistry::new(vec![Box::new(TestMethod {
424            name: "blocked_method",
425        })])
426        .with_auth(auth);
427
428        let response = registry.call("blocked_method", None, Some(json!(1))).await;
429        assert!(response.result.is_none());
430        assert!(response.error.is_some());
431
432        let error = response.error.unwrap();
433        assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
434        assert!(error.message.contains("Access denied"));
435    }
436
437    #[tokio::test]
438    async fn test_registry_allow_all() {
439        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
440            .with_auth(crate::auth::AllowAll);
441
442        let response = registry.call("any_method", None, Some(json!(1))).await;
443        assert!(response.result.is_some());
444        assert!(response.error.is_none());
445    }
446
447    #[tokio::test]
448    async fn test_registry_deny_all() {
449        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
450            .with_auth(crate::auth::DenyAll);
451
452        let response = registry.call("any_method", None, Some(json!(1))).await;
453        assert!(response.result.is_none());
454        assert!(response.error.is_some());
455    }
456
457    #[tokio::test]
458    async fn test_registry_empty() {
459        let registry = MethodRegistry::empty();
460        assert_eq!(registry.method_count(), 0);
461    }
462
463    #[tokio::test]
464    async fn test_registry_default() {
465        let registry = MethodRegistry::default();
466        assert_eq!(registry.method_count(), 0);
467    }
468
469    #[tokio::test]
470    async fn test_registry_add_method() {
471        let registry = MethodRegistry::empty()
472            .add_method(Box::new(TestMethod { name: "method1" }))
473            .add_method(Box::new(TestMethod { name: "method2" }));
474
475        assert_eq!(registry.method_count(), 2);
476        assert!(registry.has_method("method1"));
477        assert!(registry.has_method("method2"));
478    }
479
480    #[tokio::test]
481    async fn test_registry_has_method() {
482        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
483
484        assert!(registry.has_method("exists"));
485        assert!(!registry.has_method("not_exists"));
486    }
487
488    #[tokio::test]
489    async fn test_registry_get_methods() {
490        let registry = MethodRegistry::new(vec![
491            Box::new(TestMethod { name: "method1" }),
492            Box::new(TestMethod { name: "method2" }),
493            Box::new(TestMethod { name: "method3" }),
494        ]);
495
496        let methods = registry.get_methods();
497        assert_eq!(methods.len(), 3);
498        assert!(methods.contains(&"method1".to_string()));
499        assert!(methods.contains(&"method2".to_string()));
500        assert!(methods.contains(&"method3".to_string()));
501    }
502
503    #[tokio::test]
504    async fn test_registry_method_count() {
505        let registry = MethodRegistry::new(vec![
506            Box::new(TestMethod { name: "m1" }),
507            Box::new(TestMethod { name: "m2" }),
508        ]);
509
510        assert_eq!(registry.method_count(), 2);
511    }
512
513    #[tokio::test]
514    async fn test_registry_call_method_not_found() {
515        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
516
517        let response = registry.call("not_exists", None, Some(json!(1))).await;
518        assert!(response.error.is_some());
519        let error = response.error.unwrap();
520        assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
521    }
522
523    #[tokio::test]
524    async fn test_registry_call_with_params() {
525        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
526
527        let params = json!({"key": "value"});
528        let response = registry.call("test", Some(params), Some(json!(1))).await;
529        assert!(response.result.is_some());
530    }
531
532    #[tokio::test]
533    async fn test_registry_call_with_context() {
534        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
535
536        let ctx = crate::auth::ConnectionContext::default();
537        let response = registry
538            .call_with_context("test", None, Some(json!(1)), &ctx)
539            .await;
540        assert!(response.result.is_some());
541    }
542
543    #[tokio::test]
544    async fn test_registry_call_with_context_auth_denied() {
545        let auth = TestAuthPolicy {
546            allowed_methods: vec!["allowed".to_string()],
547        };
548
549        let registry =
550            MethodRegistry::new(vec![Box::new(TestMethod { name: "blocked" })]).with_auth(auth);
551
552        let ctx = crate::auth::ConnectionContext::default();
553        let response = registry
554            .call_with_context("blocked", None, Some(json!(1)), &ctx)
555            .await;
556        assert!(response.error.is_some());
557    }
558
559    #[tokio::test]
560    async fn test_registry_generate_openapi_spec() {
561        let registry = MethodRegistry::new(vec![
562            Box::new(TestMethod { name: "method1" }),
563            Box::new(TestMethod { name: "method2" }),
564        ]);
565
566        let spec = registry.generate_openapi_spec("Test API", "1.0.0");
567        assert_eq!(spec.info.title, "Test API");
568        assert_eq!(spec.info.version, "1.0.0");
569        assert_eq!(spec.methods.len(), 2);
570        assert!(spec.methods.contains_key("method1"));
571        assert!(spec.methods.contains_key("method2"));
572    }
573
574    #[tokio::test]
575    async fn test_registry_generate_openapi_spec_with_info() {
576        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
577
578        let servers = vec![OpenApiServer::new("http://localhost:8080")];
579
580        let spec = registry.generate_openapi_spec_with_info(
581            "API",
582            "2.0.0",
583            Some("Test description"),
584            servers,
585        );
586
587        assert_eq!(spec.info.title, "API");
588        assert_eq!(spec.info.version, "2.0.0");
589        assert_eq!(spec.info.description, Some("Test description".to_string()));
590        assert_eq!(spec.servers.len(), 1);
591        assert_eq!(spec.servers[0].url, "http://localhost:8080");
592    }
593
594    #[tokio::test]
595    async fn test_registry_export_openapi_json() {
596        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
597
598        let json_str = registry.export_openapi_json("API", "1.0").unwrap();
599        assert!(json_str.contains("\"title\": \"API\""));
600        assert!(json_str.contains("\"version\": \"1.0\""));
601        assert!(json_str.contains("\"openapi\": \"3.0.3\""));
602    }
603
604    #[tokio::test]
605    async fn test_registry_message_processor_request() {
606        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
607
608        let request = Request {
609            jsonrpc: "2.0".to_string(),
610            method: "test".to_string(),
611            params: None,
612            id: Some(json!(1)),
613            correlation_id: None,
614        };
615
616        let response = registry.process_message(Message::Request(request)).await;
617        assert!(response.is_some());
618        assert!(response.unwrap().result.is_some());
619    }
620
621    #[tokio::test]
622    async fn test_registry_message_processor_notification() {
623        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
624
625        let notification = Notification {
626            jsonrpc: "2.0".to_string(),
627            method: "test".to_string(),
628            params: None,
629        };
630
631        let response = registry
632            .process_message(Message::Notification(notification))
633            .await;
634        assert!(response.is_none());
635    }
636
637    #[tokio::test]
638    async fn test_registry_message_processor_response() {
639        let registry = MethodRegistry::new(vec![]);
640
641        let response_msg = Response {
642            jsonrpc: "2.0".to_string(),
643            result: Some(json!(42)),
644            error: None,
645            id: Some(json!(1)),
646            correlation_id: None,
647        };
648
649        let response = registry
650            .process_message(Message::Response(response_msg))
651            .await;
652        assert!(response.is_none());
653    }
654
655    #[tokio::test]
656    async fn test_registry_process_batch() {
657        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
658
659        let messages = vec![
660            Message::Request(Request {
661                jsonrpc: "2.0".to_string(),
662                method: "test".to_string(),
663                params: None,
664                id: Some(json!(1)),
665                correlation_id: None,
666            }),
667            Message::Request(Request {
668                jsonrpc: "2.0".to_string(),
669                method: "test".to_string(),
670                params: None,
671                id: Some(json!(2)),
672                correlation_id: None,
673            }),
674        ];
675
676        let responses = registry.process_batch(messages).await;
677        assert_eq!(responses.len(), 2);
678    }
679
680    #[test]
681    fn test_register_methods_macro() {
682        let methods = register_methods![TestMethod { name: "m1" }, TestMethod { name: "m2" },];
683        assert_eq!(methods.len(), 2);
684    }
685}