1use std::collections::HashMap;
22use std::future::Future;
23use std::marker::PhantomData;
24use std::pin::Pin;
25
26use serde::de::DeserializeOwned;
27
28use super::RequestContext;
29use crate::codec::MsgPackCodec;
30use crate::control::{InitSchema, ResponseType};
31use crate::error::{ProcwireError, Result};
32
33pub type HandlerResult = Result<()>;
35
36pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
38
39pub trait Handler: Send + Sync + 'static {
41 fn call(&self, data: &[u8], ctx: RequestContext) -> BoxFuture<'static, HandlerResult>;
43}
44
45pub struct TypedHandler<F, T, Fut>
47where
48 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
49 T: DeserializeOwned + Send + 'static,
50 Fut: Future<Output = HandlerResult> + Send + 'static,
51{
52 handler: F,
53 _phantom: PhantomData<fn(T) -> Fut>,
54}
55
56impl<F, T, Fut> TypedHandler<F, T, Fut>
57where
58 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
59 T: DeserializeOwned + Send + 'static,
60 Fut: Future<Output = HandlerResult> + Send + 'static,
61{
62 pub fn new(handler: F) -> Self {
64 Self {
65 handler,
66 _phantom: PhantomData,
67 }
68 }
69}
70
71impl<F, T, Fut> Handler for TypedHandler<F, T, Fut>
72where
73 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
74 T: DeserializeOwned + Send + 'static,
75 Fut: Future<Output = HandlerResult> + Send + 'static,
76{
77 fn call(&self, data: &[u8], ctx: RequestContext) -> BoxFuture<'static, HandlerResult> {
78 let parsed: T = match MsgPackCodec::decode(data) {
80 Ok(v) => v,
81 Err(e) => return Box::pin(async move { Err(e) }),
82 };
83
84 let fut = (self.handler)(parsed, ctx);
85 Box::pin(fut)
86 }
87}
88
89struct MethodEntry {
91 handler: Box<dyn Handler>,
93 response_type: ResponseType,
95 id: u16,
97}
98
99pub struct HandlerRegistry {
101 methods: HashMap<String, MethodEntry>,
103 events: HashMap<String, u16>,
105 next_method_id: u16,
107 next_event_id: u16,
109 id_to_name: HashMap<u16, String>,
111}
112
113impl HandlerRegistry {
114 pub fn new() -> Self {
116 Self {
117 methods: HashMap::new(),
118 events: HashMap::new(),
119 next_method_id: 1, next_event_id: 1,
121 id_to_name: HashMap::new(),
122 }
123 }
124
125 pub fn register<F, T, Fut>(&mut self, name: &str, response_type: ResponseType, handler: F)
133 where
134 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
135 T: DeserializeOwned + Send + 'static,
136 Fut: Future<Output = HandlerResult> + Send + 'static,
137 {
138 let id = self.next_method_id;
139 self.next_method_id += 1;
140
141 let typed = TypedHandler::new(handler);
142 self.methods.insert(
143 name.to_string(),
144 MethodEntry {
145 handler: Box::new(typed),
146 response_type,
147 id,
148 },
149 );
150 self.id_to_name.insert(id, name.to_string());
151 }
152
153 pub fn register_event(&mut self, name: &str) {
155 let id = self.next_event_id;
156 self.next_event_id += 1;
157 self.events.insert(name.to_string(), id);
158 }
159
160 pub fn get_handler(&self, name: &str) -> Option<&dyn Handler> {
162 self.methods.get(name).map(|e| e.handler.as_ref())
163 }
164
165 pub fn get_handler_by_id(&self, id: u16) -> Option<&dyn Handler> {
167 self.id_to_name
168 .get(&id)
169 .and_then(|name| self.methods.get(name))
170 .map(|e| e.handler.as_ref())
171 }
172
173 pub fn get_method_name(&self, id: u16) -> Option<&str> {
175 self.id_to_name.get(&id).map(|s| s.as_str())
176 }
177
178 pub fn get_method_id(&self, name: &str) -> Option<u16> {
180 self.methods.get(name).map(|e| e.id)
181 }
182
183 pub fn get_event_id(&self, name: &str) -> Option<u16> {
185 self.events.get(name).copied()
186 }
187
188 pub fn get_response_type(&self, name: &str) -> Option<ResponseType> {
190 self.methods.get(name).map(|e| e.response_type)
191 }
192
193 pub fn build_schema(&self) -> InitSchema {
195 let mut schema = InitSchema::new();
196
197 for (name, entry) in &self.methods {
198 schema.add_method(name, entry.id, entry.response_type);
199 }
200
201 for (name, &id) in &self.events {
202 schema.add_event(name, id);
203 }
204
205 schema
206 }
207
208 pub async fn dispatch(
216 &self,
217 method_id: u16,
218 payload: &[u8],
219 ctx: RequestContext,
220 ) -> Result<()> {
221 let handler = self
222 .get_handler_by_id(method_id)
223 .ok_or(ProcwireError::HandlerNotFound(method_id))?;
224
225 handler.call(payload, ctx).await
226 }
227}
228
229impl Default for HandlerRegistry {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_register_method() {
241 let mut registry = HandlerRegistry::new();
242
243 registry.register("echo", ResponseType::Result, |_data: String, _ctx| async {
244 Ok(())
245 });
246
247 assert!(registry.get_handler("echo").is_some());
248 assert_eq!(registry.get_method_id("echo"), Some(1));
249 assert_eq!(registry.get_method_name(1), Some("echo"));
250 }
251
252 #[test]
253 fn test_id_assignment_sequential() {
254 let mut registry = HandlerRegistry::new();
255
256 registry.register("method1", ResponseType::Result, |_: (), _ctx| async {
257 Ok(())
258 });
259 registry.register("method2", ResponseType::Stream, |_: (), _ctx| async {
260 Ok(())
261 });
262 registry.register("method3", ResponseType::Ack, |_: (), _ctx| async { Ok(()) });
263
264 assert_eq!(registry.get_method_id("method1"), Some(1));
265 assert_eq!(registry.get_method_id("method2"), Some(2));
266 assert_eq!(registry.get_method_id("method3"), Some(3));
267 }
268
269 #[test]
270 fn test_register_event() {
271 let mut registry = HandlerRegistry::new();
272
273 registry.register_event("progress");
274 registry.register_event("status");
275
276 assert_eq!(registry.get_event_id("progress"), Some(1));
277 assert_eq!(registry.get_event_id("status"), Some(2));
278 }
279
280 #[test]
281 fn test_build_schema() {
282 let mut registry = HandlerRegistry::new();
283
284 registry.register("echo", ResponseType::Result, |_: String, _ctx| async {
285 Ok(())
286 });
287 registry.register("generate", ResponseType::Stream, |_: i32, _ctx| async {
288 Ok(())
289 });
290 registry.register_event("progress");
291
292 let schema = registry.build_schema();
293
294 assert_eq!(schema.get_method("echo").unwrap().id, 1);
295 assert_eq!(
296 schema.get_method("echo").unwrap().response,
297 ResponseType::Result
298 );
299 assert_eq!(schema.get_method("generate").unwrap().id, 2);
300 assert_eq!(
301 schema.get_method("generate").unwrap().response,
302 ResponseType::Stream
303 );
304 assert_eq!(schema.get_event("progress").unwrap().id, 1);
305 }
306
307 #[test]
308 fn test_handler_not_found() {
309 let registry = HandlerRegistry::new();
310
311 assert!(registry.get_handler("nonexistent").is_none());
312 assert!(registry.get_handler_by_id(99).is_none());
313 }
314
315 #[test]
316 fn test_response_type() {
317 let mut registry = HandlerRegistry::new();
318
319 registry.register("result_method", ResponseType::Result, |_: (), _ctx| async {
320 Ok(())
321 });
322 registry.register("stream_method", ResponseType::Stream, |_: (), _ctx| async {
323 Ok(())
324 });
325
326 assert_eq!(
327 registry.get_response_type("result_method"),
328 Some(ResponseType::Result)
329 );
330 assert_eq!(
331 registry.get_response_type("stream_method"),
332 Some(ResponseType::Stream)
333 );
334 }
335}