1use std::fmt::{self, Debug};
40use std::future::Future;
41use std::hash::BuildHasher;
42use std::marker::PhantomData;
43use std::pin::Pin;
44use std::task::{Context, Poll};
45
46mod group;
47mod unary;
48
49pub use group::*;
50pub use unary::*;
51
52use pin_project::{pin_project, pinned_drop};
53use std::collections::HashMap;
54use std::hash::Hash;
55use std::hash::RandomState;
56use tokio::sync::{watch, Mutex};
57
58#[derive(Clone)]
59enum State<T> {
60 Starting,
61 LeaderDropped,
62 LeaderFailed,
63 Success(T),
64}
65
66enum ChannelHandler<T> {
67 Sender(watch::Sender<State<T>>),
68 Receiver(watch::Receiver<State<T>>),
69}
70
71#[pin_project(PinnedDrop)]
72struct Leader<T, F, Output>
73where
74 T: Clone,
75 F: Future<Output = Output>,
76{
77 #[pin]
78 fut: F,
79 tx: watch::Sender<State<T>>,
80}
81
82impl<T, F, Output> Leader<T, F, Output>
83where
84 T: Clone,
85 F: Future<Output = Output>,
86{
87 fn new(fut: F, tx: watch::Sender<State<T>>) -> Self {
88 Self { fut, tx }
89 }
90}
91
92#[pinned_drop]
93impl<T, F, Output> PinnedDrop for Leader<T, F, Output>
94where
95 T: Clone,
96 F: Future<Output = Output>,
97{
98 fn drop(self: Pin<&mut Self>) {
99 let this = self.project();
100 let _ = this.tx.send_if_modified(|s| {
101 if matches!(s, State::Starting) {
102 *s = State::LeaderDropped;
103 true
104 } else {
105 false
106 }
107 });
108 }
109}
110
111impl<T, E, F> Future for Leader<T, F, Result<T, E>>
112where
113 T: Clone,
114 F: Future<Output = Result<T, E>>,
115{
116 type Output = Result<T, E>;
117
118 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119 let this = self.project();
120 let result = this.fut.poll(cx);
121 if let Poll::Ready(val) = &result {
122 let _send = match val {
123 Ok(v) => this.tx.send(State::Success(v.clone())),
124 Err(_) => this.tx.send(State::LeaderFailed),
125 };
126 }
127 result
128 }
129}
130
131impl<T, F> Future for Leader<T, F, T>
132where
133 T: Clone + Send + Sync,
134 F: Future<Output = T>,
135{
136 type Output = T;
137
138 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 let this = self.project();
140 let result = this.fut.poll(cx);
141 if let Poll::Ready(val) = &result {
142 let _send = this.tx.send(State::Success(val.clone()));
143 }
144 result
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::sync::Arc;
152 use std::time::Duration;
153 use tokio::sync::oneshot;
154
155 async fn return_res() -> Result<usize, ()> {
156 Ok(7)
157 }
158
159 async fn expensive_fn<const RES: usize>(delay: u64) -> Result<usize, ()> {
160 tokio::time::sleep(Duration::from_millis(delay)).await;
161 Ok(RES)
162 }
163
164 async fn expensive_unary_fn<const RES: usize>(delay: u64) -> usize {
165 tokio::time::sleep(Duration::from_millis(delay)).await;
166 RES
167 }
168
169 #[tokio::test]
170 async fn test_simple() {
171 let g = DefaultGroup::new();
172 let res = g.work("key", return_res()).await;
173 let r = res.unwrap();
174 assert_eq!(r, 7);
175 }
176
177 #[tokio::test]
178 async fn test_multiple_threads() {
179 use std::sync::Arc;
180
181 use futures::future::join_all;
182
183 let g = Arc::new(DefaultGroup::new());
184 let mut handlers = Vec::with_capacity(10);
185 for _ in 0..10 {
186 let g = g.clone();
187 handlers.push(tokio::spawn(async move {
188 let res = g.work("key", expensive_fn::<7>(300)).await;
189 let r = res.unwrap();
190 println!("{}", r);
191 }));
192 }
193
194 join_all(handlers).await;
195 }
196
197 #[tokio::test]
198 async fn test_multiple_threads_custom_type() {
199 use std::sync::Arc;
200
201 use futures::future::join_all;
202
203 let g = Arc::new(Group::<u64, usize, ()>::new());
204 let mut handlers = Vec::with_capacity(10);
205 for _ in 0..10 {
206 let g = g.clone();
207 handlers.push(tokio::spawn(async move {
208 let res = g.work(&42, expensive_fn::<8>(300)).await;
209 let r = res.unwrap();
210 println!("{}", r);
211 }));
212 }
213
214 join_all(handlers).await;
215 }
216
217 #[tokio::test]
218 async fn test_multiple_threads_unary() {
219 use std::sync::Arc;
220
221 use futures::future::join_all;
222
223 let g = Arc::new(UnaryGroup::<u64, usize>::new());
224 let mut handlers = Vec::with_capacity(10);
225 for _ in 0..10 {
226 let g = g.clone();
227 handlers.push(tokio::spawn(async move {
228 let res = g.work(&42, expensive_unary_fn::<8>(300)).await;
229 assert_eq!(res, 8);
230 }));
231 }
232
233 join_all(handlers).await;
234 }
235
236 #[tokio::test]
237 async fn test_drop_leader() {
238 let group = Arc::new(DefaultGroup::new());
239
240 let (ready_tx, ready_rx) = oneshot::channel::<()>();
242
243 let leader_owned = group.clone();
244 let leader = tokio::spawn(async move {
245 let fut = async move {
247 let _ = ready_tx.send(());
248 tokio::time::sleep(Duration::from_millis(500)).await;
249 Ok::<usize, ()>(7)
250 };
251 let _ = leader_owned.work("key", fut).await;
253 });
254
255 let _ = ready_rx.await;
257
258 let follower_owned = group.clone();
260 let follower = tokio::spawn(async move {
261 follower_owned
262 .work("key", async { Ok::<usize, ()>(42) })
263 .await
264 });
265
266 tokio::task::yield_now().await;
268
269 leader.abort();
271
272 let res = tokio::time::timeout(Duration::from_secs(1), follower)
274 .await
275 .expect("follower should finish in time")
276 .expect("follower task should not panic");
277
278 assert_eq!(res, Ok(42));
279 }
280
281 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
298 async fn test_leader_drop_single_new_leader() {
299 use std::sync::atomic::{AtomicUsize, Ordering};
300 use tokio::sync::Barrier;
301
302 const NUM_FOLLOWERS: usize = 5;
303
304 for iteration in 0..200 {
306 let group = Arc::new(DefaultGroup::new());
307
308 let execute_count = Arc::new(AtomicUsize::new(0));
310
311 let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
313
314 let barrier = Arc::new(Barrier::new(NUM_FOLLOWERS + 1));
317
318 let leader_group = group.clone();
320 let leader = tokio::spawn(async move {
321 let fut = async move {
322 let _ = leader_ready_tx.send(());
323 tokio::time::sleep(Duration::from_secs(60)).await;
324 Ok::<usize, ()>(999)
325 };
326 let _ = leader_group.work("key", fut).await;
327 });
328
329 let _ = leader_ready_rx.await;
331
332 let mut follower_handles = Vec::with_capacity(NUM_FOLLOWERS);
333
334 for _ in 0..NUM_FOLLOWERS {
335 let g = group.clone();
336 let cnt = execute_count.clone();
337 let b = barrier.clone();
338 follower_handles.push(tokio::spawn(async move {
339 b.wait().await;
345
346 g.work("key", async move {
347 cnt.fetch_add(1, Ordering::SeqCst);
348 tokio::task::yield_now().await;
351 Ok::<usize, ()>(42)
352 })
353 .await
354 }));
355 }
356
357 barrier.wait().await;
359
360 tokio::time::sleep(Duration::from_millis(5)).await;
363
364 leader.abort();
366
367 for handle in follower_handles {
369 let res = tokio::time::timeout(Duration::from_secs(5), handle)
370 .await
371 .expect("follower should finish in time")
372 .expect("follower task should not panic");
373 assert_eq!(res, Ok(42), "follower should get the correct result");
374 }
375
376 let count = execute_count.load(Ordering::SeqCst);
380 assert_eq!(
381 count, 1,
382 "Iteration {}: Expected exactly 1 work execution after leader drop, \
383 but got {}. This indicates multiple followers became leaders (issue #12).",
384 iteration, count
385 );
386 }
387 }
388
389 #[tokio::test]
390 async fn test_drop_leader_no_retry() {
391 let group = Arc::new(DefaultGroup::<usize>::new());
392
393 let (ready_tx, ready_rx) = oneshot::channel::<()>();
395
396 let leader_owned = group.clone();
397 let leader = tokio::spawn(async move {
398 let fut = async move {
400 let _ = ready_tx.send(());
401 tokio::time::sleep(Duration::from_millis(500)).await;
402 Ok::<usize, ()>(7)
403 };
404 let _ = leader_owned.work("key", fut).await;
406 });
407
408 let _ = ready_rx.await;
410
411 let follower_owned = group.clone();
413 let follower = tokio::spawn(async move {
414 follower_owned
415 .work_no_retry("key", async { Ok::<usize, ()>(42) })
416 .await
417 });
418
419 tokio::task::yield_now().await;
421
422 leader.abort();
424
425 let res = tokio::time::timeout(Duration::from_secs(1), follower)
427 .await
428 .expect("follower should finish in time")
429 .expect("follower task should not panic");
430
431 assert_eq!(res, Err(GroupWorkError::LeaderDropped));
432 }
433
434 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
436 async fn test_leader_drop_single_new_leader_unary() {
437 use std::sync::atomic::{AtomicUsize, Ordering};
438 use tokio::sync::Barrier;
439
440 const NUM_FOLLOWERS: usize = 5;
441
442 for iteration in 0..200 {
443 let group = Arc::new(DefaultUnaryGroup::new());
444 let execute_count = Arc::new(AtomicUsize::new(0));
445 let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
446 let barrier = Arc::new(Barrier::new(NUM_FOLLOWERS + 1));
447
448 let leader_group = group.clone();
449 let leader = tokio::spawn(async move {
450 let fut = async move {
451 let _ = leader_ready_tx.send(());
452 tokio::time::sleep(Duration::from_secs(60)).await;
453 999_usize
454 };
455 leader_group.work("key", fut).await
456 });
457
458 let _ = leader_ready_rx.await;
459
460 let mut follower_handles = Vec::with_capacity(NUM_FOLLOWERS);
461 for _ in 0..NUM_FOLLOWERS {
462 let g = group.clone();
463 let cnt = execute_count.clone();
464 let b = barrier.clone();
465 follower_handles.push(tokio::spawn(async move {
466 b.wait().await;
467 g.work("key", async move {
468 cnt.fetch_add(1, Ordering::SeqCst);
469 tokio::task::yield_now().await;
470 42_usize
471 })
472 .await
473 }));
474 }
475
476 barrier.wait().await;
477 tokio::time::sleep(Duration::from_millis(5)).await;
478 leader.abort();
479
480 for handle in follower_handles {
481 let res = tokio::time::timeout(Duration::from_secs(5), handle)
482 .await
483 .expect("follower should finish in time")
484 .expect("follower task should not panic");
485 assert_eq!(res, 42, "follower should get the correct result");
486 }
487
488 let count = execute_count.load(Ordering::SeqCst);
489 assert_eq!(
490 count, 1,
491 "Iteration {}: Expected exactly 1 work execution after leader drop, \
492 but got {}. This indicates multiple followers became leaders (issue #12).",
493 iteration, count
494 );
495 }
496 }
497
498 #[tokio::test]
502 async fn test_fresh_caller_replaces_stale_entry() {
503 let group = Arc::new(DefaultGroup::new());
504
505 let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
506 let leader_group = group.clone();
507 let leader = tokio::spawn(async move {
508 let _ = leader_group
509 .work("key", async move {
510 let _ = leader_ready_tx.send(());
511 tokio::time::sleep(Duration::from_secs(60)).await;
512 Ok::<usize, ()>(999)
513 })
514 .await;
515 });
516 let _ = leader_ready_rx.await;
517
518 let follower_group = group.clone();
520 let follower = tokio::spawn(async move {
521 follower_group
522 .work("key", async { Ok::<usize, ()>(42) })
523 .await
524 });
525 tokio::task::yield_now().await;
526
527 leader.abort();
528 let res = follower.await.unwrap();
529 assert_eq!(res, Ok(42));
530
531 let res = group.work("key", async { Ok::<usize, ()>(99) }).await;
534 assert_eq!(res, Ok(99));
535 }
536
537 #[tokio::test]
540 async fn test_purge_stale() {
541 let group = Arc::new(DefaultGroup::new());
542
543 let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
544 let leader_group = group.clone();
545 let leader = tokio::spawn(async move {
546 let _ = leader_group
547 .work("key", async move {
548 let _ = leader_ready_tx.send(());
549 tokio::time::sleep(Duration::from_secs(60)).await;
550 Ok::<usize, ()>(999)
551 })
552 .await;
553 });
554 let _ = leader_ready_rx.await;
555
556 let follower_group = group.clone();
557 let follower = tokio::spawn(async move {
558 follower_group
559 .work("key", async { Ok::<usize, ()>(42) })
560 .await
561 });
562 tokio::task::yield_now().await;
563
564 leader.abort();
565 let res = follower.await.unwrap();
566 assert_eq!(res, Ok(42));
567
568 group.purge_stale().await;
570
571 let res = group.work("key", async { Ok::<usize, ()>(77) }).await;
573 assert_eq!(res, Ok(77));
574 }
575
576 #[tokio::test]
578 async fn test_purge_stale_unary() {
579 let group = Arc::new(DefaultUnaryGroup::new());
580
581 let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
582 let leader_group = group.clone();
583 let leader = tokio::spawn(async move {
584 let fut = async move {
585 let _ = leader_ready_tx.send(());
586 tokio::time::sleep(Duration::from_secs(60)).await;
587 999_usize
588 };
589 leader_group.work("key", fut).await
590 });
591 let _ = leader_ready_rx.await;
592
593 let follower_group = group.clone();
594 let follower =
595 tokio::spawn(async move { follower_group.work("key", async { 42_usize }).await });
596 tokio::task::yield_now().await;
597
598 leader.abort();
599 let res = follower.await.unwrap();
600 assert_eq!(res, 42);
601
602 group.purge_stale().await;
603
604 let res = group.work("key", async { 77_usize }).await;
605 assert_eq!(res, 77);
606 }
607}