use crate::cancellations::{cancellations, CanceledRequests, RequestCancellation};
use crate::context::{self, SpanExt};
use crate::util::TimeUntil;
use crate::{trace, ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport};
use ::tokio::sync::mpsc;
use futures::future::{AbortRegistration, Abortable};
use futures::prelude::*;
use futures::ready;
use futures::stream::Fuse;
use futures::task::*;
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::time::SystemTime;
use tracing::instrument::Instrument;
use tracing::{info_span, Span};
mod in_flight_requests;
pub mod request_hook;
#[cfg(test)]
mod testing;
pub mod limits;
pub mod incoming;
#[derive(Clone, Debug)]
pub struct Config {
pub pending_response_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config { pending_response_buffer: 100 }
}
}
impl Config {
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
BaseChannel::new(self, transport)
}
}
#[allow(async_fn_in_trait)]
pub trait Serve {
type Req: RequestName;
type Resp;
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
}
#[derive(Debug)]
pub struct ServeFn<Req, Resp, F> {
f: F,
data: PhantomData<fn(Req) -> Resp>,
}
impl<Req, Resp, F> Clone for ServeFn<Req, Resp, F>
where
F: Clone,
{
fn clone(&self) -> Self {
Self { f: self.f.clone(), data: PhantomData }
}
}
impl<Req, Resp, F> Copy for ServeFn<Req, Resp, F> where F: Copy {}
pub fn serve<Req, Resp, Fut, F>(f: F) -> ServeFn<Req, Resp, F>
where
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
{
ServeFn { f, data: PhantomData }
}
impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
where
Req: RequestName,
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
{
type Req = Req;
type Resp = Resp;
async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
(self.f)(ctx, req).await
}
}
#[pin_project]
pub struct BaseChannel<Req, Resp, T> {
config: Config,
#[pin]
transport: Fuse<T>,
#[pin]
canceled_requests: CanceledRequests,
request_cancellation: RequestCancellation,
in_flight_requests: InFlightRequests,
ghost: PhantomData<(fn() -> Req, fn(Resp))>,
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
pub fn new(config: Config, transport: T) -> Self {
let (request_cancellation, canceled_requests) = cancellations();
BaseChannel {
config,
transport: transport.fuse(),
canceled_requests,
request_cancellation,
in_flight_requests: InFlightRequests::default(),
ghost: PhantomData,
}
}
pub fn with_defaults(transport: T) -> Self {
Self::new(Config::default(), transport)
}
pub fn get_ref(&self) -> &T {
self.transport.get_ref()
}
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().transport.get_pin_mut()
}
fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
self.as_mut().project().in_flight_requests
}
fn canceled_requests_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut CanceledRequests> {
self.as_mut().project().canceled_requests
}
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
self.as_mut().project().transport
}
fn start_request(mut self: Pin<&mut Self>, mut request: Request<Req>) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
let span = info_span!(
"RPC",
rpc.trace_id = %request.context.trace_id(),
rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()),
otel.kind = "server",
otel.name = tracing::field::Empty,
);
span.set_context(&request.context);
request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
tracing::trace!(
"OpenTelemetry subscriber not installed; making unsampled \
child context."
);
request.context.trace_context.new_child()
});
let entered = span.enter();
tracing::info!("ReceiveRequest");
let start = self.in_flight_requests_mut().start_request(request.id, request.context.deadline, span.clone());
match start {
Ok(abort_registration) => {
drop(entered);
Ok(TrackedRequest {
abort_registration,
span,
response_guard: ResponseGuard {
request_id: request.id,
request_cancellation: self.request_cancellation.clone(),
cancel: false,
},
request,
})
},
Err(AlreadyExistsError) => {
tracing::trace!("DuplicateRequest");
Err(AlreadyExistsError)
},
}
}
}
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BaseChannel")
}
}
#[derive(Debug)]
pub struct TrackedRequest<Req> {
pub request: Request<Req>,
pub abort_registration: AbortRegistration,
pub span: Span,
pub response_guard: ResponseGuard,
}
pub trait Channel
where
Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
{
type Req;
type Resp;
type Transport;
fn config(&self) -> &Config;
fn in_flight_requests(&self) -> usize;
fn transport(&self) -> &Self::Transport;
fn max_concurrent_requests(self, limit: usize) -> limits::requests_per_channel::MaxRequests<Self>
where
Self: Sized,
{
limits::requests_per_channel::MaxRequests::new(self, limit)
}
fn requests(self) -> Requests<Self>
where
Self: Sized,
{
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
Requests {
channel: self,
pending_responses: responses,
responses_tx,
}
}
fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
where
Self: Sized,
Self::Req: RequestName,
S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
{
self.requests().execute(serve)
}
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
#[derive(Clone, Copy, Debug)]
enum ReceiverStatus {
Ready,
Pending,
Closed,
}
impl ReceiverStatus {
fn combine(self, other: Self) -> Self {
use ReceiverStatus::*;
match (self, other) {
(Ready, _) | (_, Ready) => Ready,
(Closed, Closed) => Closed,
(Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
}
}
}
use ReceiverStatus::*;
loop {
let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
Poll::Ready(Some(request_id)) => {
if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
let _entered = span.enter();
tracing::info!("ResponseCancelled");
}
Ready
},
Poll::Pending | Poll::Ready(None) => Closed,
};
let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) {
Poll::Ready(Some(_)) => Ready,
Poll::Ready(None) => Closed,
Poll::Pending => Pending,
};
let request_status = match self.transport_pin_mut().poll_next(cx).map_err(|e| ChannelError::Read(Arc::new(e)))? {
Poll::Ready(Some(message)) => match message {
ClientMessage::Request(request) => {
match self.as_mut().start_request(request) {
Ok(request) => return Poll::Ready(Some(Ok(request))),
Err(AlreadyExistsError) => {
continue;
},
}
},
ClientMessage::Cancel { trace_context, request_id } => {
if !self.in_flight_requests_mut().cancel_request(request_id) {
tracing::trace!(
rpc.trace_id = %trace_context.trace_id,
"Received cancellation, but response handler is already complete.",
);
}
Ready
},
},
Poll::Ready(None) => Closed,
Poll::Pending => Pending,
};
let status = cancellation_status.combine(expiration_status).combine(request_status);
tracing::trace!(
"Cancellations: {cancellation_status:?}, \
Expired requests: {expiration_status:?}, \
Inbound: {request_status:?}, \
Overall: {status:?}",
);
match status {
Ready => continue,
Closed => return Poll::Ready(None),
Pending => return Poll::Pending,
}
}
}
}
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
T::Error: Error,
{
type Error = ChannelError<T::Error>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_ready(cx).map_err(|e| ChannelError::Ready(Arc::new(e)))
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if let Some(span) = self.in_flight_requests_mut().remove_request(response.request_id) {
let _entered = span.enter();
tracing::info!("SendResponse");
self.project().transport.start_send(response).map_err(|e| ChannelError::Write(Arc::new(e)))
} else {
Ok(())
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
tracing::trace!("poll_flush");
self.project().transport.poll_flush(cx).map_err(|e| ChannelError::Flush(Arc::new(e)))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_close(cx).map_err(|e| ChannelError::Close(Arc::new(e)))
}
}
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
fn as_ref(&self) -> &T {
self.transport.get_ref()
}
}
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Req = Req;
type Resp = Resp;
type Transport = T;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(&self) -> usize {
self.in_flight_requests.len()
}
fn transport(&self) -> &Self::Transport {
self.get_ref()
}
}
#[pin_project]
pub struct Requests<C>
where
C: Channel,
{
#[pin]
channel: C,
pending_responses: mpsc::Receiver<Response<C::Resp>>,
responses_tx: mpsc::Sender<Response<C::Resp>>,
}
impl<C> Requests<C>
where
C: Channel,
{
pub fn channel(&self) -> &C {
&self.channel
}
pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
self.as_mut().project().channel
}
pub fn pending_responses_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut mpsc::Receiver<Response<C::Resp>> {
self.as_mut().project().pending_responses
}
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
self.channel_pin_mut().poll_next(cx).map_ok(
|TrackedRequest {
request,
abort_registration,
span,
mut response_guard,
}| {
response_guard.cancel = true;
{
let _entered = span.enter();
tracing::info!("BeginRequest");
}
InFlightRequest {
request,
abort_registration,
span,
response_guard,
response_tx: self.responses_tx.clone(),
}
},
)
}
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, read_half_closed: bool) -> Poll<Option<Result<(), C::Error>>> {
match self.as_mut().poll_next_response(cx)? {
Poll::Ready(Some(response)) => {
self.channel_pin_mut().start_send(response)?;
Poll::Ready(Some(Ok(())))
},
Poll::Ready(None) => {
ready!(self.channel_pin_mut().poll_flush(cx)?);
Poll::Ready(None)
},
Poll::Pending => {
ready!(self.channel_pin_mut().poll_flush(cx)?);
if read_half_closed && self.channel.in_flight_requests() == 0 {
Poll::Ready(None)
} else {
Poll::Pending
}
},
}
}
fn poll_next_response(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Response<C::Resp>, C::Error>>> {
ready!(self.ensure_writeable(cx)?);
match ready!(self.pending_responses_mut().poll_recv(cx)) {
Some(response) => Poll::Ready(Some(Ok(response))),
None => {
Poll::Ready(None)
},
}
}
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<(), C::Error>>> {
while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.channel_pin_mut().poll_flush(cx)?);
}
Poll::Ready(Some(Ok(())))
}
pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
where
C::Req: RequestName,
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
{
self.take_while(|result| {
if let Err(e) = result {
tracing::warn!("Requests stream errored out: {}", e);
}
futures::future::ready(result.is_ok())
})
.filter_map(|result| async move { result.ok() })
.map(move |request| {
let serve = serve.clone();
request.execute(serve)
})
}
}
impl<C> fmt::Debug for Requests<C>
where
C: Channel,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "Requests")
}
}
#[derive(Debug)]
pub struct ResponseGuard {
request_cancellation: RequestCancellation,
request_id: u64,
cancel: bool,
}
impl Drop for ResponseGuard {
fn drop(&mut self) {
if self.cancel {
self.request_cancellation.cancel(self.request_id);
}
}
}
#[derive(Debug)]
pub struct InFlightRequest<Req, Res> {
request: Request<Req>,
abort_registration: AbortRegistration,
response_guard: ResponseGuard,
span: Span,
response_tx: mpsc::Sender<Response<Res>>,
}
impl<Req, Res> InFlightRequest<Req, Res> {
pub fn get(&self) -> &Request<Req> {
&self.request
}
pub async fn execute<S>(self, serve: S)
where
Req: RequestName,
S: Serve<Req = Req, Resp = Res>,
{
let Self {
response_tx,
mut response_guard,
abort_registration,
span,
request: Request { context, message, id: request_id },
} = self;
span.record("otel.name", message.name());
let _ = Abortable::new(
async move {
let message = serve.serve(context, message).await;
tracing::info!("CompleteRequest");
let response = Response { request_id, message };
let _ = response_tx.send(response).await;
tracing::info!("BufferResponse");
},
abort_registration,
)
.instrument(span)
.await;
response_guard.cancel = false;
}
}
fn print_err(e: &(dyn Error + 'static)) -> String {
anyhow::Chain::new(e).map(|e| e.to_string()).collect::<Vec<_>>().join(": ")
}
impl<C> Stream for Requests<C>
where
C: Channel,
{
type Item = Result<InFlightRequest<C::Req, C::Resp>, C::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let read = self.as_mut().pump_read(cx).map_err(|e| {
tracing::trace!("read: {}", print_err(&e));
e
})?;
let read_closed = matches!(read, Poll::Ready(None));
let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
tracing::trace!("write: {}", print_err(&e));
e
})?;
match (read, write) {
(Poll::Ready(None), Poll::Ready(None)) => {
tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
return Poll::Ready(None);
},
(Poll::Ready(Some(request_handler)), _) => {
tracing::trace!("read: Poll::Ready(Some), write: _");
return Poll::Ready(Some(Ok(request_handler)));
},
(_, Poll::Ready(Some(()))) => {
tracing::trace!("read: _, write: Poll::Ready(Some)");
},
(read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
tracing::trace!("read pending: {}, write pending: {}", read.is_pending(), write.is_pending());
return Poll::Pending;
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::in_flight_requests::AlreadyExistsError;
use super::request_hook::{AfterRequest, BeforeRequest, RequestHook};
use super::{serve, BaseChannel, Channel, Config, Requests, Serve};
use crate::transport::channel::{self, UnboundedChannel};
use crate::{context, trace, ClientMessage, Request, Response, ServerError};
use assert_matches::assert_matches;
use futures::future::{pending, AbortRegistration, Abortable, Aborted};
use futures::prelude::*;
use futures::Future;
use futures_test::task::noop_context;
use std::io;
use std::pin::Pin;
use std::task::Poll;
use std::time::{Duration, Instant};
fn test_channel<Req, Resp>() -> (
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::unbounded();
(Box::pin(BaseChannel::new(Config::default(), rx)), tx)
}
fn test_requests<Req, Resp>() -> (
Pin<Box<Requests<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>>,
UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::unbounded();
(Box::pin(BaseChannel::new(Config::default(), rx).requests()), tx)
}
fn test_bounded_requests<Req, Resp>(
capacity: usize,
) -> (
Pin<Box<Requests<BaseChannel<Req, Resp, channel::Channel<ClientMessage<Req>, Response<Resp>>>>>>,
channel::Channel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::bounded(capacity);
let config = Config {
pending_response_buffer: capacity + 1,
};
(Box::pin(BaseChannel::new(config, rx).requests()), tx)
}
fn fake_request<Req>(req: Req) -> ClientMessage<Req> {
ClientMessage::Request(Request {
context: context::rpc_current(),
id: 0,
message: req,
})
}
fn test_abortable(abort_registration: AbortRegistration) -> impl Future<Output = Result<(), Aborted>> {
Abortable::new(pending(), abort_registration)
}
#[tokio::test]
async fn test_serve() {
let serve = serve(|_, i| async move { Ok(i) });
assert_matches!(serve.serve(context::rpc_current(), 7).await, Ok(7));
}
#[tokio::test]
async fn serve_before_mutates_context() -> anyhow::Result<()> {
struct SetDeadline(Instant);
impl<Req> BeforeRequest<Req> for SetDeadline {
async fn before(&mut self, ctx: &mut context::Context, _: &Req) -> Result<(), ServerError> {
ctx.deadline = self.0;
Ok(())
}
}
let some_time = Instant::now() + Duration::from_secs(37);
let some_other_time = Instant::now() + Duration::from_secs(83);
let serve = serve(move |ctx: context::Context, i| async move {
assert_eq!(ctx.deadline, some_time);
Ok(i)
});
let deadline_hook = serve.before(SetDeadline(some_time));
let mut ctx = context::rpc_current();
ctx.deadline = some_other_time;
deadline_hook.serve(ctx, 7).await?;
Ok(())
}
#[tokio::test]
async fn serve_before_and_after() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
struct PrintLatency {
start: Instant,
}
impl PrintLatency {
fn new() -> Self {
Self { start: Instant::now() }
}
}
impl<Req> BeforeRequest<Req> for PrintLatency {
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
self.start = Instant::now();
Ok(())
}
}
impl<Resp> AfterRequest<Resp> for PrintLatency {
async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
tracing::info!("Elapsed: {:?}", self.start.elapsed());
}
}
let serve = serve(move |_: context::Context, i| async move { Ok(i) });
serve.before_and_after(PrintLatency::new()).serve(context::rpc_current(), 7).await?;
Ok(())
}
#[tokio::test]
async fn serve_before_error_aborts_request() -> anyhow::Result<()> {
let serve = serve(|_, _| async { panic!("Shouldn't get here") });
let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) });
let resp: Result<i32, _> = deadline_hook.serve(context::rpc_current(), 7).await;
assert_matches!(resp, Err(_));
Ok(())
}
#[tokio::test]
async fn base_channel_start_send_duplicate_request_returns_error() {
let (mut channel, _tx) = test_channel::<(), ()>();
channel
.as_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
assert_matches!(
channel.as_mut().start_request(Request {
id: 0,
context: context::rpc_current(),
message: ()
}),
Err(AlreadyExistsError)
);
}
#[tokio::test]
async fn base_channel_poll_next_aborts_multiple_requests() {
let (mut channel, _tx) = test_channel::<(), ()>();
tokio::time::pause();
let req0 = channel
.as_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
let req1 = channel
.as_mut()
.start_request(Request {
id: 1,
context: context::rpc_current(),
message: (),
})
.unwrap();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(channel.as_mut().poll_next(&mut noop_context()), Poll::Pending);
assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
}
#[tokio::test]
async fn base_channel_poll_next_aborts_canceled_request() {
let (mut channel, mut tx) = test_channel::<(), ()>();
tokio::time::pause();
let req = channel
.as_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
tx.send(ClientMessage::Cancel {
trace_context: trace::Context::default(),
request_id: 0,
})
.await
.unwrap();
assert_matches!(channel.as_mut().poll_next(&mut noop_context()), Poll::Pending);
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
}
#[tokio::test]
async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() {
let (mut channel, tx) = test_channel::<(), ()>();
tokio::time::pause();
let _abort_registration = channel
.as_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
drop(tx);
assert_matches!(channel.as_mut().poll_next(&mut noop_context()), Poll::Pending);
}
#[tokio::test]
async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() {
let (mut channel, tx) = test_channel::<(), ()>();
drop(tx);
assert_matches!(channel.as_mut().poll_next(&mut noop_context()), Poll::Ready(None));
}
#[tokio::test]
async fn base_channel_poll_next_yields_request() {
let (mut channel, mut tx) = test_channel::<(), ()>();
tx.send(fake_request(())).await.unwrap();
assert_matches!(channel.as_mut().poll_next(&mut noop_context()), Poll::Ready(Some(Ok(_))));
}
#[tokio::test]
async fn base_channel_poll_next_aborts_request_and_yields_request() {
let (mut channel, mut tx) = test_channel::<(), ()>();
tokio::time::pause();
let req = channel
.as_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
tx.send(fake_request(())).await.unwrap();
assert_matches!(channel.as_mut().poll_next(&mut noop_context()), Poll::Ready(Some(Ok(_))));
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
}
#[tokio::test]
async fn base_channel_start_send_removes_in_flight_request() {
let (mut channel, _tx) = test_channel::<(), ()>();
channel
.as_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
assert_eq!(channel.in_flight_requests(), 1);
channel.as_mut().start_send(Response { request_id: 0, message: Ok(()) }).unwrap();
assert_eq!(channel.in_flight_requests(), 0);
}
#[tokio::test]
async fn in_flight_request_drop_cancels_request() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
let request = match requests.as_mut().poll_next(&mut noop_context()) {
Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result),
};
drop(request);
let poll = requests.as_mut().channel_pin_mut().poll_next(&mut noop_context());
assert!(poll.is_pending());
let in_flight_requests = requests.channel().in_flight_requests();
assert_eq!(in_flight_requests, 0);
}
#[tokio::test]
async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
let request = match requests.as_mut().poll_next(&mut noop_context()) {
Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result),
};
request.execute(serve(|_, _| async { Ok(()) })).await;
assert!(requests.as_mut().channel_pin_mut().canceled_requests.poll_recv(&mut noop_context()).is_pending());
}
#[tokio::test]
async fn requests_poll_next_response_returns_pending_when_buffer_full() {
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
requests
.as_mut()
.channel_pin_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
requests.as_mut().channel_pin_mut().start_send(Response { request_id: 0, message: Ok(()) }).unwrap();
requests.as_mut().project().responses_tx.send(Response { request_id: 1, message: Ok(()) }).await.unwrap();
requests
.as_mut()
.channel_pin_mut()
.start_request(Request {
id: 1,
context: context::rpc_current(),
message: (),
})
.unwrap();
assert_matches!(requests.as_mut().poll_next_response(&mut noop_context()), Poll::Pending);
}
#[tokio::test]
async fn requests_pump_write_returns_pending_when_buffer_full() {
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
requests
.as_mut()
.channel_pin_mut()
.start_request(Request {
id: 0,
context: context::rpc_current(),
message: (),
})
.unwrap();
requests.as_mut().channel_pin_mut().start_send(Response { request_id: 0, message: Ok(()) }).unwrap();
requests
.as_mut()
.channel_pin_mut()
.start_request(Request {
id: 1,
context: context::rpc_current(),
message: (),
})
.unwrap();
requests.as_mut().project().responses_tx.send(Response { request_id: 1, message: Ok(()) }).await.unwrap();
assert_matches!(requests.as_mut().pump_write(&mut noop_context(), true), Poll::Pending);
assert_matches!(requests.as_mut().pending_responses_mut().recv().await, Some(_));
}
#[tokio::test]
async fn requests_pump_read() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
assert_matches!(requests.as_mut().pump_read(&mut noop_context()), Poll::Ready(Some(Ok(_))));
assert_eq!(requests.channel.in_flight_requests(), 1);
}
}