1use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::{
6 any::{Any, TypeId},
7 collections::HashMap,
8 fmt,
9 future::Future,
10 ops::Deref,
11 pin::Pin,
12 sync::{Arc, RwLock},
13};
14use tokio::sync::mpsc;
15use tokio_util::sync::CancellationToken;
16
17pub type SharedStateMap = Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>;
23
24pub fn resolve_state<S: Send + Sync + 'static>(
30 states: &SharedStateMap,
31) -> Result<Arc<S>, String> {
32 let map = states.read().map_err(|e| format!("State lock poisoned: {e}"))?;
33 let any = map
34 .get(&TypeId::of::<S>())
35 .ok_or_else(|| {
36 format!(
37 "State not found: {}. Was with_state::<{0}>() or inject_state::<{0}>() called?",
38 std::any::type_name::<S>()
39 )
40 })?
41 .clone();
42 any.downcast::<S>().map_err(|_| {
43 format!("State type mismatch: expected {}", std::any::type_name::<S>())
44 })
45}
46
47pub fn resolve_state_erased(
57 states: &SharedStateMap,
58 type_id: TypeId,
59 type_name: &str,
60) -> Result<Arc<dyn Any + Send + Sync>, String> {
61 let map = states
62 .read()
63 .map_err(|e| format!("State lock poisoned: {e}"))?;
64 map.get(&type_id).cloned().ok_or_else(|| {
65 format!(
66 "State not found: {type_name}. Was with_state::<{type_name}>() or inject_state::<{type_name}>() called?"
67 )
68 })
69}
70
71pub trait IntoHandlerResult: Send {
80 fn into_handler_result(self) -> Result<String, String>;
82}
83
84impl IntoHandlerResult for String {
86 fn into_handler_result(self) -> Result<String, String> {
87 Ok(self)
88 }
89}
90
91pub struct Json<T>(pub T);
96
97impl<T: Serialize + Send> IntoHandlerResult for Json<T> {
98 fn into_handler_result(self) -> Result<String, String> {
99 serde_json::to_string(&self.0)
100 .map_err(|e| format!("Failed to serialize response: {e}"))
101 }
102}
103
104impl<T: Serialize + Send, E: fmt::Display + Send> IntoHandlerResult for Result<T, E> {
106 fn into_handler_result(self) -> Result<String, String> {
107 match self {
108 Ok(value) => serde_json::to_string(&value)
109 .map_err(|e| format!("Failed to serialize response: {e}")),
110 Err(e) => Err(e.to_string()),
111 }
112 }
113}
114
115pub trait IntoStreamItem: Send {
121 fn into_stream_item(self) -> Result<String, String>;
123}
124
125impl IntoStreamItem for String {
127 fn into_stream_item(self) -> Result<String, String> {
128 Ok(self)
129 }
130}
131
132impl<T: Serialize + Send> IntoStreamItem for Json<T> {
134 fn into_stream_item(self) -> Result<String, String> {
135 serde_json::to_string(&self.0)
136 .map_err(|e| format!("Failed to serialize stream item: {e}"))
137 }
138}
139
140impl<T: Serialize + Send, E: fmt::Display + Send> IntoStreamItem for Result<T, E> {
142 fn into_stream_item(self) -> Result<String, String> {
143 match self {
144 Ok(value) => serde_json::to_string(&value)
145 .map_err(|e| format!("Failed to serialize stream item: {e}")),
146 Err(e) => Err(e.to_string()),
147 }
148 }
149}
150
151#[derive(Debug, Clone, PartialEq)]
155pub enum StreamError {
156 Closed,
158 Serialize(String),
160}
161
162impl fmt::Display for StreamError {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 match self {
165 StreamError::Closed => write!(f, "stream closed: receiver dropped"),
166 StreamError::Serialize(e) => write!(f, "stream serialization error: {e}"),
167 }
168 }
169}
170
171impl std::error::Error for StreamError {}
172
173pub const DEFAULT_STREAM_CAPACITY: usize = 64;
177
178#[derive(Clone)]
206pub struct StreamSender {
207 tx: mpsc::Sender<String>,
208 cancel: CancellationToken,
209}
210
211pub struct StreamReceiver {
217 rx: mpsc::Receiver<String>,
218 cancel: CancellationToken,
219}
220
221impl StreamReceiver {
222 pub async fn recv(&mut self) -> Option<String> {
224 self.rx.recv().await
225 }
226
227}
228
229impl Drop for StreamReceiver {
230 fn drop(&mut self) {
231 self.cancel.cancel();
232 }
233}
234
235impl fmt::Debug for StreamReceiver {
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 f.debug_struct("StreamReceiver")
238 .field("cancelled", &self.cancel.is_cancelled())
239 .finish()
240 }
241}
242
243impl StreamSender {
244 pub fn channel() -> (Self, StreamReceiver) {
249 Self::with_capacity(DEFAULT_STREAM_CAPACITY)
250 }
251
252 pub fn with_capacity(capacity: usize) -> (Self, StreamReceiver) {
256 let (tx, rx) = mpsc::channel(capacity);
257 let cancel = CancellationToken::new();
258 (
259 Self { tx, cancel: cancel.clone() },
260 StreamReceiver { rx, cancel },
261 )
262 }
263
264 pub fn cancellation_token(&self) -> CancellationToken {
277 self.cancel.clone()
278 }
279
280 pub fn cancel(&self) {
285 self.cancel.cancel();
286 }
287
288 pub async fn send(&self, item: impl IntoStreamItem) -> Result<(), StreamError> {
294 let serialized = item.into_stream_item().map_err(StreamError::Serialize)?;
295 self.tx
296 .send(serialized)
297 .await
298 .map_err(|_| StreamError::Closed)
299 }
300
301 pub fn is_closed(&self) -> bool {
310 self.tx.is_closed()
311 }
312}
313
314impl fmt::Debug for StreamSender {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 f.debug_struct("StreamSender")
317 .field("closed", &self.is_closed())
318 .field("cancelled", &self.cancel.is_cancelled())
319 .finish()
320 }
321}
322
323pub trait Handler: Send + Sync {
330 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
332}
333
334#[derive(Debug, Clone)]
340pub struct State<S>(pub S);
341
342impl<S> Deref for State<S> {
343 type Target = S;
344
345 fn deref(&self) -> &Self::Target {
346 &self.0
347 }
348}
349
350pub struct HandlerFn<F, Fut, R>
365where
366 F: Fn() -> Fut + Send + Sync,
367 Fut: Future<Output = R> + Send,
368 R: IntoHandlerResult,
369{
370 func: F,
371 _marker: std::marker::PhantomData<fn() -> R>,
372}
373
374impl<F, Fut, R> HandlerFn<F, Fut, R>
375where
376 F: Fn() -> Fut + Send + Sync,
377 Fut: Future<Output = R> + Send,
378 R: IntoHandlerResult,
379{
380 pub fn new(func: F) -> Self {
382 Self {
383 func,
384 _marker: std::marker::PhantomData,
385 }
386 }
387}
388
389impl<F, Fut, R> Handler for HandlerFn<F, Fut, R>
390where
391 F: Fn() -> Fut + Send + Sync + 'static,
392 Fut: Future<Output = R> + Send + 'static,
393 R: IntoHandlerResult + 'static,
394{
395 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
396 let fut = (self.func)();
397 Box::pin(async move { fut.await.into_handler_result() })
398 }
399}
400
401#[allow(clippy::type_complexity)]
408pub struct HandlerWithArgs<F, T, Fut, R>
409where
410 F: Fn(T) -> Fut + Send + Sync,
411 T: DeserializeOwned + Send,
412 Fut: Future<Output = R> + Send,
413 R: IntoHandlerResult,
414{
415 func: F,
416 _marker: std::marker::PhantomData<(fn() -> T, fn() -> R)>,
419}
420
421impl<F, T, Fut, R> HandlerWithArgs<F, T, Fut, R>
422where
423 F: Fn(T) -> Fut + Send + Sync,
424 T: DeserializeOwned + Send,
425 Fut: Future<Output = R> + Send,
426 R: IntoHandlerResult,
427{
428 pub fn new(func: F) -> Self {
430 Self {
431 func,
432 _marker: std::marker::PhantomData,
433 }
434 }
435}
436
437impl<F, T, Fut, R> Handler for HandlerWithArgs<F, T, Fut, R>
438where
439 F: Fn(T) -> Fut + Send + Sync + 'static,
440 T: DeserializeOwned + Send + 'static,
441 Fut: Future<Output = R> + Send + 'static,
442 R: IntoHandlerResult + 'static,
443{
444 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
445 let parsed: Result<T, _> = serde_json::from_str(args);
446 match parsed {
447 Ok(value) => {
448 let fut = (self.func)(value);
449 Box::pin(async move { fut.await.into_handler_result() })
450 }
451 Err(e) => Box::pin(async move {
452 Err(format!("Failed to deserialize args: {e}"))
453 }),
454 }
455 }
456}
457
458#[allow(clippy::type_complexity)]
465pub struct HandlerWithState<F, S, T, Fut, R>
466where
467 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
468 S: Send + Sync + 'static,
469 T: DeserializeOwned + Send,
470 Fut: Future<Output = R> + Send,
471 R: IntoHandlerResult,
472{
473 func: F,
474 states: SharedStateMap,
475 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T, fn() -> R)>,
476}
477
478impl<F, S, T, Fut, R> HandlerWithState<F, S, T, Fut, R>
479where
480 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
481 S: Send + Sync + 'static,
482 T: DeserializeOwned + Send,
483 Fut: Future<Output = R> + Send,
484 R: IntoHandlerResult,
485{
486 pub fn new(func: F, states: SharedStateMap) -> Self {
488 Self {
489 func,
490 states,
491 _marker: std::marker::PhantomData,
492 }
493 }
494}
495
496impl<F, S, T, Fut, R> Handler for HandlerWithState<F, S, T, Fut, R>
497where
498 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
499 S: Send + Sync + 'static,
500 T: DeserializeOwned + Send + 'static,
501 Fut: Future<Output = R> + Send + 'static,
502 R: IntoHandlerResult + 'static,
503{
504 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
505 let state_arc = match resolve_state::<S>(&self.states) {
506 Ok(s) => s,
507 Err(msg) => return Box::pin(async move { Err(msg) }),
508 };
509
510 let parsed: Result<T, _> = serde_json::from_str(args);
511 match parsed {
512 Ok(value) => {
513 let fut = (self.func)(State(state_arc), value);
514 Box::pin(async move { fut.await.into_handler_result() })
515 }
516 Err(e) => Box::pin(async move {
517 Err(format!("Failed to deserialize args: {e}"))
518 }),
519 }
520 }
521}
522
523#[allow(clippy::type_complexity)]
530pub struct HandlerWithStateOnly<F, S, Fut, R>
531where
532 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
533 S: Send + Sync + 'static,
534 Fut: Future<Output = R> + Send,
535 R: IntoHandlerResult,
536{
537 func: F,
538 states: SharedStateMap,
539 _marker: std::marker::PhantomData<(fn() -> S, fn() -> R)>,
540}
541
542impl<F, S, Fut, R> HandlerWithStateOnly<F, S, Fut, R>
543where
544 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
545 S: Send + Sync + 'static,
546 Fut: Future<Output = R> + Send,
547 R: IntoHandlerResult,
548{
549 pub fn new(func: F, states: SharedStateMap) -> Self {
551 Self {
552 func,
553 states,
554 _marker: std::marker::PhantomData,
555 }
556 }
557}
558
559impl<F, S, Fut, R> Handler for HandlerWithStateOnly<F, S, Fut, R>
560where
561 F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
562 S: Send + Sync + 'static,
563 Fut: Future<Output = R> + Send + 'static,
564 R: IntoHandlerResult + 'static,
565{
566 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
567 let state_arc = match resolve_state::<S>(&self.states) {
568 Ok(s) => s,
569 Err(msg) => return Box::pin(async move { Err(msg) }),
570 };
571
572 let fut = (self.func)(State(state_arc));
573 Box::pin(async move { fut.await.into_handler_result() })
574 }
575}
576
577pub type HandlerCallFn =
594 dyn Fn(&str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> + Send + Sync;
595
596pub struct ErasedHandler(pub(crate) Box<HandlerCallFn>);
615
616impl ErasedHandler {
617 pub fn from_closure(f: Box<HandlerCallFn>) -> Self {
627 Self(f)
628 }
629
630 pub fn no_args<F, Fut, R>(handler: F) -> Self
632 where
633 F: Fn() -> Fut + Send + Sync + 'static,
634 Fut: Future<Output = R> + Send + 'static,
635 R: IntoHandlerResult + 'static,
636 {
637 Self(Box::new(move |_args: &str| {
638 let fut = handler();
639 Box::pin(async move { fut.await.into_handler_result() })
640 }))
641 }
642
643 pub fn with_args<F, T, Fut, R>(handler: F) -> Self
645 where
646 F: Fn(T) -> Fut + Send + Sync + 'static,
647 T: DeserializeOwned + Send + 'static,
648 Fut: Future<Output = R> + Send + 'static,
649 R: IntoHandlerResult + 'static,
650 {
651 Self(Box::new(move |args: &str| {
652 let parsed: Result<T, _> = serde_json::from_str(args);
653 match parsed {
654 Ok(value) => {
655 let fut = handler(value);
656 Box::pin(async move { fut.await.into_handler_result() })
657 }
658 Err(e) => Box::pin(async move {
659 Err(format!("Failed to deserialize args: {e}"))
660 }),
661 }
662 }))
663 }
664
665 pub fn with_state<F, S, T, Fut, R>(handler: F, states: SharedStateMap) -> Self
667 where
668 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
669 S: Send + Sync + 'static,
670 T: DeserializeOwned + Send + 'static,
671 Fut: Future<Output = R> + Send + 'static,
672 R: IntoHandlerResult + 'static,
673 {
674 Self(Box::new(move |args: &str| {
675 let state_arc = match resolve_state::<S>(&states) {
676 Ok(s) => s,
677 Err(msg) => {
678 return Box::pin(async move { Err(msg) })
679 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
680 }
681 };
682 let parsed: Result<T, _> = serde_json::from_str(args);
683 match parsed {
684 Ok(value) => {
685 let fut = handler(State(state_arc), value);
686 Box::pin(async move { fut.await.into_handler_result() })
687 }
688 Err(e) => Box::pin(async move {
689 Err(format!("Failed to deserialize args: {e}"))
690 }),
691 }
692 }))
693 }
694
695 pub fn with_state_only<F, S, Fut, R>(handler: F, states: SharedStateMap) -> Self
697 where
698 F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
699 S: Send + Sync + 'static,
700 Fut: Future<Output = R> + Send + 'static,
701 R: IntoHandlerResult + 'static,
702 {
703 Self(Box::new(move |_args: &str| {
704 let state_arc = match resolve_state::<S>(&states) {
705 Ok(s) => s,
706 Err(msg) => {
707 return Box::pin(async move { Err(msg) })
708 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
709 }
710 };
711 let fut = handler(State(state_arc));
712 Box::pin(async move { fut.await.into_handler_result() })
713 }))
714 }
715}
716
717impl Handler for ErasedHandler {
718 fn call(
719 &self,
720 args: &str,
721 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
722 (self.0)(args)
723 }
724}
725
726pub trait StreamHandler: Send + Sync {
733 fn call_streaming(
738 &self,
739 args: &str,
740 tx: StreamSender,
741 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
742}
743
744pub type StreamHandlerCallFn = dyn Fn(&str, StreamSender) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
750 + Send
751 + Sync;
752
753pub struct ErasedStreamHandler(pub(crate) Box<StreamHandlerCallFn>);
755
756impl ErasedStreamHandler {
757 pub fn from_closure(f: Box<StreamHandlerCallFn>) -> Self {
762 Self(f)
763 }
764
765 pub fn no_args<F, Fut, R>(handler: F) -> Self
767 where
768 F: Fn(StreamSender) -> Fut + Send + Sync + 'static,
769 Fut: Future<Output = R> + Send + 'static,
770 R: IntoHandlerResult + 'static,
771 {
772 Self(Box::new(move |_args: &str, tx: StreamSender| {
773 let fut = handler(tx);
774 Box::pin(async move { fut.await.into_handler_result() })
775 }))
776 }
777
778 pub fn with_args<F, T, Fut, R>(handler: F) -> Self
780 where
781 F: Fn(T, StreamSender) -> Fut + Send + Sync + 'static,
782 T: DeserializeOwned + Send + 'static,
783 Fut: Future<Output = R> + Send + 'static,
784 R: IntoHandlerResult + 'static,
785 {
786 Self(Box::new(move |args: &str, tx: StreamSender| {
787 let parsed: Result<T, _> = serde_json::from_str(args);
788 match parsed {
789 Ok(value) => {
790 let fut = handler(value, tx);
791 Box::pin(async move { fut.await.into_handler_result() })
792 }
793 Err(e) => Box::pin(async move {
794 Err(format!("Failed to deserialize args: {e}"))
795 }),
796 }
797 }))
798 }
799
800 pub fn with_state<F, S, T, Fut, R>(handler: F, states: SharedStateMap) -> Self
802 where
803 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync + 'static,
804 S: Send + Sync + 'static,
805 T: DeserializeOwned + Send + 'static,
806 Fut: Future<Output = R> + Send + 'static,
807 R: IntoHandlerResult + 'static,
808 {
809 Self(Box::new(move |args: &str, tx: StreamSender| {
810 let state_arc = match resolve_state::<S>(&states) {
811 Ok(s) => s,
812 Err(msg) => {
813 return Box::pin(async move { Err(msg) })
814 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
815 }
816 };
817 let parsed: Result<T, _> = serde_json::from_str(args);
818 match parsed {
819 Ok(value) => {
820 let fut = handler(State(state_arc), value, tx);
821 Box::pin(async move { fut.await.into_handler_result() })
822 }
823 Err(e) => Box::pin(async move {
824 Err(format!("Failed to deserialize args: {e}"))
825 }),
826 }
827 }))
828 }
829
830 pub fn with_state_only<F, S, Fut, R>(handler: F, states: SharedStateMap) -> Self
832 where
833 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync + 'static,
834 S: Send + Sync + 'static,
835 Fut: Future<Output = R> + Send + 'static,
836 R: IntoHandlerResult + 'static,
837 {
838 Self(Box::new(move |_args: &str, tx: StreamSender| {
839 let state_arc = match resolve_state::<S>(&states) {
840 Ok(s) => s,
841 Err(msg) => {
842 return Box::pin(async move { Err(msg) })
843 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
844 }
845 };
846 let fut = handler(State(state_arc), tx);
847 Box::pin(async move { fut.await.into_handler_result() })
848 }))
849 }
850}
851
852impl StreamHandler for ErasedStreamHandler {
853 fn call_streaming(
854 &self,
855 args: &str,
856 tx: StreamSender,
857 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
858 (self.0)(args, tx)
859 }
860}
861
862pub struct StreamingHandlerFn<F, Fut, R>
871where
872 F: Fn(StreamSender) -> Fut + Send + Sync,
873 Fut: Future<Output = R> + Send,
874 R: IntoHandlerResult,
875{
876 func: F,
877 _marker: std::marker::PhantomData<fn() -> R>,
878}
879
880impl<F, Fut, R> StreamingHandlerFn<F, Fut, R>
881where
882 F: Fn(StreamSender) -> Fut + Send + Sync,
883 Fut: Future<Output = R> + Send,
884 R: IntoHandlerResult,
885{
886 pub fn new(func: F) -> Self {
888 Self {
889 func,
890 _marker: std::marker::PhantomData,
891 }
892 }
893}
894
895impl<F, Fut, R> StreamHandler for StreamingHandlerFn<F, Fut, R>
896where
897 F: Fn(StreamSender) -> Fut + Send + Sync + 'static,
898 Fut: Future<Output = R> + Send + 'static,
899 R: IntoHandlerResult + 'static,
900{
901 fn call_streaming(
902 &self,
903 _args: &str,
904 tx: StreamSender,
905 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
906 let fut = (self.func)(tx);
907 Box::pin(async move { fut.await.into_handler_result() })
908 }
909}
910
911#[allow(clippy::type_complexity)]
916pub struct StreamingHandlerWithArgs<F, T, Fut, R>
917where
918 F: Fn(T, StreamSender) -> Fut + Send + Sync,
919 T: DeserializeOwned + Send,
920 Fut: Future<Output = R> + Send,
921 R: IntoHandlerResult,
922{
923 func: F,
924 _marker: std::marker::PhantomData<(fn() -> T, fn() -> R)>,
925}
926
927impl<F, T, Fut, R> StreamingHandlerWithArgs<F, T, Fut, R>
928where
929 F: Fn(T, StreamSender) -> Fut + Send + Sync,
930 T: DeserializeOwned + Send,
931 Fut: Future<Output = R> + Send,
932 R: IntoHandlerResult,
933{
934 pub fn new(func: F) -> Self {
936 Self {
937 func,
938 _marker: std::marker::PhantomData,
939 }
940 }
941}
942
943impl<F, T, Fut, R> StreamHandler for StreamingHandlerWithArgs<F, T, Fut, R>
944where
945 F: Fn(T, StreamSender) -> Fut + Send + Sync + 'static,
946 T: DeserializeOwned + Send + 'static,
947 Fut: Future<Output = R> + Send + 'static,
948 R: IntoHandlerResult + 'static,
949{
950 fn call_streaming(
951 &self,
952 args: &str,
953 tx: StreamSender,
954 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
955 let parsed: Result<T, _> = serde_json::from_str(args);
956 match parsed {
957 Ok(value) => {
958 let fut = (self.func)(value, tx);
959 Box::pin(async move { fut.await.into_handler_result() })
960 }
961 Err(e) => Box::pin(async move {
962 Err(format!("Failed to deserialize args: {e}"))
963 }),
964 }
965 }
966}
967
968#[allow(clippy::type_complexity)]
973pub struct StreamingHandlerWithState<F, S, T, Fut, R>
974where
975 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync,
976 S: Send + Sync + 'static,
977 T: DeserializeOwned + Send,
978 Fut: Future<Output = R> + Send,
979 R: IntoHandlerResult,
980{
981 func: F,
982 states: SharedStateMap,
983 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T, fn() -> R)>,
984}
985
986impl<F, S, T, Fut, R> StreamingHandlerWithState<F, S, T, Fut, R>
987where
988 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync,
989 S: Send + Sync + 'static,
990 T: DeserializeOwned + Send,
991 Fut: Future<Output = R> + Send,
992 R: IntoHandlerResult,
993{
994 pub fn new(func: F, states: SharedStateMap) -> Self {
996 Self {
997 func,
998 states,
999 _marker: std::marker::PhantomData,
1000 }
1001 }
1002}
1003
1004impl<F, S, T, Fut, R> StreamHandler for StreamingHandlerWithState<F, S, T, Fut, R>
1005where
1006 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync + 'static,
1007 S: Send + Sync + 'static,
1008 T: DeserializeOwned + Send + 'static,
1009 Fut: Future<Output = R> + Send + 'static,
1010 R: IntoHandlerResult + 'static,
1011{
1012 fn call_streaming(
1013 &self,
1014 args: &str,
1015 tx: StreamSender,
1016 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
1017 let state_arc = match resolve_state::<S>(&self.states) {
1018 Ok(s) => s,
1019 Err(msg) => return Box::pin(async move { Err(msg) }),
1020 };
1021
1022 let parsed: Result<T, _> = serde_json::from_str(args);
1023 match parsed {
1024 Ok(value) => {
1025 let fut = (self.func)(State(state_arc), value, tx);
1026 Box::pin(async move { fut.await.into_handler_result() })
1027 }
1028 Err(e) => Box::pin(async move {
1029 Err(format!("Failed to deserialize args: {e}"))
1030 }),
1031 }
1032 }
1033}
1034
1035#[allow(clippy::type_complexity)]
1040pub struct StreamingHandlerWithStateOnly<F, S, Fut, R>
1041where
1042 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync,
1043 S: Send + Sync + 'static,
1044 Fut: Future<Output = R> + Send,
1045 R: IntoHandlerResult,
1046{
1047 func: F,
1048 states: SharedStateMap,
1049 _marker: std::marker::PhantomData<(fn() -> S, fn() -> R)>,
1050}
1051
1052impl<F, S, Fut, R> StreamingHandlerWithStateOnly<F, S, Fut, R>
1053where
1054 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync,
1055 S: Send + Sync + 'static,
1056 Fut: Future<Output = R> + Send,
1057 R: IntoHandlerResult,
1058{
1059 pub fn new(func: F, states: SharedStateMap) -> Self {
1061 Self {
1062 func,
1063 states,
1064 _marker: std::marker::PhantomData,
1065 }
1066 }
1067}
1068
1069impl<F, S, Fut, R> StreamHandler for StreamingHandlerWithStateOnly<F, S, Fut, R>
1070where
1071 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync + 'static,
1072 S: Send + Sync + 'static,
1073 Fut: Future<Output = R> + Send + 'static,
1074 R: IntoHandlerResult + 'static,
1075{
1076 fn call_streaming(
1077 &self,
1078 _args: &str,
1079 tx: StreamSender,
1080 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
1081 let state_arc = match resolve_state::<S>(&self.states) {
1082 Ok(s) => s,
1083 Err(msg) => return Box::pin(async move { Err(msg) }),
1084 };
1085
1086 let fut = (self.func)(State(state_arc), tx);
1087 Box::pin(async move { fut.await.into_handler_result() })
1088 }
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093 use super::*;
1094
1095 fn state_map<S: Send + Sync + 'static>(value: S) -> SharedStateMap {
1097 let mut map = HashMap::new();
1098 map.insert(TypeId::of::<S>(), Arc::new(value) as Arc<dyn Any + Send + Sync>);
1099 Arc::new(RwLock::new(map))
1100 }
1101
1102 #[tokio::test]
1105 async fn test_handler_fn() {
1106 let handler = HandlerFn::new(|| async { "test".to_string() });
1107 let result = handler.call("{}").await;
1108 assert_eq!(result, Ok("test".to_string()));
1109 }
1110
1111 #[tokio::test]
1112 async fn test_handler_fn_ignores_args() {
1113 let handler = HandlerFn::new(|| async { "no-args".to_string() });
1114 let result = handler.call(r#"{"unexpected": true}"#).await;
1115 assert_eq!(result, Ok("no-args".to_string()));
1116 }
1117
1118 #[tokio::test]
1119 async fn test_handler_with_args() {
1120 #[derive(serde::Deserialize)]
1121 struct Input {
1122 name: String,
1123 }
1124
1125 let handler = HandlerWithArgs::new(|args: Input| async move {
1126 format!("hello {}", args.name)
1127 });
1128
1129 let result = handler.call(r#"{"name":"Alice"}"#).await;
1130 assert_eq!(result, Ok("hello Alice".to_string()));
1131 }
1132
1133 #[tokio::test]
1134 async fn test_handler_with_args_bad_json() {
1135 #[derive(serde::Deserialize)]
1136 struct Input {
1137 _name: String,
1138 }
1139
1140 let handler = HandlerWithArgs::new(|_args: Input| async move {
1141 "unreachable".to_string()
1142 });
1143
1144 let result = handler.call("not-json").await;
1145 assert!(result.is_err());
1146 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1147 }
1148
1149 #[tokio::test]
1150 async fn test_handler_with_args_missing_field() {
1151 #[derive(serde::Deserialize)]
1152 struct Input {
1153 _name: String,
1154 }
1155
1156 let handler = HandlerWithArgs::new(|_args: Input| async move {
1157 "unreachable".to_string()
1158 });
1159
1160 let result = handler.call(r#"{"age": 30}"#).await;
1161 assert!(result.is_err());
1162 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1163 }
1164
1165 #[tokio::test]
1166 async fn test_handler_with_state() {
1167 struct AppState {
1168 greeting: String,
1169 }
1170
1171 #[derive(serde::Deserialize)]
1172 struct Input {
1173 name: String,
1174 }
1175
1176 let states = state_map(AppState {
1177 greeting: "Hi".to_string(),
1178 });
1179
1180 let handler = HandlerWithState::new(
1181 |state: State<Arc<AppState>>, args: Input| async move {
1182 format!("{} {}", state.greeting, args.name)
1183 },
1184 states,
1185 );
1186
1187 let result = handler.call(r#"{"name":"Bob"}"#).await;
1188 assert_eq!(result, Ok("Hi Bob".to_string()));
1189 }
1190
1191 #[tokio::test]
1192 async fn test_handler_with_state_only() {
1193 struct AppState {
1194 value: i32,
1195 }
1196
1197 let states = state_map(AppState { value: 42 });
1198
1199 let handler = HandlerWithStateOnly::new(
1200 |state: State<Arc<AppState>>| async move {
1201 format!("value={}", state.value)
1202 },
1203 states,
1204 );
1205
1206 let result = handler.call("{}").await;
1207 assert_eq!(result, Ok("value=42".to_string()));
1208 }
1209
1210 #[tokio::test]
1211 async fn test_handler_with_state_deser_error() {
1212 struct AppState;
1213
1214 #[derive(serde::Deserialize)]
1215 struct Input {
1216 _x: i32,
1217 }
1218
1219 let states = state_map(AppState);
1220
1221 let handler = HandlerWithState::new(
1222 |_state: State<Arc<AppState>>, _args: Input| async move {
1223 "unreachable".to_string()
1224 },
1225 states,
1226 );
1227
1228 let result = handler.call("bad").await;
1229 assert!(result.is_err());
1230 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1231 }
1232
1233 #[tokio::test]
1236 async fn test_json_handler_fn_struct() {
1237 #[derive(serde::Serialize)]
1238 struct User {
1239 id: u32,
1240 name: String,
1241 }
1242
1243 let handler = HandlerFn::new(|| async {
1244 Json(User {
1245 id: 1,
1246 name: "Alice".to_string(),
1247 })
1248 });
1249
1250 let result = handler.call("{}").await;
1251 assert_eq!(result, Ok(r#"{"id":1,"name":"Alice"}"#.to_string()));
1252 }
1253
1254 #[tokio::test]
1255 async fn test_json_handler_fn_vec() {
1256 let handler = HandlerFn::new(|| async { Json(vec![1, 2, 3]) });
1257 let result = handler.call("{}").await;
1258 assert_eq!(result, Ok("[1,2,3]".to_string()));
1259 }
1260
1261 #[tokio::test]
1262 async fn test_json_handler_with_args() {
1263 #[derive(serde::Deserialize)]
1264 struct Input {
1265 name: String,
1266 }
1267
1268 #[derive(serde::Serialize)]
1269 struct Output {
1270 greeting: String,
1271 }
1272
1273 let handler = HandlerWithArgs::new(|args: Input| async move {
1274 Json(Output {
1275 greeting: format!("Hello {}", args.name),
1276 })
1277 });
1278
1279 let result = handler.call(r#"{"name":"Bob"}"#).await;
1280 assert_eq!(result, Ok(r#"{"greeting":"Hello Bob"}"#.to_string()));
1281 }
1282
1283 #[tokio::test]
1284 async fn test_json_handler_with_args_bad_json() {
1285 #[derive(serde::Deserialize)]
1286 struct Input {
1287 _x: i32,
1288 }
1289
1290 let handler = HandlerWithArgs::new(|_: Input| async move { Json(42) });
1291 let result = handler.call("bad").await;
1292 assert!(result.is_err());
1293 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1294 }
1295
1296 #[tokio::test]
1297 async fn test_json_handler_with_state() {
1298 struct AppState {
1299 prefix: String,
1300 }
1301
1302 #[derive(serde::Deserialize)]
1303 struct Input {
1304 name: String,
1305 }
1306
1307 #[derive(serde::Serialize)]
1308 struct Output {
1309 message: String,
1310 }
1311
1312 let states = state_map(AppState {
1313 prefix: "Hi".to_string(),
1314 });
1315
1316 let handler = HandlerWithState::new(
1317 |state: State<Arc<AppState>>, args: Input| async move {
1318 Json(Output {
1319 message: format!("{} {}", state.prefix, args.name),
1320 })
1321 },
1322 states,
1323 );
1324
1325 let result = handler.call(r#"{"name":"Charlie"}"#).await;
1326 assert_eq!(result, Ok(r#"{"message":"Hi Charlie"}"#.to_string()));
1327 }
1328
1329 #[tokio::test]
1330 async fn test_json_handler_with_state_only() {
1331 struct AppState {
1332 version: String,
1333 }
1334
1335 #[derive(serde::Serialize)]
1336 struct Info {
1337 version: String,
1338 }
1339
1340 let states = state_map(AppState {
1341 version: "1.0".to_string(),
1342 });
1343
1344 let handler = HandlerWithStateOnly::new(
1345 |state: State<Arc<AppState>>| async move {
1346 Json(Info {
1347 version: state.version.clone(),
1348 })
1349 },
1350 states,
1351 );
1352
1353 let result = handler.call("{}").await;
1354 assert_eq!(result, Ok(r#"{"version":"1.0"}"#.to_string()));
1355 }
1356
1357 #[tokio::test]
1360 async fn test_result_handler_fn_ok() {
1361 #[derive(serde::Serialize)]
1362 struct Data {
1363 value: i32,
1364 }
1365
1366 let handler = HandlerFn::new(|| async {
1367 Ok::<_, String>(Data { value: 42 })
1368 });
1369
1370 let result = handler.call("{}").await;
1371 assert_eq!(result, Ok(r#"{"value":42}"#.to_string()));
1372 }
1373
1374 #[tokio::test]
1375 async fn test_result_handler_fn_err() {
1376 #[derive(serde::Serialize)]
1377 struct Data {
1378 value: i32,
1379 }
1380
1381 let handler = HandlerFn::new(|| async {
1382 Err::<Data, String>("something went wrong".to_string())
1383 });
1384
1385 let result = handler.call("{}").await;
1386 assert_eq!(result, Err("something went wrong".to_string()));
1387 }
1388
1389 #[tokio::test]
1390 async fn test_result_handler_with_args_ok() {
1391 #[derive(serde::Deserialize)]
1392 struct Input {
1393 x: i32,
1394 }
1395
1396 #[derive(serde::Serialize)]
1397 struct Output {
1398 doubled: i32,
1399 }
1400
1401 let handler = HandlerWithArgs::new(|args: Input| async move {
1402 Ok::<_, String>(Output { doubled: args.x * 2 })
1403 });
1404
1405 let result = handler.call(r#"{"x":21}"#).await;
1406 assert_eq!(result, Ok(r#"{"doubled":42}"#.to_string()));
1407 }
1408
1409 #[tokio::test]
1410 async fn test_result_handler_with_args_err() {
1411 #[derive(serde::Deserialize)]
1412 struct Input {
1413 x: i32,
1414 }
1415
1416 let handler = HandlerWithArgs::new(|args: Input| async move {
1417 if args.x < 0 {
1418 Err::<i32, String>("negative".to_string())
1419 } else {
1420 Ok(args.x)
1421 }
1422 });
1423
1424 let result = handler.call(r#"{"x":-1}"#).await;
1425 assert_eq!(result, Err("negative".to_string()));
1426 }
1427
1428 #[tokio::test]
1429 async fn test_result_handler_with_state() {
1430 struct AppState {
1431 threshold: i32,
1432 }
1433
1434 #[derive(serde::Deserialize)]
1435 struct Input {
1436 value: i32,
1437 }
1438
1439 #[derive(serde::Serialize)]
1440 struct Output {
1441 accepted: bool,
1442 }
1443
1444 let states = state_map(AppState { threshold: 10 });
1445
1446 let handler = HandlerWithState::new(
1447 |state: State<Arc<AppState>>, args: Input| async move {
1448 if args.value >= state.threshold {
1449 Ok::<_, String>(Output { accepted: true })
1450 } else {
1451 Err("below threshold".to_string())
1452 }
1453 },
1454 states,
1455 );
1456
1457 let ok_result = handler.call(r#"{"value":15}"#).await;
1458 assert_eq!(ok_result, Ok(r#"{"accepted":true}"#.to_string()));
1459
1460 let err_result = handler.call(r#"{"value":5}"#).await;
1461 assert_eq!(err_result, Err("below threshold".to_string()));
1462 }
1463
1464 #[tokio::test]
1465 async fn test_result_handler_with_state_only() {
1466 struct AppState {
1467 ready: bool,
1468 }
1469
1470 #[derive(serde::Serialize)]
1471 struct Status {
1472 ok: bool,
1473 }
1474
1475 let states = state_map(AppState { ready: true });
1476
1477 let handler = HandlerWithStateOnly::new(
1478 |state: State<Arc<AppState>>| async move {
1479 if state.ready {
1480 Ok::<_, String>(Status { ok: true })
1481 } else {
1482 Err("not ready".to_string())
1483 }
1484 },
1485 states,
1486 );
1487
1488 let result = handler.call("{}").await;
1489 assert_eq!(result, Ok(r#"{"ok":true}"#.to_string()));
1490 }
1491
1492 #[test]
1495 fn test_into_stream_item_string() {
1496 let item = "hello".to_string();
1497 assert_eq!(item.into_stream_item(), Ok("hello".to_string()));
1498 }
1499
1500 #[test]
1501 fn test_into_stream_item_json() {
1502 #[derive(serde::Serialize)]
1503 struct Token {
1504 text: String,
1505 }
1506 let item = Json(Token {
1507 text: "hi".to_string(),
1508 });
1509 assert_eq!(
1510 item.into_stream_item(),
1511 Ok(r#"{"text":"hi"}"#.to_string())
1512 );
1513 }
1514
1515 #[test]
1516 fn test_into_stream_item_json_vec() {
1517 let item = Json(vec![1, 2, 3]);
1518 assert_eq!(item.into_stream_item(), Ok("[1,2,3]".to_string()));
1519 }
1520
1521 #[test]
1522 fn test_into_stream_item_result_ok() {
1523 #[derive(serde::Serialize)]
1524 struct Data {
1525 v: i32,
1526 }
1527 let item: Result<Data, String> = Ok(Data { v: 42 });
1528 assert_eq!(item.into_stream_item(), Ok(r#"{"v":42}"#.to_string()));
1529 }
1530
1531 #[test]
1532 fn test_into_stream_item_result_err() {
1533 let item: Result<i32, String> = Err("bad".to_string());
1534 assert_eq!(item.into_stream_item(), Err("bad".to_string()));
1535 }
1536
1537 #[test]
1540 fn test_stream_error_display_closed() {
1541 let err = StreamError::Closed;
1542 assert_eq!(err.to_string(), "stream closed: receiver dropped");
1543 }
1544
1545 #[test]
1546 fn test_stream_error_display_serialize() {
1547 let err = StreamError::Serialize("bad json".to_string());
1548 assert_eq!(err.to_string(), "stream serialization error: bad json");
1549 }
1550
1551 #[test]
1552 fn test_stream_error_is_std_error() {
1553 let err: Box<dyn std::error::Error> = Box::new(StreamError::Closed);
1554 assert!(err.to_string().contains("closed"));
1555 }
1556
1557 #[tokio::test]
1560 async fn test_stream_sender_send_and_receive() {
1561 let (tx, mut rx) = StreamSender::channel();
1562 tx.send("hello".to_string()).await.unwrap();
1563 tx.send("world".to_string()).await.unwrap();
1564 drop(tx);
1565
1566 assert_eq!(rx.recv().await, Some("hello".to_string()));
1567 assert_eq!(rx.recv().await, Some("world".to_string()));
1568 assert_eq!(rx.recv().await, None);
1569 }
1570
1571 #[tokio::test]
1572 async fn test_stream_sender_send_json() {
1573 #[derive(serde::Serialize)]
1574 struct Token {
1575 t: String,
1576 }
1577 let (tx, mut rx) = StreamSender::channel();
1578 tx.send(Json(Token {
1579 t: "hi".to_string(),
1580 }))
1581 .await
1582 .unwrap();
1583 drop(tx);
1584
1585 assert_eq!(rx.recv().await, Some(r#"{"t":"hi"}"#.to_string()));
1586 }
1587
1588 #[tokio::test]
1589 async fn test_stream_sender_closed_detection() {
1590 let (tx, rx) = StreamSender::channel();
1591 assert!(!tx.is_closed());
1592 drop(rx);
1593 assert!(tx.is_closed());
1594 }
1595
1596 #[tokio::test]
1597 async fn test_stream_sender_send_after_close() {
1598 let (tx, rx) = StreamSender::channel();
1599 drop(rx);
1600 let result = tx.send("late".to_string()).await;
1601 assert_eq!(result, Err(StreamError::Closed));
1602 }
1603
1604 #[tokio::test]
1605 async fn test_stream_sender_custom_capacity() {
1606 let (tx, mut rx) = StreamSender::with_capacity(2);
1607
1608 tx.send("a".to_string()).await.unwrap();
1610 tx.send("b".to_string()).await.unwrap();
1611
1612 assert_eq!(rx.recv().await, Some("a".to_string()));
1614 assert_eq!(rx.recv().await, Some("b".to_string()));
1615
1616 tx.send("c".to_string()).await.unwrap();
1618 assert_eq!(rx.recv().await, Some("c".to_string()));
1619 }
1620
1621 #[tokio::test]
1622 async fn test_stream_sender_default_capacity() {
1623 assert_eq!(DEFAULT_STREAM_CAPACITY, 64);
1624 }
1625
1626 #[tokio::test]
1627 async fn test_stream_sender_clone() {
1628 let (tx, mut rx) = StreamSender::channel();
1629 let tx2 = tx.clone();
1630
1631 tx.send("from-tx1".to_string()).await.unwrap();
1632 tx2.send("from-tx2".to_string()).await.unwrap();
1633 drop(tx);
1634 drop(tx2);
1635
1636 assert_eq!(rx.recv().await, Some("from-tx1".to_string()));
1637 assert_eq!(rx.recv().await, Some("from-tx2".to_string()));
1638 assert_eq!(rx.recv().await, None);
1639 }
1640
1641 #[test]
1642 fn test_stream_sender_debug() {
1643 let (tx, _rx) = StreamSender::channel();
1644 let debug = format!("{:?}", tx);
1645 assert!(debug.contains("StreamSender"));
1646 }
1647
1648 #[tokio::test]
1651 async fn test_cancellation_token_not_cancelled_initially() {
1652 let (tx, _rx) = StreamSender::channel();
1653 let token = tx.cancellation_token();
1654 assert!(!token.is_cancelled());
1655 }
1656
1657 #[tokio::test]
1658 async fn test_cancellation_token_cancelled_on_explicit_cancel() {
1659 let (tx, _rx) = StreamSender::channel();
1660 let token = tx.cancellation_token();
1661 assert!(!token.is_cancelled());
1662 tx.cancel();
1663 assert!(token.is_cancelled());
1664 }
1665
1666 #[tokio::test]
1667 async fn test_cancellation_token_cancelled_future_resolves() {
1668 let (tx, _rx) = StreamSender::channel();
1669 let token = tx.cancellation_token();
1670
1671 let tx2 = tx.clone();
1673 tokio::spawn(async move {
1674 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1675 tx2.cancel();
1676 });
1677
1678 tokio::time::timeout(std::time::Duration::from_secs(1), token.cancelled())
1680 .await
1681 .expect("cancelled future should resolve");
1682 }
1683
1684 #[tokio::test]
1685 async fn test_cancellation_token_shared_across_clones() {
1686 let (tx, _rx) = StreamSender::channel();
1687 let token1 = tx.cancellation_token();
1688 let token2 = tx.cancellation_token();
1689 let tx2 = tx.clone();
1690 let token3 = tx2.cancellation_token();
1691
1692 tx.cancel();
1693 assert!(token1.is_cancelled());
1694 assert!(token2.is_cancelled());
1695 assert!(token3.is_cancelled());
1696 }
1697
1698 #[tokio::test]
1699 async fn test_cancellation_token_auto_cancelled_on_receiver_drop() {
1700 let (tx, rx) = StreamSender::channel();
1701 let token = tx.cancellation_token();
1702
1703 assert!(!token.is_cancelled());
1704 drop(rx); assert!(token.is_cancelled());
1706 }
1707
1708 #[tokio::test]
1709 async fn test_cancellation_token_auto_cancel_future_resolves_on_drop() {
1710 let (tx, rx) = StreamSender::channel();
1711 let token = tx.cancellation_token();
1712
1713 tokio::spawn(async move {
1714 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1715 drop(rx);
1716 });
1717
1718 tokio::time::timeout(std::time::Duration::from_secs(1), token.cancelled())
1719 .await
1720 .expect("cancelled future should resolve when receiver is dropped");
1721 }
1722
1723 #[tokio::test]
1726 async fn test_streaming_handler_fn() {
1727 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1728 tx.send("item1".to_string()).await.ok();
1729 tx.send("item2".to_string()).await.ok();
1730 "done".to_string()
1731 });
1732
1733 let (tx, mut rx) = StreamSender::channel();
1734 let result = handler.call_streaming("{}", tx).await;
1735
1736 assert_eq!(result, Ok("done".to_string()));
1737 assert_eq!(rx.recv().await, Some("item1".to_string()));
1738 assert_eq!(rx.recv().await, Some("item2".to_string()));
1739 }
1740
1741 #[tokio::test]
1742 async fn test_streaming_handler_with_args() {
1743 #[derive(serde::Deserialize)]
1744 struct Input {
1745 count: usize,
1746 }
1747
1748 let handler =
1749 StreamingHandlerWithArgs::new(|args: Input, tx: StreamSender| async move {
1750 for i in 0..args.count {
1751 tx.send(format!("item-{i}")).await.ok();
1752 }
1753 format!("sent {}", args.count)
1754 });
1755
1756 let (tx, mut rx) = StreamSender::channel();
1757 let result = handler.call_streaming(r#"{"count":3}"#, tx).await;
1758
1759 assert_eq!(result, Ok("sent 3".to_string()));
1760 assert_eq!(rx.recv().await, Some("item-0".to_string()));
1761 assert_eq!(rx.recv().await, Some("item-1".to_string()));
1762 assert_eq!(rx.recv().await, Some("item-2".to_string()));
1763 }
1764
1765 #[tokio::test]
1766 async fn test_streaming_handler_with_args_bad_json() {
1767 #[derive(serde::Deserialize)]
1768 struct Input {
1769 _x: i32,
1770 }
1771
1772 let handler =
1773 StreamingHandlerWithArgs::new(|_args: Input, _tx: StreamSender| async move {
1774 "unreachable".to_string()
1775 });
1776
1777 let (tx, _rx) = StreamSender::channel();
1778 let result = handler.call_streaming("bad-json", tx).await;
1779 assert!(result.is_err());
1780 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1781 }
1782
1783 #[tokio::test]
1784 async fn test_streaming_handler_with_state() {
1785 struct AppState {
1786 prefix: String,
1787 }
1788
1789 #[derive(serde::Deserialize)]
1790 struct Input {
1791 name: String,
1792 }
1793
1794 let states = state_map(AppState {
1795 prefix: "Hi".to_string(),
1796 });
1797
1798 let handler = StreamingHandlerWithState::new(
1799 |state: State<Arc<AppState>>, args: Input, tx: StreamSender| async move {
1800 tx.send(format!("{} {}", state.prefix, args.name))
1801 .await
1802 .ok();
1803 "done".to_string()
1804 },
1805 states,
1806 );
1807
1808 let (tx, mut rx) = StreamSender::channel();
1809 let result = handler.call_streaming(r#"{"name":"Alice"}"#, tx).await;
1810
1811 assert_eq!(result, Ok("done".to_string()));
1812 assert_eq!(rx.recv().await, Some("Hi Alice".to_string()));
1813 }
1814
1815 #[tokio::test]
1816 async fn test_streaming_handler_with_state_only() {
1817 struct AppState {
1818 items: Vec<String>,
1819 }
1820
1821 let states = state_map(AppState {
1822 items: vec!["a".to_string(), "b".to_string()],
1823 });
1824
1825 let handler = StreamingHandlerWithStateOnly::new(
1826 |state: State<Arc<AppState>>, tx: StreamSender| async move {
1827 for item in &state.items {
1828 tx.send(item.clone()).await.ok();
1829 }
1830 format!("sent {}", state.items.len())
1831 },
1832 states,
1833 );
1834
1835 let (tx, mut rx) = StreamSender::channel();
1836 let result = handler.call_streaming("{}", tx).await;
1837
1838 assert_eq!(result, Ok("sent 2".to_string()));
1839 assert_eq!(rx.recv().await, Some("a".to_string()));
1840 assert_eq!(rx.recv().await, Some("b".to_string()));
1841 }
1842
1843 #[tokio::test]
1844 async fn test_streaming_handler_with_state_type_mismatch() {
1845 struct WrongState;
1846 struct AppState;
1847
1848 let states = state_map(WrongState);
1849
1850 let handler = StreamingHandlerWithStateOnly::new(
1851 |_state: State<Arc<AppState>>, _tx: StreamSender| async move {
1852 "unreachable".to_string()
1853 },
1854 states,
1855 );
1856
1857 let (tx, _rx) = StreamSender::channel();
1858 let result = handler.call_streaming("{}", tx).await;
1859 assert!(result.is_err());
1860 assert!(result.unwrap_err().contains("State not found"));
1861 }
1862
1863 #[tokio::test]
1864 async fn test_streaming_handler_json_return() {
1865 #[derive(serde::Serialize)]
1866 struct Summary {
1867 count: usize,
1868 }
1869
1870 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1871 tx.send("item".to_string()).await.ok();
1872 Json(Summary { count: 1 })
1873 });
1874
1875 let (tx, mut rx) = StreamSender::channel();
1876 let result = handler.call_streaming("{}", tx).await;
1877
1878 assert_eq!(result, Ok(r#"{"count":1}"#.to_string()));
1879 assert_eq!(rx.recv().await, Some("item".to_string()));
1880 }
1881
1882 #[tokio::test]
1883 async fn test_streaming_handler_result_return() {
1884 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1885 tx.send("progress".to_string()).await.ok();
1886 Ok::<_, String>(42)
1887 });
1888
1889 let (tx, mut rx) = StreamSender::channel();
1890 let result = handler.call_streaming("{}", tx).await;
1891
1892 assert_eq!(result, Ok("42".to_string()));
1893 assert_eq!(rx.recv().await, Some("progress".to_string()));
1894 }
1895
1896 #[tokio::test]
1899 async fn test_erased_handler_from_closure_no_args() {
1900 let handler = ErasedHandler::from_closure(Box::new(|_args: &str| {
1901 Box::pin(async { Ok("hello".to_string()) })
1902 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
1903 }));
1904 let result = handler.call("{}").await;
1905 assert_eq!(result, Ok("hello".to_string()));
1906 }
1907
1908 #[tokio::test]
1909 async fn test_erased_handler_from_closure_with_args() {
1910 #[derive(serde::Deserialize)]
1911 struct Input { name: String }
1912
1913 let handler = ErasedHandler::from_closure(Box::new(|args: &str| {
1914 let parsed: Result<Input, _> = serde_json::from_str(args);
1915 match parsed {
1916 Ok(input) => {
1917 Box::pin(async move { Ok(format!("hello {}", input.name)) })
1918 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
1919 }
1920 Err(e) => Box::pin(async move { Err(e.to_string()) })
1921 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>,
1922 }
1923 }));
1924 let result = handler.call(r#"{"name":"Alice"}"#).await;
1925 assert_eq!(result, Ok("hello Alice".to_string()));
1926 }
1927
1928 #[tokio::test]
1929 async fn test_erased_handler_no_args_constructor() {
1930 let handler = ErasedHandler::no_args(|| async { "zero-arg".to_string() });
1931 let result = handler.call("ignored").await;
1932 assert_eq!(result, Ok("zero-arg".to_string()));
1933 }
1934
1935 #[tokio::test]
1936 async fn test_erased_handler_with_args_constructor() {
1937 #[derive(serde::Deserialize)]
1938 struct Input { name: String }
1939
1940 let handler = ErasedHandler::with_args(|input: Input| async move {
1941 format!("hi {}", input.name)
1942 });
1943 let result = handler.call(r#"{"name":"Bob"}"#).await;
1944 assert_eq!(result, Ok("hi Bob".to_string()));
1945 }
1946
1947 #[tokio::test]
1948 async fn test_erased_handler_with_state_constructor() {
1949 #[derive(serde::Deserialize)]
1950 struct Input { #[allow(dead_code)] name: String }
1951
1952 let states = state_map("shared-state".to_string());
1953 let handler =
1954 ErasedHandler::with_state(
1955 |state: State<Arc<String>>, _input: Input| async move {
1956 format!("state={}", *state)
1957 },
1958 states,
1959 );
1960 let result = handler.call(r#"{"name":"x"}"#).await;
1961 assert_eq!(result, Ok("state=shared-state".to_string()));
1962 }
1963
1964 #[tokio::test]
1965 async fn test_erased_handler_with_state_only_constructor() {
1966 let states = state_map(42u32);
1967 let handler =
1968 ErasedHandler::with_state_only(
1969 |state: State<Arc<u32>>| async move { format!("n={}", *state) },
1970 states,
1971 );
1972 let result = handler.call("{}").await;
1973 assert_eq!(result, Ok("n=42".to_string()));
1974 }
1975
1976 #[tokio::test]
1977 async fn test_erased_stream_handler_from_closure() {
1978 let handler = ErasedStreamHandler::from_closure(Box::new(
1979 |_args: &str, tx: StreamSender| {
1980 Box::pin(async move {
1981 tx.send("chunk".to_string()).await.ok();
1982 Ok("done".to_string())
1983 })
1984 as Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
1985 },
1986 ));
1987 let (tx, mut rx) = StreamSender::channel();
1988 let result = handler.call_streaming("{}", tx).await;
1989 assert_eq!(result, Ok("done".to_string()));
1990 assert_eq!(rx.recv().await, Some("chunk".to_string()));
1991 }
1992
1993 #[test]
1996 fn test_resolve_state_erased_success() {
1997 let states = state_map(99u64);
1998 let type_id = TypeId::of::<u64>();
1999 let type_name = std::any::type_name::<u64>();
2000
2001 let any = resolve_state_erased(&states, type_id, type_name).unwrap();
2002 let val = any.downcast::<u64>().unwrap();
2003 assert_eq!(*val, 99u64);
2004 }
2005
2006 #[test]
2007 fn test_resolve_state_erased_missing() {
2008 let states: SharedStateMap = Arc::new(RwLock::new(HashMap::new()));
2009 let type_id = TypeId::of::<String>();
2010 let type_name = std::any::type_name::<String>();
2011
2012 let err = resolve_state_erased(&states, type_id, type_name).unwrap_err();
2013 assert!(err.contains("State not found"));
2014 assert!(err.contains(type_name));
2015 }
2016
2017 #[tokio::test]
2020 async fn test_erase_handler_with_state_macro() {
2021 let states = state_map("macro-state".to_string());
2022
2023 async fn handler(
2024 state: State<Arc<String>>,
2025 _args: serde_json::Value,
2026 ) -> String {
2027 format!("got={}", *state)
2028 }
2029
2030 let erased = crate::erase_handler_with_state!(handler, String, serde_json::Value, states);
2031 let result = erased.call("{}").await;
2032 assert_eq!(result, Ok("got=macro-state".to_string()));
2033 }
2034
2035 #[tokio::test]
2036 async fn test_erase_handler_with_state_only_macro() {
2037 let states = state_map(7u32);
2038
2039 async fn handler(state: State<Arc<u32>>) -> String {
2040 format!("n={}", *state)
2041 }
2042
2043 let erased = crate::erase_handler_with_state_only!(handler, u32, states);
2044 let result = erased.call("{}").await;
2045 assert_eq!(result, Ok("n=7".to_string()));
2046 }
2047
2048 #[tokio::test]
2049 async fn test_erase_streaming_handler_with_state_only_macro() {
2050 let states = state_map("stream-state".to_string());
2051
2052 async fn handler(
2053 state: State<Arc<String>>,
2054 tx: StreamSender,
2055 ) -> String {
2056 tx.send(format!("from={}", *state)).await.ok();
2057 "done".to_string()
2058 }
2059
2060 let erased = crate::erase_streaming_handler_with_state_only!(handler, String, states);
2061 let (tx, mut rx) = StreamSender::channel();
2062 let result = erased.call_streaming("{}", tx).await;
2063 assert_eq!(result, Ok("done".to_string()));
2064 assert_eq!(rx.recv().await, Some("from=stream-state".to_string()));
2065 }
2066}