use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Instant,
};
use tokio::sync::OwnedSemaphorePermit;
use crate::Algorithm;
use crate::classifier::Classifier;
use crate::controller::Controller;
struct FutureGuard<A: Algorithm> {
controller: Arc<Mutex<Controller<A>>>,
is_canceled: bool,
is_error: bool,
permit: Option<OwnedSemaphorePermit>,
start: Instant,
}
impl<A: Algorithm> Drop for FutureGuard<A> {
fn drop(&mut self) {
drop(self.permit.take());
self.controller
.lock()
.expect("Controller::update should not panic")
.update(self.start.elapsed(), self.is_error, self.is_canceled);
}
}
pin_project! {
pub struct ResponseFuture<F, A: Algorithm, C> {
#[pin]
future: F,
classifier: C,
guard: FutureGuard<A>,
}
}
impl<F, A: Algorithm, C> ResponseFuture<F, A, C> {
pub(super) fn new(
future: F,
controller: Arc<Mutex<Controller<A>>>,
permit: OwnedSemaphorePermit,
start: Instant,
classifier: C,
) -> Self {
Self {
future,
classifier,
guard: FutureGuard {
controller,
is_canceled: true,
is_error: false,
permit: Some(permit),
start,
},
}
}
}
impl<F, T, E, A, C> Future for ResponseFuture<F, A, C>
where
F: Future<Output = Result<T, E>>,
A: Algorithm,
C: Classifier<T, E>,
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.future.poll(cx) {
Poll::Ready(result) => {
this.guard.is_canceled = false;
this.guard.is_error = this.classifier.is_server_error(&result);
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::classifier::DefaultClassifier;
use crate::controller::Controller;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
struct RecordingAlgorithm {
limit: usize,
updates: Arc<Mutex<Vec<(bool, bool)>>>, }
impl RecordingAlgorithm {
fn new(limit: usize) -> (Self, Arc<Mutex<Vec<(bool, bool)>>>) {
let updates = Arc::new(Mutex::new(Vec::new()));
(
Self {
limit,
updates: updates.clone(),
},
updates,
)
}
}
impl Algorithm for RecordingAlgorithm {
fn max_concurrency(&self) -> usize {
self.limit
}
fn update(
&mut self,
_rtt: Duration,
_num_inflight: usize,
is_error: bool,
is_canceled: bool,
) {
self.updates.lock().unwrap().push((is_error, is_canceled));
}
}
fn make_fixture(
limit: usize,
) -> (
Arc<Mutex<Controller<RecordingAlgorithm>>>,
Arc<Semaphore>,
Arc<Mutex<Vec<(bool, bool)>>>,
) {
let (algo, updates) = RecordingAlgorithm::new(limit);
let controller = Controller::new(algo);
let semaphore = controller.semaphore();
(Arc::new(Mutex::new(controller)), semaphore, updates)
}
#[tokio::test]
async fn success_reports_no_error_no_cancel() {
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
let fut = ResponseFuture::new(
async { Ok::<_, ()>("ok") },
controller,
permit,
Instant::now(),
DefaultClassifier,
);
let result = fut.await;
assert!(result.is_ok());
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], (false, false)); }
#[tokio::test]
async fn error_reports_is_error() {
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
let fut = ResponseFuture::new(
async { Err::<(), _>("fail") },
controller,
permit,
Instant::now(),
DefaultClassifier,
);
let result = fut.await;
assert!(result.is_err());
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], (true, false)); }
#[tokio::test]
async fn drop_before_completion_reports_canceled() {
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
let fut = ResponseFuture::new(
std::future::pending::<Result<(), ()>>(),
controller,
permit,
Instant::now(),
DefaultClassifier,
);
drop(fut);
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], (false, true)); }
#[tokio::test]
async fn permit_returned_before_controller_update() {
let (controller, semaphore, updates) = make_fixture(1);
let permit = semaphore.clone().acquire_owned().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
let fut = ResponseFuture::new(
async { Ok::<_, ()>("ok") },
controller,
permit,
Instant::now(),
DefaultClassifier,
);
fut.await.unwrap();
assert_eq!(semaphore.available_permits(), 1);
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
}
#[tokio::test]
async fn pending_then_ready() {
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
struct OnePending {
polled: bool,
}
impl Future for OnePending {
type Output = Result<&'static str, ()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if !this.polled {
this.polled = true;
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(Ok("done"))
}
}
}
let fut = ResponseFuture::new(
OnePending { polled: false },
controller,
permit,
Instant::now(),
DefaultClassifier,
);
let result = fut.await;
assert_eq!(result.unwrap(), "done");
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], (false, false));
}
#[tokio::test]
async fn custom_classifier_overrides_error_detection() {
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
let classifier = |result: &Result<(), &str>| match result {
Err(e) => *e != "not_found",
Ok(_) => false,
};
let fut = ResponseFuture::new(
async { Err::<(), _>("not_found") },
controller,
permit,
Instant::now(),
classifier,
);
let result = fut.await;
assert!(result.is_err());
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], (false, false)); }
#[tokio::test]
async fn classifier_can_inspect_ok_variant() {
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
let classifier = |result: &Result<u16, &str>| match result {
Ok(status) => *status >= 500,
Err(_) => true,
};
let fut = ResponseFuture::new(
async { Ok::<u16, &str>(503) },
controller,
permit,
Instant::now(),
classifier,
);
let result = fut.await;
assert_eq!(result.unwrap(), 503);
let log = updates.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], (true, false)); }
#[tokio::test]
async fn struct_classifier() {
struct HttpClassifier;
impl Classifier<u16, &'static str> for HttpClassifier {
fn is_server_error(&self, result: &Result<u16, &'static str>) -> bool {
match result {
Ok(status) => *status >= 500,
Err(_) => true,
}
}
}
let (controller, semaphore, updates) = make_fixture(10);
let permit = semaphore.acquire_owned().await.unwrap();
let fut = ResponseFuture::new(
async { Ok::<u16, &str>(200) },
controller,
permit,
Instant::now(),
HttpClassifier,
);
let result = fut.await;
assert_eq!(result.unwrap(), 200);
let log = updates.lock().unwrap();
assert_eq!(log[0], (false, false)); }
}