1use crate::{
10 ErrorBuilder, Message, MessageProcessor, Request, Response, ResponseBuilder, error_codes,
11};
12use std::sync::Arc;
13
14pub trait ServiceContext: Send + Sync + 'static {
16 type Error: std::error::Error + Send + Sync + 'static;
17}
18
19#[async_trait::async_trait]
21pub trait StatefulJsonRPCMethod<C: ServiceContext>: Send + Sync {
22 fn method_name(&self) -> &'static str;
24
25 async fn call(
27 &self,
28 context: &C,
29 params: Option<serde_json::Value>,
30 id: Option<crate::RequestId>,
31 ) -> Result<Response, C::Error>;
32
33 fn openapi_components(&self) -> crate::traits::OpenApiMethodSpec {
35 crate::traits::OpenApiMethodSpec::new(self.method_name())
36 }
37}
38
39#[async_trait::async_trait]
41pub trait StatefulHandler<C: ServiceContext>: Send + Sync {
42 async fn handle_request(&self, context: &C, request: Request) -> Result<Response, C::Error>;
44
45 async fn handle_notification(
47 &self,
48 context: &C,
49 notification: crate::Notification,
50 ) -> Result<(), C::Error> {
51 let _ = context;
52 let _ = notification;
53 Ok(())
54 }
55}
56
57pub struct StatefulMethodRegistry<C: ServiceContext> {
59 methods: Vec<Box<dyn StatefulJsonRPCMethod<C>>>,
60}
61
62impl<C: ServiceContext> StatefulMethodRegistry<C> {
63 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 methods: Vec::new(),
68 }
69 }
70
71 #[must_use]
73 pub fn register<M>(mut self, method: M) -> Self
74 where
75 M: StatefulJsonRPCMethod<C> + 'static,
76 {
77 tracing::trace!("registering stateful method");
78 self.methods.push(Box::new(method));
79 self
80 }
81
82 pub async fn call(
87 &self,
88 context: &C,
89 method: &str,
90 params: Option<serde_json::Value>,
91 id: Option<crate::RequestId>,
92 ) -> Result<Response, C::Error> {
93 for handler in &self.methods {
95 if handler.method_name() == method {
96 tracing::debug!(method = %method, "calling stateful method");
97 return handler.call(context, params, id).await;
98 }
99 }
100
101 tracing::warn!(method = %method, "stateful method not found");
102 Ok(ResponseBuilder::new()
104 .error(ErrorBuilder::new(error_codes::METHOD_NOT_FOUND, "Method not found").build())
105 .id(id)
106 .build())
107 }
108}
109
110impl<C: ServiceContext> Default for StatefulMethodRegistry<C> {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[async_trait::async_trait]
117impl<C: ServiceContext> StatefulHandler<C> for StatefulMethodRegistry<C> {
118 async fn handle_request(&self, context: &C, request: Request) -> Result<Response, C::Error> {
119 self.call(context, &request.method, request.params, request.id)
120 .await
121 }
122
123 async fn handle_notification(
124 &self,
125 context: &C,
126 notification: crate::Notification,
127 ) -> Result<(), C::Error> {
128 let _ = self
129 .call(context, ¬ification.method, notification.params, None)
130 .await?;
131 Ok(())
132 }
133}
134
135pub struct StatefulProcessor<C: ServiceContext> {
137 context: Arc<C>,
138 handler: Arc<dyn StatefulHandler<C>>,
139}
140
141impl<C: ServiceContext> StatefulProcessor<C> {
142 pub fn new<H>(context: C, handler: H) -> Self
144 where
145 H: StatefulHandler<C> + 'static,
146 {
147 Self {
148 context: Arc::new(context),
149 handler: Arc::new(handler),
150 }
151 }
152
153 pub fn builder(context: C) -> StatefulProcessorBuilder<C> {
155 StatefulProcessorBuilder::new(context)
156 }
157}
158
159#[async_trait::async_trait]
160impl<C: ServiceContext> MessageProcessor for StatefulProcessor<C> {
161 async fn process_message(&self, message: Message) -> Option<Response> {
162 match message {
163 Message::Request(request) => {
164 let request_id = request.id.clone();
165 let correlation_id = request.correlation_id.clone();
166
167 match self.handler.handle_request(&self.context, request).await {
168 Ok(response) => Some(response),
169 Err(error) => {
170 tracing::error!(
172 error = %error,
173 request_id = ?request_id,
174 correlation_id = ?correlation_id,
175 "stateful handler error"
176 );
177
178 let generic_error = crate::Error::from_error_logged(&error);
181
182 Some(
183 ResponseBuilder::new()
184 .error(generic_error)
185 .id(request_id) .correlation_id(correlation_id) .build(),
188 )
189 }
190 }
191 }
192 Message::Notification(notification) => {
193 drop(
194 self.handler
195 .handle_notification(&self.context, notification)
196 .await,
197 );
198 None
199 }
200 Message::Response(_) => None,
201 }
202 }
203}
204
205pub struct StatefulProcessorBuilder<C: ServiceContext> {
207 context: C,
208 handler: Option<Arc<dyn StatefulHandler<C>>>,
209}
210
211impl<C: ServiceContext> StatefulProcessorBuilder<C> {
212 pub fn new(context: C) -> Self {
214 Self {
215 context,
216 handler: None,
217 }
218 }
219
220 #[must_use]
222 pub fn handler<H>(mut self, handler: H) -> Self
223 where
224 H: StatefulHandler<C> + 'static,
225 {
226 self.handler = Some(Arc::new(handler));
227 self
228 }
229
230 #[must_use]
232 pub fn registry(mut self, registry: StatefulMethodRegistry<C>) -> Self {
233 self.handler = Some(Arc::new(registry));
234 self
235 }
236
237 pub fn build(self) -> Result<StatefulProcessor<C>, Box<dyn std::error::Error>> {
242 let handler = self.handler.ok_or("Handler not set")?;
243 Ok(StatefulProcessor {
244 context: Arc::new(self.context),
245 handler,
246 })
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::{Notification, RequestBuilder};
254 use std::sync::atomic::{AtomicU32, Ordering};
255
256 #[derive(Debug)]
258 struct TestError(String);
259
260 impl std::fmt::Display for TestError {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 write!(f, "{}", self.0)
263 }
264 }
265
266 impl std::error::Error for TestError {}
267
268 struct TestContext {
269 counter: AtomicU32,
270 }
271
272 impl ServiceContext for TestContext {
273 type Error = TestError;
274 }
275
276 impl TestContext {
277 fn new() -> Self {
278 Self {
279 counter: AtomicU32::new(0),
280 }
281 }
282
283 fn increment(&self) -> u32 {
284 self.counter.fetch_add(1, Ordering::SeqCst) + 1
285 }
286
287 fn get_count(&self) -> u32 {
288 self.counter.load(Ordering::SeqCst)
289 }
290 }
291
292 struct IncrementMethod;
294
295 #[async_trait::async_trait]
296 impl StatefulJsonRPCMethod<TestContext> for IncrementMethod {
297 fn method_name(&self) -> &'static str {
298 "increment"
299 }
300
301 async fn call(
302 &self,
303 context: &TestContext,
304 _params: Option<serde_json::Value>,
305 id: Option<crate::RequestId>,
306 ) -> Result<Response, TestError> {
307 let count = context.increment();
308 Ok(ResponseBuilder::new()
309 .success(serde_json::json!({"count": count}))
310 .id(id)
311 .build())
312 }
313 }
314
315 struct FailingMethod;
317
318 #[async_trait::async_trait]
319 impl StatefulJsonRPCMethod<TestContext> for FailingMethod {
320 fn method_name(&self) -> &'static str {
321 "fail"
322 }
323
324 async fn call(
325 &self,
326 _context: &TestContext,
327 _params: Option<serde_json::Value>,
328 _id: Option<crate::RequestId>,
329 ) -> Result<Response, TestError> {
330 Err(TestError("intentional failure".to_string()))
331 }
332 }
333
334 #[tokio::test]
335 async fn test_stateful_registry_register_and_call() {
336 let context = TestContext::new();
337 let registry = StatefulMethodRegistry::new().register(IncrementMethod);
338
339 let result = registry
340 .call(&context, "increment", None, Some(serde_json::json!(1)))
341 .await
342 .unwrap();
343
344 assert!(result.result.is_some());
345 assert_eq!(context.get_count(), 1);
346 }
347
348 #[tokio::test]
349 async fn test_stateful_registry_method_not_found() {
350 let context = TestContext::new();
351 let registry = StatefulMethodRegistry::<TestContext>::new();
352
353 let result = registry
354 .call(&context, "unknown", None, Some(serde_json::json!(1)))
355 .await
356 .unwrap();
357
358 assert!(result.error.is_some());
359 let error = result.error.unwrap();
360 assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
361 }
362
363 #[tokio::test]
364 async fn test_stateful_registry_multiple_methods() {
365 let context = TestContext::new();
366 let registry = StatefulMethodRegistry::new()
367 .register(IncrementMethod)
368 .register(FailingMethod);
369
370 let _ = registry
372 .call(&context, "increment", None, Some(serde_json::json!(1)))
373 .await;
374 let _ = registry
375 .call(&context, "increment", None, Some(serde_json::json!(2)))
376 .await;
377 assert_eq!(context.get_count(), 2);
378
379 let result = registry
381 .call(&context, "fail", None, Some(serde_json::json!(3)))
382 .await;
383 assert!(result.is_err());
384 }
385
386 #[tokio::test]
387 async fn test_stateful_handler_request() {
388 let context = TestContext::new();
389 let registry = StatefulMethodRegistry::new().register(IncrementMethod);
390
391 let request = RequestBuilder::new("increment")
392 .id(serde_json::json!(1))
393 .build();
394
395 let result = registry.handle_request(&context, request).await.unwrap();
396 assert!(result.result.is_some());
397 }
398
399 #[tokio::test]
400 async fn test_stateful_handler_notification() {
401 let context = TestContext::new();
402 let registry = StatefulMethodRegistry::new().register(IncrementMethod);
403
404 let notification = Notification {
405 jsonrpc: "2.0".to_string(),
406 method: "increment".to_string(),
407 params: None,
408 };
409
410 let result = registry.handle_notification(&context, notification).await;
411 assert!(result.is_ok());
412 assert_eq!(context.get_count(), 1);
413 }
414
415 #[tokio::test]
416 async fn test_stateful_processor_request() {
417 let context = TestContext::new();
418 let registry = StatefulMethodRegistry::new().register(IncrementMethod);
419 let processor = StatefulProcessor::new(context, registry);
420
421 let request = RequestBuilder::new("increment")
422 .id(serde_json::json!(1))
423 .build();
424
425 let response = processor.process_message(Message::Request(request)).await;
426 assert!(response.is_some());
427 let response = response.unwrap();
428 assert!(response.result.is_some());
429 }
430
431 #[tokio::test]
432 async fn test_stateful_processor_notification() {
433 let context = TestContext::new();
434 let registry = StatefulMethodRegistry::new().register(IncrementMethod);
435 let processor = StatefulProcessor::new(context, registry);
436
437 let notification = Notification {
438 jsonrpc: "2.0".to_string(),
439 method: "increment".to_string(),
440 params: None,
441 };
442
443 let response = processor
444 .process_message(Message::Notification(notification))
445 .await;
446 assert!(response.is_none());
447 }
448
449 #[tokio::test]
450 async fn test_stateful_processor_error_handling() {
451 let context = TestContext::new();
452 let registry = StatefulMethodRegistry::new().register(FailingMethod);
453 let processor = StatefulProcessor::new(context, registry);
454
455 let request = RequestBuilder::new("fail").id(serde_json::json!(1)).build();
456
457 let response = processor.process_message(Message::Request(request)).await;
458 assert!(response.is_some());
459 let response = response.unwrap();
460 assert!(response.error.is_some());
461 assert_eq!(response.id, Some(serde_json::json!(1)));
462 }
463
464 #[tokio::test]
465 async fn test_stateful_processor_preserves_correlation_id() {
466 let context = TestContext::new();
467 let registry = StatefulMethodRegistry::new().register(FailingMethod);
468 let processor = StatefulProcessor::new(context, registry);
469
470 let correlation_id = uuid::Uuid::new_v4().to_string();
471 let request = RequestBuilder::new("fail")
472 .id(serde_json::json!(1))
473 .correlation_id(correlation_id.clone())
474 .build();
475
476 let response = processor
477 .process_message(Message::Request(request))
478 .await
479 .unwrap();
480 assert_eq!(response.correlation_id, Some(correlation_id));
481 }
482
483 #[tokio::test]
484 async fn test_stateful_processor_builder() {
485 let context = TestContext::new();
486 let registry = StatefulMethodRegistry::new().register(IncrementMethod);
487
488 let processor = StatefulProcessor::builder(context)
489 .registry(registry)
490 .build()
491 .unwrap();
492
493 let request = RequestBuilder::new("increment")
494 .id(serde_json::json!(1))
495 .build();
496
497 let response = processor.process_message(Message::Request(request)).await;
498 assert!(response.is_some());
499 }
500
501 #[tokio::test]
502 async fn test_stateful_processor_builder_no_handler() {
503 let context = TestContext::new();
504 let result = StatefulProcessor::builder(context).build();
505 assert!(result.is_err());
506 }
507
508 #[test]
509 fn test_stateful_method_openapi_components() {
510 let method = IncrementMethod;
511 let spec = method.openapi_components();
512 assert_eq!(spec.method_name, "increment");
513 }
514
515 #[test]
516 fn test_stateful_registry_default() {
517 let registry = StatefulMethodRegistry::<TestContext>::default();
518 assert_eq!(registry.methods.len(), 0);
519 }
520}