1use crate::builders::{ErrorBuilder, ResponseBuilder};
40use crate::traits::{
41 Handler, JsonRPCMethod, MessageProcessor, OpenApiServer, OpenApiSpec, ProcessorCapabilities,
42};
43use crate::types::{Message, Notification, Request, RequestId, Response, error_codes};
44use std::sync::Arc;
45
46pub struct MethodRegistry {
48 methods: Vec<Box<dyn JsonRPCMethod>>,
49 auth_policy: Option<Arc<dyn crate::auth::AuthPolicy>>,
50}
51
52#[macro_export]
54macro_rules! register_methods {
55 ($($method:expr),* $(,)?) => {
56 vec![
57 $(
58 Box::new($method) as Box<dyn JsonRPCMethod>
59 ),*
60 ]
61 };
62}
63
64#[macro_export]
67macro_rules! dispatch_call {
68 ($method_name:expr, $params:expr, $id:expr => $($method:expr),* $(,)?) => {
69 {
70 $(
72 let temp_method = $method;
73 if $method_name == temp_method.method_name() {
74 return temp_method.call($params, $id).await;
75 }
76 )*
77
78 ResponseBuilder::new()
80 .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
81 .id($id)
82 .build()
83 }
84 };
85}
86
87impl MethodRegistry {
88 pub fn new(methods: Vec<Box<dyn JsonRPCMethod>>) -> Self {
90 tracing::debug!(method_count = methods.len(), "registry created");
91 Self {
92 methods,
93 auth_policy: None,
94 }
95 }
96
97 #[must_use]
99 pub fn empty() -> Self {
100 Self {
101 methods: Vec::new(),
102 auth_policy: None,
103 }
104 }
105
106 #[must_use]
117 pub fn with_auth<A: crate::auth::AuthPolicy + 'static>(mut self, policy: A) -> Self {
118 self.auth_policy = Some(Arc::new(policy));
119 self
120 }
121
122 #[must_use]
124 pub fn add_method(mut self, method: Box<dyn JsonRPCMethod>) -> Self {
125 tracing::trace!("adding method to registry");
126 self.methods.push(method);
127 self
128 }
129
130 pub async fn call(
134 &self,
135 method_name: &str,
136 params: Option<serde_json::Value>,
137 id: Option<RequestId>,
138 ) -> Response {
139 self.call_with_context(
140 method_name,
141 params,
142 id,
143 &crate::auth::ConnectionContext::default(),
144 )
145 .await
146 }
147
148 pub async fn call_with_context(
152 &self,
153 method_name: &str,
154 params: Option<serde_json::Value>,
155 id: Option<RequestId>,
156 ctx: &crate::auth::ConnectionContext,
157 ) -> Response {
158 if let Some(auth) = &self.auth_policy
160 && !auth.can_access(method_name, params.as_ref(), ctx)
161 {
162 tracing::warn!(
163 method = %method_name,
164 remote_addr = ?ctx.remote_addr,
165 "access denied by auth policy"
166 );
167 return auth.unauthorized_error(method_name);
168 }
169
170 for method in &self.methods {
172 if method.method_name() == method_name {
173 tracing::debug!(method = %method_name, "calling method");
174 return method.call(params, id).await;
175 }
176 }
177
178 tracing::warn!(method = %method_name, "method not found");
179 ResponseBuilder::new()
180 .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
181 .id(id)
182 .build()
183 }
184
185 #[must_use]
187 pub fn has_method(&self, method_name: &str) -> bool {
188 self.methods.iter().any(|m| m.method_name() == method_name)
189 }
190
191 #[must_use]
193 pub fn get_methods(&self) -> Vec<String> {
194 self.methods
195 .iter()
196 .map(|m| m.method_name().to_owned())
197 .collect()
198 }
199
200 #[must_use]
202 pub fn method_count(&self) -> usize {
203 self.methods.len()
204 }
205
206 pub fn generate_openapi_spec(&self, title: &str, version: &str) -> OpenApiSpec {
208 tracing::debug!(method_count = self.methods.len(), "generating openapi spec");
209 let mut spec = OpenApiSpec::new(title, version);
210
211 for method in &self.methods {
212 let method_spec = method.openapi_components();
213 spec.add_method(method_spec);
214 }
215
216 spec
217 }
218
219 #[must_use]
221 pub fn generate_openapi_spec_with_info(
222 &self,
223 title: &str,
224 version: &str,
225 description: Option<&str>,
226 servers: Vec<OpenApiServer>,
227 ) -> OpenApiSpec {
228 let mut spec = self.generate_openapi_spec(title, version);
229
230 if let Some(desc) = description {
231 spec.info.description = Some(desc.to_owned());
232 }
233
234 for server in servers {
235 spec.add_server(server);
236 }
237
238 spec
239 }
240
241 pub fn export_openapi_json(
246 &self,
247 title: &str,
248 version: &str,
249 ) -> Result<String, serde_json::Error> {
250 let spec = self.generate_openapi_spec(title, version);
251 serde_json::to_string_pretty(&spec)
252 }
253}
254
255impl Default for MethodRegistry {
256 fn default() -> Self {
257 Self::empty()
258 }
259}
260
261#[async_trait::async_trait]
262impl MessageProcessor for MethodRegistry {
263 async fn process_message(&self, message: Message) -> Option<Response> {
264 match message {
265 Message::Request(request) => {
266 tracing::trace!(method = %request.method, correlation_id = ?request.correlation_id, "processing request");
267 let response = self.call(&request.method, request.params, request.id).await;
268 Some(response)
269 }
270 Message::Notification(notification) => {
271 tracing::trace!(method = %notification.method, "processing notification");
272 let _ = self
273 .call(¬ification.method, notification.params, None)
274 .await;
275 None
276 }
277 Message::Response(_) => None,
278 }
279 }
280
281 async fn process_batch(&self, messages: Vec<Message>) -> Vec<Response> {
282 let capabilities = self.get_capabilities();
283
284 if let Some(max_size) = capabilities.max_batch_size
286 && messages.len() > max_size
287 {
288 tracing::warn!(
289 batch_size = messages.len(),
290 max_batch_size = max_size,
291 "batch size limit exceeded"
292 );
293 return vec![crate::Response::error(
294 crate::ErrorBuilder::new(
295 crate::error_codes::INVALID_REQUEST,
296 format!("Batch size {} exceeds maximum {}", messages.len(), max_size),
297 )
298 .build(),
299 None,
300 )];
301 }
302
303 tracing::debug!(batch_size = messages.len(), "processing batch");
304 let mut results = Vec::new();
305 for msg in messages {
306 if let Some(response) = self.process_message(msg).await {
307 results.push(response);
308 }
309 }
310 results
311 }
312
313 fn get_capabilities(&self) -> ProcessorCapabilities {
314 ProcessorCapabilities {
315 supports_batch: true,
316 supports_notifications: true,
317 max_batch_size: Some(100),
318 max_request_size: Some(1024 * 1024), request_timeout_secs: Some(30),
320 supported_versions: vec!["2.0".to_owned()],
321 }
322 }
323}
324
325#[async_trait::async_trait]
326impl Handler for MethodRegistry {
327 async fn handle_request(&self, request: Request) -> Response {
328 self.call(&request.method, request.params, request.id).await
329 }
330
331 async fn handle_notification(&self, notification: Notification) {
332 let _ = self
333 .call(¬ification.method, notification.params, None)
334 .await;
335 }
336
337 fn supports_method(&self, method: &str) -> bool {
338 self.has_method(method)
339 }
340
341 fn get_supported_methods(&self) -> Vec<String> {
342 self.get_methods()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use serde_json::json;
350
351 struct TestMethod {
353 name: &'static str,
354 }
355
356 #[async_trait::async_trait]
357 impl JsonRPCMethod for TestMethod {
358 fn method_name(&self) -> &'static str {
359 self.name
360 }
361
362 async fn call(
363 &self,
364 _params: Option<serde_json::Value>,
365 id: Option<RequestId>,
366 ) -> Response {
367 ResponseBuilder::new()
368 .success(json!({"method": self.name}))
369 .id(id)
370 .build()
371 }
372 }
373
374 struct TestAuthPolicy {
376 allowed_methods: Vec<String>,
377 }
378
379 impl crate::auth::AuthPolicy for TestAuthPolicy {
380 fn can_access(
381 &self,
382 method: &str,
383 _params: Option<&serde_json::Value>,
384 _ctx: &crate::auth::ConnectionContext,
385 ) -> bool {
386 self.allowed_methods.contains(&method.to_string())
387 }
388
389 fn unauthorized_error(&self, method: &str) -> Response {
390 ResponseBuilder::new()
391 .error(
392 ErrorBuilder::new(
393 crate::error_codes::INTERNAL_ERROR,
394 format!("Access denied for method '{}'", method),
395 )
396 .build(),
397 )
398 .build()
399 }
400 }
401
402 #[tokio::test]
403 async fn test_registry_without_auth() {
404 let registry = MethodRegistry::new(vec![Box::new(TestMethod {
405 name: "test_method",
406 })]);
407
408 let response = registry.call("test_method", None, Some(json!(1))).await;
409 assert!(response.result.is_some());
410 assert!(response.error.is_none());
411 }
412
413 #[tokio::test]
414 async fn test_registry_with_auth_allowed() {
415 let auth = TestAuthPolicy {
416 allowed_methods: vec!["allowed_method".to_string()],
417 };
418
419 let registry = MethodRegistry::new(vec![Box::new(TestMethod {
420 name: "allowed_method",
421 })])
422 .with_auth(auth);
423
424 let response = registry.call("allowed_method", None, Some(json!(1))).await;
425 assert!(response.result.is_some());
426 assert!(response.error.is_none());
427 }
428
429 #[tokio::test]
430 async fn test_registry_with_auth_denied() {
431 let auth = TestAuthPolicy {
432 allowed_methods: vec!["allowed_method".to_string()],
433 };
434
435 let registry = MethodRegistry::new(vec![Box::new(TestMethod {
436 name: "blocked_method",
437 })])
438 .with_auth(auth);
439
440 let response = registry.call("blocked_method", None, Some(json!(1))).await;
441 assert!(response.result.is_none());
442 assert!(response.error.is_some());
443
444 let error = response.error.unwrap();
445 assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
446 assert!(error.message.contains("Access denied"));
447 }
448
449 #[tokio::test]
450 async fn test_registry_allow_all() {
451 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
452 .with_auth(crate::auth::AllowAll);
453
454 let response = registry.call("any_method", None, Some(json!(1))).await;
455 assert!(response.result.is_some());
456 assert!(response.error.is_none());
457 }
458
459 #[tokio::test]
460 async fn test_registry_deny_all() {
461 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "any_method" })])
462 .with_auth(crate::auth::DenyAll);
463
464 let response = registry.call("any_method", None, Some(json!(1))).await;
465 assert!(response.result.is_none());
466 assert!(response.error.is_some());
467 }
468
469 #[tokio::test]
470 async fn test_registry_empty() {
471 let registry = MethodRegistry::empty();
472 assert_eq!(registry.method_count(), 0);
473 }
474
475 #[tokio::test]
476 async fn test_registry_default() {
477 let registry = MethodRegistry::default();
478 assert_eq!(registry.method_count(), 0);
479 }
480
481 #[tokio::test]
482 async fn test_registry_add_method() {
483 let registry = MethodRegistry::empty()
484 .add_method(Box::new(TestMethod { name: "method1" }))
485 .add_method(Box::new(TestMethod { name: "method2" }));
486
487 assert_eq!(registry.method_count(), 2);
488 assert!(registry.has_method("method1"));
489 assert!(registry.has_method("method2"));
490 }
491
492 #[tokio::test]
493 async fn test_registry_has_method() {
494 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
495
496 assert!(registry.has_method("exists"));
497 assert!(!registry.has_method("not_exists"));
498 }
499
500 #[tokio::test]
501 async fn test_registry_get_methods() {
502 let registry = MethodRegistry::new(vec![
503 Box::new(TestMethod { name: "method1" }),
504 Box::new(TestMethod { name: "method2" }),
505 Box::new(TestMethod { name: "method3" }),
506 ]);
507
508 let methods = registry.get_methods();
509 assert_eq!(methods.len(), 3);
510 assert!(methods.contains(&"method1".to_string()));
511 assert!(methods.contains(&"method2".to_string()));
512 assert!(methods.contains(&"method3".to_string()));
513 }
514
515 #[tokio::test]
516 async fn test_registry_method_count() {
517 let registry = MethodRegistry::new(vec![
518 Box::new(TestMethod { name: "m1" }),
519 Box::new(TestMethod { name: "m2" }),
520 ]);
521
522 assert_eq!(registry.method_count(), 2);
523 }
524
525 #[tokio::test]
526 async fn test_registry_call_method_not_found() {
527 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "exists" })]);
528
529 let response = registry.call("not_exists", None, Some(json!(1))).await;
530 assert!(response.error.is_some());
531 let error = response.error.unwrap();
532 assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
533 }
534
535 #[tokio::test]
536 async fn test_registry_call_with_params() {
537 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
538
539 let params = json!({"key": "value"});
540 let response = registry.call("test", Some(params), Some(json!(1))).await;
541 assert!(response.result.is_some());
542 }
543
544 #[tokio::test]
545 async fn test_registry_call_with_context() {
546 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
547
548 let ctx = crate::auth::ConnectionContext::default();
549 let response = registry
550 .call_with_context("test", None, Some(json!(1)), &ctx)
551 .await;
552 assert!(response.result.is_some());
553 }
554
555 #[tokio::test]
556 async fn test_registry_call_with_context_auth_denied() {
557 let auth = TestAuthPolicy {
558 allowed_methods: vec!["allowed".to_string()],
559 };
560
561 let registry =
562 MethodRegistry::new(vec![Box::new(TestMethod { name: "blocked" })]).with_auth(auth);
563
564 let ctx = crate::auth::ConnectionContext::default();
565 let response = registry
566 .call_with_context("blocked", None, Some(json!(1)), &ctx)
567 .await;
568 assert!(response.error.is_some());
569 }
570
571 #[tokio::test]
572 async fn test_registry_generate_openapi_spec() {
573 let registry = MethodRegistry::new(vec![
574 Box::new(TestMethod { name: "method1" }),
575 Box::new(TestMethod { name: "method2" }),
576 ]);
577
578 let spec = registry.generate_openapi_spec("Test API", "1.0.0");
579 assert_eq!(spec.info.title, "Test API");
580 assert_eq!(spec.info.version, "1.0.0");
581 assert_eq!(spec.methods.len(), 2);
582 assert!(spec.methods.contains_key("method1"));
583 assert!(spec.methods.contains_key("method2"));
584 }
585
586 #[tokio::test]
587 async fn test_registry_generate_openapi_spec_with_info() {
588 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
589
590 let servers = vec![OpenApiServer::new("http://localhost:8080")];
591
592 let spec = registry.generate_openapi_spec_with_info(
593 "API",
594 "2.0.0",
595 Some("Test description"),
596 servers,
597 );
598
599 assert_eq!(spec.info.title, "API");
600 assert_eq!(spec.info.version, "2.0.0");
601 assert_eq!(spec.info.description, Some("Test description".to_string()));
602 assert_eq!(spec.servers.len(), 1);
603 assert_eq!(spec.servers[0].url, "http://localhost:8080");
604 }
605
606 #[tokio::test]
607 async fn test_registry_export_openapi_json() {
608 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
609
610 let json_str = registry.export_openapi_json("API", "1.0").unwrap();
611 assert!(json_str.contains("\"title\": \"API\""));
612 assert!(json_str.contains("\"version\": \"1.0\""));
613 assert!(json_str.contains("\"openapi\": \"3.0.3\""));
614 }
615
616 #[tokio::test]
617 async fn test_registry_message_processor_request() {
618 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
619
620 let request = Request {
621 jsonrpc: "2.0".to_string(),
622 method: "test".to_string(),
623 params: None,
624 id: Some(json!(1)),
625 correlation_id: None,
626 };
627
628 let response = registry.process_message(Message::Request(request)).await;
629 assert!(response.is_some());
630 assert!(response.unwrap().result.is_some());
631 }
632
633 #[tokio::test]
634 async fn test_registry_message_processor_notification() {
635 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
636
637 let notification = Notification {
638 jsonrpc: "2.0".to_string(),
639 method: "test".to_string(),
640 params: None,
641 };
642
643 let response = registry
644 .process_message(Message::Notification(notification))
645 .await;
646 assert!(response.is_none());
647 }
648
649 #[tokio::test]
650 async fn test_registry_message_processor_response() {
651 let registry = MethodRegistry::new(vec![]);
652
653 let response_msg = Response {
654 jsonrpc: "2.0".to_string(),
655 result: Some(json!(42)),
656 error: None,
657 id: Some(json!(1)),
658 correlation_id: None,
659 };
660
661 let response = registry
662 .process_message(Message::Response(response_msg))
663 .await;
664 assert!(response.is_none());
665 }
666
667 #[tokio::test]
668 async fn test_registry_process_batch() {
669 let registry = MethodRegistry::new(vec![Box::new(TestMethod { name: "test" })]);
670
671 let messages = vec![
672 Message::Request(Request {
673 jsonrpc: "2.0".to_string(),
674 method: "test".to_string(),
675 params: None,
676 id: Some(json!(1)),
677 correlation_id: None,
678 }),
679 Message::Request(Request {
680 jsonrpc: "2.0".to_string(),
681 method: "test".to_string(),
682 params: None,
683 id: Some(json!(2)),
684 correlation_id: None,
685 }),
686 ];
687
688 let responses = registry.process_batch(messages).await;
689 assert_eq!(responses.len(), 2);
690 }
691
692 #[test]
693 fn test_register_methods_macro() {
694 let methods = register_methods![TestMethod { name: "m1" }, TestMethod { name: "m2" },];
695 assert_eq!(methods.len(), 2);
696 }
697}