Skip to main content

ash_rpc/
registry.rs

1//! Method registry for organizing and dispatching JSON-RPC methods.
2//!
3//! ## Usage
4//!
5//! ### Basic Usage (Runtime Dispatch)
6//! Create method implementations using the `JsonRPCMethod` trait:
7//!
8//! ```rust
9//! use ash_rpc::*;
10//!
11//! struct PingMethod;
12//!
13//! #[async_trait::async_trait]
14//! impl JsonRPCMethod for PingMethod {
15//!     fn method_name(&self) -> &'static str { "ping" }
16//!     
17//!     async fn call(
18//!         &self,
19//!         _params: Option<serde_json::Value>,
20//!         id: Option<RequestId>,
21//!     ) -> Response {
22//!         rpc_success!("pong", id)
23//!     }
24//! }
25//!
26//! let registry = MethodRegistry::new(register_methods![PingMethod]);
27//! ```
28//!
29//! ### Optimized Usage (Compile-time Dispatch)
30//! For better performance, use the `dispatch_call`! macro:
31//!
32//! ```text
33//! // In your handler function:
34//! async fn handle_call(method_name: &str, params: Option<serde_json::Value>, id: Option<RequestId>) -> Response {
35//!     dispatch_call!(method_name, params, id => PingMethod, EchoMethod, CalculatorMethod)
36//! }
37//! ```
38
39use crate::builders::{ErrorBuilder, ResponseBuilder};
40use crate::traits::{
41    Handler, JsonRPCMethod, MessageProcessor, OpenApiServer, OpenApiSpec, ProcessorCapabilities,
42};
43use crate::types::{Message, Notification, Request, RequestId, Response, error_codes};
44use std::sync::Arc;
45
46/// Method registry with optional authentication
47pub struct MethodRegistry {
48    methods: Vec<Box<dyn JsonRPCMethod>>,
49    auth_policy: Option<Arc<dyn crate::auth::AuthPolicy>>,
50}
51
52/// Macro to generate method dispatch match arms for registered `JsonRPCMethod` implementations
53#[macro_export]
54macro_rules! register_methods {
55    ($($method:expr),* $(,)?) => {
56        vec![
57            $(
58                Box::new($method) as Box<dyn JsonRPCMethod>
59            ),*
60        ]
61    };
62}
63
64/// Macro to generate a dispatch function with compile-time method matching
65/// This replaces runtime iteration with a compile-time generated match statement
66#[macro_export]
67macro_rules! dispatch_call {
68    ($method_name:expr, $params:expr, $id:expr => $($method:expr),* $(,)?) => {
69        {
70            // Create temporary instances for method name comparison
71            $(
72                let temp_method = $method;
73                if $method_name == temp_method.method_name() {
74                    return temp_method.call($params, $id).await;
75                }
76            )*
77
78            // Method not found
79            ResponseBuilder::new()
80                .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
81                .id($id)
82                .build()
83        }
84    };
85}
86
87impl MethodRegistry {
88    /// Create a new method registry with the given method implementations
89    pub fn new(methods: Vec<Box<dyn JsonRPCMethod>>) -> Self {
90        tracing::debug!(method_count = methods.len(), "registry created");
91        Self {
92            methods,
93            auth_policy: None,
94        }
95    }
96
97    /// Create an empty registry
98    #[must_use]
99    pub fn empty() -> Self {
100        Self {
101            methods: Vec::new(),
102            auth_policy: None,
103        }
104    }
105
106    /// Set an authentication/authorization policy
107    ///
108    /// When set, `can_access` will be checked before executing methods.
109    /// The user implements ALL auth logic in the trait.
110    ///
111    /// # Example
112    /// ```text
113    /// let registry = MethodRegistry::new(methods)
114    ///     .with_auth(MyAuthPolicy::new());
115    /// ```
116    #[must_use]
117    pub fn with_auth<A: crate::auth::AuthPolicy + 'static>(mut self, policy: A) -> Self {
118        self.auth_policy = Some(Arc::new(policy));
119        self
120    }
121
122    /// Add a method implementation to the registry
123    #[must_use]
124    pub fn add_method(mut self, method: Box<dyn JsonRPCMethod>) -> Self {
125        tracing::trace!("adding method to registry");
126        self.methods.push(method);
127        self
128    }
129
130    /// Call a registered method asynchronously using compile-time dispatch
131    /// Note: This method should typically be replaced by using the `dispatch_methods`! macro directly
132    /// for better compile-time optimization
133    pub async fn call(
134        &self,
135        method_name: &str,
136        params: Option<serde_json::Value>,
137        id: Option<RequestId>,
138    ) -> Response {
139        self.call_with_context(
140            method_name,
141            params,
142            id,
143            &crate::auth::ConnectionContext::default(),
144        )
145        .await
146    }
147
148    /// Call a registered method with authentication context
149    ///
150    /// Use this when you have connection context from your transport layer.
151    pub async fn call_with_context(
152        &self,
153        method_name: &str,
154        params: Option<serde_json::Value>,
155        id: Option<RequestId>,
156        ctx: &crate::auth::ConnectionContext,
157    ) -> Response {
158        // Check authentication if policy is set
159        if let Some(auth) = &self.auth_policy
160            && !auth.can_access(method_name, params.as_ref(), ctx)
161        {
162            tracing::warn!(
163                method = %method_name,
164                remote_addr = ?ctx.remote_addr,
165                "access denied by auth policy"
166            );
167            return auth.unauthorized_error(method_name);
168        }
169
170        // Fallback to runtime dispatch if compile-time dispatch is not used
171        for method in &self.methods {
172            if method.method_name() == method_name {
173                tracing::debug!(method = %method_name, "calling method");
174                return method.call(params, id).await;
175            }
176        }
177
178        tracing::warn!(method = %method_name, "method not found");
179        ResponseBuilder::new()
180            .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
181            .id(id)
182            .build()
183    }
184
185    /// Check if a method is registered
186    #[must_use]
187    pub fn has_method(&self, method_name: &str) -> bool {
188        self.methods.iter().any(|m| m.method_name() == method_name)
189    }
190
191    /// Get list of all registered methods
192    #[must_use]
193    pub fn get_methods(&self) -> Vec<String> {
194        self.methods
195            .iter()
196            .map(|m| m.method_name().to_owned())
197            .collect()
198    }
199
200    /// Get the number of registered methods
201    #[must_use]
202    pub fn method_count(&self) -> usize {
203        self.methods.len()
204    }
205
206    /// Generate `OpenAPI` specification for all registered methods
207    pub fn generate_openapi_spec(&self, title: &str, version: &str) -> OpenApiSpec {
208        tracing::debug!(method_count = self.methods.len(), "generating openapi spec");
209        let mut spec = OpenApiSpec::new(title, version);
210
211        for method in &self.methods {
212            let method_spec = method.openapi_components();
213            spec.add_method(method_spec);
214        }
215
216        spec
217    }
218
219    /// Generate `OpenAPI` specification with custom info and servers
220    #[must_use]
221    pub fn generate_openapi_spec_with_info(
222        &self,
223        title: &str,
224        version: &str,
225        description: Option<&str>,
226        servers: Vec<OpenApiServer>,
227    ) -> OpenApiSpec {
228        let mut spec = self.generate_openapi_spec(title, version);
229
230        if let Some(desc) = description {
231            spec.info.description = Some(desc.to_owned());
232        }
233
234        for server in servers {
235            spec.add_server(server);
236        }
237
238        spec
239    }
240
241    /// Export `OpenAPI` spec as JSON string
242    ///
243    /// # Errors
244    /// Returns error if JSON serialization fails
245    pub fn export_openapi_json(
246        &self,
247        title: &str,
248        version: &str,
249    ) -> Result<String, serde_json::Error> {
250        let spec = self.generate_openapi_spec(title, version);
251        serde_json::to_string_pretty(&spec)
252    }
253}
254
255impl Default for MethodRegistry {
256    fn default() -> Self {
257        Self::empty()
258    }
259}
260
261#[async_trait::async_trait]
262impl MessageProcessor for MethodRegistry {
263    async fn process_message(&self, message: Message) -> Option<Response> {
264        match message {
265            Message::Request(request) => {
266                tracing::trace!(method = %request.method, correlation_id = ?request.correlation_id, "processing request");
267                let response = self.call(&request.method, request.params, request.id).await;
268                Some(response)
269            }
270            Message::Notification(notification) => {
271                tracing::trace!(method = %notification.method, "processing notification");
272                let _ = self
273                    .call(&notification.method, notification.params, None)
274                    .await;
275                None
276            }
277            Message::Response(_) => None,
278        }
279    }
280
281    async fn process_batch(&self, messages: Vec<Message>) -> Vec<Response> {
282        let capabilities = self.get_capabilities();
283
284        // Validate batch size
285        if let Some(max_size) = capabilities.max_batch_size
286            && messages.len() > max_size
287        {
288            tracing::warn!(
289                batch_size = messages.len(),
290                max_batch_size = max_size,
291                "batch size limit exceeded"
292            );
293            return vec![crate::Response::error(
294                crate::ErrorBuilder::new(
295                    crate::error_codes::INVALID_REQUEST,
296                    format!("Batch size {} exceeds maximum {}", messages.len(), max_size),
297                )
298                .build(),
299                None,
300            )];
301        }
302
303        tracing::debug!(batch_size = messages.len(), "processing batch");
304        let mut results = Vec::new();
305        for msg in messages {
306            if let Some(response) = self.process_message(msg).await {
307                results.push(response);
308            }
309        }
310        results
311    }
312
313    fn get_capabilities(&self) -> ProcessorCapabilities {
314        ProcessorCapabilities {
315            supports_batch: true,
316            supports_notifications: true,
317            max_batch_size: Some(100),
318            max_request_size: Some(1024 * 1024), // 1 MB
319            request_timeout_secs: Some(30),
320            supported_versions: vec!["2.0".to_owned()],
321        }
322    }
323}
324
325#[async_trait::async_trait]
326impl Handler for MethodRegistry {
327    async fn handle_request(&self, request: Request) -> Response {
328        self.call(&request.method, request.params, request.id).await
329    }
330
331    async fn handle_notification(&self, notification: Notification) {
332        let _ = self
333            .call(&notification.method, notification.params, None)
334            .await;
335    }
336
337    fn supports_method(&self, method: &str) -> bool {
338        self.has_method(method)
339    }
340
341    fn get_supported_methods(&self) -> Vec<String> {
342        self.get_methods()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use serde_json::json;
350
351    // Test method implementation
352    struct TestMethod {
353        name: &'static str,
354    }
355
356    #[async_trait::async_trait]
357    impl JsonRPCMethod for TestMethod {
358        fn method_name(&self) -> &'static str {
359            self.name
360        }
361
362        async fn call(
363            &self,
364            _params: Option<serde_json::Value>,
365            id: Option<RequestId>,
366        ) -> Response {
367            ResponseBuilder::new()
368                .success(json!({"method": self.name}))
369                .id(id)
370                .build()
371        }
372    }
373
374    // Simple auth policy for testing
375    struct TestAuthPolicy {
376        allowed_methods: Vec<String>,
377    }
378
379    impl crate::auth::AuthPolicy for TestAuthPolicy {
380        fn can_access(
381            &self,
382            method: &str,
383            _params: Option<&serde_json::Value>,
384            _ctx: &crate::auth::ConnectionContext,
385        ) -> bool {
386            self.allowed_methods.contains(&method.to_string())
387        }
388
389        fn unauthorized_error(&self, method: &str) -> Response {
390            ResponseBuilder::new()
391                .error(
392                    ErrorBuilder::new(
393                        crate::error_codes::INTERNAL_ERROR,
394                        format!("Access denied for method '{}'", method),
395                    )
396                    .build(),
397                )
398                .build()
399        }
400    }
401
402    #[tokio::test]
403    async fn test_registry_without_auth() {
404        let registry = MethodRegistry::new(vec![Box::new(TestMethod {
405            name: "test_method",
406        })]);
407
408        let response = registry.call("test_method", None, Some(json!(1))).await;
409        assert!(response.result.is_some());
410        assert!(response.error.is_none());
411    }
412
413    #[tokio::test]
414    async fn test_registry_with_auth_allowed() {
415        let auth = TestAuthPolicy {
416            allowed_methods: vec!["allowed_method".to_string()],
417        };
418
419        let registry = MethodRegistry::new(vec![Box::new(TestMethod {
420            name: "allowed_method",
421        })])
422        .with_auth(auth);
423
424        let response = registry.call("allowed_method", None, Some(json!(1))).await;
425        assert!(response.result.is_some());
426        assert!(response.error.is_none());
427    }
428
429    #[tokio::test]
430    async fn test_registry_with_auth_denied() {
431        let auth = TestAuthPolicy {
432            allowed_methods: vec!["allowed_method".to_string()],
433        };
434
435        let registry = MethodRegistry::new(vec![Box::new(TestMethod {
436            name: "blocked_method",
437        })])
438        .with_auth(auth);
439
440        let response = registry.call("blocked_method", None, Some(json!(1))).await;
441        assert!(response.result.is_none());
442        assert!(response.error.is_some());
443
444        let error = response.error.unwrap();
445        assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
446        assert!(error.message.contains("Access denied"));
447    }
448
449    #[tokio::test]
450    async fn test_registry_allow_all() {
451        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
452            .with_auth(crate::auth::AllowAll);
453
454        let response = registry.call("any_method", None, Some(json!(1))).await;
455        assert!(response.result.is_some());
456        assert!(response.error.is_none());
457    }
458
459    #[tokio::test]
460    async fn test_registry_deny_all() {
461        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
462            .with_auth(crate::auth::DenyAll);
463
464        let response = registry.call("any_method", None, Some(json!(1))).await;
465        assert!(response.result.is_none());
466        assert!(response.error.is_some());
467    }
468
469    #[tokio::test]
470    async fn test_registry_empty() {
471        let registry = MethodRegistry::empty();
472        assert_eq!(registry.method_count(), 0);
473    }
474
475    #[tokio::test]
476    async fn test_registry_default() {
477        let registry = MethodRegistry::default();
478        assert_eq!(registry.method_count(), 0);
479    }
480
481    #[tokio::test]
482    async fn test_registry_add_method() {
483        let registry = MethodRegistry::empty()
484            .add_method(Box::new(TestMethod { name: "method1" }))
485            .add_method(Box::new(TestMethod { name: "method2" }));
486
487        assert_eq!(registry.method_count(), 2);
488        assert!(registry.has_method("method1"));
489        assert!(registry.has_method("method2"));
490    }
491
492    #[tokio::test]
493    async fn test_registry_has_method() {
494        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
495
496        assert!(registry.has_method("exists"));
497        assert!(!registry.has_method("not_exists"));
498    }
499
500    #[tokio::test]
501    async fn test_registry_get_methods() {
502        let registry = MethodRegistry::new(vec![
503            Box::new(TestMethod { name: "method1" }),
504            Box::new(TestMethod { name: "method2" }),
505            Box::new(TestMethod { name: "method3" }),
506        ]);
507
508        let methods = registry.get_methods();
509        assert_eq!(methods.len(), 3);
510        assert!(methods.contains(&"method1".to_string()));
511        assert!(methods.contains(&"method2".to_string()));
512        assert!(methods.contains(&"method3".to_string()));
513    }
514
515    #[tokio::test]
516    async fn test_registry_method_count() {
517        let registry = MethodRegistry::new(vec![
518            Box::new(TestMethod { name: "m1" }),
519            Box::new(TestMethod { name: "m2" }),
520        ]);
521
522        assert_eq!(registry.method_count(), 2);
523    }
524
525    #[tokio::test]
526    async fn test_registry_call_method_not_found() {
527        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
528
529        let response = registry.call("not_exists", None, Some(json!(1))).await;
530        assert!(response.error.is_some());
531        let error = response.error.unwrap();
532        assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
533    }
534
535    #[tokio::test]
536    async fn test_registry_call_with_params() {
537        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
538
539        let params = json!({"key": "value"});
540        let response = registry.call("test", Some(params), Some(json!(1))).await;
541        assert!(response.result.is_some());
542    }
543
544    #[tokio::test]
545    async fn test_registry_call_with_context() {
546        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
547
548        let ctx = crate::auth::ConnectionContext::default();
549        let response = registry
550            .call_with_context("test", None, Some(json!(1)), &ctx)
551            .await;
552        assert!(response.result.is_some());
553    }
554
555    #[tokio::test]
556    async fn test_registry_call_with_context_auth_denied() {
557        let auth = TestAuthPolicy {
558            allowed_methods: vec!["allowed".to_string()],
559        };
560
561        let registry =
562            MethodRegistry::new(vec![Box::new(TestMethod { name: "blocked" })]).with_auth(auth);
563
564        let ctx = crate::auth::ConnectionContext::default();
565        let response = registry
566            .call_with_context("blocked", None, Some(json!(1)), &ctx)
567            .await;
568        assert!(response.error.is_some());
569    }
570
571    #[tokio::test]
572    async fn test_registry_generate_openapi_spec() {
573        let registry = MethodRegistry::new(vec![
574            Box::new(TestMethod { name: "method1" }),
575            Box::new(TestMethod { name: "method2" }),
576        ]);
577
578        let spec = registry.generate_openapi_spec("Test API", "1.0.0");
579        assert_eq!(spec.info.title, "Test API");
580        assert_eq!(spec.info.version, "1.0.0");
581        assert_eq!(spec.methods.len(), 2);
582        assert!(spec.methods.contains_key("method1"));
583        assert!(spec.methods.contains_key("method2"));
584    }
585
586    #[tokio::test]
587    async fn test_registry_generate_openapi_spec_with_info() {
588        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
589
590        let servers = vec![OpenApiServer::new("http://localhost:8080")];
591
592        let spec = registry.generate_openapi_spec_with_info(
593            "API",
594            "2.0.0",
595            Some("Test description"),
596            servers,
597        );
598
599        assert_eq!(spec.info.title, "API");
600        assert_eq!(spec.info.version, "2.0.0");
601        assert_eq!(spec.info.description, Some("Test description".to_string()));
602        assert_eq!(spec.servers.len(), 1);
603        assert_eq!(spec.servers[0].url, "http://localhost:8080");
604    }
605
606    #[tokio::test]
607    async fn test_registry_export_openapi_json() {
608        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
609
610        let json_str = registry.export_openapi_json("API", "1.0").unwrap();
611        assert!(json_str.contains("\"title\": \"API\""));
612        assert!(json_str.contains("\"version\": \"1.0\""));
613        assert!(json_str.contains("\"openapi\": \"3.0.3\""));
614    }
615
616    #[tokio::test]
617    async fn test_registry_message_processor_request() {
618        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
619
620        let request = Request {
621            jsonrpc: "2.0".to_string(),
622            method: "test".to_string(),
623            params: None,
624            id: Some(json!(1)),
625            correlation_id: None,
626        };
627
628        let response = registry.process_message(Message::Request(request)).await;
629        assert!(response.is_some());
630        assert!(response.unwrap().result.is_some());
631    }
632
633    #[tokio::test]
634    async fn test_registry_message_processor_notification() {
635        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
636
637        let notification = Notification {
638            jsonrpc: "2.0".to_string(),
639            method: "test".to_string(),
640            params: None,
641        };
642
643        let response = registry
644            .process_message(Message::Notification(notification))
645            .await;
646        assert!(response.is_none());
647    }
648
649    #[tokio::test]
650    async fn test_registry_message_processor_response() {
651        let registry = MethodRegistry::new(vec![]);
652
653        let response_msg = Response {
654            jsonrpc: "2.0".to_string(),
655            result: Some(json!(42)),
656            error: None,
657            id: Some(json!(1)),
658            correlation_id: None,
659        };
660
661        let response = registry
662            .process_message(Message::Response(response_msg))
663            .await;
664        assert!(response.is_none());
665    }
666
667    #[tokio::test]
668    async fn test_registry_process_batch() {
669        let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
670
671        let messages = vec![
672            Message::Request(Request {
673                jsonrpc: "2.0".to_string(),
674                method: "test".to_string(),
675                params: None,
676                id: Some(json!(1)),
677                correlation_id: None,
678            }),
679            Message::Request(Request {
680                jsonrpc: "2.0".to_string(),
681                method: "test".to_string(),
682                params: None,
683                id: Some(json!(2)),
684                correlation_id: None,
685            }),
686        ];
687
688        let responses = registry.process_batch(messages).await;
689        assert_eq!(responses.len(), 2);
690    }
691
692    #[test]
693    fn test_register_methods_macro() {
694        let methods = register_methods![TestMethod { name: "m1" }, TestMethod { name: "m2" },];
695        assert_eq!(methods.len(), 2);
696    }
697}