1use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::{
6 any::Any,
7 fmt,
8 future::Future,
9 ops::Deref,
10 pin::Pin,
11 sync::Arc,
12};
13use tokio::sync::mpsc;
14use tokio_util::sync::CancellationToken;
15
16pub trait IntoHandlerResult: Send {
25 fn into_handler_result(self) -> Result<String, String>;
27}
28
29impl IntoHandlerResult for String {
31 fn into_handler_result(self) -> Result<String, String> {
32 Ok(self)
33 }
34}
35
36pub struct Json<T>(pub T);
41
42impl<T: Serialize + Send> IntoHandlerResult for Json<T> {
43 fn into_handler_result(self) -> Result<String, String> {
44 serde_json::to_string(&self.0)
45 .map_err(|e| format!("Failed to serialize response: {e}"))
46 }
47}
48
49impl<T: Serialize + Send, E: fmt::Display + Send> IntoHandlerResult for Result<T, E> {
51 fn into_handler_result(self) -> Result<String, String> {
52 match self {
53 Ok(value) => serde_json::to_string(&value)
54 .map_err(|e| format!("Failed to serialize response: {e}")),
55 Err(e) => Err(e.to_string()),
56 }
57 }
58}
59
60pub trait IntoStreamItem: Send {
66 fn into_stream_item(self) -> Result<String, String>;
68}
69
70impl IntoStreamItem for String {
72 fn into_stream_item(self) -> Result<String, String> {
73 Ok(self)
74 }
75}
76
77impl<T: Serialize + Send> IntoStreamItem for Json<T> {
79 fn into_stream_item(self) -> Result<String, String> {
80 serde_json::to_string(&self.0)
81 .map_err(|e| format!("Failed to serialize stream item: {e}"))
82 }
83}
84
85impl<T: Serialize + Send, E: fmt::Display + Send> IntoStreamItem for Result<T, E> {
87 fn into_stream_item(self) -> Result<String, String> {
88 match self {
89 Ok(value) => serde_json::to_string(&value)
90 .map_err(|e| format!("Failed to serialize stream item: {e}")),
91 Err(e) => Err(e.to_string()),
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq)]
100pub enum StreamError {
101 Closed,
103 Serialize(String),
105}
106
107impl fmt::Display for StreamError {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 StreamError::Closed => write!(f, "stream closed: receiver dropped"),
111 StreamError::Serialize(e) => write!(f, "stream serialization error: {e}"),
112 }
113 }
114}
115
116impl std::error::Error for StreamError {}
117
118pub const DEFAULT_STREAM_CAPACITY: usize = 64;
122
123#[derive(Clone)]
151pub struct StreamSender {
152 tx: mpsc::Sender<String>,
153 cancel: CancellationToken,
154}
155
156pub struct StreamReceiver {
162 rx: mpsc::Receiver<String>,
163 cancel: CancellationToken,
164}
165
166impl StreamReceiver {
167 pub async fn recv(&mut self) -> Option<String> {
169 self.rx.recv().await
170 }
171
172}
173
174impl Drop for StreamReceiver {
175 fn drop(&mut self) {
176 self.cancel.cancel();
177 }
178}
179
180impl fmt::Debug for StreamReceiver {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 f.debug_struct("StreamReceiver")
183 .field("cancelled", &self.cancel.is_cancelled())
184 .finish()
185 }
186}
187
188impl StreamSender {
189 pub fn channel() -> (Self, StreamReceiver) {
194 Self::with_capacity(DEFAULT_STREAM_CAPACITY)
195 }
196
197 pub fn with_capacity(capacity: usize) -> (Self, StreamReceiver) {
201 let (tx, rx) = mpsc::channel(capacity);
202 let cancel = CancellationToken::new();
203 (
204 Self { tx, cancel: cancel.clone() },
205 StreamReceiver { rx, cancel },
206 )
207 }
208
209 pub fn cancellation_token(&self) -> CancellationToken {
222 self.cancel.clone()
223 }
224
225 pub fn cancel(&self) {
230 self.cancel.cancel();
231 }
232
233 pub async fn send(&self, item: impl IntoStreamItem) -> Result<(), StreamError> {
239 let serialized = item.into_stream_item().map_err(StreamError::Serialize)?;
240 self.tx
241 .send(serialized)
242 .await
243 .map_err(|_| StreamError::Closed)
244 }
245
246 pub fn is_closed(&self) -> bool {
255 self.tx.is_closed()
256 }
257}
258
259impl fmt::Debug for StreamSender {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_struct("StreamSender")
262 .field("closed", &self.is_closed())
263 .field("cancelled", &self.cancel.is_cancelled())
264 .finish()
265 }
266}
267
268pub trait Handler: Send + Sync {
275 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
277}
278
279#[derive(Debug, Clone)]
285pub struct State<S>(pub S);
286
287impl<S> Deref for State<S> {
288 type Target = S;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294
295pub struct HandlerFn<F, Fut, R>
299where
300 F: Fn() -> Fut + Send + Sync,
301 Fut: Future<Output = R> + Send,
302 R: IntoHandlerResult,
303{
304 func: F,
305 _marker: std::marker::PhantomData<fn() -> R>,
306}
307
308impl<F, Fut, R> HandlerFn<F, Fut, R>
309where
310 F: Fn() -> Fut + Send + Sync,
311 Fut: Future<Output = R> + Send,
312 R: IntoHandlerResult,
313{
314 pub fn new(func: F) -> Self {
316 Self {
317 func,
318 _marker: std::marker::PhantomData,
319 }
320 }
321}
322
323impl<F, Fut, R> Handler for HandlerFn<F, Fut, R>
324where
325 F: Fn() -> Fut + Send + Sync + 'static,
326 Fut: Future<Output = R> + Send + 'static,
327 R: IntoHandlerResult + 'static,
328{
329 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
330 let fut = (self.func)();
331 Box::pin(async move { fut.await.into_handler_result() })
332 }
333}
334
335#[allow(clippy::type_complexity)]
337pub struct HandlerWithArgs<F, T, Fut, R>
338where
339 F: Fn(T) -> Fut + Send + Sync,
340 T: DeserializeOwned + Send,
341 Fut: Future<Output = R> + Send,
342 R: IntoHandlerResult,
343{
344 func: F,
345 _marker: std::marker::PhantomData<(fn() -> T, fn() -> R)>,
348}
349
350impl<F, T, Fut, R> HandlerWithArgs<F, T, Fut, R>
351where
352 F: Fn(T) -> Fut + Send + Sync,
353 T: DeserializeOwned + Send,
354 Fut: Future<Output = R> + Send,
355 R: IntoHandlerResult,
356{
357 pub fn new(func: F) -> Self {
359 Self {
360 func,
361 _marker: std::marker::PhantomData,
362 }
363 }
364}
365
366impl<F, T, Fut, R> Handler for HandlerWithArgs<F, T, Fut, R>
367where
368 F: Fn(T) -> Fut + Send + Sync + 'static,
369 T: DeserializeOwned + Send + 'static,
370 Fut: Future<Output = R> + Send + 'static,
371 R: IntoHandlerResult + 'static,
372{
373 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
374 let parsed: Result<T, _> = serde_json::from_str(args);
375 match parsed {
376 Ok(value) => {
377 let fut = (self.func)(value);
378 Box::pin(async move { fut.await.into_handler_result() })
379 }
380 Err(e) => Box::pin(async move {
381 Err(format!("Failed to deserialize args: {e}"))
382 }),
383 }
384 }
385}
386
387#[allow(clippy::type_complexity)]
389pub struct HandlerWithState<F, S, T, Fut, R>
390where
391 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
392 S: Send + Sync + 'static,
393 T: DeserializeOwned + Send,
394 Fut: Future<Output = R> + Send,
395 R: IntoHandlerResult,
396{
397 func: F,
398 state: Arc<dyn Any + Send + Sync>,
399 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T, fn() -> R)>,
400}
401
402impl<F, S, T, Fut, R> HandlerWithState<F, S, T, Fut, R>
403where
404 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync,
405 S: Send + Sync + 'static,
406 T: DeserializeOwned + Send,
407 Fut: Future<Output = R> + Send,
408 R: IntoHandlerResult,
409{
410 pub fn new(func: F, state: Arc<dyn Any + Send + Sync>) -> Self {
412 Self {
413 func,
414 state,
415 _marker: std::marker::PhantomData,
416 }
417 }
418}
419
420impl<F, S, T, Fut, R> Handler for HandlerWithState<F, S, T, Fut, R>
421where
422 F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
423 S: Send + Sync + 'static,
424 T: DeserializeOwned + Send + 'static,
425 Fut: Future<Output = R> + Send + 'static,
426 R: IntoHandlerResult + 'static,
427{
428 fn call(&self, args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
429 let state_arc = match self.state.clone().downcast::<S>() {
430 Ok(s) => s,
431 Err(_) => {
432 let msg = format!(
433 "State type mismatch: expected {}",
434 std::any::type_name::<S>()
435 );
436 return Box::pin(async move { Err(msg) });
437 }
438 };
439
440 let parsed: Result<T, _> = serde_json::from_str(args);
441 match parsed {
442 Ok(value) => {
443 let fut = (self.func)(State(state_arc), value);
444 Box::pin(async move { fut.await.into_handler_result() })
445 }
446 Err(e) => Box::pin(async move {
447 Err(format!("Failed to deserialize args: {e}"))
448 }),
449 }
450 }
451}
452
453#[allow(clippy::type_complexity)]
455pub struct HandlerWithStateOnly<F, S, Fut, R>
456where
457 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
458 S: Send + Sync + 'static,
459 Fut: Future<Output = R> + Send,
460 R: IntoHandlerResult,
461{
462 func: F,
463 state: Arc<dyn Any + Send + Sync>,
464 _marker: std::marker::PhantomData<(fn() -> S, fn() -> R)>,
465}
466
467impl<F, S, Fut, R> HandlerWithStateOnly<F, S, Fut, R>
468where
469 F: Fn(State<Arc<S>>) -> Fut + Send + Sync,
470 S: Send + Sync + 'static,
471 Fut: Future<Output = R> + Send,
472 R: IntoHandlerResult,
473{
474 pub fn new(func: F, state: Arc<dyn Any + Send + Sync>) -> Self {
476 Self {
477 func,
478 state,
479 _marker: std::marker::PhantomData,
480 }
481 }
482}
483
484impl<F, S, Fut, R> Handler for HandlerWithStateOnly<F, S, Fut, R>
485where
486 F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
487 S: Send + Sync + 'static,
488 Fut: Future<Output = R> + Send + 'static,
489 R: IntoHandlerResult + 'static,
490{
491 fn call(&self, _args: &str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
492 let state_arc = match self.state.clone().downcast::<S>() {
493 Ok(s) => s,
494 Err(_) => {
495 let msg = format!(
496 "State type mismatch: expected {}",
497 std::any::type_name::<S>()
498 );
499 return Box::pin(async move { Err(msg) });
500 }
501 };
502
503 let fut = (self.func)(State(state_arc));
504 Box::pin(async move { fut.await.into_handler_result() })
505 }
506}
507
508pub trait StreamHandler: Send + Sync {
515 fn call_streaming(
520 &self,
521 args: &str,
522 tx: StreamSender,
523 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
524}
525
526pub struct StreamingHandlerFn<F, Fut, R>
530where
531 F: Fn(StreamSender) -> Fut + Send + Sync,
532 Fut: Future<Output = R> + Send,
533 R: IntoHandlerResult,
534{
535 func: F,
536 _marker: std::marker::PhantomData<fn() -> R>,
537}
538
539impl<F, Fut, R> StreamingHandlerFn<F, Fut, R>
540where
541 F: Fn(StreamSender) -> Fut + Send + Sync,
542 Fut: Future<Output = R> + Send,
543 R: IntoHandlerResult,
544{
545 pub fn new(func: F) -> Self {
547 Self {
548 func,
549 _marker: std::marker::PhantomData,
550 }
551 }
552}
553
554impl<F, Fut, R> StreamHandler for StreamingHandlerFn<F, Fut, R>
555where
556 F: Fn(StreamSender) -> Fut + Send + Sync + 'static,
557 Fut: Future<Output = R> + Send + 'static,
558 R: IntoHandlerResult + 'static,
559{
560 fn call_streaming(
561 &self,
562 _args: &str,
563 tx: StreamSender,
564 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
565 let fut = (self.func)(tx);
566 Box::pin(async move { fut.await.into_handler_result() })
567 }
568}
569
570#[allow(clippy::type_complexity)]
572pub struct StreamingHandlerWithArgs<F, T, Fut, R>
573where
574 F: Fn(T, StreamSender) -> Fut + Send + Sync,
575 T: DeserializeOwned + Send,
576 Fut: Future<Output = R> + Send,
577 R: IntoHandlerResult,
578{
579 func: F,
580 _marker: std::marker::PhantomData<(fn() -> T, fn() -> R)>,
581}
582
583impl<F, T, Fut, R> StreamingHandlerWithArgs<F, T, Fut, R>
584where
585 F: Fn(T, StreamSender) -> Fut + Send + Sync,
586 T: DeserializeOwned + Send,
587 Fut: Future<Output = R> + Send,
588 R: IntoHandlerResult,
589{
590 pub fn new(func: F) -> Self {
592 Self {
593 func,
594 _marker: std::marker::PhantomData,
595 }
596 }
597}
598
599impl<F, T, Fut, R> StreamHandler for StreamingHandlerWithArgs<F, T, Fut, R>
600where
601 F: Fn(T, StreamSender) -> Fut + Send + Sync + 'static,
602 T: DeserializeOwned + Send + 'static,
603 Fut: Future<Output = R> + Send + 'static,
604 R: IntoHandlerResult + 'static,
605{
606 fn call_streaming(
607 &self,
608 args: &str,
609 tx: StreamSender,
610 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
611 let parsed: Result<T, _> = serde_json::from_str(args);
612 match parsed {
613 Ok(value) => {
614 let fut = (self.func)(value, tx);
615 Box::pin(async move { fut.await.into_handler_result() })
616 }
617 Err(e) => Box::pin(async move {
618 Err(format!("Failed to deserialize args: {e}"))
619 }),
620 }
621 }
622}
623
624#[allow(clippy::type_complexity)]
626pub struct StreamingHandlerWithState<F, S, T, Fut, R>
627where
628 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync,
629 S: Send + Sync + 'static,
630 T: DeserializeOwned + Send,
631 Fut: Future<Output = R> + Send,
632 R: IntoHandlerResult,
633{
634 func: F,
635 state: Arc<dyn Any + Send + Sync>,
636 _marker: std::marker::PhantomData<(fn() -> S, fn() -> T, fn() -> R)>,
637}
638
639impl<F, S, T, Fut, R> StreamingHandlerWithState<F, S, T, Fut, R>
640where
641 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync,
642 S: Send + Sync + 'static,
643 T: DeserializeOwned + Send,
644 Fut: Future<Output = R> + Send,
645 R: IntoHandlerResult,
646{
647 pub fn new(func: F, state: Arc<dyn Any + Send + Sync>) -> Self {
649 Self {
650 func,
651 state,
652 _marker: std::marker::PhantomData,
653 }
654 }
655}
656
657impl<F, S, T, Fut, R> StreamHandler for StreamingHandlerWithState<F, S, T, Fut, R>
658where
659 F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync + 'static,
660 S: Send + Sync + 'static,
661 T: DeserializeOwned + Send + 'static,
662 Fut: Future<Output = R> + Send + 'static,
663 R: IntoHandlerResult + 'static,
664{
665 fn call_streaming(
666 &self,
667 args: &str,
668 tx: StreamSender,
669 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
670 let state_arc = match self.state.clone().downcast::<S>() {
671 Ok(s) => s,
672 Err(_) => {
673 let msg = format!(
674 "State type mismatch: expected {}",
675 std::any::type_name::<S>()
676 );
677 return Box::pin(async move { Err(msg) });
678 }
679 };
680
681 let parsed: Result<T, _> = serde_json::from_str(args);
682 match parsed {
683 Ok(value) => {
684 let fut = (self.func)(State(state_arc), value, tx);
685 Box::pin(async move { fut.await.into_handler_result() })
686 }
687 Err(e) => Box::pin(async move {
688 Err(format!("Failed to deserialize args: {e}"))
689 }),
690 }
691 }
692}
693
694#[allow(clippy::type_complexity)]
696pub struct StreamingHandlerWithStateOnly<F, S, Fut, R>
697where
698 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync,
699 S: Send + Sync + 'static,
700 Fut: Future<Output = R> + Send,
701 R: IntoHandlerResult,
702{
703 func: F,
704 state: Arc<dyn Any + Send + Sync>,
705 _marker: std::marker::PhantomData<(fn() -> S, fn() -> R)>,
706}
707
708impl<F, S, Fut, R> StreamingHandlerWithStateOnly<F, S, Fut, R>
709where
710 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync,
711 S: Send + Sync + 'static,
712 Fut: Future<Output = R> + Send,
713 R: IntoHandlerResult,
714{
715 pub fn new(func: F, state: Arc<dyn Any + Send + Sync>) -> Self {
717 Self {
718 func,
719 state,
720 _marker: std::marker::PhantomData,
721 }
722 }
723}
724
725impl<F, S, Fut, R> StreamHandler for StreamingHandlerWithStateOnly<F, S, Fut, R>
726where
727 F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync + 'static,
728 S: Send + Sync + 'static,
729 Fut: Future<Output = R> + Send + 'static,
730 R: IntoHandlerResult + 'static,
731{
732 fn call_streaming(
733 &self,
734 _args: &str,
735 tx: StreamSender,
736 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
737 let state_arc = match self.state.clone().downcast::<S>() {
738 Ok(s) => s,
739 Err(_) => {
740 let msg = format!(
741 "State type mismatch: expected {}",
742 std::any::type_name::<S>()
743 );
744 return Box::pin(async move { Err(msg) });
745 }
746 };
747
748 let fut = (self.func)(State(state_arc), tx);
749 Box::pin(async move { fut.await.into_handler_result() })
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 #[tokio::test]
760 async fn test_handler_fn() {
761 let handler = HandlerFn::new(|| async { "test".to_string() });
762 let result = handler.call("{}").await;
763 assert_eq!(result, Ok("test".to_string()));
764 }
765
766 #[tokio::test]
767 async fn test_handler_fn_ignores_args() {
768 let handler = HandlerFn::new(|| async { "no-args".to_string() });
769 let result = handler.call(r#"{"unexpected": true}"#).await;
770 assert_eq!(result, Ok("no-args".to_string()));
771 }
772
773 #[tokio::test]
774 async fn test_handler_with_args() {
775 #[derive(serde::Deserialize)]
776 struct Input {
777 name: String,
778 }
779
780 let handler = HandlerWithArgs::new(|args: Input| async move {
781 format!("hello {}", args.name)
782 });
783
784 let result = handler.call(r#"{"name":"Alice"}"#).await;
785 assert_eq!(result, Ok("hello Alice".to_string()));
786 }
787
788 #[tokio::test]
789 async fn test_handler_with_args_bad_json() {
790 #[derive(serde::Deserialize)]
791 struct Input {
792 _name: String,
793 }
794
795 let handler = HandlerWithArgs::new(|_args: Input| async move {
796 "unreachable".to_string()
797 });
798
799 let result = handler.call("not-json").await;
800 assert!(result.is_err());
801 assert!(result.unwrap_err().contains("Failed to deserialize args"));
802 }
803
804 #[tokio::test]
805 async fn test_handler_with_args_missing_field() {
806 #[derive(serde::Deserialize)]
807 struct Input {
808 _name: String,
809 }
810
811 let handler = HandlerWithArgs::new(|_args: Input| async move {
812 "unreachable".to_string()
813 });
814
815 let result = handler.call(r#"{"age": 30}"#).await;
816 assert!(result.is_err());
817 assert!(result.unwrap_err().contains("Failed to deserialize args"));
818 }
819
820 #[tokio::test]
821 async fn test_handler_with_state() {
822 struct AppState {
823 greeting: String,
824 }
825
826 #[derive(serde::Deserialize)]
827 struct Input {
828 name: String,
829 }
830
831 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState {
832 greeting: "Hi".to_string(),
833 });
834
835 let handler = HandlerWithState::new(
836 |state: State<Arc<AppState>>, args: Input| async move {
837 format!("{} {}", state.greeting, args.name)
838 },
839 state,
840 );
841
842 let result = handler.call(r#"{"name":"Bob"}"#).await;
843 assert_eq!(result, Ok("Hi Bob".to_string()));
844 }
845
846 #[tokio::test]
847 async fn test_handler_with_state_only() {
848 struct AppState {
849 value: i32,
850 }
851
852 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState { value: 42 });
853
854 let handler = HandlerWithStateOnly::new(
855 |state: State<Arc<AppState>>| async move {
856 format!("value={}", state.value)
857 },
858 state,
859 );
860
861 let result = handler.call("{}").await;
862 assert_eq!(result, Ok("value=42".to_string()));
863 }
864
865 #[tokio::test]
866 async fn test_handler_with_state_deser_error() {
867 struct AppState;
868
869 #[derive(serde::Deserialize)]
870 struct Input {
871 _x: i32,
872 }
873
874 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState);
875
876 let handler = HandlerWithState::new(
877 |_state: State<Arc<AppState>>, _args: Input| async move {
878 "unreachable".to_string()
879 },
880 state,
881 );
882
883 let result = handler.call("bad").await;
884 assert!(result.is_err());
885 assert!(result.unwrap_err().contains("Failed to deserialize args"));
886 }
887
888 #[tokio::test]
891 async fn test_json_handler_fn_struct() {
892 #[derive(serde::Serialize)]
893 struct User {
894 id: u32,
895 name: String,
896 }
897
898 let handler = HandlerFn::new(|| async {
899 Json(User {
900 id: 1,
901 name: "Alice".to_string(),
902 })
903 });
904
905 let result = handler.call("{}").await;
906 assert_eq!(result, Ok(r#"{"id":1,"name":"Alice"}"#.to_string()));
907 }
908
909 #[tokio::test]
910 async fn test_json_handler_fn_vec() {
911 let handler = HandlerFn::new(|| async { Json(vec![1, 2, 3]) });
912 let result = handler.call("{}").await;
913 assert_eq!(result, Ok("[1,2,3]".to_string()));
914 }
915
916 #[tokio::test]
917 async fn test_json_handler_with_args() {
918 #[derive(serde::Deserialize)]
919 struct Input {
920 name: String,
921 }
922
923 #[derive(serde::Serialize)]
924 struct Output {
925 greeting: String,
926 }
927
928 let handler = HandlerWithArgs::new(|args: Input| async move {
929 Json(Output {
930 greeting: format!("Hello {}", args.name),
931 })
932 });
933
934 let result = handler.call(r#"{"name":"Bob"}"#).await;
935 assert_eq!(result, Ok(r#"{"greeting":"Hello Bob"}"#.to_string()));
936 }
937
938 #[tokio::test]
939 async fn test_json_handler_with_args_bad_json() {
940 #[derive(serde::Deserialize)]
941 struct Input {
942 _x: i32,
943 }
944
945 let handler = HandlerWithArgs::new(|_: Input| async move { Json(42) });
946 let result = handler.call("bad").await;
947 assert!(result.is_err());
948 assert!(result.unwrap_err().contains("Failed to deserialize args"));
949 }
950
951 #[tokio::test]
952 async fn test_json_handler_with_state() {
953 struct AppState {
954 prefix: String,
955 }
956
957 #[derive(serde::Deserialize)]
958 struct Input {
959 name: String,
960 }
961
962 #[derive(serde::Serialize)]
963 struct Output {
964 message: String,
965 }
966
967 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState {
968 prefix: "Hi".to_string(),
969 });
970
971 let handler = HandlerWithState::new(
972 |state: State<Arc<AppState>>, args: Input| async move {
973 Json(Output {
974 message: format!("{} {}", state.prefix, args.name),
975 })
976 },
977 state,
978 );
979
980 let result = handler.call(r#"{"name":"Charlie"}"#).await;
981 assert_eq!(result, Ok(r#"{"message":"Hi Charlie"}"#.to_string()));
982 }
983
984 #[tokio::test]
985 async fn test_json_handler_with_state_only() {
986 struct AppState {
987 version: String,
988 }
989
990 #[derive(serde::Serialize)]
991 struct Info {
992 version: String,
993 }
994
995 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState {
996 version: "1.0".to_string(),
997 });
998
999 let handler = HandlerWithStateOnly::new(
1000 |state: State<Arc<AppState>>| async move {
1001 Json(Info {
1002 version: state.version.clone(),
1003 })
1004 },
1005 state,
1006 );
1007
1008 let result = handler.call("{}").await;
1009 assert_eq!(result, Ok(r#"{"version":"1.0"}"#.to_string()));
1010 }
1011
1012 #[tokio::test]
1015 async fn test_result_handler_fn_ok() {
1016 #[derive(serde::Serialize)]
1017 struct Data {
1018 value: i32,
1019 }
1020
1021 let handler = HandlerFn::new(|| async {
1022 Ok::<_, String>(Data { value: 42 })
1023 });
1024
1025 let result = handler.call("{}").await;
1026 assert_eq!(result, Ok(r#"{"value":42}"#.to_string()));
1027 }
1028
1029 #[tokio::test]
1030 async fn test_result_handler_fn_err() {
1031 #[derive(serde::Serialize)]
1032 struct Data {
1033 value: i32,
1034 }
1035
1036 let handler = HandlerFn::new(|| async {
1037 Err::<Data, String>("something went wrong".to_string())
1038 });
1039
1040 let result = handler.call("{}").await;
1041 assert_eq!(result, Err("something went wrong".to_string()));
1042 }
1043
1044 #[tokio::test]
1045 async fn test_result_handler_with_args_ok() {
1046 #[derive(serde::Deserialize)]
1047 struct Input {
1048 x: i32,
1049 }
1050
1051 #[derive(serde::Serialize)]
1052 struct Output {
1053 doubled: i32,
1054 }
1055
1056 let handler = HandlerWithArgs::new(|args: Input| async move {
1057 Ok::<_, String>(Output { doubled: args.x * 2 })
1058 });
1059
1060 let result = handler.call(r#"{"x":21}"#).await;
1061 assert_eq!(result, Ok(r#"{"doubled":42}"#.to_string()));
1062 }
1063
1064 #[tokio::test]
1065 async fn test_result_handler_with_args_err() {
1066 #[derive(serde::Deserialize)]
1067 struct Input {
1068 x: i32,
1069 }
1070
1071 let handler = HandlerWithArgs::new(|args: Input| async move {
1072 if args.x < 0 {
1073 Err::<i32, String>("negative".to_string())
1074 } else {
1075 Ok(args.x)
1076 }
1077 });
1078
1079 let result = handler.call(r#"{"x":-1}"#).await;
1080 assert_eq!(result, Err("negative".to_string()));
1081 }
1082
1083 #[tokio::test]
1084 async fn test_result_handler_with_state() {
1085 struct AppState {
1086 threshold: i32,
1087 }
1088
1089 #[derive(serde::Deserialize)]
1090 struct Input {
1091 value: i32,
1092 }
1093
1094 #[derive(serde::Serialize)]
1095 struct Output {
1096 accepted: bool,
1097 }
1098
1099 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState { threshold: 10 });
1100
1101 let handler = HandlerWithState::new(
1102 |state: State<Arc<AppState>>, args: Input| async move {
1103 if args.value >= state.threshold {
1104 Ok::<_, String>(Output { accepted: true })
1105 } else {
1106 Err("below threshold".to_string())
1107 }
1108 },
1109 state,
1110 );
1111
1112 let ok_result = handler.call(r#"{"value":15}"#).await;
1113 assert_eq!(ok_result, Ok(r#"{"accepted":true}"#.to_string()));
1114
1115 let err_result = handler.call(r#"{"value":5}"#).await;
1116 assert_eq!(err_result, Err("below threshold".to_string()));
1117 }
1118
1119 #[tokio::test]
1120 async fn test_result_handler_with_state_only() {
1121 struct AppState {
1122 ready: bool,
1123 }
1124
1125 #[derive(serde::Serialize)]
1126 struct Status {
1127 ok: bool,
1128 }
1129
1130 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState { ready: true });
1131
1132 let handler = HandlerWithStateOnly::new(
1133 |state: State<Arc<AppState>>| async move {
1134 if state.ready {
1135 Ok::<_, String>(Status { ok: true })
1136 } else {
1137 Err("not ready".to_string())
1138 }
1139 },
1140 state,
1141 );
1142
1143 let result = handler.call("{}").await;
1144 assert_eq!(result, Ok(r#"{"ok":true}"#.to_string()));
1145 }
1146
1147 #[test]
1150 fn test_into_stream_item_string() {
1151 let item = "hello".to_string();
1152 assert_eq!(item.into_stream_item(), Ok("hello".to_string()));
1153 }
1154
1155 #[test]
1156 fn test_into_stream_item_json() {
1157 #[derive(serde::Serialize)]
1158 struct Token {
1159 text: String,
1160 }
1161 let item = Json(Token {
1162 text: "hi".to_string(),
1163 });
1164 assert_eq!(
1165 item.into_stream_item(),
1166 Ok(r#"{"text":"hi"}"#.to_string())
1167 );
1168 }
1169
1170 #[test]
1171 fn test_into_stream_item_json_vec() {
1172 let item = Json(vec![1, 2, 3]);
1173 assert_eq!(item.into_stream_item(), Ok("[1,2,3]".to_string()));
1174 }
1175
1176 #[test]
1177 fn test_into_stream_item_result_ok() {
1178 #[derive(serde::Serialize)]
1179 struct Data {
1180 v: i32,
1181 }
1182 let item: Result<Data, String> = Ok(Data { v: 42 });
1183 assert_eq!(item.into_stream_item(), Ok(r#"{"v":42}"#.to_string()));
1184 }
1185
1186 #[test]
1187 fn test_into_stream_item_result_err() {
1188 let item: Result<i32, String> = Err("bad".to_string());
1189 assert_eq!(item.into_stream_item(), Err("bad".to_string()));
1190 }
1191
1192 #[test]
1195 fn test_stream_error_display_closed() {
1196 let err = StreamError::Closed;
1197 assert_eq!(err.to_string(), "stream closed: receiver dropped");
1198 }
1199
1200 #[test]
1201 fn test_stream_error_display_serialize() {
1202 let err = StreamError::Serialize("bad json".to_string());
1203 assert_eq!(err.to_string(), "stream serialization error: bad json");
1204 }
1205
1206 #[test]
1207 fn test_stream_error_is_std_error() {
1208 let err: Box<dyn std::error::Error> = Box::new(StreamError::Closed);
1209 assert!(err.to_string().contains("closed"));
1210 }
1211
1212 #[tokio::test]
1215 async fn test_stream_sender_send_and_receive() {
1216 let (tx, mut rx) = StreamSender::channel();
1217 tx.send("hello".to_string()).await.unwrap();
1218 tx.send("world".to_string()).await.unwrap();
1219 drop(tx);
1220
1221 assert_eq!(rx.recv().await, Some("hello".to_string()));
1222 assert_eq!(rx.recv().await, Some("world".to_string()));
1223 assert_eq!(rx.recv().await, None);
1224 }
1225
1226 #[tokio::test]
1227 async fn test_stream_sender_send_json() {
1228 #[derive(serde::Serialize)]
1229 struct Token {
1230 t: String,
1231 }
1232 let (tx, mut rx) = StreamSender::channel();
1233 tx.send(Json(Token {
1234 t: "hi".to_string(),
1235 }))
1236 .await
1237 .unwrap();
1238 drop(tx);
1239
1240 assert_eq!(rx.recv().await, Some(r#"{"t":"hi"}"#.to_string()));
1241 }
1242
1243 #[tokio::test]
1244 async fn test_stream_sender_closed_detection() {
1245 let (tx, rx) = StreamSender::channel();
1246 assert!(!tx.is_closed());
1247 drop(rx);
1248 assert!(tx.is_closed());
1249 }
1250
1251 #[tokio::test]
1252 async fn test_stream_sender_send_after_close() {
1253 let (tx, rx) = StreamSender::channel();
1254 drop(rx);
1255 let result = tx.send("late".to_string()).await;
1256 assert_eq!(result, Err(StreamError::Closed));
1257 }
1258
1259 #[tokio::test]
1260 async fn test_stream_sender_custom_capacity() {
1261 let (tx, mut rx) = StreamSender::with_capacity(2);
1262
1263 tx.send("a".to_string()).await.unwrap();
1265 tx.send("b".to_string()).await.unwrap();
1266
1267 assert_eq!(rx.recv().await, Some("a".to_string()));
1269 assert_eq!(rx.recv().await, Some("b".to_string()));
1270
1271 tx.send("c".to_string()).await.unwrap();
1273 assert_eq!(rx.recv().await, Some("c".to_string()));
1274 }
1275
1276 #[tokio::test]
1277 async fn test_stream_sender_default_capacity() {
1278 assert_eq!(DEFAULT_STREAM_CAPACITY, 64);
1279 }
1280
1281 #[tokio::test]
1282 async fn test_stream_sender_clone() {
1283 let (tx, mut rx) = StreamSender::channel();
1284 let tx2 = tx.clone();
1285
1286 tx.send("from-tx1".to_string()).await.unwrap();
1287 tx2.send("from-tx2".to_string()).await.unwrap();
1288 drop(tx);
1289 drop(tx2);
1290
1291 assert_eq!(rx.recv().await, Some("from-tx1".to_string()));
1292 assert_eq!(rx.recv().await, Some("from-tx2".to_string()));
1293 assert_eq!(rx.recv().await, None);
1294 }
1295
1296 #[test]
1297 fn test_stream_sender_debug() {
1298 let (tx, _rx) = StreamSender::channel();
1299 let debug = format!("{:?}", tx);
1300 assert!(debug.contains("StreamSender"));
1301 }
1302
1303 #[tokio::test]
1306 async fn test_cancellation_token_not_cancelled_initially() {
1307 let (tx, _rx) = StreamSender::channel();
1308 let token = tx.cancellation_token();
1309 assert!(!token.is_cancelled());
1310 }
1311
1312 #[tokio::test]
1313 async fn test_cancellation_token_cancelled_on_explicit_cancel() {
1314 let (tx, _rx) = StreamSender::channel();
1315 let token = tx.cancellation_token();
1316 assert!(!token.is_cancelled());
1317 tx.cancel();
1318 assert!(token.is_cancelled());
1319 }
1320
1321 #[tokio::test]
1322 async fn test_cancellation_token_cancelled_future_resolves() {
1323 let (tx, _rx) = StreamSender::channel();
1324 let token = tx.cancellation_token();
1325
1326 let tx2 = tx.clone();
1328 tokio::spawn(async move {
1329 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1330 tx2.cancel();
1331 });
1332
1333 tokio::time::timeout(std::time::Duration::from_secs(1), token.cancelled())
1335 .await
1336 .expect("cancelled future should resolve");
1337 }
1338
1339 #[tokio::test]
1340 async fn test_cancellation_token_shared_across_clones() {
1341 let (tx, _rx) = StreamSender::channel();
1342 let token1 = tx.cancellation_token();
1343 let token2 = tx.cancellation_token();
1344 let tx2 = tx.clone();
1345 let token3 = tx2.cancellation_token();
1346
1347 tx.cancel();
1348 assert!(token1.is_cancelled());
1349 assert!(token2.is_cancelled());
1350 assert!(token3.is_cancelled());
1351 }
1352
1353 #[tokio::test]
1354 async fn test_cancellation_token_auto_cancelled_on_receiver_drop() {
1355 let (tx, rx) = StreamSender::channel();
1356 let token = tx.cancellation_token();
1357
1358 assert!(!token.is_cancelled());
1359 drop(rx); assert!(token.is_cancelled());
1361 }
1362
1363 #[tokio::test]
1364 async fn test_cancellation_token_auto_cancel_future_resolves_on_drop() {
1365 let (tx, rx) = StreamSender::channel();
1366 let token = tx.cancellation_token();
1367
1368 tokio::spawn(async move {
1369 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1370 drop(rx);
1371 });
1372
1373 tokio::time::timeout(std::time::Duration::from_secs(1), token.cancelled())
1374 .await
1375 .expect("cancelled future should resolve when receiver is dropped");
1376 }
1377
1378 #[tokio::test]
1381 async fn test_streaming_handler_fn() {
1382 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1383 tx.send("item1".to_string()).await.ok();
1384 tx.send("item2".to_string()).await.ok();
1385 "done".to_string()
1386 });
1387
1388 let (tx, mut rx) = StreamSender::channel();
1389 let result = handler.call_streaming("{}", tx).await;
1390
1391 assert_eq!(result, Ok("done".to_string()));
1392 assert_eq!(rx.recv().await, Some("item1".to_string()));
1393 assert_eq!(rx.recv().await, Some("item2".to_string()));
1394 }
1395
1396 #[tokio::test]
1397 async fn test_streaming_handler_with_args() {
1398 #[derive(serde::Deserialize)]
1399 struct Input {
1400 count: usize,
1401 }
1402
1403 let handler =
1404 StreamingHandlerWithArgs::new(|args: Input, tx: StreamSender| async move {
1405 for i in 0..args.count {
1406 tx.send(format!("item-{i}")).await.ok();
1407 }
1408 format!("sent {}", args.count)
1409 });
1410
1411 let (tx, mut rx) = StreamSender::channel();
1412 let result = handler.call_streaming(r#"{"count":3}"#, tx).await;
1413
1414 assert_eq!(result, Ok("sent 3".to_string()));
1415 assert_eq!(rx.recv().await, Some("item-0".to_string()));
1416 assert_eq!(rx.recv().await, Some("item-1".to_string()));
1417 assert_eq!(rx.recv().await, Some("item-2".to_string()));
1418 }
1419
1420 #[tokio::test]
1421 async fn test_streaming_handler_with_args_bad_json() {
1422 #[derive(serde::Deserialize)]
1423 struct Input {
1424 _x: i32,
1425 }
1426
1427 let handler =
1428 StreamingHandlerWithArgs::new(|_args: Input, _tx: StreamSender| async move {
1429 "unreachable".to_string()
1430 });
1431
1432 let (tx, _rx) = StreamSender::channel();
1433 let result = handler.call_streaming("bad-json", tx).await;
1434 assert!(result.is_err());
1435 assert!(result.unwrap_err().contains("Failed to deserialize args"));
1436 }
1437
1438 #[tokio::test]
1439 async fn test_streaming_handler_with_state() {
1440 struct AppState {
1441 prefix: String,
1442 }
1443
1444 #[derive(serde::Deserialize)]
1445 struct Input {
1446 name: String,
1447 }
1448
1449 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState {
1450 prefix: "Hi".to_string(),
1451 });
1452
1453 let handler = StreamingHandlerWithState::new(
1454 |state: State<Arc<AppState>>, args: Input, tx: StreamSender| async move {
1455 tx.send(format!("{} {}", state.prefix, args.name))
1456 .await
1457 .ok();
1458 "done".to_string()
1459 },
1460 state,
1461 );
1462
1463 let (tx, mut rx) = StreamSender::channel();
1464 let result = handler.call_streaming(r#"{"name":"Alice"}"#, tx).await;
1465
1466 assert_eq!(result, Ok("done".to_string()));
1467 assert_eq!(rx.recv().await, Some("Hi Alice".to_string()));
1468 }
1469
1470 #[tokio::test]
1471 async fn test_streaming_handler_with_state_only() {
1472 struct AppState {
1473 items: Vec<String>,
1474 }
1475
1476 let state: Arc<dyn Any + Send + Sync> = Arc::new(AppState {
1477 items: vec!["a".to_string(), "b".to_string()],
1478 });
1479
1480 let handler = StreamingHandlerWithStateOnly::new(
1481 |state: State<Arc<AppState>>, tx: StreamSender| async move {
1482 for item in &state.items {
1483 tx.send(item.clone()).await.ok();
1484 }
1485 format!("sent {}", state.items.len())
1486 },
1487 state,
1488 );
1489
1490 let (tx, mut rx) = StreamSender::channel();
1491 let result = handler.call_streaming("{}", tx).await;
1492
1493 assert_eq!(result, Ok("sent 2".to_string()));
1494 assert_eq!(rx.recv().await, Some("a".to_string()));
1495 assert_eq!(rx.recv().await, Some("b".to_string()));
1496 }
1497
1498 #[tokio::test]
1499 async fn test_streaming_handler_with_state_type_mismatch() {
1500 struct WrongState;
1501 struct AppState;
1502
1503 let state: Arc<dyn Any + Send + Sync> = Arc::new(WrongState);
1504
1505 let handler = StreamingHandlerWithStateOnly::new(
1506 |_state: State<Arc<AppState>>, _tx: StreamSender| async move {
1507 "unreachable".to_string()
1508 },
1509 state,
1510 );
1511
1512 let (tx, _rx) = StreamSender::channel();
1513 let result = handler.call_streaming("{}", tx).await;
1514 assert!(result.is_err());
1515 assert!(result.unwrap_err().contains("State type mismatch"));
1516 }
1517
1518 #[tokio::test]
1519 async fn test_streaming_handler_json_return() {
1520 #[derive(serde::Serialize)]
1521 struct Summary {
1522 count: usize,
1523 }
1524
1525 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1526 tx.send("item".to_string()).await.ok();
1527 Json(Summary { count: 1 })
1528 });
1529
1530 let (tx, mut rx) = StreamSender::channel();
1531 let result = handler.call_streaming("{}", tx).await;
1532
1533 assert_eq!(result, Ok(r#"{"count":1}"#.to_string()));
1534 assert_eq!(rx.recv().await, Some("item".to_string()));
1535 }
1536
1537 #[tokio::test]
1538 async fn test_streaming_handler_result_return() {
1539 let handler = StreamingHandlerFn::new(|tx: StreamSender| async move {
1540 tx.send("progress".to_string()).await.ok();
1541 Ok::<_, String>(42)
1542 });
1543
1544 let (tx, mut rx) = StreamSender::channel();
1545 let result = handler.call_streaming("{}", tx).await;
1546
1547 assert_eq!(result, Ok("42".to_string()));
1548 assert_eq!(rx.recv().await, Some("progress".to_string()));
1549 }
1550}