1use serde::de::DeserializeOwned;
4use std::{
5 any::Any,
6 future::Future,
7 ops::Deref,
8 pin::Pin,
9 sync::Arc,
10};
11
12pub trait Handler: Send + Sync {
17 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
19}
20
21pub struct HandlerFn<F, Fut>
23where
24 F: Fn() -> Fut + Send + Sync,
25 Fut: Future<Output = String> + Send,
26{
27 func: F,
28}
29
30impl<F, Fut> HandlerFn<F, Fut>
31where
32 F: Fn() -> Fut + Send + Sync,
33 Fut: Future<Output = String> + Send,
34{
35 pub fn new(func: F) -> Self {
37 Self { func }
38 }
39}
40
41impl<F, Fut> Handler for HandlerFn<F, Fut>
42where
43 F: Fn() -> Fut + Send + Sync + 'static,
44 Fut: Future<Output = String> + Send + 'static,
45{
46 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
47 let fut = (self.func)();
48 Box::pin(async move { Ok(fut.await) })
49 }
50}
51
52pub struct HandlerWithArgs<F, T, Fut>
54where
55 F: Fn(T) -> Fut + Send + Sync,
56 T: DeserializeOwned + Send,
57 Fut: Future<Output = String> + Send,
58{
59 func: F,
60 _marker: std::marker::PhantomData<fn() -> T>,
63}
64
65impl<F, T, Fut> HandlerWithArgs<F, T, Fut>
66where
67 F: Fn(T) -> Fut + Send + Sync,
68 T: DeserializeOwned + Send,
69 Fut: Future<Output = String> + Send,
70{
71 pub fn new(func: F) -> Self {
73 Self {
74 func,
75 _marker: std::marker::PhantomData,
76 }
77 }
78}
79
80impl<F, T, Fut> Handler for HandlerWithArgs<F, T, Fut>
81where
82 F: Fn(T) -> Fut + Send + Sync + 'static,
83 T: DeserializeOwned + Send + 'static,
84 Fut: Future<Output = String> + Send + 'static,
85{
86 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
87 let parsed: Result<T, _> = serde_json::from_str(args);
88 match parsed {
89 Ok(value) => {
90 let fut = (self.func)(value);
91 Box::pin(async move { Ok(fut.await) })
92 }
93 Err(e) => Box::pin(async move {
94 Err(format!("Failed to deserialize args: {e}"))
95 }),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
104pub struct State<S>(pub S);
105
106impl<S> Deref for State<S> {
107 type Target = S;
108
109 fn deref(&self) -> &Self::Target {
110 &self.0
111 }
112}
113
114pub struct HandlerWithState<F, S, T, Fut>
116where
117 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
118 S: Send + Sync + 'static,
119 T: DeserializeOwned + Send,
120 Fut: Future<Output = String> + Send,
121{
122 func: F,
123 state: Arc<dyn Any + Send + Sync>,
124 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T)>,
126}
127
128impl<F, S, T, Fut> HandlerWithState<F, S, T, Fut>
129where
130 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
131 S: Send + Sync + 'static,
132 T: DeserializeOwned + Send,
133 Fut: Future<Output = String> + Send,
134{
135 pub fn new(func: F, state: Arc<dyn Any + Send + Sync>) -> Self {
137 Self {
138 func,
139 state,
140 _marker: std::marker::PhantomData,
141 }
142 }
143}
144
145impl<F, S, T, Fut> Handler for HandlerWithState<F, S, T, Fut>
146where
147 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
148 S: Send + Sync + 'static,
149 T: DeserializeOwned + Send + 'static,
150 Fut: Future<Output = String> + Send + 'static,
151{
152 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
153 let state_arc = match self.state.clone().downcast::<S>() {
154 Ok(s) => s,
155 Err(_) => {
156 let msg = format!(
157 "State type mismatch: expected {}",
158 std::any::type_name::<S>()
159 );
160 return Box::pin(async move { Err(msg) });
161 }
162 };
163
164 let parsed: Result<T, _> = serde_json::from_str(args);
165 match parsed {
166 Ok(value) => {
167 let fut = (self.func)(State(state_arc), value);
168 Box::pin(async move { Ok(fut.await) })
169 }
170 Err(e) => Box::pin(async move {
171 Err(format!("Failed to deserialize args: {e}"))
172 }),
173 }
174 }
175}
176
177pub struct HandlerWithStateOnly<F, S, Fut>
179where
180 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
181 S: Send + Sync + 'static,
182 Fut: Future<Output = String> + Send,
183{
184 func: F,
185 state: Arc<dyn Any + Send + Sync>,
186 _marker: std::marker::PhantomData<fn() -> S>,
187}
188
189impl<F, S, Fut> HandlerWithStateOnly<F, S, Fut>
190where
191 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
192 S: Send + Sync + 'static,
193 Fut: Future<Output = String> + Send,
194{
195 pub fn new(func: F, state: Arc<dyn Any + Send + Sync>) -> Self {
197 Self {
198 func,
199 state,
200 _marker: std::marker::PhantomData,
201 }
202 }
203}
204
205impl<F, S, Fut> Handler for HandlerWithStateOnly<F, S, Fut>
206where
207 F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
208 S: Send + Sync + 'static,
209 Fut: Future<Output = String> + Send + 'static,
210{
211 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
212 let state_arc = match self.state.clone().downcast::<S>() {
213 Ok(s) => s,
214 Err(_) => {
215 let msg = format!(
216 "State type mismatch: expected {}",
217 std::any::type_name::<S>()
218 );
219 return Box::pin(async move { Err(msg) });
220 }
221 };
222
223 let fut = (self.func)(State(state_arc));
224 Box::pin(async move { Ok(fut.await) })
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[tokio::test]
233 async fn test_handler_fn() {
234 let handler = HandlerFn::new(|| async { "test".to_string() });
235 let result = handler.call("{}").await;
236 assert_eq!(result, Ok("test".to_string()));
237 }
238
239 #[tokio::test]
240 async fn test_handler_fn_ignores_args() {
241 let handler = HandlerFn::new(|| async { "no-args".to_string() });
242 let result = handler.call(r#"{"unexpected": true}"#).await;
243 assert_eq!(result, Ok("no-args".to_string()));
244 }
245
246 #[tokio::test]
247 async fn test_handler_with_args() {
248 #[derive(serde::Deserialize)]
249 struct Input {
250 name: String,
251 }
252
253 let handler = HandlerWithArgs::new(|args: Input| async move {
254 format!("hello {}", args.name)
255 });
256
257 let result = handler.call(r#"{"name":"Alice"}"#).await;
258 assert_eq!(result, Ok("hello Alice".to_string()));
259 }
260
261 #[tokio::test]
262 async fn test_handler_with_args_bad_json() {
263 #[derive(serde::Deserialize)]
264 struct Input {
265 _name: String,
266 }
267
268 let handler = HandlerWithArgs::new(|_args: Input| async move {
269 "unreachable".to_string()
270 });
271
272 let result = handler.call("not-json").await;
273 assert!(result.is_err());
274 assert!(result.unwrap_err().contains("Failed to deserialize args"));
275 }
276
277 #[tokio::test]
278 async fn test_handler_with_args_missing_field() {
279 #[derive(serde::Deserialize)]
280 struct Input {
281 _name: String,
282 }
283
284 let handler = HandlerWithArgs::new(|_args: Input| async move {
285 "unreachable".to_string()
286 });
287
288 let result = handler.call(r#"{"age": 30}"#).await;
289 assert!(result.is_err());
290 assert!(result.unwrap_err().contains("Failed to deserialize args"));
291 }
292
293 #[tokio::test]
294 async fn test_handler_with_state() {
295 struct AppState {
296 greeting: String,
297 }
298
299 #[derive(serde::Deserialize)]
300 struct Input {
301 name: String,
302 }
303
304 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState {
305 greeting: "Hi".to_string(),
306 });
307
308 let handler = HandlerWithState::new(
309 |state: State<Arc<AppState>>, args: Input| async move {
310 format!("{} {}", state.greeting, args.name)
311 },
312 state,
313 );
314
315 let result = handler.call(r#"{"name":"Bob"}"#).await;
316 assert_eq!(result, Ok("Hi Bob".to_string()));
317 }
318
319 #[tokio::test]
320 async fn test_handler_with_state_only() {
321 struct AppState {
322 value: i32,
323 }
324
325 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState { value: 42 });
326
327 let handler = HandlerWithStateOnly::new(
328 |state: State<Arc<AppState>>| async move {
329 format!("value={}", state.value)
330 },
331 state,
332 );
333
334 let result = handler.call("{}").await;
335 assert_eq!(result, Ok("value=42".to_string()));
336 }
337
338 #[tokio::test]
339 async fn test_handler_with_state_deser_error() {
340 struct AppState;
341
342 #[derive(serde::Deserialize)]
343 struct Input {
344 _x: i32,
345 }
346
347 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState);
348
349 let handler = HandlerWithState::new(
350 |_state: State<Arc<AppState>>, _args: Input| async move {
351 "unreachable".to_string()
352 },
353 state,
354 );
355
356 let result = handler.call("bad").await;
357 assert!(result.is_err());
358 assert!(result.unwrap_err().contains("Failed to deserialize args"));
359 }
360}