1use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::{
6 any::{Any, TypeId},
7 collections::HashMap,
8 fmt,
9 future::Future,
10 ops::Deref,
11 pin::Pin,
12 sync::{Arc, RwLock},
13};
14use tokio::sync::mpsc;
15use tokio_util::sync::CancellationToken;
16
17pub type SharedStateMap = Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>;
23
24fn resolve_state<S: Send + Sync + 'static>(
26 states: &SharedStateMap,
27) -> Result<Arc<S>, String> {
28 let map = states.read().map_err(|e| format!("State lock poisoned: {e}"))?;
29 let any = map
30 .get(&TypeId::of::<S>())
31 .ok_or_else(|| {
32 format!(
33 "State not found: {}. Was with_state::<{0}>() or inject_state::<{0}>() called?",
34 std::any::type_name::<S>()
35 )
36 })?
37 .clone();
38 any.downcast::<S>().map_err(|_| {
39 format!("State type mismatch: expected {}", std::any::type_name::<S>())
40 })
41}
42
43pub trait IntoHandlerResult: Send {
52 fn into_handler_result(self) -> Result<String, String>;
54}
55
56impl IntoHandlerResult for String {
58 fn into_handler_result(self) -> Result<String, String> {
59 Ok(self)
60 }
61}
62
63pub struct Json<T>(pub T);
68
69impl<T: Serialize + Send> IntoHandlerResult for Json<T> {
70 fn into_handler_result(self) -> Result<String, String> {
71 serde_json::to_string(&self.0)
72 .map_err(|e| format!("Failed to serialize response: {e}"))
73 }
74}
75
76impl<T: Serialize + Send, E: fmt::Display + Send> IntoHandlerResult for Result<T, E> {
78 fn into_handler_result(self) -> Result<String, String> {
79 match self {
80 Ok(value) => serde_json::to_string(&value)
81 .map_err(|e| format!("Failed to serialize response: {e}")),
82 Err(e) => Err(e.to_string()),
83 }
84 }
85}
86
87pub trait IntoStreamItem: Send {
93 fn into_stream_item(self) -> Result<String, String>;
95}
96
97impl IntoStreamItem for String {
99 fn into_stream_item(self) -> Result<String, String> {
100 Ok(self)
101 }
102}
103
104impl<T: Serialize + Send> IntoStreamItem for Json<T> {
106 fn into_stream_item(self) -> Result<String, String> {
107 serde_json::to_string(&self.0)
108 .map_err(|e| format!("Failed to serialize stream item: {e}"))
109 }
110}
111
112impl<T: Serialize + Send, E: fmt::Display + Send> IntoStreamItem for Result<T, E> {
114 fn into_stream_item(self) -> Result<String, String> {
115 match self {
116 Ok(value) => serde_json::to_string(&value)
117 .map_err(|e| format!("Failed to serialize stream item: {e}")),
118 Err(e) => Err(e.to_string()),
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq)]
127pub enum StreamError {
128 Closed,
130 Serialize(String),
132}
133
134impl fmt::Display for StreamError {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 match self {
137 StreamError::Closed => write!(f, "stream closed: receiver dropped"),
138 StreamError::Serialize(e) => write!(f, "stream serialization error: {e}"),
139 }
140 }
141}
142
143impl std::error::Error for StreamError {}
144
145pub const DEFAULT_STREAM_CAPACITY: usize = 64;
149
150#[derive(Clone)]
178pub struct StreamSender {
179 tx: mpsc::Sender<String>,
180 cancel: CancellationToken,
181}
182
183pub struct StreamReceiver {
189 rx: mpsc::Receiver<String>,
190 cancel: CancellationToken,
191}
192
193impl StreamReceiver {
194 pub async fn recv(&mut self) -> Option<String> {
196 self.rx.recv().await
197 }
198
199}
200
201impl Drop for StreamReceiver {
202 fn drop(&mut self) {
203 self.cancel.cancel();
204 }
205}
206
207impl fmt::Debug for StreamReceiver {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("StreamReceiver")
210 .field("cancelled", &self.cancel.is_cancelled())
211 .finish()
212 }
213}
214
215impl StreamSender {
216 pub fn channel() -> (Self, StreamReceiver) {
221 Self::with_capacity(DEFAULT_STREAM_CAPACITY)
222 }
223
224 pub fn with_capacity(capacity: usize) -> (Self, StreamReceiver) {
228 let (tx, rx) = mpsc::channel(capacity);
229 let cancel = CancellationToken::new();
230 (
231 Self { tx, cancel: cancel.clone() },
232 StreamReceiver { rx, cancel },
233 )
234 }
235
236 pub fn cancellation_token(&self) -> CancellationToken {
249 self.cancel.clone()
250 }
251
252 pub fn cancel(&self) {
257 self.cancel.cancel();
258 }
259
260 pub async fn send(&self, item: impl IntoStreamItem) -> Result<(), StreamError> {
266 let serialized = item.into_stream_item().map_err(StreamError::Serialize)?;
267 self.tx
268 .send(serialized)
269 .await
270 .map_err(|_| StreamError::Closed)
271 }
272
273 pub fn is_closed(&self) -> bool {
282 self.tx.is_closed()
283 }
284}
285
286impl fmt::Debug for StreamSender {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 f.debug_struct("StreamSender")
289 .field("closed", &self.is_closed())
290 .field("cancelled", &self.cancel.is_cancelled())
291 .finish()
292 }
293}
294
295pub trait Handler: Send + Sync {
302 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
304}
305
306#[derive(Debug, Clone)]
312pub struct State<S>(pub S);
313
314impl<S> Deref for State<S> {
315 type Target = S;
316
317 fn deref(&self) -> &Self::Target {
318 &self.0
319 }
320}
321
322pub struct HandlerFn<F, Fut, R>
326where
327 F: Fn() -> Fut + Send + Sync,
328 Fut: Future<Output = R> + Send,
329 R: IntoHandlerResult,
330{
331 func: F,
332 _marker: std::marker::PhantomData<fn() -> R>,
333}
334
335impl<F, Fut, R> HandlerFn<F, Fut, R>
336where
337 F: Fn() -> Fut + Send + Sync,
338 Fut: Future<Output = R> + Send,
339 R: IntoHandlerResult,
340{
341 pub fn new(func: F) -> Self {
343 Self {
344 func,
345 _marker: std::marker::PhantomData,
346 }
347 }
348}
349
350impl<F, Fut, R> Handler for HandlerFn<F, Fut, R>
351where
352 F: Fn() -> Fut + Send + Sync + 'static,
353 Fut: Future<Output = R> + Send + 'static,
354 R: IntoHandlerResult + 'static,
355{
356 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
357 let fut = (self.func)();
358 Box::pin(async move { fut.await.into_handler_result() })
359 }
360}
361
362#[allow(clippy::type_complexity)]
364pub struct HandlerWithArgs<F, T, Fut, R>
365where
366 F: Fn(T) -> Fut + Send + Sync,
367 T: DeserializeOwned + Send,
368 Fut: Future<Output = R> + Send,
369 R: IntoHandlerResult,
370{
371 func: F,
372 _marker: std::marker::PhantomData<(fn() -> T, fn() -> R)>,
375}
376
377impl<F, T, Fut, R> HandlerWithArgs<F, T, Fut, R>
378where
379 F: Fn(T) -> Fut + Send + Sync,
380 T: DeserializeOwned + Send,
381 Fut: Future<Output = R> + Send,
382 R: IntoHandlerResult,
383{
384 pub fn new(func: F) -> Self {
386 Self {
387 func,
388 _marker: std::marker::PhantomData,
389 }
390 }
391}
392
393impl<F, T, Fut, R> Handler for HandlerWithArgs<F, T, Fut, R>
394where
395 F: Fn(T) -> Fut + Send + Sync + 'static,
396 T: DeserializeOwned + Send + 'static,
397 Fut: Future<Output = R> + Send + 'static,
398 R: IntoHandlerResult + 'static,
399{
400 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
401 let parsed: Result<T, _> = serde_json::from_str(args);
402 match parsed {
403 Ok(value) => {
404 let fut = (self.func)(value);
405 Box::pin(async move { fut.await.into_handler_result() })
406 }
407 Err(e) => Box::pin(async move {
408 Err(format!("Failed to deserialize args: {e}"))
409 }),
410 }
411 }
412}
413
414#[allow(clippy::type_complexity)]
416pub struct HandlerWithState<F, S, T, Fut, R>
417where
418 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
419 S: Send + Sync + 'static,
420 T: DeserializeOwned + Send,
421 Fut: Future<Output = R> + Send,
422 R: IntoHandlerResult,
423{
424 func: F,
425 states: SharedStateMap,
426 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T, fn() -> R)>,
427}
428
429impl<F, S, T, Fut, R> HandlerWithState<F, S, T, Fut, R>
430where
431 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
432 S: Send + Sync + 'static,
433 T: DeserializeOwned + Send,
434 Fut: Future<Output = R> + Send,
435 R: IntoHandlerResult,
436{
437 pub fn new(func: F, states: SharedStateMap) -> Self {
439 Self {
440 func,
441 states,
442 _marker: std::marker::PhantomData,
443 }
444 }
445}
446
447impl<F, S, T, Fut, R> Handler for HandlerWithState<F, S, T, Fut, R>
448where
449 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
450 S: Send + Sync + 'static,
451 T: DeserializeOwned + Send + 'static,
452 Fut: Future<Output = R> + Send + 'static,
453 R: IntoHandlerResult + 'static,
454{
455 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
456 let state_arc = match resolve_state::<S>(&self.states) {
457 Ok(s) => s,
458 Err(msg) => return Box::pin(async move { Err(msg) }),
459 };
460
461 let parsed: Result<T, _> = serde_json::from_str(args);
462 match parsed {
463 Ok(value) => {
464 let fut = (self.func)(State(state_arc), value);
465 Box::pin(async move { fut.await.into_handler_result() })
466 }
467 Err(e) => Box::pin(async move {
468 Err(format!("Failed to deserialize args: {e}"))
469 }),
470 }
471 }
472}
473
474#[allow(clippy::type_complexity)]
476pub struct HandlerWithStateOnly<F, S, Fut, R>
477where
478 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
479 S: Send + Sync + 'static,
480 Fut: Future<Output = R> + Send,
481 R: IntoHandlerResult,
482{
483 func: F,
484 states: SharedStateMap,
485 _marker: std::marker::PhantomData<(fn() -> S, fn() -> R)>,
486}
487
488impl<F, S, Fut, R> HandlerWithStateOnly<F, S, Fut, R>
489where
490 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
491 S: Send + Sync + 'static,
492 Fut: Future<Output = R> + Send,
493 R: IntoHandlerResult,
494{
495 pub fn new(func: F, states: SharedStateMap) -> Self {
497 Self {
498 func,
499 states,
500 _marker: std::marker::PhantomData,
501 }
502 }
503}
504
505impl<F, S, Fut, R> Handler for HandlerWithStateOnly<F, S, Fut, R>
506where
507 F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
508 S: Send + Sync + 'static,
509 Fut: Future<Output = R> + Send + 'static,
510 R: IntoHandlerResult + 'static,
511{
512 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
513 let state_arc = match resolve_state::<S>(&self.states) {
514 Ok(s) => s,
515 Err(msg) => return Box::pin(async move { Err(msg) }),
516 };
517
518 let fut = (self.func)(State(state_arc));
519 Box::pin(async move { fut.await.into_handler_result() })
520 }
521}
522
523pub trait StreamHandler: Send + Sync {
530 fn call_streaming(
535 &self,
536 args: &str,
537 tx: StreamSender,
538 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
539}
540
541pub struct StreamingHandlerFn<F, Fut, R>
545where
546 F: Fn(StreamSender) -> Fut + Send + Sync,
547 Fut: Future<Output = R> + Send,
548 R: IntoHandlerResult,
549{
550 func: F,
551 _marker: std::marker::PhantomData<fn() -> R>,
552}
553
554impl<F, Fut, R> StreamingHandlerFn<F, Fut, R>
555where
556 F: Fn(StreamSender) -> Fut + Send + Sync,
557 Fut: Future<Output = R> + Send,
558 R: IntoHandlerResult,
559{
560 pub fn new(func: F) -> Self {
562 Self {
563 func,
564 _marker: std::marker::PhantomData,
565 }
566 }
567}
568
569impl<F, Fut, R> StreamHandler for StreamingHandlerFn<F, Fut, R>
570where
571 F: Fn(StreamSender) -> Fut + Send + Sync + 'static,
572 Fut: Future<Output = R> + Send + 'static,
573 R: IntoHandlerResult + 'static,
574{
575 fn call_streaming(
576 &self,
577 _args: &str,
578 tx: StreamSender,
579 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
580 let fut = (self.func)(tx);
581 Box::pin(async move { fut.await.into_handler_result() })
582 }
583}
584
585#[allow(clippy::type_complexity)]
587pub struct StreamingHandlerWithArgs<F, T, Fut, R>
588where
589 F: Fn(T, StreamSender) -> Fut + Send + Sync,
590 T: DeserializeOwned + Send,
591 Fut: Future<Output = R> + Send,
592 R: IntoHandlerResult,
593{
594 func: F,
595 _marker: std::marker::PhantomData<(fn() -> T, fn() -> R)>,
596}
597
598impl<F, T, Fut, R> StreamingHandlerWithArgs<F, T, Fut, R>
599where
600 F: Fn(T, StreamSender) -> Fut + Send + Sync,
601 T: DeserializeOwned + Send,
602 Fut: Future<Output = R> + Send,
603 R: IntoHandlerResult,
604{
605 pub fn new(func: F) -> Self {
607 Self {
608 func,
609 _marker: std::marker::PhantomData,
610 }
611 }
612}
613
614impl<F, T, Fut, R> StreamHandler for StreamingHandlerWithArgs<F, T, Fut, R>
615where
616 F: Fn(T, StreamSender) -> Fut + Send + Sync + 'static,
617 T: DeserializeOwned + Send + 'static,
618 Fut: Future<Output = R> + Send + 'static,
619 R: IntoHandlerResult + 'static,
620{
621 fn call_streaming(
622 &self,
623 args: &str,
624 tx: StreamSender,
625 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
626 let parsed: Result<T, _> = serde_json::from_str(args);
627 match parsed {
628 Ok(value) => {
629 let fut = (self.func)(value, tx);
630 Box::pin(async move { fut.await.into_handler_result() })
631 }
632 Err(e) => Box::pin(async move {
633 Err(format!("Failed to deserialize args: {e}"))
634 }),
635 }
636 }
637}
638
639#[allow(clippy::type_complexity)]
641pub struct StreamingHandlerWithState<F, S, T, Fut, R>
642where
643 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync,
644 S: Send + Sync + 'static,
645 T: DeserializeOwned + Send,
646 Fut: Future<Output = R> + Send,
647 R: IntoHandlerResult,
648{
649 func: F,
650 states: SharedStateMap,
651 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T, fn() -> R)>,
652}
653
654impl<F, S, T, Fut, R> StreamingHandlerWithState<F, S, T, Fut, R>
655where
656 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync,
657 S: Send + Sync + 'static,
658 T: DeserializeOwned + Send,
659 Fut: Future<Output = R> + Send,
660 R: IntoHandlerResult,
661{
662 pub fn new(func: F, states: SharedStateMap) -> Self {
664 Self {
665 func,
666 states,
667 _marker: std::marker::PhantomData,
668 }
669 }
670}
671
672impl<F, S, T, Fut, R> StreamHandler for StreamingHandlerWithState<F, S, T, Fut, R>
673where
674 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync + 'static,
675 S: Send + Sync + 'static,
676 T: DeserializeOwned + Send + 'static,
677 Fut: Future<Output = R> + Send + 'static,
678 R: IntoHandlerResult + 'static,
679{
680 fn call_streaming(
681 &self,
682 args: &str,
683 tx: StreamSender,
684 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
685 let state_arc = match resolve_state::<S>(&self.states) {
686 Ok(s) => s,
687 Err(msg) => return Box::pin(async move { Err(msg) }),
688 };
689
690 let parsed: Result<T, _> = serde_json::from_str(args);
691 match parsed {
692 Ok(value) => {
693 let fut = (self.func)(State(state_arc), value, tx);
694 Box::pin(async move { fut.await.into_handler_result() })
695 }
696 Err(e) => Box::pin(async move {
697 Err(format!("Failed to deserialize args: {e}"))
698 }),
699 }
700 }
701}
702
703#[allow(clippy::type_complexity)]
705pub struct StreamingHandlerWithStateOnly<F, S, Fut, R>
706where
707 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync,
708 S: Send + Sync + 'static,
709 Fut: Future<Output = R> + Send,
710 R: IntoHandlerResult,
711{
712 func: F,
713 states: SharedStateMap,
714 _marker: std::marker::PhantomData<(fn() -> S, fn() -> R)>,
715}
716
717impl<F, S, Fut, R> StreamingHandlerWithStateOnly<F, S, Fut, R>
718where
719 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync,
720 S: Send + Sync + 'static,
721 Fut: Future<Output = R> + Send,
722 R: IntoHandlerResult,
723{
724 pub fn new(func: F, states: SharedStateMap) -> Self {
726 Self {
727 func,
728 states,
729 _marker: std::marker::PhantomData,
730 }
731 }
732}
733
734impl<F, S, Fut, R> StreamHandler for StreamingHandlerWithStateOnly<F, S, Fut, R>
735where
736 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync + 'static,
737 S: Send + Sync + 'static,
738 Fut: Future<Output = R> + Send + 'static,
739 R: IntoHandlerResult + 'static,
740{
741 fn call_streaming(
742 &self,
743 _args: &str,
744 tx: StreamSender,
745 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
746 let state_arc = match resolve_state::<S>(&self.states) {
747 Ok(s) => s,
748 Err(msg) => return Box::pin(async move { Err(msg) }),
749 };
750
751 let fut = (self.func)(State(state_arc), tx);
752 Box::pin(async move { fut.await.into_handler_result() })
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759
760 fn state_map<S: Send + Sync + 'static>(value: S) -> SharedStateMap {
762 let mut map = HashMap::new();
763 map.insert(TypeId::of::<S>(), Arc::new(value) as Arc<dyn Any + Send + Sync>);
764 Arc::new(RwLock::new(map))
765 }
766
767 #[tokio::test]
770 async fn test_handler_fn() {
771 let handler = HandlerFn::new(|| async { "test".to_string() });
772 let result = handler.call("{}").await;
773 assert_eq!(result, Ok("test".to_string()));
774 }
775
776 #[tokio::test]
777 async fn test_handler_fn_ignores_args() {
778 let handler = HandlerFn::new(|| async { "no-args".to_string() });
779 let result = handler.call(r#"{"unexpected": true}"#).await;
780 assert_eq!(result, Ok("no-args".to_string()));
781 }
782
783 #[tokio::test]
784 async fn test_handler_with_args() {
785 #[derive(serde::Deserialize)]
786 struct Input {
787 name: String,
788 }
789
790 let handler = HandlerWithArgs::new(|args: Input| async move {
791 format!("hello {}", args.name)
792 });
793
794 let result = handler.call(r#"{"name":"Alice"}"#).await;
795 assert_eq!(result, Ok("hello Alice".to_string()));
796 }
797
798 #[tokio::test]
799 async fn test_handler_with_args_bad_json() {
800 #[derive(serde::Deserialize)]
801 struct Input {
802 _name: String,
803 }
804
805 let handler = HandlerWithArgs::new(|_args: Input| async move {
806 "unreachable".to_string()
807 });
808
809 let result = handler.call("not-json").await;
810 assert!(result.is_err());
811 assert!(result.unwrap_err().contains("Failed to deserialize args"));
812 }
813
814 #[tokio::test]
815 async fn test_handler_with_args_missing_field() {
816 #[derive(serde::Deserialize)]
817 struct Input {
818 _name: String,
819 }
820
821 let handler = HandlerWithArgs::new(|_args: Input| async move {
822 "unreachable".to_string()
823 });
824
825 let result = handler.call(r#"{"age": 30}"#).await;
826 assert!(result.is_err());
827 assert!(result.unwrap_err().contains("Failed to deserialize args"));
828 }
829
830 #[tokio::test]
831 async fn test_handler_with_state() {
832 struct AppState {
833 greeting: String,
834 }
835
836 #[derive(serde::Deserialize)]
837 struct Input {
838 name: String,
839 }
840
841 let states = state_map(AppState {
842 greeting: "Hi".to_string(),
843 });
844
845 let handler = HandlerWithState::new(
846 |state: State<Arc<AppState>>, args: Input| async move {
847 format!("{} {}", state.greeting, args.name)
848 },
849 states,
850 );
851
852 let result = handler.call(r#"{"name":"Bob"}"#).await;
853 assert_eq!(result, Ok("Hi Bob".to_string()));
854 }
855
856 #[tokio::test]
857 async fn test_handler_with_state_only() {
858 struct AppState {
859 value: i32,
860 }
861
862 let states = state_map(AppState { value: 42 });
863
864 let handler = HandlerWithStateOnly::new(
865 |state: State<Arc<AppState>>| async move {
866 format!("value={}", state.value)
867 },
868 states,
869 );
870
871 let result = handler.call("{}").await;
872 assert_eq!(result, Ok("value=42".to_string()));
873 }
874
875 #[tokio::test]
876 async fn test_handler_with_state_deser_error() {
877 struct AppState;
878
879 #[derive(serde::Deserialize)]
880 struct Input {
881 _x: i32,
882 }
883
884 let states = state_map(AppState);
885
886 let handler = HandlerWithState::new(
887 |_state: State<Arc<AppState>>, _args: Input| async move {
888 "unreachable".to_string()
889 },
890 states,
891 );
892
893 let result = handler.call("bad").await;
894 assert!(result.is_err());
895 assert!(result.unwrap_err().contains("Failed to deserialize args"));
896 }
897
898 #[tokio::test]
901 async fn test_json_handler_fn_struct() {
902 #[derive(serde::Serialize)]
903 struct User {
904 id: u32,
905 name: String,
906 }
907
908 let handler = HandlerFn::new(|| async {
909 Json(User {
910 id: 1,
911 name: "Alice".to_string(),
912 })
913 });
914
915 let result = handler.call("{}").await;
916 assert_eq!(result, Ok(r#"{"id":1,"name":"Alice"}"#.to_string()));
917 }
918
919 #[tokio::test]
920 async fn test_json_handler_fn_vec() {
921 let handler = HandlerFn::new(|| async { Json(vec![1, 2, 3]) });
922 let result = handler.call("{}").await;
923 assert_eq!(result, Ok("[1,2,3]".to_string()));
924 }
925
926 #[tokio::test]
927 async fn test_json_handler_with_args() {
928 #[derive(serde::Deserialize)]
929 struct Input {
930 name: String,
931 }
932
933 #[derive(serde::Serialize)]
934 struct Output {
935 greeting: String,
936 }
937
938 let handler = HandlerWithArgs::new(|args: Input| async move {
939 Json(Output {
940 greeting: format!("Hello {}", args.name),
941 })
942 });
943
944 let result = handler.call(r#"{"name":"Bob"}"#).await;
945 assert_eq!(result, Ok(r#"{"greeting":"Hello Bob"}"#.to_string()));
946 }
947
948 #[tokio::test]
949 async fn test_json_handler_with_args_bad_json() {
950 #[derive(serde::Deserialize)]
951 struct Input {
952 _x: i32,
953 }
954
955 let handler = HandlerWithArgs::new(|_: Input| async move { Json(42) });
956 let result = handler.call("bad").await;
957 assert!(result.is_err());
958 assert!(result.unwrap_err().contains("Failed to deserialize args"));
959 }
960
961 #[tokio::test]
962 async fn test_json_handler_with_state() {
963 struct AppState {
964 prefix: String,
965 }
966
967 #[derive(serde::Deserialize)]
968 struct Input {
969 name: String,
970 }
971
972 #[derive(serde::Serialize)]
973 struct Output {
974 message: String,
975 }
976
977 let states = state_map(AppState {
978 prefix: "Hi".to_string(),
979 });
980
981 let handler = HandlerWithState::new(
982 |state: State<Arc<AppState>>, args: Input| async move {
983 Json(Output {
984 message: format!("{} {}", state.prefix, args.name),
985 })
986 },
987 states,
988 );
989
990 let result = handler.call(r#"{"name":"Charlie"}"#).await;
991 assert_eq!(result, Ok(r#"{"message":"Hi Charlie"}"#.to_string()));
992 }
993
994 #[tokio::test]
995 async fn test_json_handler_with_state_only() {
996 struct AppState {
997 version: String,
998 }
999
1000 #[derive(serde::Serialize)]
1001 struct Info {
1002 version: String,
1003 }
1004
1005 let states = state_map(AppState {
1006 version: "1.0".to_string(),
1007 });
1008
1009 let handler = HandlerWithStateOnly::new(
1010 |state: State<Arc<AppState>>| async move {
1011 Json(Info {
1012 version: state.version.clone(),
1013 })
1014 },
1015 states,
1016 );
1017
1018 let result = handler.call("{}").await;
1019 assert_eq!(result, Ok(r#"{"version":"1.0"}"#.to_string()));
1020 }
1021
1022 #[tokio::test]
1025 async fn test_result_handler_fn_ok() {
1026 #[derive(serde::Serialize)]
1027 struct Data {
1028 value: i32,
1029 }
1030
1031 let handler = HandlerFn::new(|| async {
1032 Ok::<_, String>(Data { value: 42 })
1033 });
1034
1035 let result = handler.call("{}").await;
1036 assert_eq!(result, Ok(r#"{"value":42}"#.to_string()));
1037 }
1038
1039 #[tokio::test]
1040 async fn test_result_handler_fn_err() {
1041 #[derive(serde::Serialize)]
1042 struct Data {
1043 value: i32,
1044 }
1045
1046 let handler = HandlerFn::new(|| async {
1047 Err::<Data, String>("something went wrong".to_string())
1048 });
1049
1050 let result = handler.call("{}").await;
1051 assert_eq!(result, Err("something went wrong".to_string()));
1052 }
1053
1054 #[tokio::test]
1055 async fn test_result_handler_with_args_ok() {
1056 #[derive(serde::Deserialize)]
1057 struct Input {
1058 x: i32,
1059 }
1060
1061 #[derive(serde::Serialize)]
1062 struct Output {
1063 doubled: i32,
1064 }
1065
1066 let handler = HandlerWithArgs::new(|args: Input| async move {
1067 Ok::<_, String>(Output { doubled: args.x * 2 })
1068 });
1069
1070 let result = handler.call(r#"{"x":21}"#).await;
1071 assert_eq!(result, Ok(r#"{"doubled":42}"#.to_string()));
1072 }
1073
1074 #[tokio::test]
1075 async fn test_result_handler_with_args_err() {
1076 #[derive(serde::Deserialize)]
1077 struct Input {
1078 x: i32,
1079 }
1080
1081 let handler = HandlerWithArgs::new(|args: Input| async move {
1082 if args.x < 0 {
1083 Err::<i32, String>("negative".to_string())
1084 } else {
1085 Ok(args.x)
1086 }
1087 });
1088
1089 let result = handler.call(r#"{"x":-1}"#).await;
1090 assert_eq!(result, Err("negative".to_string()));
1091 }
1092
1093 #[tokio::test]
1094 async fn test_result_handler_with_state() {
1095 struct AppState {
1096 threshold: i32,
1097 }
1098
1099 #[derive(serde::Deserialize)]
1100 struct Input {
1101 value: i32,
1102 }
1103
1104 #[derive(serde::Serialize)]
1105 struct Output {
1106 accepted: bool,
1107 }
1108
1109 let states = state_map(AppState { threshold: 10 });
1110
1111 let handler = HandlerWithState::new(
1112 |state: State<Arc<AppState>>, args: Input| async move {
1113 if args.value >= state.threshold {
1114 Ok::<_, String>(Output { accepted: true })
1115 } else {
1116 Err("below threshold".to_string())
1117 }
1118 },
1119 states,
1120 );
1121
1122 let ok_result = handler.call(r#"{"value":15}"#).await;
1123 assert_eq!(ok_result, Ok(r#"{"accepted":true}"#.to_string()));
1124
1125 let err_result = handler.call(r#"{"value":5}"#).await;
1126 assert_eq!(err_result, Err("below threshold".to_string()));
1127 }
1128
1129 #[tokio::test]
1130 async fn test_result_handler_with_state_only() {
1131 struct AppState {
1132 ready: bool,
1133 }
1134
1135 #[derive(serde::Serialize)]
1136 struct Status {
1137 ok: bool,
1138 }
1139
1140 let states = state_map(AppState { ready: true });
1141
1142 let handler = HandlerWithStateOnly::new(
1143 |state: State<Arc<AppState>>| async move {
1144 if state.ready {
1145 Ok::<_, String>(Status { ok: true })
1146 } else {
1147 Err("not ready".to_string())
1148 }
1149 },
1150 states,
1151 );
1152
1153 let result = handler.call("{}").await;
1154 assert_eq!(result, Ok(r#"{"ok":true}"#.to_string()));
1155 }
1156
1157 #[test]
1160 fn test_into_stream_item_string() {
1161 let item = "hello".to_string();
1162 assert_eq!(item.into_stream_item(), Ok("hello".to_string()));
1163 }
1164
1165 #[test]
1166 fn test_into_stream_item_json() {
1167 #[derive(serde::Serialize)]
1168 struct Token {
1169 text: String,
1170 }
1171 let item = Json(Token {
1172 text: "hi".to_string(),
1173 });
1174 assert_eq!(
1175 item.into_stream_item(),
1176 Ok(r#"{"text":"hi"}"#.to_string())
1177 );
1178 }
1179
1180 #[test]
1181 fn test_into_stream_item_json_vec() {
1182 let item = Json(vec![1, 2, 3]);
1183 assert_eq!(item.into_stream_item(), Ok("[1,2,3]".to_string()));
1184 }
1185
1186 #[test]
1187 fn test_into_stream_item_result_ok() {
1188 #[derive(serde::Serialize)]
1189 struct Data {
1190 v: i32,
1191 }
1192 let item: Result<Data, String> = Ok(Data { v: 42 });
1193 assert_eq!(item.into_stream_item(), Ok(r#"{"v":42}"#.to_string()));
1194 }
1195
1196 #[test]
1197 fn test_into_stream_item_result_err() {
1198 let item: Result<i32, String> = Err("bad".to_string());
1199 assert_eq!(item.into_stream_item(), Err("bad".to_string()));
1200 }
1201
1202 #[test]
1205 fn test_stream_error_display_closed() {
1206 let err = StreamError::Closed;
1207 assert_eq!(err.to_string(), "stream closed: receiver dropped");
1208 }
1209
1210 #[test]
1211 fn test_stream_error_display_serialize() {
1212 let err = StreamError::Serialize("bad json".to_string());
1213 assert_eq!(err.to_string(), "stream serialization error: bad json");
1214 }
1215
1216 #[test]
1217 fn test_stream_error_is_std_error() {
1218 let err: Box<dyn std::error::Error> = Box::new(StreamError::Closed);
1219 assert!(err.to_string().contains("closed"));
1220 }
1221
1222 #[tokio::test]
1225 async fn test_stream_sender_send_and_receive() {
1226 let (tx, mut rx) = StreamSender::channel();
1227 tx.send("hello".to_string()).await.unwrap();
1228 tx.send("world".to_string()).await.unwrap();
1229 drop(tx);
1230
1231 assert_eq!(rx.recv().await, Some("hello".to_string()));
1232 assert_eq!(rx.recv().await, Some("world".to_string()));
1233 assert_eq!(rx.recv().await, None);
1234 }
1235
1236 #[tokio::test]
1237 async fn test_stream_sender_send_json() {
1238 #[derive(serde::Serialize)]
1239 struct Token {
1240 t: String,
1241 }
1242 let (tx, mut rx) = StreamSender::channel();
1243 tx.send(Json(Token {
1244 t: "hi".to_string(),
1245 }))
1246 .await
1247 .unwrap();
1248 drop(tx);
1249
1250 assert_eq!(rx.recv().await, Some(r#"{"t":"hi"}"#.to_string()));
1251 }
1252
1253 #[tokio::test]
1254 async fn test_stream_sender_closed_detection() {
1255 let (tx, rx) = StreamSender::channel();
1256 assert!(!tx.is_closed());
1257 drop(rx);
1258 assert!(tx.is_closed());
1259 }
1260
1261 #[tokio::test]
1262 async fn test_stream_sender_send_after_close() {
1263 let (tx, rx) = StreamSender::channel();
1264 drop(rx);
1265 let result = tx.send("late".to_string()).await;
1266 assert_eq!(result, Err(StreamError::Closed));
1267 }
1268
1269 #[tokio::test]
1270 async fn test_stream_sender_custom_capacity() {
1271 let (tx, mut rx) = StreamSender::with_capacity(2);
1272
1273 tx.send("a".to_string()).await.unwrap();
1275 tx.send("b".to_string()).await.unwrap();
1276
1277 assert_eq!(rx.recv().await, Some("a".to_string()));
1279 assert_eq!(rx.recv().await, Some("b".to_string()));
1280
1281 tx.send("c".to_string()).await.unwrap();
1283 assert_eq!(rx.recv().await, Some("c".to_string()));
1284 }
1285
1286 #[tokio::test]
1287 async fn test_stream_sender_default_capacity() {
1288 assert_eq!(DEFAULT_STREAM_CAPACITY, 64);
1289 }
1290
1291 #[tokio::test]
1292 async fn test_stream_sender_clone() {
1293 let (tx, mut rx) = StreamSender::channel();
1294 let tx2 = tx.clone();
1295
1296 tx.send("from-tx1".to_string()).await.unwrap();
1297 tx2.send("from-tx2".to_string()).await.unwrap();
1298 drop(tx);
1299 drop(tx2);
1300
1301 assert_eq!(rx.recv().await, Some("from-tx1".to_string()));
1302 assert_eq!(rx.recv().await, Some("from-tx2".to_string()));
1303 assert_eq!(rx.recv().await, None);
1304 }
1305
1306 #[test]
1307 fn test_stream_sender_debug() {
1308 let (tx, _rx) = StreamSender::channel();
1309 let debug = format!("{:?}", tx);
1310 assert!(debug.contains("StreamSender"));
1311 }
1312
1313 #[tokio::test]
1316 async fn test_cancellation_token_not_cancelled_initially() {
1317 let (tx, _rx) = StreamSender::channel();
1318 let token = tx.cancellation_token();
1319 assert!(!token.is_cancelled());
1320 }
1321
1322 #[tokio::test]
1323 async fn test_cancellation_token_cancelled_on_explicit_cancel() {
1324 let (tx, _rx) = StreamSender::channel();
1325 let token = tx.cancellation_token();
1326 assert!(!token.is_cancelled());
1327 tx.cancel();
1328 assert!(token.is_cancelled());
1329 }
1330
1331 #[tokio::test]
1332 async fn test_cancellation_token_cancelled_future_resolves() {
1333 let (tx, _rx) = StreamSender::channel();
1334 let token = tx.cancellation_token();
1335
1336 let tx2 = tx.clone();
1338 tokio::spawn(async move {
1339 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1340 tx2.cancel();
1341 });
1342
1343 tokio::time::timeout(std::time::Duration::from_secs(1), token.cancelled())
1345 .await
1346 .expect("cancelled future should resolve");
1347 }
1348
1349 #[tokio::test]
1350 async fn test_cancellation_token_shared_across_clones() {
1351 let (tx, _rx) = StreamSender::channel();
1352 let token1 = tx.cancellation_token();
1353 let token2 = tx.cancellation_token();
1354 let tx2 = tx.clone();
1355 let token3 = tx2.cancellation_token();
1356
1357 tx.cancel();
1358 assert!(token1.is_cancelled());
1359 assert!(token2.is_cancelled());
1360 assert!(token3.is_cancelled());
1361 }
1362
1363 #[tokio::test]
1364 async fn test_cancellation_token_auto_cancelled_on_receiver_drop() {
1365 let (tx, rx) = StreamSender::channel();
1366 let token = tx.cancellation_token();
1367
1368 assert!(!token.is_cancelled());
1369 drop(rx); assert!(token.is_cancelled());
1371 }
1372
1373 #[tokio::test]
1374 async fn test_cancellation_token_auto_cancel_future_resolves_on_drop() {
1375 let (tx, rx) = StreamSender::channel();
1376 let token = tx.cancellation_token();
1377
1378 tokio::spawn(async move {
1379 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1380 drop(rx);
1381 });
1382
1383 tokio::time::timeout(std::time::Duration::from_secs(1), token.cancelled())
1384 .await
1385 .expect("cancelled future should resolve when receiver is dropped");
1386 }
1387
1388 #[tokio::test]
1391 async fn test_streaming_handler_fn() {
1392 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1393 tx.send("item1".to_string()).await.ok();
1394 tx.send("item2".to_string()).await.ok();
1395 "done".to_string()
1396 });
1397
1398 let (tx, mut rx) = StreamSender::channel();
1399 let result = handler.call_streaming("{}", tx).await;
1400
1401 assert_eq!(result, Ok("done".to_string()));
1402 assert_eq!(rx.recv().await, Some("item1".to_string()));
1403 assert_eq!(rx.recv().await, Some("item2".to_string()));
1404 }
1405
1406 #[tokio::test]
1407 async fn test_streaming_handler_with_args() {
1408 #[derive(serde::Deserialize)]
1409 struct Input {
1410 count: usize,
1411 }
1412
1413 let handler =
1414 StreamingHandlerWithArgs::new(|args: Input, tx: StreamSender| async move {
1415 for i in 0..args.count {
1416 tx.send(format!("item-{i}")).await.ok();
1417 }
1418 format!("sent {}", args.count)
1419 });
1420
1421 let (tx, mut rx) = StreamSender::channel();
1422 let result = handler.call_streaming(r#"{"count":3}"#, tx).await;
1423
1424 assert_eq!(result, Ok("sent 3".to_string()));
1425 assert_eq!(rx.recv().await, Some("item-0".to_string()));
1426 assert_eq!(rx.recv().await, Some("item-1".to_string()));
1427 assert_eq!(rx.recv().await, Some("item-2".to_string()));
1428 }
1429
1430 #[tokio::test]
1431 async fn test_streaming_handler_with_args_bad_json() {
1432 #[derive(serde::Deserialize)]
1433 struct Input {
1434 _x: i32,
1435 }
1436
1437 let handler =
1438 StreamingHandlerWithArgs::new(|_args: Input, _tx: StreamSender| async move {
1439 "unreachable".to_string()
1440 });
1441
1442 let (tx, _rx) = StreamSender::channel();
1443 let result = handler.call_streaming("bad-json", tx).await;
1444 assert!(result.is_err());
1445 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1446 }
1447
1448 #[tokio::test]
1449 async fn test_streaming_handler_with_state() {
1450 struct AppState {
1451 prefix: String,
1452 }
1453
1454 #[derive(serde::Deserialize)]
1455 struct Input {
1456 name: String,
1457 }
1458
1459 let states = state_map(AppState {
1460 prefix: "Hi".to_string(),
1461 });
1462
1463 let handler = StreamingHandlerWithState::new(
1464 |state: State<Arc<AppState>>, args: Input, tx: StreamSender| async move {
1465 tx.send(format!("{} {}", state.prefix, args.name))
1466 .await
1467 .ok();
1468 "done".to_string()
1469 },
1470 states,
1471 );
1472
1473 let (tx, mut rx) = StreamSender::channel();
1474 let result = handler.call_streaming(r#"{"name":"Alice"}"#, tx).await;
1475
1476 assert_eq!(result, Ok("done".to_string()));
1477 assert_eq!(rx.recv().await, Some("Hi Alice".to_string()));
1478 }
1479
1480 #[tokio::test]
1481 async fn test_streaming_handler_with_state_only() {
1482 struct AppState {
1483 items: Vec<String>,
1484 }
1485
1486 let states = state_map(AppState {
1487 items: vec!["a".to_string(), "b".to_string()],
1488 });
1489
1490 let handler = StreamingHandlerWithStateOnly::new(
1491 |state: State<Arc<AppState>>, tx: StreamSender| async move {
1492 for item in &state.items {
1493 tx.send(item.clone()).await.ok();
1494 }
1495 format!("sent {}", state.items.len())
1496 },
1497 states,
1498 );
1499
1500 let (tx, mut rx) = StreamSender::channel();
1501 let result = handler.call_streaming("{}", tx).await;
1502
1503 assert_eq!(result, Ok("sent 2".to_string()));
1504 assert_eq!(rx.recv().await, Some("a".to_string()));
1505 assert_eq!(rx.recv().await, Some("b".to_string()));
1506 }
1507
1508 #[tokio::test]
1509 async fn test_streaming_handler_with_state_type_mismatch() {
1510 struct WrongState;
1511 struct AppState;
1512
1513 let states = state_map(WrongState);
1514
1515 let handler = StreamingHandlerWithStateOnly::new(
1516 |_state: State<Arc<AppState>>, _tx: StreamSender| async move {
1517 "unreachable".to_string()
1518 },
1519 states,
1520 );
1521
1522 let (tx, _rx) = StreamSender::channel();
1523 let result = handler.call_streaming("{}", tx).await;
1524 assert!(result.is_err());
1525 assert!(result.unwrap_err().contains("State not found"));
1526 }
1527
1528 #[tokio::test]
1529 async fn test_streaming_handler_json_return() {
1530 #[derive(serde::Serialize)]
1531 struct Summary {
1532 count: usize,
1533 }
1534
1535 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1536 tx.send("item".to_string()).await.ok();
1537 Json(Summary { count: 1 })
1538 });
1539
1540 let (tx, mut rx) = StreamSender::channel();
1541 let result = handler.call_streaming("{}", tx).await;
1542
1543 assert_eq!(result, Ok(r#"{"count":1}"#.to_string()));
1544 assert_eq!(rx.recv().await, Some("item".to_string()));
1545 }
1546
1547 #[tokio::test]
1548 async fn test_streaming_handler_result_return() {
1549 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1550 tx.send("progress".to_string()).await.ok();
1551 Ok::<_, String>(42)
1552 });
1553
1554 let (tx, mut rx) = StreamSender::channel();
1555 let result = handler.call_streaming("{}", tx).await;
1556
1557 assert_eq!(result, Ok("42".to_string()));
1558 assert_eq!(rx.recv().await, Some("progress".to_string()));
1559 }
1560}