Skip to main content

procwire_client/handler/
registry.rs

1//! Handler registry for dispatching requests by method ID.
2//!
3//! The registry maps method names to handlers and manages ID assignment.
4//! IDs are assigned sequentially starting from 1 (0 is reserved).
5//!
6//! # Example
7//!
8//! ```ignore
9//! use procwire_client::handler::{HandlerRegistry, RequestContext};
10//! use procwire_client::control::ResponseType;
11//!
12//! let mut registry = HandlerRegistry::new();
13//!
14//! registry.register("echo", ResponseType::Result, |data: String, ctx| async move {
15//!     ctx.respond(&data).await
16//! });
17//!
18//! let schema = registry.build_schema();
19//! ```
20
21use std::collections::HashMap;
22use std::future::Future;
23use std::marker::PhantomData;
24use std::pin::Pin;
25
26use serde::de::DeserializeOwned;
27
28use super::RequestContext;
29use crate::codec::MsgPackCodec;
30use crate::control::{InitSchema, ResponseType};
31use crate::error::{ProcwireError, Result};
32
33/// Result type for handler functions.
34pub type HandlerResult = Result<()>;
35
36/// Boxed future for handler results.
37pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
38
39/// Trait for handler functions.
40pub trait Handler: Send + Sync + 'static {
41    /// Handle a request with raw payload bytes.
42    fn call(&self, data: &[u8], ctx: RequestContext) -> BoxFuture<'static, HandlerResult>;
43}
44
45/// Wrapper that deserializes payload before calling the handler.
46pub struct TypedHandler<F, T, Fut>
47where
48    F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
49    T: DeserializeOwned + Send + 'static,
50    Fut: Future<Output = HandlerResult> + Send + 'static,
51{
52    handler: F,
53    _phantom: PhantomData<fn(T) -> Fut>,
54}
55
56impl<F, T, Fut> TypedHandler<F, T, Fut>
57where
58    F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
59    T: DeserializeOwned + Send + 'static,
60    Fut: Future<Output = HandlerResult> + Send + 'static,
61{
62    /// Create a new typed handler.
63    pub fn new(handler: F) -> Self {
64        Self {
65            handler,
66            _phantom: PhantomData,
67        }
68    }
69}
70
71impl<F, T, Fut> Handler for TypedHandler<F, T, Fut>
72where
73    F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
74    T: DeserializeOwned + Send + 'static,
75    Fut: Future<Output = HandlerResult> + Send + 'static,
76{
77    fn call(&self, data: &[u8], ctx: RequestContext) -> BoxFuture<'static, HandlerResult> {
78        // Deserialize payload using MsgPack
79        let parsed: T = match MsgPackCodec::decode(data) {
80            Ok(v) => v,
81            Err(e) => return Box::pin(async move { Err(e) }),
82        };
83
84        let fut = (self.handler)(parsed, ctx);
85        Box::pin(fut)
86    }
87}
88
89/// Entry for a registered method.
90struct MethodEntry {
91    /// The handler function.
92    handler: Box<dyn Handler>,
93    /// Expected response type.
94    response_type: ResponseType,
95    /// Assigned method ID.
96    id: u16,
97}
98
99/// Registry mapping method names to handlers.
100pub struct HandlerRegistry {
101    /// Methods by name.
102    methods: HashMap<String, MethodEntry>,
103    /// Events by name (just track IDs, no handlers).
104    events: HashMap<String, u16>,
105    /// Next method ID to assign.
106    next_method_id: u16,
107    /// Next event ID to assign.
108    next_event_id: u16,
109    /// Method ID to name mapping (for dispatch).
110    id_to_name: HashMap<u16, String>,
111}
112
113impl HandlerRegistry {
114    /// Create a new empty registry.
115    pub fn new() -> Self {
116        Self {
117            methods: HashMap::new(),
118            events: HashMap::new(),
119            next_method_id: 1, // Start from 1, 0 is reserved
120            next_event_id: 1,
121            id_to_name: HashMap::new(),
122        }
123    }
124
125    /// Register a method handler.
126    ///
127    /// # Arguments
128    ///
129    /// * `name` - Method name
130    /// * `response_type` - Expected response type
131    /// * `handler` - Handler function that takes (T, RequestContext) and returns Result<()>
132    pub fn register<F, T, Fut>(&mut self, name: &str, response_type: ResponseType, handler: F)
133    where
134        F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
135        T: DeserializeOwned + Send + 'static,
136        Fut: Future<Output = HandlerResult> + Send + 'static,
137    {
138        let id = self.next_method_id;
139        self.next_method_id += 1;
140
141        let typed = TypedHandler::new(handler);
142        self.methods.insert(
143            name.to_string(),
144            MethodEntry {
145                handler: Box::new(typed),
146                response_type,
147                id,
148            },
149        );
150        self.id_to_name.insert(id, name.to_string());
151    }
152
153    /// Register an event (no handler, just ID assignment).
154    pub fn register_event(&mut self, name: &str) {
155        let id = self.next_event_id;
156        self.next_event_id += 1;
157        self.events.insert(name.to_string(), id);
158    }
159
160    /// Get a handler by method name.
161    pub fn get_handler(&self, name: &str) -> Option<&dyn Handler> {
162        self.methods.get(name).map(|e| e.handler.as_ref())
163    }
164
165    /// Get a handler by method ID.
166    pub fn get_handler_by_id(&self, id: u16) -> Option<&dyn Handler> {
167        self.id_to_name
168            .get(&id)
169            .and_then(|name| self.methods.get(name))
170            .map(|e| e.handler.as_ref())
171    }
172
173    /// Get method name by ID.
174    pub fn get_method_name(&self, id: u16) -> Option<&str> {
175        self.id_to_name.get(&id).map(|s| s.as_str())
176    }
177
178    /// Get method ID by name.
179    pub fn get_method_id(&self, name: &str) -> Option<u16> {
180        self.methods.get(name).map(|e| e.id)
181    }
182
183    /// Get event ID by name.
184    pub fn get_event_id(&self, name: &str) -> Option<u16> {
185        self.events.get(name).copied()
186    }
187
188    /// Get response type for a method.
189    pub fn get_response_type(&self, name: &str) -> Option<ResponseType> {
190        self.methods.get(name).map(|e| e.response_type)
191    }
192
193    /// Build an InitSchema from the registered methods and events.
194    pub fn build_schema(&self) -> InitSchema {
195        let mut schema = InitSchema::new();
196
197        for (name, entry) in &self.methods {
198            schema.add_method(name, entry.id, entry.response_type);
199        }
200
201        for (name, &id) in &self.events {
202            schema.add_event(name, id);
203        }
204
205        schema
206    }
207
208    /// Dispatch a request to the appropriate handler.
209    ///
210    /// # Arguments
211    ///
212    /// * `method_id` - Method ID from frame header
213    /// * `payload` - Raw payload bytes
214    /// * `ctx` - Request context for responding
215    pub async fn dispatch(
216        &self,
217        method_id: u16,
218        payload: &[u8],
219        ctx: RequestContext,
220    ) -> Result<()> {
221        let handler = self
222            .get_handler_by_id(method_id)
223            .ok_or(ProcwireError::HandlerNotFound(method_id))?;
224
225        handler.call(payload, ctx).await
226    }
227}
228
229impl Default for HandlerRegistry {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_register_method() {
241        let mut registry = HandlerRegistry::new();
242
243        registry.register("echo", ResponseType::Result, |_data: String, _ctx| async {
244            Ok(())
245        });
246
247        assert!(registry.get_handler("echo").is_some());
248        assert_eq!(registry.get_method_id("echo"), Some(1));
249        assert_eq!(registry.get_method_name(1), Some("echo"));
250    }
251
252    #[test]
253    fn test_id_assignment_sequential() {
254        let mut registry = HandlerRegistry::new();
255
256        registry.register("method1", ResponseType::Result, |_: (), _ctx| async {
257            Ok(())
258        });
259        registry.register("method2", ResponseType::Stream, |_: (), _ctx| async {
260            Ok(())
261        });
262        registry.register("method3", ResponseType::Ack, |_: (), _ctx| async { Ok(()) });
263
264        assert_eq!(registry.get_method_id("method1"), Some(1));
265        assert_eq!(registry.get_method_id("method2"), Some(2));
266        assert_eq!(registry.get_method_id("method3"), Some(3));
267    }
268
269    #[test]
270    fn test_register_event() {
271        let mut registry = HandlerRegistry::new();
272
273        registry.register_event("progress");
274        registry.register_event("status");
275
276        assert_eq!(registry.get_event_id("progress"), Some(1));
277        assert_eq!(registry.get_event_id("status"), Some(2));
278    }
279
280    #[test]
281    fn test_build_schema() {
282        let mut registry = HandlerRegistry::new();
283
284        registry.register("echo", ResponseType::Result, |_: String, _ctx| async {
285            Ok(())
286        });
287        registry.register("generate", ResponseType::Stream, |_: i32, _ctx| async {
288            Ok(())
289        });
290        registry.register_event("progress");
291
292        let schema = registry.build_schema();
293
294        assert_eq!(schema.get_method("echo").unwrap().id, 1);
295        assert_eq!(
296            schema.get_method("echo").unwrap().response,
297            ResponseType::Result
298        );
299        assert_eq!(schema.get_method("generate").unwrap().id, 2);
300        assert_eq!(
301            schema.get_method("generate").unwrap().response,
302            ResponseType::Stream
303        );
304        assert_eq!(schema.get_event("progress").unwrap().id, 1);
305    }
306
307    #[test]
308    fn test_handler_not_found() {
309        let registry = HandlerRegistry::new();
310
311        assert!(registry.get_handler("nonexistent").is_none());
312        assert!(registry.get_handler_by_id(99).is_none());
313    }
314
315    #[test]
316    fn test_response_type() {
317        let mut registry = HandlerRegistry::new();
318
319        registry.register("result_method", ResponseType::Result, |_: (), _ctx| async {
320            Ok(())
321        });
322        registry.register("stream_method", ResponseType::Stream, |_: (), _ctx| async {
323            Ok(())
324        });
325
326        assert_eq!(
327            registry.get_response_type("result_method"),
328            Some(ResponseType::Result)
329        );
330        assert_eq!(
331            registry.get_response_type("stream_method"),
332            Some(ResponseType::Stream)
333        );
334    }
335}