use std::fmt::{self, Debug};
use std::future::Future;
use std::hash::BuildHasher;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
mod group;
mod unary;
pub use group::*;
pub use unary::*;
use pin_project::{pin_project, pinned_drop};
use std::collections::HashMap;
use std::hash::Hash;
use std::hash::RandomState;
use tokio::sync::{watch, Mutex};
#[derive(Clone)]
enum State<T> {
Starting,
LeaderDropped,
LeaderFailed,
Success(T),
}
enum ChannelHandler<T> {
Sender(watch::Sender<State<T>>),
Receiver(watch::Receiver<State<T>>),
}
#[pin_project(PinnedDrop)]
struct Leader<T, F, Output>
where
T: Clone,
F: Future<Output = Output>,
{
#[pin]
fut: F,
tx: watch::Sender<State<T>>,
}
impl<T, F, Output> Leader<T, F, Output>
where
T: Clone,
F: Future<Output = Output>,
{
fn new(fut: F, tx: watch::Sender<State<T>>) -> Self {
Self { fut, tx }
}
}
#[pinned_drop]
impl<T, F, Output> PinnedDrop for Leader<T, F, Output>
where
T: Clone,
F: Future<Output = Output>,
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();
let _ = this.tx.send_if_modified(|s| {
if matches!(s, State::Starting) {
*s = State::LeaderDropped;
true
} else {
false
}
});
}
}
impl<T, E, F> Future for Leader<T, F, Result<T, E>>
where
T: Clone,
F: Future<Output = Result<T, E>>,
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let result = this.fut.poll(cx);
if let Poll::Ready(val) = &result {
let _send = match val {
Ok(v) => this.tx.send(State::Success(v.clone())),
Err(_) => this.tx.send(State::LeaderFailed),
};
}
result
}
}
impl<T, F> Future for Leader<T, F, T>
where
T: Clone + Send + Sync,
F: Future<Output = T>,
{
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let result = this.fut.poll(cx);
if let Poll::Ready(val) = &result {
let _send = this.tx.send(State::Success(val.clone()));
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot;
async fn return_res() -> Result<usize, ()> {
Ok(7)
}
async fn expensive_fn<const RES: usize>(delay: u64) -> Result<usize, ()> {
tokio::time::sleep(Duration::from_millis(delay)).await;
Ok(RES)
}
async fn expensive_unary_fn<const RES: usize>(delay: u64) -> usize {
tokio::time::sleep(Duration::from_millis(delay)).await;
RES
}
#[tokio::test]
async fn test_simple() {
let g = DefaultGroup::new();
let res = g.work("key", return_res()).await;
let r = res.unwrap();
assert_eq!(r, 7);
}
#[tokio::test]
async fn test_multiple_threads() {
use std::sync::Arc;
use futures::future::join_all;
let g = Arc::new(DefaultGroup::new());
let mut handlers = Vec::with_capacity(10);
for _ in 0..10 {
let g = g.clone();
handlers.push(tokio::spawn(async move {
let res = g.work("key", expensive_fn::<7>(300)).await;
let r = res.unwrap();
println!("{}", r);
}));
}
join_all(handlers).await;
}
#[tokio::test]
async fn test_multiple_threads_custom_type() {
use std::sync::Arc;
use futures::future::join_all;
let g = Arc::new(Group::<u64, usize, ()>::new());
let mut handlers = Vec::with_capacity(10);
for _ in 0..10 {
let g = g.clone();
handlers.push(tokio::spawn(async move {
let res = g.work(&42, expensive_fn::<8>(300)).await;
let r = res.unwrap();
println!("{}", r);
}));
}
join_all(handlers).await;
}
#[tokio::test]
async fn test_multiple_threads_unary() {
use std::sync::Arc;
use futures::future::join_all;
let g = Arc::new(UnaryGroup::<u64, usize>::new());
let mut handlers = Vec::with_capacity(10);
for _ in 0..10 {
let g = g.clone();
handlers.push(tokio::spawn(async move {
let res = g.work(&42, expensive_unary_fn::<8>(300)).await;
assert_eq!(res, 8);
}));
}
join_all(handlers).await;
}
#[tokio::test]
async fn test_drop_leader() {
let group = Arc::new(DefaultGroup::new());
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let leader_owned = group.clone();
let leader = tokio::spawn(async move {
let fut = async move {
let _ = ready_tx.send(());
tokio::time::sleep(Duration::from_millis(500)).await;
Ok::<usize, ()>(7)
};
let _ = leader_owned.work("key", fut).await;
});
let _ = ready_rx.await;
let follower_owned = group.clone();
let follower = tokio::spawn(async move {
follower_owned
.work("key", async { Ok::<usize, ()>(42) })
.await
});
tokio::task::yield_now().await;
leader.abort();
let res = tokio::time::timeout(Duration::from_secs(1), follower)
.await
.expect("follower should finish in time")
.expect("follower task should not panic");
assert_eq!(res, Ok(42));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_leader_drop_single_new_leader() {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Barrier;
const NUM_FOLLOWERS: usize = 5;
for iteration in 0..200 {
let group = Arc::new(DefaultGroup::new());
let execute_count = Arc::new(AtomicUsize::new(0));
let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
let barrier = Arc::new(Barrier::new(NUM_FOLLOWERS + 1));
let leader_group = group.clone();
let leader = tokio::spawn(async move {
let fut = async move {
let _ = leader_ready_tx.send(());
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<usize, ()>(999)
};
let _ = leader_group.work("key", fut).await;
});
let _ = leader_ready_rx.await;
let mut follower_handles = Vec::with_capacity(NUM_FOLLOWERS);
for _ in 0..NUM_FOLLOWERS {
let g = group.clone();
let cnt = execute_count.clone();
let b = barrier.clone();
follower_handles.push(tokio::spawn(async move {
b.wait().await;
g.work("key", async move {
cnt.fetch_add(1, Ordering::SeqCst);
tokio::task::yield_now().await;
Ok::<usize, ()>(42)
})
.await
}));
}
barrier.wait().await;
tokio::time::sleep(Duration::from_millis(5)).await;
leader.abort();
for handle in follower_handles {
let res = tokio::time::timeout(Duration::from_secs(5), handle)
.await
.expect("follower should finish in time")
.expect("follower task should not panic");
assert_eq!(res, Ok(42), "follower should get the correct result");
}
let count = execute_count.load(Ordering::SeqCst);
assert_eq!(
count, 1,
"Iteration {}: Expected exactly 1 work execution after leader drop, \
but got {}. This indicates multiple followers became leaders (issue #12).",
iteration, count
);
}
}
#[tokio::test]
async fn test_drop_leader_no_retry() {
let group = Arc::new(DefaultGroup::<usize>::new());
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let leader_owned = group.clone();
let leader = tokio::spawn(async move {
let fut = async move {
let _ = ready_tx.send(());
tokio::time::sleep(Duration::from_millis(500)).await;
Ok::<usize, ()>(7)
};
let _ = leader_owned.work("key", fut).await;
});
let _ = ready_rx.await;
let follower_owned = group.clone();
let follower = tokio::spawn(async move {
follower_owned
.work_no_retry("key", async { Ok::<usize, ()>(42) })
.await
});
tokio::task::yield_now().await;
leader.abort();
let res = tokio::time::timeout(Duration::from_secs(1), follower)
.await
.expect("follower should finish in time")
.expect("follower task should not panic");
assert_eq!(res, Err(GroupWorkError::LeaderDropped));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_leader_drop_single_new_leader_unary() {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Barrier;
const NUM_FOLLOWERS: usize = 5;
for iteration in 0..200 {
let group = Arc::new(DefaultUnaryGroup::new());
let execute_count = Arc::new(AtomicUsize::new(0));
let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
let barrier = Arc::new(Barrier::new(NUM_FOLLOWERS + 1));
let leader_group = group.clone();
let leader = tokio::spawn(async move {
let fut = async move {
let _ = leader_ready_tx.send(());
tokio::time::sleep(Duration::from_secs(60)).await;
999_usize
};
leader_group.work("key", fut).await
});
let _ = leader_ready_rx.await;
let mut follower_handles = Vec::with_capacity(NUM_FOLLOWERS);
for _ in 0..NUM_FOLLOWERS {
let g = group.clone();
let cnt = execute_count.clone();
let b = barrier.clone();
follower_handles.push(tokio::spawn(async move {
b.wait().await;
g.work("key", async move {
cnt.fetch_add(1, Ordering::SeqCst);
tokio::task::yield_now().await;
42_usize
})
.await
}));
}
barrier.wait().await;
tokio::time::sleep(Duration::from_millis(5)).await;
leader.abort();
for handle in follower_handles {
let res = tokio::time::timeout(Duration::from_secs(5), handle)
.await
.expect("follower should finish in time")
.expect("follower task should not panic");
assert_eq!(res, 42, "follower should get the correct result");
}
let count = execute_count.load(Ordering::SeqCst);
assert_eq!(
count, 1,
"Iteration {}: Expected exactly 1 work execution after leader drop, \
but got {}. This indicates multiple followers became leaders (issue #12).",
iteration, count
);
}
}
#[tokio::test]
async fn test_fresh_caller_replaces_stale_entry() {
let group = Arc::new(DefaultGroup::new());
let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
let leader_group = group.clone();
let leader = tokio::spawn(async move {
let _ = leader_group
.work("key", async move {
let _ = leader_ready_tx.send(());
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<usize, ()>(999)
})
.await;
});
let _ = leader_ready_rx.await;
let follower_group = group.clone();
let follower = tokio::spawn(async move {
follower_group
.work("key", async { Ok::<usize, ()>(42) })
.await
});
tokio::task::yield_now().await;
leader.abort();
let res = follower.await.unwrap();
assert_eq!(res, Ok(42));
let res = group.work("key", async { Ok::<usize, ()>(99) }).await;
assert_eq!(res, Ok(99));
}
#[tokio::test]
async fn test_purge_stale() {
let group = Arc::new(DefaultGroup::new());
let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
let leader_group = group.clone();
let leader = tokio::spawn(async move {
let _ = leader_group
.work("key", async move {
let _ = leader_ready_tx.send(());
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<usize, ()>(999)
})
.await;
});
let _ = leader_ready_rx.await;
let follower_group = group.clone();
let follower = tokio::spawn(async move {
follower_group
.work("key", async { Ok::<usize, ()>(42) })
.await
});
tokio::task::yield_now().await;
leader.abort();
let res = follower.await.unwrap();
assert_eq!(res, Ok(42));
group.purge_stale().await;
let res = group.work("key", async { Ok::<usize, ()>(77) }).await;
assert_eq!(res, Ok(77));
}
#[tokio::test]
async fn test_purge_stale_unary() {
let group = Arc::new(DefaultUnaryGroup::new());
let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
let leader_group = group.clone();
let leader = tokio::spawn(async move {
let fut = async move {
let _ = leader_ready_tx.send(());
tokio::time::sleep(Duration::from_secs(60)).await;
999_usize
};
leader_group.work("key", fut).await
});
let _ = leader_ready_rx.await;
let follower_group = group.clone();
let follower =
tokio::spawn(async move { follower_group.work("key", async { 42_usize }).await });
tokio::task::yield_now().await;
leader.abort();
let res = follower.await.unwrap();
assert_eq!(res, 42);
group.purge_stale().await;
let res = group.work("key", async { 77_usize }).await;
assert_eq!(res, 77);
}
}