use std::{
future::Future,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use futures::{StreamExt, stream::FuturesUnordered};
use tokio::sync::watch;
pub(super) async fn try_priority_groups<Item, T, E, F, Fut>(
groups: Vec<Vec<Item>>,
f: F,
per_group_delay: Duration,
) -> Result<(Item, T), Vec<(Item, E)>>
where
Item: Clone + Send + 'static,
T: Send + 'static,
E: Send + 'static,
F: Fn(Item) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send + 'static,
{
let groups: Vec<Vec<Item>> = groups
.into_iter()
.filter(|group| !group.is_empty())
.collect();
let group_done: Vec<Arc<watch::Sender<bool>>> = (0..groups.len())
.map(|_| Arc::new(watch::Sender::new(false)))
.collect();
let mut futures: FuturesUnordered<_> = FuturesUnordered::new();
for (group_idx, group) in groups.into_iter().enumerate() {
let remaining = Arc::new(AtomicUsize::new(group.len()));
let my_done_tx = group_done[group_idx].clone();
let prev_done_rx = (group_idx > 0).then(|| group_done[group_idx - 1].subscribe());
let delay = per_group_delay * group_idx as u32;
for item in group {
let f = f.clone();
let remaining = remaining.clone();
let my_done_tx = my_done_tx.clone();
let mut prev_done_rx = prev_done_rx.clone();
futures.push(async move {
if let Some(ref mut rx) = prev_done_rx {
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = rx.wait_for(|&done| done) => {}
}
}
let item_for_call = item.clone();
let result = f(item_for_call).await;
if remaining.fetch_sub(1, Ordering::Relaxed) == 1 {
let _ = my_done_tx.send(true);
}
(item, result)
});
}
}
let mut errors = Vec::new();
while let Some((item, result)) = futures.next().await {
match result {
Ok(value) => return Ok((item, value)),
Err(e) => errors.push((item, e)),
}
}
Err(errors)
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
enum Outcome {
Ok(Duration),
Err(Duration),
}
fn ok_after(ms: u64) -> Outcome {
Outcome::Ok(Duration::from_millis(ms))
}
fn err_after(ms: u64) -> Outcome {
Outcome::Err(Duration::from_millis(ms))
}
async fn run_groups(
specs: Vec<Vec<Outcome>>,
per_group_delay_ms: u64,
) -> Result<String, Vec<String>> {
let groups: Vec<Vec<(String, Outcome)>> = specs
.into_iter()
.enumerate()
.map(|(g, apis)| {
apis.into_iter()
.enumerate()
.map(|(a, outcome)| (format!("g{g}a{a}"), outcome))
.collect()
})
.collect();
let result = try_priority_groups(
groups,
|(tag, outcome)| {
async move {
match outcome {
Outcome::Ok(delay) => {
tokio::time::sleep(delay).await;
Ok(tag)
}
Outcome::Err(delay) => {
tokio::time::sleep(delay).await;
Err(tag)
}
}
}
},
Duration::from_millis(per_group_delay_ms),
)
.await;
match result {
Ok((_item, winner)) => Ok(winner),
Err(errors) => Err(errors.into_iter().map(|(_item, tag)| tag).collect()),
}
}
#[tokio::test]
async fn group0_wins() {
tokio::time::pause();
let result = run_groups(vec![vec![ok_after(0)], vec![ok_after(0)]], 500).await;
assert_eq!(result.unwrap(), "g0a0");
}
#[tokio::test]
async fn group1_wins_when_group0_fails_fast() {
tokio::time::pause();
let result = run_groups(vec![vec![err_after(0)], vec![ok_after(0)]], 500).await;
assert_eq!(result.unwrap(), "g1a0");
}
#[tokio::test]
async fn group1_wins_when_group0_is_slow() {
tokio::time::pause();
let fut = tokio::spawn(run_groups(
vec![vec![err_after(10_000)], vec![ok_after(0)]],
500,
));
tokio::time::advance(Duration::from_millis(500)).await;
assert_eq!(fut.await.unwrap().unwrap(), "g1a0");
}
#[tokio::test]
async fn all_groups_fail() {
tokio::time::pause();
let mut errors = run_groups(vec![vec![err_after(0)], vec![err_after(0)]], 500)
.await
.unwrap_err();
errors.sort();
assert_eq!(errors, &["g0a0", "g1a0"]);
}
#[tokio::test]
async fn group2_wins_when_groups01_fail_fast() {
tokio::time::pause();
let result = run_groups(
vec![vec![err_after(0)], vec![err_after(0)], vec![ok_after(0)]],
500,
)
.await;
assert_eq!(result.unwrap(), "g2a0");
}
#[tokio::test]
async fn group2_wins_when_groups01_are_slow() {
tokio::time::pause();
let fut = tokio::spawn(run_groups(
vec![
vec![err_after(10_000)],
vec![err_after(10_000)],
vec![ok_after(0)],
],
500,
));
tokio::time::advance(Duration::from_millis(1000)).await;
assert_eq!(fut.await.unwrap().unwrap(), "g2a0");
}
#[tokio::test]
async fn group0_wins_multi() {
tokio::time::pause();
let result = run_groups(
vec![
vec![ok_after(0), err_after(0)],
vec![ok_after(0), err_after(0)],
],
500,
)
.await;
assert_eq!(result.unwrap(), "g0a0");
}
#[tokio::test]
async fn group1_wins_when_group0_fails_fast_multi() {
tokio::time::pause();
let result = run_groups(
vec![
vec![err_after(0), err_after(0)],
vec![ok_after(0), err_after(0)],
],
500,
)
.await;
assert_eq!(result.unwrap(), "g1a0");
}
#[tokio::test]
async fn group1_wins_when_group0_is_slow_multi() {
tokio::time::pause();
let fut = tokio::spawn(run_groups(
vec![
vec![err_after(10_000), err_after(0)],
vec![ok_after(0), err_after(0)],
],
500,
));
tokio::time::advance(Duration::from_millis(500)).await;
assert_eq!(fut.await.unwrap().unwrap(), "g1a0");
}
#[tokio::test]
async fn all_groups_fail_multi() {
tokio::time::pause();
let mut errors = run_groups(
vec![
vec![err_after(0), err_after(0)],
vec![err_after(0), err_after(0)],
],
500,
)
.await
.unwrap_err();
errors.sort();
assert_eq!(errors, &["g0a0", "g0a1", "g1a0", "g1a1"]);
}
#[tokio::test]
async fn group2_wins_when_groups01_fail_fast_multi() {
tokio::time::pause();
let result = run_groups(
vec![
vec![err_after(0), err_after(0)],
vec![err_after(0), err_after(0)],
vec![ok_after(0), err_after(0)],
],
500,
)
.await;
assert_eq!(result.unwrap(), "g2a0");
}
#[tokio::test]
async fn group2_wins_when_groups01_are_slow_multi() {
tokio::time::pause();
let fut = tokio::spawn(run_groups(
vec![
vec![err_after(10_000), err_after(0)],
vec![err_after(10_000), err_after(0)],
vec![ok_after(0), err_after(0)],
],
500,
));
tokio::time::advance(Duration::from_millis(1000)).await;
assert_eq!(fut.await.unwrap().unwrap(), "g2a0");
}
}