Skip to main content

ash_rpc/
stateful.rs

1//! # Stateful JSON-RPC Handlers
2//!
3//! Stateful JSON-RPC handlers with shared context support.
4//!
5//! This module extends ash-rpc-core with stateful method handlers that can access
6//! shared application state through a service context.
7//!
8
9use crate::{
10    ErrorBuilder, Message, MessageProcessor, Request, Response, ResponseBuilder, error_codes,
11};
12use std::sync::Arc;
13
14/// Trait for service context shared across stateful handlers
15pub trait ServiceContext: Send + Sync + 'static {
16    type Error: std::error::Error + Send + Sync + 'static;
17}
18
19/// Async trait for stateful JSON-RPC method implementations with context
20#[async_trait::async_trait]
21pub trait StatefulJsonRPCMethod<C: ServiceContext>: Send + Sync {
22    /// Get the method name for runtime dispatch
23    fn method_name(&self) -> &'static str;
24
25    /// Execute the JSON-RPC method asynchronously with context
26    async fn call(
27        &self,
28        context: &C,
29        params: Option<serde_json::Value>,
30        id: Option<crate::RequestId>,
31    ) -> Result<Response, C::Error>;
32
33    /// Get `OpenAPI` components for this method
34    fn openapi_components(&self) -> crate::traits::OpenApiMethodSpec {
35        crate::traits::OpenApiMethodSpec::new(self.method_name())
36    }
37}
38
39/// Trait for stateful JSON-RPC handlers
40#[async_trait::async_trait]
41pub trait StatefulHandler<C: ServiceContext>: Send + Sync {
42    /// Handle a JSON-RPC request with access to shared context
43    async fn handle_request(&self, context: &C, request: Request) -> Result<Response, C::Error>;
44
45    /// Handle a JSON-RPC notification with access to shared context
46    async fn handle_notification(
47        &self,
48        context: &C,
49        notification: crate::Notification,
50    ) -> Result<(), C::Error> {
51        let _ = context;
52        let _ = notification;
53        Ok(())
54    }
55}
56
57/// Registry for organizing stateful JSON-RPC methods
58pub struct StatefulMethodRegistry<C: ServiceContext> {
59    methods: Vec<Box<dyn StatefulJsonRPCMethod<C>>>,
60}
61
62impl<C: ServiceContext> StatefulMethodRegistry<C> {
63    /// Create a new empty registry
64    #[must_use]
65    pub fn new() -> Self {
66        Self {
67            methods: Vec::new(),
68        }
69    }
70
71    /// Register a method handler
72    #[must_use]
73    pub fn register<M>(mut self, method: M) -> Self
74    where
75        M: StatefulJsonRPCMethod<C> + 'static,
76    {
77        tracing::trace!("registering stateful method");
78        self.methods.push(Box::new(method));
79        self
80    }
81
82    /// Call a registered method with context
83    ///
84    /// # Errors
85    /// Returns error if method handler fails
86    pub async fn call(
87        &self,
88        context: &C,
89        method: &str,
90        params: Option<serde_json::Value>,
91        id: Option<crate::RequestId>,
92    ) -> Result<Response, C::Error> {
93        // Generate match statement for all registered methods
94        for handler in &self.methods {
95            if handler.method_name() == method {
96                tracing::debug!(method = %method, "calling stateful method");
97                return handler.call(context, params, id).await;
98            }
99        }
100
101        tracing::warn!(method = %method, "stateful method not found");
102        // Method not found
103        Ok(ResponseBuilder::new()
104            .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
105            .id(id)
106            .build())
107    }
108}
109
110impl<C: ServiceContext> Default for StatefulMethodRegistry<C> {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116#[async_trait::async_trait]
117impl<C: ServiceContext> StatefulHandler<C> for StatefulMethodRegistry<C> {
118    async fn handle_request(&self, context: &C, request: Request) -> Result<Response, C::Error> {
119        self.call(context, &request.method, request.params, request.id)
120            .await
121    }
122
123    async fn handle_notification(
124        &self,
125        context: &C,
126        notification: crate::Notification,
127    ) -> Result<(), C::Error> {
128        let _ = self
129            .call(context, &notification.method, notification.params, None)
130            .await?;
131        Ok(())
132    }
133}
134
135/// Stateful message processor that wraps a context and handler
136pub struct StatefulProcessor<C: ServiceContext> {
137    context: Arc<C>,
138    handler: Arc<dyn StatefulHandler<C>>,
139}
140
141impl<C: ServiceContext> StatefulProcessor<C> {
142    /// Create a new stateful processor with context and handler
143    pub fn new<H>(context: C, handler: H) -> Self
144    where
145        H: StatefulHandler<C> + 'static,
146    {
147        Self {
148            context: Arc::new(context),
149            handler: Arc::new(handler),
150        }
151    }
152
153    /// Create a builder for configuring the processor
154    pub fn builder(context: C) -> StatefulProcessorBuilder<C> {
155        StatefulProcessorBuilder::new(context)
156    }
157}
158
159#[async_trait::async_trait]
160impl<C: ServiceContext> MessageProcessor for StatefulProcessor<C> {
161    async fn process_message(&self, message: Message) -> Option<Response> {
162        match message {
163            Message::Request(request) => {
164                let request_id = request.id.clone();
165                let correlation_id = request.correlation_id.clone();
166
167                match self.handler.handle_request(&self.context, request).await {
168                    Ok(response) => Some(response),
169                    Err(error) => {
170                        // Log the actual error with correlation tracking
171                        tracing::error!(
172                            error = %error,
173                            request_id = ?request_id,
174                            correlation_id = ?correlation_id,
175                            "stateful handler error"
176                        );
177
178                        // Return generic error that preserves request ID
179                        // Users can customize error handling by implementing their own MessageProcessor
180                        let generic_error = crate::Error::from_error_logged(&error);
181
182                        Some(
183                            ResponseBuilder::new()
184                                .error(generic_error)
185                                .id(request_id) // Preserve request ID for correlation
186                                .correlation_id(correlation_id) // Preserve correlation ID
187                                .build(),
188                        )
189                    }
190                }
191            }
192            Message::Notification(notification) => {
193                drop(
194                    self.handler
195                        .handle_notification(&self.context, notification)
196                        .await,
197                );
198                None
199            }
200            Message::Response(_) => None,
201        }
202    }
203}
204
205/// Builder for creating stateful processors
206pub struct StatefulProcessorBuilder<C: ServiceContext> {
207    context: C,
208    handler: Option<Arc<dyn StatefulHandler<C>>>,
209}
210
211impl<C: ServiceContext> StatefulProcessorBuilder<C> {
212    /// Create a new builder with the given context
213    pub fn new(context: C) -> Self {
214        Self {
215            context,
216            handler: None,
217        }
218    }
219
220    /// Set the handler for processing requests
221    #[must_use]
222    pub fn handler<H>(mut self, handler: H) -> Self
223    where
224        H: StatefulHandler<C> + 'static,
225    {
226        self.handler = Some(Arc::new(handler));
227        self
228    }
229
230    /// Set a method registry as the handler
231    #[must_use]
232    pub fn registry(mut self, registry: StatefulMethodRegistry<C>) -> Self {
233        self.handler = Some(Arc::new(registry));
234        self
235    }
236
237    /// Build the stateful processor
238    ///
239    /// # Errors
240    /// Returns error if handler is not set
241    pub fn build(self) -> Result<StatefulProcessor<C>, Box<dyn std::error::Error>> {
242        let handler = self.handler.ok_or("Handler not set")?;
243        Ok(StatefulProcessor {
244            context: Arc::new(self.context),
245            handler,
246        })
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::{Notification, RequestBuilder};
254    use std::sync::atomic::{AtomicU32, Ordering};
255
256    // Test context implementation
257    #[derive(Debug)]
258    struct TestError(String);
259
260    impl std::fmt::Display for TestError {
261        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262            write!(f, "{}", self.0)
263        }
264    }
265
266    impl std::error::Error for TestError {}
267
268    struct TestContext {
269        counter: AtomicU32,
270    }
271
272    impl ServiceContext for TestContext {
273        type Error = TestError;
274    }
275
276    impl TestContext {
277        fn new() -> Self {
278            Self {
279                counter: AtomicU32::new(0),
280            }
281        }
282
283        fn increment(&self) -> u32 {
284            self.counter.fetch_add(1, Ordering::SeqCst) + 1
285        }
286
287        fn get_count(&self) -> u32 {
288            self.counter.load(Ordering::SeqCst)
289        }
290    }
291
292    // Test method implementation
293    struct IncrementMethod;
294
295    #[async_trait::async_trait]
296    impl StatefulJsonRPCMethod<TestContext> for IncrementMethod {
297        fn method_name(&self) -> &'static str {
298            "increment"
299        }
300
301        async fn call(
302            &self,
303            context: &TestContext,
304            _params: Option<serde_json::Value>,
305            id: Option<crate::RequestId>,
306        ) -> Result<Response, TestError> {
307            let count = context.increment();
308            Ok(ResponseBuilder::new()
309                .success(serde_json::json!({"count": count}))
310                .id(id)
311                .build())
312        }
313    }
314
315    // Failing method for error tests
316    struct FailingMethod;
317
318    #[async_trait::async_trait]
319    impl StatefulJsonRPCMethod<TestContext> for FailingMethod {
320        fn method_name(&self) -> &'static str {
321            "fail"
322        }
323
324        async fn call(
325            &self,
326            _context: &TestContext,
327            _params: Option<serde_json::Value>,
328            _id: Option<crate::RequestId>,
329        ) -> Result<Response, TestError> {
330            Err(TestError("intentional failure".to_string()))
331        }
332    }
333
334    #[tokio::test]
335    async fn test_stateful_registry_register_and_call() {
336        let context = TestContext::new();
337        let registry = StatefulMethodRegistry::new().register(IncrementMethod);
338
339        let result = registry
340            .call(&context, "increment", None, Some(serde_json::json!(1)))
341            .await
342            .unwrap();
343
344        assert!(result.result.is_some());
345        assert_eq!(context.get_count(), 1);
346    }
347
348    #[tokio::test]
349    async fn test_stateful_registry_method_not_found() {
350        let context = TestContext::new();
351        let registry = StatefulMethodRegistry::<TestContext>::new();
352
353        let result = registry
354            .call(&context, "unknown", None, Some(serde_json::json!(1)))
355            .await
356            .unwrap();
357
358        assert!(result.error.is_some());
359        let error = result.error.unwrap();
360        assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
361    }
362
363    #[tokio::test]
364    async fn test_stateful_registry_multiple_methods() {
365        let context = TestContext::new();
366        let registry = StatefulMethodRegistry::new()
367            .register(IncrementMethod)
368            .register(FailingMethod);
369
370        // Call increment twice
371        let _ = registry
372            .call(&context, "increment", None, Some(serde_json::json!(1)))
373            .await;
374        let _ = registry
375            .call(&context, "increment", None, Some(serde_json::json!(2)))
376            .await;
377        assert_eq!(context.get_count(), 2);
378
379        // Call failing method
380        let result = registry
381            .call(&context, "fail", None, Some(serde_json::json!(3)))
382            .await;
383        assert!(result.is_err());
384    }
385
386    #[tokio::test]
387    async fn test_stateful_handler_request() {
388        let context = TestContext::new();
389        let registry = StatefulMethodRegistry::new().register(IncrementMethod);
390
391        let request = RequestBuilder::new("increment")
392            .id(serde_json::json!(1))
393            .build();
394
395        let result = registry.handle_request(&context, request).await.unwrap();
396        assert!(result.result.is_some());
397    }
398
399    #[tokio::test]
400    async fn test_stateful_handler_notification() {
401        let context = TestContext::new();
402        let registry = StatefulMethodRegistry::new().register(IncrementMethod);
403
404        let notification = Notification {
405            jsonrpc: "2.0".to_string(),
406            method: "increment".to_string(),
407            params: None,
408        };
409
410        let result = registry.handle_notification(&context, notification).await;
411        assert!(result.is_ok());
412        assert_eq!(context.get_count(), 1);
413    }
414
415    #[tokio::test]
416    async fn test_stateful_processor_request() {
417        let context = TestContext::new();
418        let registry = StatefulMethodRegistry::new().register(IncrementMethod);
419        let processor = StatefulProcessor::new(context, registry);
420
421        let request = RequestBuilder::new("increment")
422            .id(serde_json::json!(1))
423            .build();
424
425        let response = processor.process_message(Message::Request(request)).await;
426        assert!(response.is_some());
427        let response = response.unwrap();
428        assert!(response.result.is_some());
429    }
430
431    #[tokio::test]
432    async fn test_stateful_processor_notification() {
433        let context = TestContext::new();
434        let registry = StatefulMethodRegistry::new().register(IncrementMethod);
435        let processor = StatefulProcessor::new(context, registry);
436
437        let notification = Notification {
438            jsonrpc: "2.0".to_string(),
439            method: "increment".to_string(),
440            params: None,
441        };
442
443        let response = processor
444            .process_message(Message::Notification(notification))
445            .await;
446        assert!(response.is_none());
447    }
448
449    #[tokio::test]
450    async fn test_stateful_processor_error_handling() {
451        let context = TestContext::new();
452        let registry = StatefulMethodRegistry::new().register(FailingMethod);
453        let processor = StatefulProcessor::new(context, registry);
454
455        let request = RequestBuilder::new("fail").id(serde_json::json!(1)).build();
456
457        let response = processor.process_message(Message::Request(request)).await;
458        assert!(response.is_some());
459        let response = response.unwrap();
460        assert!(response.error.is_some());
461        assert_eq!(response.id, Some(serde_json::json!(1)));
462    }
463
464    #[tokio::test]
465    async fn test_stateful_processor_preserves_correlation_id() {
466        let context = TestContext::new();
467        let registry = StatefulMethodRegistry::new().register(FailingMethod);
468        let processor = StatefulProcessor::new(context, registry);
469
470        let correlation_id = uuid::Uuid::new_v4().to_string();
471        let request = RequestBuilder::new("fail")
472            .id(serde_json::json!(1))
473            .correlation_id(correlation_id.clone())
474            .build();
475
476        let response = processor
477            .process_message(Message::Request(request))
478            .await
479            .unwrap();
480        assert_eq!(response.correlation_id, Some(correlation_id));
481    }
482
483    #[tokio::test]
484    async fn test_stateful_processor_builder() {
485        let context = TestContext::new();
486        let registry = StatefulMethodRegistry::new().register(IncrementMethod);
487
488        let processor = StatefulProcessor::builder(context)
489            .registry(registry)
490            .build()
491            .unwrap();
492
493        let request = RequestBuilder::new("increment")
494            .id(serde_json::json!(1))
495            .build();
496
497        let response = processor.process_message(Message::Request(request)).await;
498        assert!(response.is_some());
499    }
500
501    #[tokio::test]
502    async fn test_stateful_processor_builder_no_handler() {
503        let context = TestContext::new();
504        let result = StatefulProcessor::builder(context).build();
505        assert!(result.is_err());
506    }
507
508    #[test]
509    fn test_stateful_method_openapi_components() {
510        let method = IncrementMethod;
511        let spec = method.openapi_components();
512        assert_eq!(spec.method_name, "increment");
513    }
514
515    #[test]
516    fn test_stateful_registry_default() {
517        let registry = StatefulMethodRegistry::<TestContext>::default();
518        assert_eq!(registry.methods.len(), 0);
519    }
520}