1use crate::{
10 cancellations::{cancellations, CanceledRequests, RequestCancellation},
11 context::{self},
12 trace, ClientMessage, Request, Response, Transport,
13};
14#[cfg(feature = "opentelemetry")]
15use crate::context::SpanExt;
16use ::tokio::sync::mpsc;
17use futures::{
18 future::{AbortRegistration, Abortable},
19 prelude::*,
20 ready,
21 stream::Fuse,
22 task::*,
23};
24use in_flight_requests::{AlreadyExistsError, InFlightRequests};
25use pin_project::pin_project;
26use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
27use tracing::{info_span, instrument::Instrument, Span};
28
29mod in_flight_requests;
30#[cfg(test)]
31mod testing;
32
33pub mod limits;
35
36pub mod incoming;
38
39#[cfg(feature = "tokio1")]
41#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
42pub mod tokio;
43
44#[derive(Clone, Debug)]
46pub struct Config {
47 pub pending_response_buffer: usize,
51}
52
53impl Default for Config {
54 fn default() -> Self {
55 Config {
56 pending_response_buffer: 100,
57 }
58 }
59}
60
61impl Config {
62 pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
64 where
65 T: Transport<Response<Resp>, ClientMessage<Req>>,
66 {
67 BaseChannel::new(self, transport)
68 }
69}
70
71pub trait Serve<Req> {
73 type Resp;
75
76 type Fut: Future<Output = Self::Resp>;
78
79 fn method(&self, _request: &Req) -> Option<&'static str> {
81 None
82 }
83
84 fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
86}
87
88impl<Req, Resp, Fut, F> Serve<Req> for F
89where
90 F: FnOnce(context::Context, Req) -> Fut,
91 Fut: Future<Output = Resp>,
92{
93 type Resp = Resp;
94 type Fut = Fut;
95
96 fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
97 self(ctx, req)
98 }
99}
100
101#[pin_project]
112pub struct BaseChannel<Req, Resp, T> {
113 config: Config,
114 #[pin]
116 transport: Fuse<T>,
117 #[pin]
119 canceled_requests: CanceledRequests,
120 request_cancellation: RequestCancellation,
122 in_flight_requests: InFlightRequests,
124 ghost: PhantomData<(Req, Resp)>,
126}
127
128impl<Req, Resp, T> BaseChannel<Req, Resp, T>
129where
130 T: Transport<Response<Resp>, ClientMessage<Req>>,
131{
132 pub fn new(config: Config, transport: T) -> Self {
134 let (request_cancellation, canceled_requests) = cancellations();
135 BaseChannel {
136 config,
137 transport: transport.fuse(),
138 canceled_requests,
139 request_cancellation,
140 in_flight_requests: InFlightRequests::default(),
141 ghost: PhantomData,
142 }
143 }
144
145 pub fn with_defaults(transport: T) -> Self {
147 Self::new(Config::default(), transport)
148 }
149
150 pub fn get_ref(&self) -> &T {
152 self.transport.get_ref()
153 }
154
155 pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
157 self.project().transport.get_pin_mut()
158 }
159
160 fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
161 self.as_mut().project().in_flight_requests
162 }
163
164 fn canceled_requests_pin_mut<'a>(
165 self: &'a mut Pin<&mut Self>,
166 ) -> Pin<&'a mut CanceledRequests> {
167 self.as_mut().project().canceled_requests
168 }
169
170 fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
171 self.as_mut().project().transport
172 }
173
174 fn start_request(
175 mut self: Pin<&mut Self>,
176 mut request: Request<Req>,
177 ) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
178 let span = info_span!(
179 "RPC",
180 rpc.trace_id = %request.context.trace_id(),
181 rpc.deadline = %humantime::format_rfc3339(request.context.deadline),
182 otel.kind = "server",
183 otel.name = tracing::field::Empty,
184 );
185 #[cfg(feature = "opentelemetry")]
186 span.set_context(&request.context);
187 request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
188 tracing::trace!(
189 "OpenTelemetry subscriber not installed; making unsampled \
190 child context."
191 );
192 request.context.trace_context.new_child()
193 });
194 let entered = span.enter();
195 tracing::info!("ReceiveRequest");
196 let start = self.in_flight_requests_mut().start_request(
197 request.id,
198 request.context.deadline,
199 span.clone(),
200 );
201 match start {
202 Ok(abort_registration) => {
203 drop(entered);
204 Ok(TrackedRequest {
205 abort_registration,
206 span,
207 response_guard: ResponseGuard {
208 request_id: request.id,
209 request_cancellation: self.request_cancellation.clone(),
210 cancel: false,
211 },
212 request,
213 })
214 }
215 Err(AlreadyExistsError) => {
216 tracing::trace!("DuplicateRequest");
217 Err(AlreadyExistsError)
218 }
219 }
220 }
221}
222
223impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 write!(f, "BaseChannel")
226 }
227}
228
229#[derive(Debug)]
231pub struct TrackedRequest<Req> {
232 pub request: Request<Req>,
234 pub abort_registration: AbortRegistration,
237 pub span: Span,
239 pub response_guard: ResponseGuard,
241}
242
243pub trait Channel
269where
270 Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
271{
272 type Req;
274
275 type Resp;
277
278 type Transport;
280
281 fn config(&self) -> &Config;
283
284 fn in_flight_requests(&self) -> usize;
286
287 fn transport(&self) -> &Self::Transport;
289
290 fn max_concurrent_requests(
298 self,
299 limit: usize,
300 ) -> limits::requests_per_channel::MaxRequests<Self>
301 where
302 Self: Sized,
303 {
304 limits::requests_per_channel::MaxRequests::new(self, limit)
305 }
306
307 fn requests(self) -> Requests<Self>
314 where
315 Self: Sized,
316 {
317 let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
318
319 Requests {
320 channel: self,
321 pending_responses: responses,
322 responses_tx,
323 }
324 }
325
326 #[cfg(feature = "tokio1")]
330 #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
331 fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
332 where
333 Self: Sized,
334 S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
335 S::Fut: Send,
336 Self::Req: Send + 'static,
337 Self::Resp: Send + 'static,
338 {
339 self.requests().execute(serve)
340 }
341}
342
343#[derive(thiserror::Error, Debug)]
345pub enum ChannelError<E>
346where
347 E: Error + Send + Sync + 'static,
348{
349 #[error("an error occurred in the transport: {0}")]
351 Transport(#[source] E),
352 #[error("an error occurred while polling expired requests: {0}")]
354 Timer(#[source] ::tokio::time::error::Error),
355}
356
357impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
358where
359 T: Transport<Response<Resp>, ClientMessage<Req>>,
360{
361 type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
362
363 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
364 #[derive(Clone, Copy, Debug)]
365 enum ReceiverStatus {
366 Ready,
367 Pending,
368 Closed,
369 }
370
371 impl ReceiverStatus {
372 fn combine(self, other: Self) -> Self {
373 use ReceiverStatus::*;
374 match (self, other) {
375 (Ready, _) | (_, Ready) => Ready,
376 (Closed, Closed) => Closed,
377 (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
378 }
379 }
380 }
381
382 use ReceiverStatus::*;
383
384 loop {
385 let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
386 Poll::Ready(Some(request_id)) => {
387 if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
388 let _entered = span.enter();
389 tracing::info!("ResponseCancelled");
390 }
391 Ready
392 }
393 Poll::Pending | Poll::Ready(None) => Closed,
400 };
401
402 let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) {
403 Poll::Ready(Some(_)) => Ready,
406 Poll::Ready(None) => Closed,
407 Poll::Pending => Pending,
408 };
409
410 let request_status = match self
411 .transport_pin_mut()
412 .poll_next(cx)
413 .map_err(ChannelError::Transport)?
414 {
415 Poll::Ready(Some(message)) => match message {
416 ClientMessage::Request(request) => {
417 match self.as_mut().start_request(request) {
418 Ok(request) => return Poll::Ready(Some(Ok(request))),
419 Err(AlreadyExistsError) => {
420 continue;
425 }
426 }
427 }
428 ClientMessage::Cancel {
429 trace_context,
430 request_id,
431 } => {
432 if !self.in_flight_requests_mut().cancel_request(request_id) {
433 tracing::trace!(
434 rpc.trace_id = %trace_context.trace_id,
435 "Received cancellation, but response handler is already complete.",
436 );
437 }
438 Ready
439 }
440 },
441 Poll::Ready(None) => Closed,
442 Poll::Pending => Pending,
443 };
444
445 tracing::trace!(
446 "Expired requests: {:?}, Inbound: {:?}",
447 expiration_status,
448 request_status
449 );
450 match cancellation_status
451 .combine(expiration_status)
452 .combine(request_status)
453 {
454 Ready => continue,
455 Closed => return Poll::Ready(None),
456 Pending => return Poll::Pending,
457 }
458 }
459 }
460}
461
462impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
463where
464 T: Transport<Response<Resp>, ClientMessage<Req>>,
465 T::Error: Error,
466{
467 type Error = ChannelError<T::Error>;
468
469 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
470 self.project()
471 .transport
472 .poll_ready(cx)
473 .map_err(ChannelError::Transport)
474 }
475
476 fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
477 if let Some(span) = self
478 .in_flight_requests_mut()
479 .remove_request(response.request_id)
480 {
481 let _entered = span.enter();
482 tracing::info!("SendResponse");
483 self.project()
484 .transport
485 .start_send(response)
486 .map_err(ChannelError::Transport)
487 } else {
488 Ok(())
490 }
491 }
492
493 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
494 tracing::trace!("poll_flush");
495 self.project()
496 .transport
497 .poll_flush(cx)
498 .map_err(ChannelError::Transport)
499 }
500
501 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
502 self.project()
503 .transport
504 .poll_close(cx)
505 .map_err(ChannelError::Transport)
506 }
507}
508
509impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
510 fn as_ref(&self) -> &T {
511 self.transport.get_ref()
512 }
513}
514
515impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
516where
517 T: Transport<Response<Resp>, ClientMessage<Req>>,
518{
519 type Req = Req;
520 type Resp = Resp;
521 type Transport = T;
522
523 fn config(&self) -> &Config {
524 &self.config
525 }
526
527 fn in_flight_requests(&self) -> usize {
528 self.in_flight_requests.len()
529 }
530
531 fn transport(&self) -> &Self::Transport {
532 self.get_ref()
533 }
534}
535
536#[pin_project]
539pub struct Requests<C>
540where
541 C: Channel,
542{
543 #[pin]
544 channel: C,
545 pending_responses: mpsc::Receiver<Response<C::Resp>>,
547 responses_tx: mpsc::Sender<Response<C::Resp>>,
549}
550
551impl<C> Requests<C>
552where
553 C: Channel,
554{
555 pub fn channel(&self) -> &C {
557 &self.channel
558 }
559
560 pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
562 self.as_mut().project().channel
563 }
564
565 pub fn pending_responses_mut<'a>(
567 self: &'a mut Pin<&mut Self>,
568 ) -> &'a mut mpsc::Receiver<Response<C::Resp>> {
569 self.as_mut().project().pending_responses
570 }
571
572 fn pump_read(
573 mut self: Pin<&mut Self>,
574 cx: &mut Context<'_>,
575 ) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
576 self.channel_pin_mut().poll_next(cx).map_ok(
577 |TrackedRequest {
578 request,
579 abort_registration,
580 span,
581 mut response_guard,
582 }| {
583 response_guard.cancel = true;
585 InFlightRequest {
586 request,
587 abort_registration,
588 span,
589 response_guard,
590 response_tx: self.responses_tx.clone(),
591 }
592 },
593 )
594 }
595
596 fn pump_write(
597 mut self: Pin<&mut Self>,
598 cx: &mut Context<'_>,
599 read_half_closed: bool,
600 ) -> Poll<Option<Result<(), C::Error>>> {
601 match self.as_mut().poll_next_response(cx)? {
602 Poll::Ready(Some(response)) => {
603 self.channel_pin_mut().start_send(response)?;
606 Poll::Ready(Some(Ok(())))
607 }
608 Poll::Ready(None) => {
609 ready!(self.channel_pin_mut().poll_flush(cx)?);
611 Poll::Ready(None)
612 }
613 Poll::Pending => {
614 ready!(self.channel_pin_mut().poll_flush(cx)?);
616
617 if read_half_closed && self.channel.in_flight_requests() == 0 {
621 Poll::Ready(None)
622 } else {
623 Poll::Pending
624 }
625 }
626 }
627 }
628
629 fn poll_next_response(
634 mut self: Pin<&mut Self>,
635 cx: &mut Context<'_>,
636 ) -> Poll<Option<Result<Response<C::Resp>, C::Error>>> {
637 ready!(self.ensure_writeable(cx)?);
638
639 match ready!(self.pending_responses_mut().poll_recv(cx)) {
640 Some(response) => Poll::Ready(Some(Ok(response))),
641 None => {
642 Poll::Ready(None)
644 }
645 }
646 }
647
648 fn ensure_writeable<'a>(
651 self: &'a mut Pin<&mut Self>,
652 cx: &mut Context<'_>,
653 ) -> Poll<Option<Result<(), C::Error>>> {
654 while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
655 ready!(self.channel_pin_mut().poll_flush(cx)?);
656 }
657 Poll::Ready(Some(Ok(())))
658 }
659}
660
661impl<C> fmt::Debug for Requests<C>
662where
663 C: Channel,
664{
665 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
666 write!(fmt, "Requests")
667 }
668}
669
670#[derive(Debug)]
673pub struct ResponseGuard {
674 request_cancellation: RequestCancellation,
675 request_id: u64,
676 cancel: bool,
677}
678
679impl Drop for ResponseGuard {
680 fn drop(&mut self) {
681 if self.cancel {
682 self.request_cancellation.cancel(self.request_id);
683 }
684 }
685}
686
687#[derive(Debug)]
692pub struct InFlightRequest<Req, Res> {
693 request: Request<Req>,
694 abort_registration: AbortRegistration,
695 response_guard: ResponseGuard,
696 span: Span,
697 response_tx: mpsc::Sender<Response<Res>>,
698}
699
700impl<Req, Res> InFlightRequest<Req, Res> {
701 pub fn get(&self) -> &Request<Req> {
703 &self.request
704 }
705
706 pub async fn execute<S>(self, serve: S)
721 where
722 S: Serve<Req, Resp = Res>,
723 {
724 let Self {
725 response_tx,
726 mut response_guard,
727 abort_registration,
728 span,
729 request:
730 Request {
731 context,
732 message,
733 id: request_id,
734 },
735 } = self;
736 let method = serve.method(&message);
737 #[allow(clippy::needless_borrow)]
740 span.record("otel.name", &method.unwrap_or(""));
741 let _ = Abortable::new(
742 async move {
743 tracing::info!("BeginRequest");
744 let response = serve.serve(context, message).await;
745
746 tracing::info!("CompleteRequest");
747 if context.discard_response {
748 tracing::info!("DiscardingResponse");
749 } else {
750 let response = Response {
751 request_id,
752 message: Ok(response),
753 };
754 let _ = response_tx.send(response).await;
755 tracing::info!("BufferResponse");
756 }
757 },
758 abort_registration,
759 )
760 .instrument(span)
761 .await;
762 response_guard.cancel = false;
766 }
767}
768
769impl<C> Stream for Requests<C>
770where
771 C: Channel,
772{
773 type Item = Result<InFlightRequest<C::Req, C::Resp>, C::Error>;
774
775 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
776 loop {
777 let read = self.as_mut().pump_read(cx)?;
778 let read_closed = matches!(read, Poll::Ready(None));
779 match (read, self.as_mut().pump_write(cx, read_closed)?) {
780 (Poll::Ready(None), Poll::Ready(None)) => {
781 return Poll::Ready(None);
782 }
783 (Poll::Ready(Some(request_handler)), _) => {
784 return Poll::Ready(Some(Ok(request_handler)));
785 }
786 (_, Poll::Ready(Some(()))) => {}
787 _ => {
788 return Poll::Pending;
789 }
790 }
791 }
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
798 use crate::{
799 context, trace,
800 transport::channel::{self, UnboundedChannel},
801 ClientMessage, Request, Response,
802 };
803 use assert_matches::assert_matches;
804 use futures::{
805 future::{pending, AbortRegistration, Abortable, Aborted},
806 prelude::*,
807 Future,
808 };
809 use futures_test::task::noop_context;
810 use std::{pin::Pin, task::Poll};
811
812 fn test_channel<Req, Resp>() -> (
813 Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
814 UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
815 ) {
816 let (tx, rx) = crate::transport::channel::unbounded();
817 (Box::pin(BaseChannel::new(Config::default(), rx)), tx)
818 }
819
820 fn test_requests<Req, Resp>() -> (
821 Pin<
822 Box<
823 Requests<
824 BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>,
825 >,
826 >,
827 >,
828 UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
829 ) {
830 let (tx, rx) = crate::transport::channel::unbounded();
831 (
832 Box::pin(BaseChannel::new(Config::default(), rx).requests()),
833 tx,
834 )
835 }
836
837 fn test_bounded_requests<Req, Resp>(
838 capacity: usize,
839 ) -> (
840 Pin<
841 Box<
842 Requests<
843 BaseChannel<Req, Resp, channel::Channel<ClientMessage<Req>, Response<Resp>>>,
844 >,
845 >,
846 >,
847 channel::Channel<Response<Resp>, ClientMessage<Req>>,
848 ) {
849 let (tx, rx) = crate::transport::channel::bounded(capacity);
850 let config = Config {
852 pending_response_buffer: capacity + 1,
853 };
854 (Box::pin(BaseChannel::new(config, rx).requests()), tx)
855 }
856
857 fn fake_request<Req>(req: Req) -> ClientMessage<Req> {
858 ClientMessage::Request(Request {
859 context: context::current(),
860 id: 0,
861 message: req,
862 })
863 }
864
865 fn test_abortable(
866 abort_registration: AbortRegistration,
867 ) -> impl Future<Output = Result<(), Aborted>> {
868 Abortable::new(pending(), abort_registration)
869 }
870
871 #[tokio::test]
872 async fn base_channel_start_send_duplicate_request_returns_error() {
873 let (mut channel, _tx) = test_channel::<(), ()>();
874
875 channel
876 .as_mut()
877 .start_request(Request {
878 id: 0,
879 context: context::current(),
880 message: (),
881 })
882 .unwrap();
883 assert_matches!(
884 channel.as_mut().start_request(Request {
885 id: 0,
886 context: context::current(),
887 message: ()
888 }),
889 Err(AlreadyExistsError)
890 );
891 }
892
893 #[tokio::test]
894 async fn base_channel_poll_next_aborts_multiple_requests() {
895 let (mut channel, _tx) = test_channel::<(), ()>();
896
897 tokio::time::pause();
898 let req0 = channel
899 .as_mut()
900 .start_request(Request {
901 id: 0,
902 context: context::current(),
903 message: (),
904 })
905 .unwrap();
906 let req1 = channel
907 .as_mut()
908 .start_request(Request {
909 id: 1,
910 context: context::current(),
911 message: (),
912 })
913 .unwrap();
914 tokio::time::advance(std::time::Duration::from_secs(1000)).await;
915
916 assert_matches!(
917 channel.as_mut().poll_next(&mut noop_context()),
918 Poll::Pending
919 );
920 assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
921 assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
922 }
923
924 #[tokio::test]
925 async fn base_channel_poll_next_aborts_canceled_request() {
926 let (mut channel, mut tx) = test_channel::<(), ()>();
927
928 tokio::time::pause();
929 let req = channel
930 .as_mut()
931 .start_request(Request {
932 id: 0,
933 context: context::current(),
934 message: (),
935 })
936 .unwrap();
937
938 tx.send(ClientMessage::Cancel {
939 trace_context: trace::Context::default(),
940 request_id: 0,
941 })
942 .await
943 .unwrap();
944
945 assert_matches!(
946 channel.as_mut().poll_next(&mut noop_context()),
947 Poll::Pending
948 );
949
950 assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
951 }
952
953 #[tokio::test]
954 async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() {
955 let (mut channel, tx) = test_channel::<(), ()>();
956
957 tokio::time::pause();
958 let _abort_registration = channel
959 .as_mut()
960 .start_request(Request {
961 id: 0,
962 context: context::current(),
963 message: (),
964 })
965 .unwrap();
966
967 drop(tx);
968 assert_matches!(
969 channel.as_mut().poll_next(&mut noop_context()),
970 Poll::Pending
971 );
972 }
973
974 #[tokio::test]
975 async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() {
976 let (mut channel, tx) = test_channel::<(), ()>();
977 drop(tx);
978 assert_matches!(
979 channel.as_mut().poll_next(&mut noop_context()),
980 Poll::Ready(None)
981 );
982 }
983
984 #[tokio::test]
985 async fn base_channel_poll_next_yields_request() {
986 let (mut channel, mut tx) = test_channel::<(), ()>();
987 tx.send(fake_request(())).await.unwrap();
988
989 assert_matches!(
990 channel.as_mut().poll_next(&mut noop_context()),
991 Poll::Ready(Some(Ok(_)))
992 );
993 }
994
995 #[tokio::test]
996 async fn base_channel_poll_next_aborts_request_and_yields_request() {
997 let (mut channel, mut tx) = test_channel::<(), ()>();
998
999 tokio::time::pause();
1000 let req = channel
1001 .as_mut()
1002 .start_request(Request {
1003 id: 0,
1004 context: context::current(),
1005 message: (),
1006 })
1007 .unwrap();
1008 tokio::time::advance(std::time::Duration::from_secs(1000)).await;
1009
1010 tx.send(fake_request(())).await.unwrap();
1011
1012 assert_matches!(
1013 channel.as_mut().poll_next(&mut noop_context()),
1014 Poll::Ready(Some(Ok(_)))
1015 );
1016 assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
1017 }
1018
1019 #[tokio::test]
1020 async fn base_channel_start_send_removes_in_flight_request() {
1021 let (mut channel, _tx) = test_channel::<(), ()>();
1022
1023 channel
1024 .as_mut()
1025 .start_request(Request {
1026 id: 0,
1027 context: context::current(),
1028 message: (),
1029 })
1030 .unwrap();
1031 assert_eq!(channel.in_flight_requests(), 1);
1032 channel
1033 .as_mut()
1034 .start_send(Response {
1035 request_id: 0,
1036 message: Ok(()),
1037 })
1038 .unwrap();
1039 assert_eq!(channel.in_flight_requests(), 0);
1040 }
1041
1042 #[tokio::test]
1043 async fn in_flight_request_drop_cancels_request() {
1044 let (mut requests, mut tx) = test_requests::<(), ()>();
1045 tx.send(fake_request(())).await.unwrap();
1046
1047 let request = match requests.as_mut().poll_next(&mut noop_context()) {
1048 Poll::Ready(Some(Ok(request))) => request,
1049 result => panic!("Unexpected result: {:?}", result),
1050 };
1051 drop(request);
1052
1053 let poll = requests
1054 .as_mut()
1055 .channel_pin_mut()
1056 .poll_next(&mut noop_context());
1057 assert!(poll.is_pending());
1058 let in_flight_requests = requests.channel().in_flight_requests();
1059 assert_eq!(in_flight_requests, 0);
1060 }
1061
1062 #[tokio::test]
1063 async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
1064 let (mut requests, mut tx) = test_requests::<(), ()>();
1065 tx.send(fake_request(())).await.unwrap();
1066
1067 let request = match requests.as_mut().poll_next(&mut noop_context()) {
1068 Poll::Ready(Some(Ok(request))) => request,
1069 result => panic!("Unexpected result: {:?}", result),
1070 };
1071 request.execute(|_, _| async {}).await;
1072 assert!(requests
1073 .as_mut()
1074 .channel_pin_mut()
1075 .canceled_requests
1076 .poll_recv(&mut noop_context())
1077 .is_pending());
1078 }
1079
1080 #[tokio::test]
1081 async fn requests_poll_next_response_returns_pending_when_buffer_full() {
1082 let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1083
1084 requests
1086 .as_mut()
1087 .channel_pin_mut()
1088 .start_request(Request {
1089 id: 0,
1090 context: context::current(),
1091 message: (),
1092 })
1093 .unwrap();
1094 requests
1095 .as_mut()
1096 .channel_pin_mut()
1097 .start_send(Response {
1098 request_id: 0,
1099 message: Ok(()),
1100 })
1101 .unwrap();
1102
1103 requests
1105 .as_mut()
1106 .project()
1107 .responses_tx
1108 .send(Response {
1109 request_id: 1,
1110 message: Ok(()),
1111 })
1112 .await
1113 .unwrap();
1114
1115 requests
1116 .as_mut()
1117 .channel_pin_mut()
1118 .start_request(Request {
1119 id: 1,
1120 context: context::current(),
1121 message: (),
1122 })
1123 .unwrap();
1124
1125 assert_matches!(
1126 requests.as_mut().poll_next_response(&mut noop_context()),
1127 Poll::Pending
1128 );
1129 }
1130
1131 #[tokio::test]
1132 async fn requests_pump_write_returns_pending_when_buffer_full() {
1133 let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1134
1135 requests
1137 .as_mut()
1138 .channel_pin_mut()
1139 .start_request(Request {
1140 id: 0,
1141 context: context::current(),
1142 message: (),
1143 })
1144 .unwrap();
1145 requests
1146 .as_mut()
1147 .channel_pin_mut()
1148 .start_send(Response {
1149 request_id: 0,
1150 message: Ok(()),
1151 })
1152 .unwrap();
1153
1154 requests
1156 .as_mut()
1157 .channel_pin_mut()
1158 .start_request(Request {
1159 id: 1,
1160 context: context::current(),
1161 message: (),
1162 })
1163 .unwrap();
1164 requests
1165 .as_mut()
1166 .project()
1167 .responses_tx
1168 .send(Response {
1169 request_id: 1,
1170 message: Ok(()),
1171 })
1172 .await
1173 .unwrap();
1174
1175 assert_matches!(
1176 requests.as_mut().pump_write(&mut noop_context(), true),
1177 Poll::Pending
1178 );
1179 assert_matches!(
1181 requests.as_mut().pending_responses_mut().recv().await,
1182 Some(_)
1183 );
1184 }
1185
1186 #[tokio::test]
1187 async fn requests_pump_read() {
1188 let (mut requests, mut tx) = test_requests::<(), ()>();
1189
1190 tx.send(fake_request(())).await.unwrap();
1192
1193 assert_matches!(
1194 requests.as_mut().pump_read(&mut noop_context()),
1195 Poll::Ready(Some(Ok(_)))
1196 );
1197 assert_eq!(requests.channel.in_flight_requests(), 1);
1198 }
1199}