use crate::cx::Cx;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
pub trait Service<Request> {
type Response;
type Error;
type Future: Future<Output = Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
fn call(&mut self, req: Request) -> Self::Future;
}
pub trait ServiceExt<Request>: Service<Request> {
fn ready(&mut self) -> Ready<'_, Self, Request>
where
Self: Sized,
{
Ready::new(self)
}
fn oneshot(self, req: Request) -> Oneshot<Self, Request>
where
Self: Sized + Unpin,
Request: Unpin,
Self::Future: Unpin,
{
Oneshot::new(self, req)
}
}
impl<T, Request> ServiceExt<Request> for T where T: Service<Request> + ?Sized {}
#[allow(async_fn_in_trait)]
pub trait AsupersyncService<Request>: Send + Sync {
type Response;
type Error;
async fn call(&self, cx: &Cx, request: Request) -> Result<Self::Response, Self::Error>;
}
pub trait AsupersyncServiceExt<Request>: AsupersyncService<Request> {
fn map_response<F, NewResponse>(self, f: F) -> MapResponse<Self, F>
where
Self: Sized,
F: Fn(Self::Response) -> NewResponse + Send + Sync,
{
MapResponse::new(self, f)
}
fn map_err<F, NewError>(self, f: F) -> MapErr<Self, F>
where
Self: Sized,
F: Fn(Self::Error) -> NewError + Send + Sync,
{
MapErr::new(self, f)
}
#[cfg(feature = "tower")]
fn into_tower(self) -> TowerAdapter<Self>
where
Self: Sized,
{
TowerAdapter::new(self)
}
#[cfg(feature = "tower")]
fn into_tower_with_provider(self) -> TowerAdapterWithProvider<Self, ThreadLocalCxProvider>
where
Self: Sized,
{
TowerAdapterWithProvider::new(self)
}
}
impl<T, Request> AsupersyncServiceExt<Request> for T where T: AsupersyncService<Request> + ?Sized {}
pub struct MapResponse<S, F> {
service: S,
map: F,
}
impl<S, F> MapResponse<S, F> {
fn new(service: S, map: F) -> Self {
Self { service, map }
}
}
impl<S, F, Request, NewResponse> AsupersyncService<Request> for MapResponse<S, F>
where
S: AsupersyncService<Request>,
F: Fn(S::Response) -> NewResponse + Send + Sync,
{
type Response = NewResponse;
type Error = S::Error;
async fn call(&self, cx: &Cx, request: Request) -> Result<Self::Response, Self::Error> {
let response = self.service.call(cx, request).await?;
Ok((self.map)(response))
}
}
pub struct MapErr<S, F> {
service: S,
map: F,
}
impl<S, F> MapErr<S, F> {
fn new(service: S, map: F) -> Self {
Self { service, map }
}
}
impl<S, F, Request, NewError> AsupersyncService<Request> for MapErr<S, F>
where
S: AsupersyncService<Request>,
F: Fn(S::Error) -> NewError + Send + Sync,
{
type Response = S::Response;
type Error = NewError;
async fn call(&self, cx: &Cx, request: Request) -> Result<Self::Response, Self::Error> {
self.service.call(cx, request).await.map_err(&self.map)
}
}
impl<F, Fut, Request, Response, Error> AsupersyncService<Request> for F
where
F: Fn(&Cx, Request) -> Fut + Send + Sync,
Fut: Future<Output = Result<Response, Error>> + Send,
{
type Response = Response;
type Error = Error;
async fn call(&self, cx: &Cx, request: Request) -> Result<Self::Response, Self::Error> {
(self)(cx, request).await
}
}
#[cfg(feature = "tower")]
pub trait CxProvider: Send + Sync {
fn current_cx(&self) -> Option<Cx>;
}
#[cfg(feature = "tower")]
#[derive(Clone, Copy, Debug, Default)]
pub struct ThreadLocalCxProvider;
#[cfg(feature = "tower")]
impl CxProvider for ThreadLocalCxProvider {
fn current_cx(&self) -> Option<Cx> {
Cx::current()
}
}
#[cfg(feature = "tower")]
#[derive(Clone, Debug)]
pub struct FixedCxProvider {
cx: Cx,
}
#[cfg(feature = "tower")]
impl FixedCxProvider {
#[must_use]
pub fn new(cx: Cx) -> Self {
Self { cx }
}
#[must_use]
pub fn for_testing() -> Self {
Self {
cx: Cx::for_testing(),
}
}
}
#[cfg(feature = "tower")]
impl CxProvider for FixedCxProvider {
fn current_cx(&self) -> Option<Cx> {
Some(self.cx.clone())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CancellationMode {
#[default]
BestEffort,
Strict,
TimeoutFallback,
}
#[derive(Debug, Clone)]
pub struct AdapterConfig {
pub cancellation_mode: CancellationMode,
pub fallback_timeout: Option<std::time::Duration>,
pub min_budget_for_wait: u64,
}
impl Default for AdapterConfig {
fn default() -> Self {
Self {
cancellation_mode: CancellationMode::BestEffort,
fallback_timeout: None,
min_budget_for_wait: 10,
}
}
}
impl AdapterConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn cancellation_mode(mut self, mode: CancellationMode) -> Self {
self.cancellation_mode = mode;
self
}
#[must_use]
pub fn fallback_timeout(mut self, timeout: std::time::Duration) -> Self {
self.fallback_timeout = Some(timeout);
self
}
#[must_use]
pub fn min_budget_for_wait(mut self, budget: u64) -> Self {
self.min_budget_for_wait = budget;
self
}
}
pub trait ErrorAdapter: Send + Sync {
type TowerError;
type AsupersyncError;
fn to_asupersync(&self, err: Self::TowerError) -> Self::AsupersyncError;
fn to_tower(&self, err: Self::AsupersyncError) -> Self::TowerError;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultErrorAdapter<E> {
_marker: PhantomData<E>,
}
impl<E> DefaultErrorAdapter<E> {
#[must_use]
pub const fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<E> ErrorAdapter for DefaultErrorAdapter<E>
where
E: Clone + Send + Sync,
{
type TowerError = E;
type AsupersyncError = E;
fn to_asupersync(&self, err: Self::TowerError) -> Self::AsupersyncError {
err
}
fn to_tower(&self, err: Self::AsupersyncError) -> Self::TowerError {
err
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TowerAdapterError<E> {
Service(E),
Cancelled,
Timeout,
Overloaded,
CancellationIgnored,
}
impl<E: std::fmt::Display> std::fmt::Display for TowerAdapterError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Service(e) => write!(f, "service error: {e}"),
Self::Cancelled => write!(f, "operation cancelled"),
Self::Timeout => write!(f, "operation timed out"),
Self::Overloaded => write!(f, "service overloaded, insufficient budget"),
Self::CancellationIgnored => {
write!(f, "service ignored cancellation request (strict mode)")
}
}
}
}
impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for TowerAdapterError<E> {}
#[cfg(feature = "tower")]
pub struct TowerAdapter<S> {
service: std::sync::Arc<S>,
}
#[cfg(feature = "tower")]
impl<S> TowerAdapter<S> {
fn new(service: S) -> Self {
Self {
service: std::sync::Arc::new(service),
}
}
}
#[cfg(feature = "tower")]
impl<S, Request> tower::Service<(Cx, Request)> for TowerAdapter<S>
where
S: AsupersyncService<Request> + Send + Sync + 'static,
Request: Send + 'static,
S::Response: Send + 'static,
S::Error: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (cx, request): (Cx, Request)) -> Self::Future {
let service = std::sync::Arc::clone(&self.service);
Box::pin(async move { service.call(&cx, request).await })
}
}
#[cfg(feature = "tower")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NoCxAvailable;
#[cfg(feature = "tower")]
impl std::fmt::Display for NoCxAvailable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"no Cx available from provider (not running within asupersync runtime?)"
)
}
}
#[cfg(feature = "tower")]
impl std::error::Error for NoCxAvailable {}
#[cfg(feature = "tower")]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProviderAdapterError<E> {
NoCx(NoCxAvailable),
Service(E),
}
#[cfg(feature = "tower")]
impl<E: std::fmt::Display> std::fmt::Display for ProviderAdapterError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoCx(e) => write!(f, "{e}"),
Self::Service(e) => write!(f, "service error: {e}"),
}
}
}
#[cfg(feature = "tower")]
impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for ProviderAdapterError<E> {}
#[cfg(feature = "tower")]
pub struct TowerAdapterWithProvider<S, P = ThreadLocalCxProvider> {
service: std::sync::Arc<S>,
provider: P,
}
#[cfg(feature = "tower")]
impl<S> TowerAdapterWithProvider<S, ThreadLocalCxProvider> {
#[must_use]
pub fn new(service: S) -> Self {
Self {
service: std::sync::Arc::new(service),
provider: ThreadLocalCxProvider,
}
}
}
#[cfg(feature = "tower")]
impl<S, P> TowerAdapterWithProvider<S, P> {
#[must_use]
pub fn with_provider(service: S, provider: P) -> Self {
Self {
service: std::sync::Arc::new(service),
provider,
}
}
#[must_use]
pub fn provider(&self) -> &P {
&self.provider
}
}
#[cfg(feature = "tower")]
impl<S, P> Clone for TowerAdapterWithProvider<S, P>
where
P: Clone,
{
fn clone(&self) -> Self {
Self {
service: std::sync::Arc::clone(&self.service),
provider: self.provider.clone(),
}
}
}
#[cfg(feature = "tower")]
impl<S, P, Request> tower::Service<Request> for TowerAdapterWithProvider<S, P>
where
S: AsupersyncService<Request> + Send + Sync + 'static,
P: CxProvider,
Request: Send + 'static,
S::Response: 'static,
S::Error: 'static,
{
type Response = S::Response;
type Error = ProviderAdapterError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request) -> Self::Future {
let Some(cx) = self.provider.current_cx() else {
return Box::pin(std::future::ready(Err(ProviderAdapterError::NoCx(
NoCxAvailable,
))));
};
let service = std::sync::Arc::clone(&self.service);
Box::pin(async move {
service
.call(&cx, request)
.await
.map_err(ProviderAdapterError::Service)
})
}
}
#[cfg(feature = "tower")]
pub struct AsupersyncAdapter<S> {
inner: parking_lot::Mutex<S>,
config: AdapterConfig,
}
#[cfg(feature = "tower")]
impl<S> AsupersyncAdapter<S> {
pub fn new(service: S) -> Self {
Self {
inner: parking_lot::Mutex::new(service),
config: AdapterConfig::default(),
}
}
pub fn with_config(service: S, config: AdapterConfig) -> Self {
Self {
inner: parking_lot::Mutex::new(service),
config,
}
}
pub fn config(&self) -> &AdapterConfig {
&self.config
}
}
#[cfg(feature = "tower")]
impl<S, Request> AsupersyncService<Request> for AsupersyncAdapter<S>
where
S: tower::Service<Request> + Send + 'static,
Request: Send + 'static,
S::Response: Send + 'static,
S::Error: Send + std::fmt::Debug + std::fmt::Display + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = TowerAdapterError<S::Error>;
#[allow(clippy::await_holding_lock)]
#[allow(clippy::future_not_send)]
async fn call(&self, cx: &Cx, request: Request) -> Result<Self::Response, Self::Error> {
use std::future::poll_fn;
if cx.checkpoint().is_err() {
return Err(TowerAdapterError::Cancelled);
}
let budget = cx.budget();
if u64::from(budget.poll_quota) < self.config.min_budget_for_wait {
return Err(TowerAdapterError::Overloaded);
}
let timeout_deadline = match self.config.cancellation_mode {
CancellationMode::TimeoutFallback => self
.config
.fallback_timeout
.and_then(|timeout| cx.timer_driver().map(|timer| timer.now() + timeout)),
_ => None,
};
let mut service = self.inner.lock();
let ready_result = if let Some(deadline) = timeout_deadline {
crate::time::timeout_at(deadline, poll_fn(|poll_cx| service.poll_ready(poll_cx)))
.await
.map_err(|_| TowerAdapterError::Timeout)?
} else {
poll_fn(|poll_cx| service.poll_ready(poll_cx)).await
};
if let Err(e) = ready_result {
return Err(TowerAdapterError::Service(e));
}
if cx.checkpoint().is_err() {
return Err(TowerAdapterError::Cancelled);
}
let future = service.call(request);
drop(service);
match self.config.cancellation_mode {
CancellationMode::BestEffort => {
future.await.map_err(TowerAdapterError::Service)
}
CancellationMode::Strict => {
let result = future.await.map_err(TowerAdapterError::Service);
if cx.checkpoint().is_err() {
return Err(TowerAdapterError::CancellationIgnored);
}
result
}
CancellationMode::TimeoutFallback => {
if let Some(deadline) = timeout_deadline {
crate::time::timeout_at(deadline, Box::pin(future))
.await
.map_or_else(
|_| Err(TowerAdapterError::Timeout),
|output| output.map_err(TowerAdapterError::Service),
)
} else {
future.await.map_err(TowerAdapterError::Service)
}
}
}
}
}
#[derive(Debug)]
pub struct Ready<'a, S: ?Sized, Request> {
service: &'a mut S,
_marker: PhantomData<fn(Request)>,
}
impl<'a, S: ?Sized, Request> Ready<'a, S, Request> {
fn new(service: &'a mut S) -> Self {
Self {
service,
_marker: PhantomData,
}
}
}
impl<S, Request> Future for Ready<'_, S, Request>
where
S: Service<Request> + ?Sized,
{
type Output = Result<(), S::Error>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
this.service.poll_ready(cx)
}
}
#[derive(Debug)]
pub enum OneshotError<E> {
Inner(E),
PolledAfterCompletion,
}
impl<E: std::fmt::Display> std::fmt::Display for OneshotError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Inner(err) => write!(f, "inner service error: {err}"),
Self::PolledAfterCompletion => write!(f, "oneshot future polled after completion"),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for OneshotError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Inner(err) => Some(err),
Self::PolledAfterCompletion => None,
}
}
}
pub struct Oneshot<S, Request>
where
S: Service<Request>,
{
state: OneshotState<S, Request>,
}
enum OneshotState<S, Request>
where
S: Service<Request>,
{
Ready {
service: S,
request: Option<Request>,
},
Calling {
future: S::Future,
},
Done,
}
impl<S, Request> Oneshot<S, Request>
where
S: Service<Request>,
{
pub fn new(service: S, request: Request) -> Self {
Self {
state: OneshotState::Ready {
service,
request: Some(request),
},
}
}
}
impl<S, Request> Future for Oneshot<S, Request>
where
S: Service<Request> + Unpin,
Request: Unpin,
S::Future: Unpin,
{
type Output = Result<S::Response, OneshotError<S::Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
let state = std::mem::replace(&mut this.state, OneshotState::Done);
match state {
OneshotState::Ready {
mut service,
mut request,
} => match service.poll_ready(cx) {
Poll::Pending => {
this.state = OneshotState::Ready { service, request };
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
return Poll::Ready(Err(OneshotError::Inner(err)));
}
Poll::Ready(Ok(())) => {
let Some(req) = request.take() else {
return Poll::Ready(Err(OneshotError::PolledAfterCompletion));
};
let fut = service.call(req);
this.state = OneshotState::Calling { future: fut };
}
},
OneshotState::Calling { mut future } => {
let result = Pin::new(&mut future).poll(cx);
if result.is_pending() {
this.state = OneshotState::Calling { future };
}
return result.map_err(OneshotError::Inner);
}
OneshotState::Done => {
return Poll::Ready(Err(OneshotError::PolledAfterCompletion));
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::{
AsupersyncService, AsupersyncServiceExt, OneshotError, OneshotState, Service, ServiceExt,
};
use crate::test_utils::run_test_with_cx;
use std::cell::Cell;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
#[derive(Clone, Debug)]
struct PanicOnCallService;
impl Service<u32> for PanicOnCallService {
type Response = ();
type Error = ();
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: u32) -> Self::Future {
panic!("panic during oneshot call construction");
}
}
#[derive(Clone, Debug)]
struct EchoU32Service;
impl Service<u32> for EchoU32Service {
type Response = u32;
type Error = ();
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: u32) -> Self::Future {
std::future::ready(Ok(req))
}
}
#[derive(Debug)]
struct PendingThenReadyFuture {
value: u32,
first_poll: bool,
}
impl Future for PendingThenReadyFuture {
type Output = Result<u32, ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.first_poll {
self.first_poll = false;
cx.waker().wake_by_ref();
return Poll::Pending;
}
Poll::Ready(Ok(self.value))
}
}
#[derive(Debug)]
struct PendingThenReadyService {
ready_polls: Cell<u8>,
}
impl PendingThenReadyService {
fn new() -> Self {
Self {
ready_polls: Cell::new(0),
}
}
}
impl Service<u32> for PendingThenReadyService {
type Response = u32;
type Error = ();
type Future = PendingThenReadyFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.ready_polls.get() == 0 {
self.ready_polls.set(1);
cx.waker().wake_by_ref();
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
fn call(&mut self, req: u32) -> Self::Future {
PendingThenReadyFuture {
value: req,
first_poll: true,
}
}
}
#[derive(Clone, Debug)]
struct ErrorOnCallService;
impl Service<u32> for ErrorOnCallService {
type Response = u32;
type Error = ();
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: u32) -> Self::Future {
std::future::ready(Err(()))
}
}
#[test]
fn function_service_call_works() {
run_test_with_cx(|cx| async move {
let svc = |_: &crate::cx::Cx, req: i32| async move { Ok::<_, ()>(req + 1) };
let result = AsupersyncService::call(&svc, &cx, 41).await.unwrap();
assert_eq!(result, 42);
});
}
#[test]
fn map_response_and_map_err() {
run_test_with_cx(|cx| async move {
let svc = |_: &crate::cx::Cx, req: i32| async move { Ok::<_, &str>(req) };
let svc = svc.map_response(|v| v + 1).map_err(|e| format!("err:{e}"));
let result = AsupersyncService::call(&svc, &cx, 41).await.unwrap();
assert_eq!(result, 42);
let fail = |_: &crate::cx::Cx, _: i32| async move { Err::<i32, &str>("nope") };
let fail = fail.map_err(|e| format!("err:{e}"));
let err = AsupersyncService::call(&fail, &cx, 0).await.unwrap_err();
assert_eq!(err, "err:nope");
});
}
#[test]
fn oneshot_second_poll_fails_closed_after_success() {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = EchoU32Service.oneshot(7);
assert!(matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Ok(7))
));
assert!(matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Err(OneshotError::PolledAfterCompletion))
));
}
#[test]
fn oneshot_pending_then_completion_then_repoll_fails_closed() {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = PendingThenReadyService::new().oneshot(9);
assert!(matches!(Pin::new(&mut fut).poll(&mut cx), Poll::Pending));
assert!(matches!(Pin::new(&mut fut).poll(&mut cx), Poll::Pending));
assert!(matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Ok(9))
));
assert!(matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Err(OneshotError::PolledAfterCompletion))
));
}
#[test]
fn oneshot_repoll_after_inner_error_fails_closed() {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = ErrorOnCallService.oneshot(7);
assert!(matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Err(OneshotError::Inner(())))
));
assert!(matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Err(OneshotError::PolledAfterCompletion))
));
}
#[test]
fn oneshot_call_panic_leaves_terminal_state() {
let mut fut = PanicOnCallService.oneshot(7);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let first_panic = catch_unwind(AssertUnwindSafe(|| {
let _ = Pin::new(&mut fut).poll(&mut cx);
}));
assert!(first_panic.is_err(), "first poll should propagate panic");
assert!(
matches!(fut.state, OneshotState::Done),
"panic path must leave Oneshot in Done state"
);
assert!(
matches!(
Pin::new(&mut fut).poll(&mut cx),
Poll::Ready(Err(OneshotError::PolledAfterCompletion))
),
"repoll should fail closed after panic left the future terminal"
);
}
#[test]
fn oneshot_error_display_and_source() {
use std::error::Error;
let inner = OneshotError::Inner(std::io::Error::other("boom"));
assert_eq!(format!("{inner}"), "inner service error: boom");
assert!(inner.source().is_some());
let done: OneshotError<std::io::Error> = OneshotError::PolledAfterCompletion;
assert_eq!(format!("{done}"), "oneshot future polled after completion");
assert!(done.source().is_none());
}
use super::{
AdapterConfig, CancellationMode, DefaultErrorAdapter, ErrorAdapter, TowerAdapterError,
};
#[test]
fn cancellation_mode_default_is_best_effort() {
let mode = CancellationMode::default();
assert_eq!(mode, CancellationMode::BestEffort);
}
#[test]
fn adapter_config_builder_pattern() {
let config = AdapterConfig::new()
.cancellation_mode(CancellationMode::Strict)
.fallback_timeout(std::time::Duration::from_secs(30))
.min_budget_for_wait(100);
assert_eq!(config.cancellation_mode, CancellationMode::Strict);
assert_eq!(
config.fallback_timeout,
Some(std::time::Duration::from_secs(30))
);
assert_eq!(config.min_budget_for_wait, 100);
}
#[test]
fn adapter_config_default_values() {
let config = AdapterConfig::default();
assert_eq!(config.cancellation_mode, CancellationMode::BestEffort);
assert!(config.fallback_timeout.is_none());
assert_eq!(config.min_budget_for_wait, 10);
}
#[test]
fn default_error_adapter_round_trip() {
let adapter = DefaultErrorAdapter::<String>::new();
let original = "test error".to_string();
let converted = adapter.to_asupersync(original.clone());
assert_eq!(converted, original);
let back = adapter.to_tower(converted);
assert_eq!(back, original);
}
#[test]
fn tower_adapter_error_display() {
let service_err: TowerAdapterError<&str> = TowerAdapterError::Service("inner error");
assert_eq!(format!("{service_err}"), "service error: inner error");
let cancelled: TowerAdapterError<&str> = TowerAdapterError::Cancelled;
assert_eq!(format!("{cancelled}"), "operation cancelled");
let timeout: TowerAdapterError<&str> = TowerAdapterError::Timeout;
assert_eq!(format!("{timeout}"), "operation timed out");
let overloaded: TowerAdapterError<&str> = TowerAdapterError::Overloaded;
assert_eq!(
format!("{overloaded}"),
"service overloaded, insufficient budget"
);
let ignored: TowerAdapterError<&str> = TowerAdapterError::CancellationIgnored;
assert_eq!(
format!("{ignored}"),
"service ignored cancellation request (strict mode)"
);
}
#[test]
fn tower_adapter_error_equality() {
let err1: TowerAdapterError<i32> = TowerAdapterError::Service(42);
let err2: TowerAdapterError<i32> = TowerAdapterError::Service(42);
let err3: TowerAdapterError<i32> = TowerAdapterError::Service(43);
assert_eq!(err1, err2);
assert_ne!(err1, err3);
assert_eq!(
TowerAdapterError::<i32>::Cancelled,
TowerAdapterError::Cancelled
);
assert_ne!(
TowerAdapterError::<i32>::Cancelled,
TowerAdapterError::Timeout
);
}
#[test]
fn cancellation_mode_all_variants() {
let best_effort = CancellationMode::BestEffort;
let strict = CancellationMode::Strict;
let timeout = CancellationMode::TimeoutFallback;
assert_ne!(best_effort, strict);
assert_ne!(best_effort, timeout);
assert_ne!(strict, timeout);
}
#[cfg(feature = "tower")]
mod cx_provider_tests {
use super::super::{CxProvider, FixedCxProvider, ThreadLocalCxProvider};
use crate::Cx;
#[test]
fn thread_local_provider_returns_none_when_not_set() {
let provider = ThreadLocalCxProvider;
assert!(provider.current_cx().is_none());
}
#[test]
fn fixed_provider_returns_cx() {
let cx: Cx = Cx::for_testing();
let provider = FixedCxProvider::new(cx.clone());
let retrieved = provider.current_cx();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().task_id(), cx.task_id());
}
#[test]
fn fixed_provider_for_testing_convenience() {
let provider = FixedCxProvider::for_testing();
assert!(provider.current_cx().is_some());
}
#[test]
fn thread_local_provider_default() {
let provider = ThreadLocalCxProvider;
let _ = provider.current_cx();
}
#[test]
fn fixed_provider_is_cloneable() {
let provider = FixedCxProvider::for_testing();
let cloned = provider.clone();
assert!(provider.current_cx().is_some());
assert!(cloned.current_cx().is_some());
}
}
#[cfg(feature = "tower")]
mod tower_provider_tests {
use super::super::{
AsupersyncService, CxProvider, FixedCxProvider, NoCxAvailable, ProviderAdapterError,
TowerAdapterWithProvider,
};
use crate::Cx;
struct AddOneService;
impl AsupersyncService<i32> for AddOneService {
type Response = i32;
type Error = std::convert::Infallible;
async fn call(&self, _cx: &Cx, req: i32) -> Result<Self::Response, Self::Error> {
Ok(req + 1)
}
}
#[test]
fn adapter_with_fixed_provider_works() {
let provider = FixedCxProvider::for_testing();
let adapter = TowerAdapterWithProvider::with_provider(AddOneService, provider);
assert!(adapter.provider().current_cx().is_some());
}
#[test]
fn adapter_new_uses_thread_local_provider() {
let adapter = TowerAdapterWithProvider::new(AddOneService);
assert!(adapter.provider().current_cx().is_none());
}
#[test]
fn adapter_is_cloneable_with_clone_provider() {
let provider = FixedCxProvider::for_testing();
let adapter = TowerAdapterWithProvider::with_provider(AddOneService, provider);
let _cloned = adapter;
}
#[test]
fn no_cx_available_error_display() {
let err = NoCxAvailable;
let msg = format!("{err}");
assert!(msg.contains("no Cx available"));
}
#[test]
fn provider_adapter_error_display() {
let no_cx: ProviderAdapterError<&str> = ProviderAdapterError::NoCx(NoCxAvailable);
assert!(format!("{no_cx}").contains("no Cx available"));
let service_err: ProviderAdapterError<&str> =
ProviderAdapterError::Service("test error");
assert_eq!(format!("{service_err}"), "service error: test error");
}
#[test]
fn provider_adapter_error_equality() {
let err1: ProviderAdapterError<i32> = ProviderAdapterError::Service(42);
let err2: ProviderAdapterError<i32> = ProviderAdapterError::Service(42);
let err3: ProviderAdapterError<i32> = ProviderAdapterError::Service(43);
assert_eq!(err1, err2);
assert_ne!(err1, err3);
let no_cx1: ProviderAdapterError<i32> = ProviderAdapterError::NoCx(NoCxAvailable);
let no_cx2: ProviderAdapterError<i32> = ProviderAdapterError::NoCx(NoCxAvailable);
assert_eq!(no_cx1, no_cx2);
}
}
#[cfg(feature = "tower")]
mod tower_adapter_timeout_tests {
use super::super::{
AdapterConfig, AsupersyncAdapter, AsupersyncService, CancellationMode,
TowerAdapterError,
};
use crate::Cx;
use crate::time::{TimerDriverHandle, VirtualClock};
use crate::types::{Budget, RegionId, TaskId, Time};
use std::future::{Future, pending};
use std::pin::pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::time::Duration;
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
fn test_cx_with_timer() -> (Cx, Arc<VirtualClock>, TimerDriverHandle) {
let clock = Arc::new(VirtualClock::starting_at(Time::ZERO));
let timer = TimerDriverHandle::with_virtual_clock(clock.clone());
let cx = Cx::new_with_drivers(
RegionId::new_for_test(1, 0),
TaskId::new_for_test(1, 0),
Budget::INFINITE,
None,
None,
None,
Some(timer.clone()),
None,
);
(cx, clock, timer)
}
#[derive(Clone)]
struct PendingService;
#[derive(Clone)]
struct PendingReadyService;
#[derive(Debug)]
struct TestError;
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "test error")
}
}
impl tower::Service<()> for PendingService {
type Response = ();
type Error = TestError;
type Future = std::future::Pending<Result<(), TestError>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
pending()
}
}
impl tower::Service<()> for PendingReadyService {
type Response = ();
type Error = TestError;
type Future = std::future::Ready<Result<(), TestError>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn call(&mut self, _req: ()) -> Self::Future {
std::future::ready(Ok(()))
}
}
#[test]
fn timeout_fallback_triggers_timeout_error() {
crate::test_utils::init_test_logging();
crate::test_phase!("timeout_fallback_triggers_timeout_error");
let (cx, clock, timer) = test_cx_with_timer();
let _guard = Cx::set_current(Some(cx.clone()));
let config = AdapterConfig::new()
.cancellation_mode(CancellationMode::TimeoutFallback)
.fallback_timeout(Duration::from_millis(5));
let adapter = AsupersyncAdapter::with_config(PendingService, config);
let mut fut = pin!(adapter.call(&cx, ()));
let waker = noop_waker();
let mut context = Context::from_waker(&waker);
let first = fut.as_mut().poll(&mut context);
assert!(first.is_pending());
clock.advance(Time::from_millis(6).as_nanos());
let _ = timer.process_timers();
let result = fut.as_mut().poll(&mut context);
let timed_out = matches!(result, Poll::Ready(Err(TowerAdapterError::Timeout)));
crate::assert_with_log!(timed_out, "timeout error", true, timed_out);
crate::test_complete!("timeout_fallback_triggers_timeout_error");
}
#[test]
fn timeout_fallback_times_out_while_waiting_for_ready() {
crate::test_utils::init_test_logging();
crate::test_phase!("timeout_fallback_times_out_while_waiting_for_ready");
let (cx, clock, timer) = test_cx_with_timer();
let _guard = Cx::set_current(Some(cx.clone()));
let config = AdapterConfig::new()
.cancellation_mode(CancellationMode::TimeoutFallback)
.fallback_timeout(Duration::from_millis(5));
let adapter = AsupersyncAdapter::with_config(PendingReadyService, config);
let mut fut = pin!(adapter.call(&cx, ()));
let waker = noop_waker();
let mut context = Context::from_waker(&waker);
let first = fut.as_mut().poll(&mut context);
assert!(first.is_pending());
clock.advance(Time::from_millis(6).as_nanos());
let _ = timer.process_timers();
let result = fut.as_mut().poll(&mut context);
let timed_out = matches!(result, Poll::Ready(Err(TowerAdapterError::Timeout)));
crate::assert_with_log!(timed_out, "readiness timeout error", true, timed_out);
crate::test_complete!("timeout_fallback_times_out_while_waiting_for_ready");
}
}
#[test]
fn cancellation_mode_debug_clone_copy_default_eq() {
let m = CancellationMode::default();
assert_eq!(m, CancellationMode::BestEffort);
let dbg = format!("{m:?}");
assert!(dbg.contains("BestEffort"));
let m2 = m;
assert_eq!(m, m2);
let m3 = m;
assert_eq!(m, m3);
assert_ne!(CancellationMode::BestEffort, CancellationMode::Strict);
}
#[test]
fn tower_adapter_error_debug_clone_eq() {
let e: TowerAdapterError<String> = TowerAdapterError::Cancelled;
let dbg = format!("{e:?}");
assert!(dbg.contains("Cancelled"));
let e2 = e.clone();
assert_eq!(e, e2);
assert_ne!(
TowerAdapterError::<String>::Cancelled,
TowerAdapterError::Timeout
);
}
}