1use crate::builders::*;
40use crate::traits::*;
41use crate::types::*;
42use std::sync::Arc;
43
44pub struct MethodRegistry {
46 methods: Vec<Box<dyn JsonRPCMethod>>,
47 auth_policy: Option<Arc<dyn crate::auth::AuthPolicy>>,
48}
49
50#[macro_export]
52macro_rules! register_methods {
53 ($($method:expr),* $(,)?) => {
54 vec![
55 $(
56 Box::new($method) as Box<dyn JsonRPCMethod>
57 ),*
58 ]
59 };
60}
61
62#[macro_export]
65macro_rules! dispatch_call {
66 ($method_name:expr, $params:expr, $id:expr => $($method:expr),* $(,)?) => {
67 {
68 $(
70 let temp_method = $method;
71 if $method_name == temp_method.method_name() {
72 return temp_method.call($params, $id).await;
73 }
74 )*
75
76 ResponseBuilder::new()
78 .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
79 .id($id)
80 .build()
81 }
82 };
83}
84
85impl MethodRegistry {
86 pub fn new(methods: Vec<Box<dyn JsonRPCMethod>>) -> Self {
88 tracing::debug!(method_count = methods.len(), "registry created");
89 Self {
90 methods,
91 auth_policy: None,
92 }
93 }
94
95 pub fn empty() -> Self {
97 Self {
98 methods: Vec::new(),
99 auth_policy: None,
100 }
101 }
102
103 pub fn with_auth<A: crate::auth::AuthPolicy + 'static>(mut self, policy: A) -> Self {
114 self.auth_policy = Some(Arc::new(policy));
115 self
116 }
117
118 pub fn add_method(mut self, method: Box<dyn JsonRPCMethod>) -> Self {
120 tracing::trace!("adding method to registry");
121 self.methods.push(method);
122 self
123 }
124
125 pub async fn call(
129 &self,
130 method_name: &str,
131 params: Option<serde_json::Value>,
132 id: Option<RequestId>,
133 ) -> Response {
134 self.call_with_context(
135 method_name,
136 params,
137 id,
138 &crate::auth::ConnectionContext::default(),
139 )
140 .await
141 }
142
143 pub async fn call_with_context(
147 &self,
148 method_name: &str,
149 params: Option<serde_json::Value>,
150 id: Option<RequestId>,
151 ctx: &crate::auth::ConnectionContext,
152 ) -> Response {
153 if let Some(auth) = &self.auth_policy
155 && !auth.can_access(method_name, params.as_ref(), ctx)
156 {
157 tracing::warn!(
158 method = %method_name,
159 remote_addr = ?ctx.remote_addr,
160 "access denied by auth policy"
161 );
162 return auth.unauthorized_error(method_name);
163 }
164
165 for method in &self.methods {
167 if method.method_name() == method_name {
168 tracing::debug!(method = %method_name, "calling method");
169 return method.call(params, id).await;
170 }
171 }
172
173 tracing::warn!(method = %method_name, "method not found");
174 ResponseBuilder::new()
175 .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
176 .id(id)
177 .build()
178 }
179
180 pub fn has_method(&self, method_name: &str) -> bool {
182 self.methods.iter().any(|m| m.method_name() == method_name)
183 }
184
185 pub fn get_methods(&self) -> Vec<String> {
187 self.methods
188 .iter()
189 .map(|m| m.method_name().to_string())
190 .collect()
191 }
192
193 pub fn method_count(&self) -> usize {
195 self.methods.len()
196 }
197
198 pub fn generate_openapi_spec(&self, title: &str, version: &str) -> OpenApiSpec {
200 tracing::debug!(method_count = self.methods.len(), "generating openapi spec");
201 let mut spec = OpenApiSpec::new(title, version);
202
203 for method in &self.methods {
204 let method_spec = method.openapi_components();
205 spec.add_method(method_spec);
206 }
207
208 spec
209 }
210
211 pub fn generate_openapi_spec_with_info(
213 &self,
214 title: &str,
215 version: &str,
216 description: Option<&str>,
217 servers: Vec<OpenApiServer>,
218 ) -> OpenApiSpec {
219 let mut spec = self.generate_openapi_spec(title, version);
220
221 if let Some(desc) = description {
222 spec.info.description = Some(desc.to_string());
223 }
224
225 for server in servers {
226 spec.add_server(server);
227 }
228
229 spec
230 }
231
232 pub fn export_openapi_json(
234 &self,
235 title: &str,
236 version: &str,
237 ) -> Result<String, serde_json::Error> {
238 let spec = self.generate_openapi_spec(title, version);
239 serde_json::to_string_pretty(&spec)
240 }
241}
242
243impl Default for MethodRegistry {
244 fn default() -> Self {
245 Self::empty()
246 }
247}
248
249#[async_trait::async_trait]
250impl MessageProcessor for MethodRegistry {
251 async fn process_message(&self, message: Message) -> Option<Response> {
252 match message {
253 Message::Request(request) => {
254 tracing::trace!(method = %request.method, correlation_id = ?request.correlation_id, "processing request");
255 let response = self.call(&request.method, request.params, request.id).await;
256 Some(response)
257 }
258 Message::Notification(notification) => {
259 tracing::trace!(method = %notification.method, "processing notification");
260 let _ = self
261 .call(¬ification.method, notification.params, None)
262 .await;
263 None
264 }
265 Message::Response(_) => None,
266 }
267 }
268
269 async fn process_batch(&self, messages: Vec<Message>) -> Vec<Response> {
270 let capabilities = self.get_capabilities();
271
272 if let Some(max_size) = capabilities.max_batch_size
274 && messages.len() > max_size
275 {
276 tracing::warn!(
277 batch_size = messages.len(),
278 max_batch_size = max_size,
279 "batch size limit exceeded"
280 );
281 return vec![crate::Response::error(
282 crate::ErrorBuilder::new(
283 crate::error_codes::INVALID_REQUEST,
284 format!("Batch size {} exceeds maximum {}", messages.len(), max_size),
285 )
286 .build(),
287 None,
288 )];
289 }
290
291 tracing::debug!(batch_size = messages.len(), "processing batch");
292 let mut results = Vec::new();
293 for msg in messages {
294 if let Some(response) = self.process_message(msg).await {
295 results.push(response);
296 }
297 }
298 results
299 }
300
301 fn get_capabilities(&self) -> ProcessorCapabilities {
302 ProcessorCapabilities {
303 supports_batch: true,
304 supports_notifications: true,
305 max_batch_size: Some(100),
306 max_request_size: Some(1024 * 1024), request_timeout_secs: Some(30),
308 supported_versions: vec!["2.0".to_string()],
309 }
310 }
311}
312
313#[async_trait::async_trait]
314impl Handler for MethodRegistry {
315 async fn handle_request(&self, request: Request) -> Response {
316 self.call(&request.method, request.params, request.id).await
317 }
318
319 async fn handle_notification(&self, notification: Notification) {
320 let _ = self
321 .call(¬ification.method, notification.params, None)
322 .await;
323 }
324
325 fn supports_method(&self, method: &str) -> bool {
326 self.has_method(method)
327 }
328
329 fn get_supported_methods(&self) -> Vec<String> {
330 self.get_methods()
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use serde_json::json;
338
339 struct TestMethod {
341 name: &'static str,
342 }
343
344 #[async_trait::async_trait]
345 impl JsonRPCMethod for TestMethod {
346 fn method_name(&self) -> &'static str {
347 self.name
348 }
349
350 async fn call(
351 &self,
352 _params: Option<serde_json::Value>,
353 id: Option<RequestId>,
354 ) -> Response {
355 ResponseBuilder::new()
356 .success(json!({"method": self.name}))
357 .id(id)
358 .build()
359 }
360 }
361
362 struct TestAuthPolicy {
364 allowed_methods: Vec<String>,
365 }
366
367 impl crate::auth::AuthPolicy for TestAuthPolicy {
368 fn can_access(
369 &self,
370 method: &str,
371 _params: Option<&serde_json::Value>,
372 _ctx: &crate::auth::ConnectionContext,
373 ) -> bool {
374 self.allowed_methods.contains(&method.to_string())
375 }
376
377 fn unauthorized_error(&self, method: &str) -> Response {
378 ResponseBuilder::new()
379 .error(
380 ErrorBuilder::new(
381 crate::error_codes::INTERNAL_ERROR,
382 format!("Access denied for method '{}'", method),
383 )
384 .build(),
385 )
386 .build()
387 }
388 }
389
390 #[tokio::test]
391 async fn test_registry_without_auth() {
392 let registry = MethodRegistry::new(vec![Box::new(TestMethod {
393 name: "test_method",
394 })]);
395
396 let response = registry.call("test_method", None, Some(json!(1))).await;
397 assert!(response.result.is_some());
398 assert!(response.error.is_none());
399 }
400
401 #[tokio::test]
402 async fn test_registry_with_auth_allowed() {
403 let auth = TestAuthPolicy {
404 allowed_methods: vec!["allowed_method".to_string()],
405 };
406
407 let registry = MethodRegistry::new(vec![Box::new(TestMethod {
408 name: "allowed_method",
409 })])
410 .with_auth(auth);
411
412 let response = registry.call("allowed_method", None, Some(json!(1))).await;
413 assert!(response.result.is_some());
414 assert!(response.error.is_none());
415 }
416
417 #[tokio::test]
418 async fn test_registry_with_auth_denied() {
419 let auth = TestAuthPolicy {
420 allowed_methods: vec!["allowed_method".to_string()],
421 };
422
423 let registry = MethodRegistry::new(vec![Box::new(TestMethod {
424 name: "blocked_method",
425 })])
426 .with_auth(auth);
427
428 let response = registry.call("blocked_method", None, Some(json!(1))).await;
429 assert!(response.result.is_none());
430 assert!(response.error.is_some());
431
432 let error = response.error.unwrap();
433 assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
434 assert!(error.message.contains("Access denied"));
435 }
436
437 #[tokio::test]
438 async fn test_registry_allow_all() {
439 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
440 .with_auth(crate::auth::AllowAll);
441
442 let response = registry.call("any_method", None, Some(json!(1))).await;
443 assert!(response.result.is_some());
444 assert!(response.error.is_none());
445 }
446
447 #[tokio::test]
448 async fn test_registry_deny_all() {
449 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
450 .with_auth(crate::auth::DenyAll);
451
452 let response = registry.call("any_method", None, Some(json!(1))).await;
453 assert!(response.result.is_none());
454 assert!(response.error.is_some());
455 }
456
457 #[tokio::test]
458 async fn test_registry_empty() {
459 let registry = MethodRegistry::empty();
460 assert_eq!(registry.method_count(), 0);
461 }
462
463 #[tokio::test]
464 async fn test_registry_default() {
465 let registry = MethodRegistry::default();
466 assert_eq!(registry.method_count(), 0);
467 }
468
469 #[tokio::test]
470 async fn test_registry_add_method() {
471 let registry = MethodRegistry::empty()
472 .add_method(Box::new(TestMethod { name: "method1" }))
473 .add_method(Box::new(TestMethod { name: "method2" }));
474
475 assert_eq!(registry.method_count(), 2);
476 assert!(registry.has_method("method1"));
477 assert!(registry.has_method("method2"));
478 }
479
480 #[tokio::test]
481 async fn test_registry_has_method() {
482 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
483
484 assert!(registry.has_method("exists"));
485 assert!(!registry.has_method("not_exists"));
486 }
487
488 #[tokio::test]
489 async fn test_registry_get_methods() {
490 let registry = MethodRegistry::new(vec![
491 Box::new(TestMethod { name: "method1" }),
492 Box::new(TestMethod { name: "method2" }),
493 Box::new(TestMethod { name: "method3" }),
494 ]);
495
496 let methods = registry.get_methods();
497 assert_eq!(methods.len(), 3);
498 assert!(methods.contains(&"method1".to_string()));
499 assert!(methods.contains(&"method2".to_string()));
500 assert!(methods.contains(&"method3".to_string()));
501 }
502
503 #[tokio::test]
504 async fn test_registry_method_count() {
505 let registry = MethodRegistry::new(vec![
506 Box::new(TestMethod { name: "m1" }),
507 Box::new(TestMethod { name: "m2" }),
508 ]);
509
510 assert_eq!(registry.method_count(), 2);
511 }
512
513 #[tokio::test]
514 async fn test_registry_call_method_not_found() {
515 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
516
517 let response = registry.call("not_exists", None, Some(json!(1))).await;
518 assert!(response.error.is_some());
519 let error = response.error.unwrap();
520 assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
521 }
522
523 #[tokio::test]
524 async fn test_registry_call_with_params() {
525 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
526
527 let params = json!({"key": "value"});
528 let response = registry.call("test", Some(params), Some(json!(1))).await;
529 assert!(response.result.is_some());
530 }
531
532 #[tokio::test]
533 async fn test_registry_call_with_context() {
534 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
535
536 let ctx = crate::auth::ConnectionContext::default();
537 let response = registry
538 .call_with_context("test", None, Some(json!(1)), &ctx)
539 .await;
540 assert!(response.result.is_some());
541 }
542
543 #[tokio::test]
544 async fn test_registry_call_with_context_auth_denied() {
545 let auth = TestAuthPolicy {
546 allowed_methods: vec!["allowed".to_string()],
547 };
548
549 let registry =
550 MethodRegistry::new(vec![Box::new(TestMethod { name: "blocked" })]).with_auth(auth);
551
552 let ctx = crate::auth::ConnectionContext::default();
553 let response = registry
554 .call_with_context("blocked", None, Some(json!(1)), &ctx)
555 .await;
556 assert!(response.error.is_some());
557 }
558
559 #[tokio::test]
560 async fn test_registry_generate_openapi_spec() {
561 let registry = MethodRegistry::new(vec![
562 Box::new(TestMethod { name: "method1" }),
563 Box::new(TestMethod { name: "method2" }),
564 ]);
565
566 let spec = registry.generate_openapi_spec("Test API", "1.0.0");
567 assert_eq!(spec.info.title, "Test API");
568 assert_eq!(spec.info.version, "1.0.0");
569 assert_eq!(spec.methods.len(), 2);
570 assert!(spec.methods.contains_key("method1"));
571 assert!(spec.methods.contains_key("method2"));
572 }
573
574 #[tokio::test]
575 async fn test_registry_generate_openapi_spec_with_info() {
576 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
577
578 let servers = vec![OpenApiServer::new("http://localhost:8080")];
579
580 let spec = registry.generate_openapi_spec_with_info(
581 "API",
582 "2.0.0",
583 Some("Test description"),
584 servers,
585 );
586
587 assert_eq!(spec.info.title, "API");
588 assert_eq!(spec.info.version, "2.0.0");
589 assert_eq!(spec.info.description, Some("Test description".to_string()));
590 assert_eq!(spec.servers.len(), 1);
591 assert_eq!(spec.servers[0].url, "http://localhost:8080");
592 }
593
594 #[tokio::test]
595 async fn test_registry_export_openapi_json() {
596 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
597
598 let json_str = registry.export_openapi_json("API", "1.0").unwrap();
599 assert!(json_str.contains("\"title\": \"API\""));
600 assert!(json_str.contains("\"version\": \"1.0\""));
601 assert!(json_str.contains("\"openapi\": \"3.0.3\""));
602 }
603
604 #[tokio::test]
605 async fn test_registry_message_processor_request() {
606 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
607
608 let request = Request {
609 jsonrpc: "2.0".to_string(),
610 method: "test".to_string(),
611 params: None,
612 id: Some(json!(1)),
613 correlation_id: None,
614 };
615
616 let response = registry.process_message(Message::Request(request)).await;
617 assert!(response.is_some());
618 assert!(response.unwrap().result.is_some());
619 }
620
621 #[tokio::test]
622 async fn test_registry_message_processor_notification() {
623 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
624
625 let notification = Notification {
626 jsonrpc: "2.0".to_string(),
627 method: "test".to_string(),
628 params: None,
629 };
630
631 let response = registry
632 .process_message(Message::Notification(notification))
633 .await;
634 assert!(response.is_none());
635 }
636
637 #[tokio::test]
638 async fn test_registry_message_processor_response() {
639 let registry = MethodRegistry::new(vec![]);
640
641 let response_msg = Response {
642 jsonrpc: "2.0".to_string(),
643 result: Some(json!(42)),
644 error: None,
645 id: Some(json!(1)),
646 correlation_id: None,
647 };
648
649 let response = registry
650 .process_message(Message::Response(response_msg))
651 .await;
652 assert!(response.is_none());
653 }
654
655 #[tokio::test]
656 async fn test_registry_process_batch() {
657 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
658
659 let messages = vec![
660 Message::Request(Request {
661 jsonrpc: "2.0".to_string(),
662 method: "test".to_string(),
663 params: None,
664 id: Some(json!(1)),
665 correlation_id: None,
666 }),
667 Message::Request(Request {
668 jsonrpc: "2.0".to_string(),
669 method: "test".to_string(),
670 params: None,
671 id: Some(json!(2)),
672 correlation_id: None,
673 }),
674 ];
675
676 let responses = registry.process_batch(messages).await;
677 assert_eq!(responses.len(), 2);
678 }
679
680 #[test]
681 fn test_register_methods_macro() {
682 let methods = register_methods![TestMethod { name: "m1" }, TestMethod { name: "m2" },];
683 assert_eq!(methods.len(), 2);
684 }
685}