1use std::fmt::{self, Debug};
40use std::future::Future;
41use std::marker::PhantomData;
42use std::pin::Pin;
43use std::task::{Context, Poll};
44
45use futures::future::BoxFuture;
46use hashbrown::HashMap;
47use parking_lot::Mutex;
48use pin_project::{pin_project, pinned_drop};
49use tokio::sync::watch;
50
51pub struct Group<T, E>
54where
55 T: Clone,
56{
57 m: Mutex<HashMap<String, watch::Receiver<State<T>>>>,
58 _marker: PhantomData<fn(E)>,
59}
60
61impl<T, E> Debug for Group<T, E>
62where
63 T: Clone,
64{
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("Group").finish()
67 }
68}
69
70impl<T, E> Default for Group<T, E>
71where
72 T: Clone,
73{
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79#[derive(Clone)]
80enum State<T: Clone> {
81 Starting,
82 LeaderDropped,
83 Done(Option<T>),
84}
85
86impl<T, E> Group<T, E>
87where
88 T: Clone,
89{
90 #[must_use]
92 pub fn new() -> Group<T, E> {
93 Self {
94 m: Mutex::new(HashMap::new()),
95 _marker: PhantomData,
96 }
97 }
98
99 pub async fn work(
105 &self,
106 key: &str,
107 fut: impl Future<Output = Result<T, E>>,
108 ) -> (Option<T>, Option<E>, bool) {
109 use hashbrown::hash_map::EntryRef;
110
111 let tx_or_rx = match self.m.lock().entry_ref(key) {
112 EntryRef::Occupied(mut entry) => {
113 let state = entry.get().borrow().clone();
114 match state {
115 State::Starting => Err(entry.get().clone()),
116 State::LeaderDropped => {
117 let (tx, rx) = watch::channel(State::Starting);
119 entry.insert(rx);
120 Ok(tx)
121 }
122 State::Done(val) => return (val, None, false),
123 }
124 }
125 EntryRef::Vacant(entry) => {
126 let (tx, rx) = watch::channel(State::Starting);
127 entry.insert(rx);
128 Ok(tx)
129 }
130 };
131
132 match tx_or_rx {
133 Ok(tx) => {
134 let fut = Leader { fut, tx };
135 let result = fut.await;
136 self.m.lock().remove(key);
137 match result {
138 Ok(val) => (Some(val), None, true),
139 Err(err) => (None, Some(err), true),
140 }
141 }
142 Err(mut rx) => {
143 let mut state = rx.borrow_and_update().clone();
144 if matches!(state, State::Starting) {
145 let _changed = rx.changed().await;
146 state = rx.borrow().clone();
147 }
148 match state {
149 State::Starting => (None, None, false), State::LeaderDropped => {
151 self.m.lock().remove(key);
152 (None, None, false)
153 }
154 State::Done(val) => (val, None, false),
155 }
156 }
157 }
158 }
159}
160
161#[pin_project(PinnedDrop)]
162struct Leader<T: Clone, F> {
163 #[pin]
164 fut: F,
165 tx: watch::Sender<State<T>>,
166}
167
168impl<T, E, F> Future for Leader<T, F>
169where
170 T: Clone,
171 F: Future<Output = Result<T, E>>,
172{
173 type Output = Result<T, E>;
174
175 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176 let this = self.project();
177 let result = this.fut.poll(cx);
178 if let Poll::Ready(val) = &result {
179 let _send = this.tx.send(State::Done(val.as_ref().ok().cloned()));
180 }
181 result
182 }
183}
184
185#[pinned_drop]
186impl<T, F> PinnedDrop for Leader<T, F>
187where
188 T: Clone,
189{
190 fn drop(self: Pin<&mut Self>) {
191 let this = self.project();
192 let _ = this.tx.send_if_modified(|s| {
193 if matches!(s, State::Starting) {
194 *s = State::LeaderDropped;
195 true
196 } else {
197 false
198 }
199 });
200 }
201}
202
203pub struct UnaryGroup<T>
206where
207 T: Clone,
208{
209 m: Mutex<HashMap<String, watch::Receiver<UnaryState<T>>>>,
210}
211
212impl<T> Debug for UnaryGroup<T>
213where
214 T: Clone,
215{
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 f.debug_struct("UnaryGroup").finish()
218 }
219}
220
221impl<T> Default for UnaryGroup<T>
222where
223 T: Clone + Send + Sync,
224{
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[derive(Clone)]
231enum UnaryState<T: Clone> {
232 Starting,
233 LeaderDropped,
234 Done(T),
235}
236
237impl<T> UnaryGroup<T>
238where
239 T: Clone + Send + Sync,
240{
241 #[must_use]
243 pub fn new() -> UnaryGroup<T> {
244 Self {
245 m: Mutex::new(HashMap::new()),
246 }
247 }
248
249 pub fn work<'s>(
255 &'s self,
256 key: &'s str,
257 fut: impl Future<Output = T> + Send + 's,
258 ) -> BoxFuture<'s, (T, bool)> {
259 use hashbrown::hash_map::EntryRef;
260 Box::pin(async move {
261 let tx_or_rx = match self.m.lock().entry_ref(key) {
262 EntryRef::Occupied(mut entry) => {
263 let state = entry.get().borrow().clone();
264 match state {
265 UnaryState::Starting => Err(entry.get().clone()),
266 UnaryState::LeaderDropped => {
267 let (tx, rx) = watch::channel(UnaryState::Starting);
269 entry.insert(rx);
270 Ok(tx)
271 }
272 UnaryState::Done(val) => return (val, false),
273 }
274 }
275 EntryRef::Vacant(entry) => {
276 let (tx, rx) = watch::channel(UnaryState::Starting);
277 entry.insert(rx);
278 Ok(tx)
279 }
280 };
281
282 match tx_or_rx {
283 Ok(tx) => {
284 let fut = UnaryLeader { fut, tx };
285 let result = fut.await;
286 self.m.lock().remove(key);
287 (result, true)
288 }
289 Err(mut rx) => {
290 let mut state = rx.borrow_and_update().clone();
291 if matches!(state, UnaryState::Starting) {
292 let _changed = rx.changed().await;
293 state = rx.borrow().clone();
294 }
295 match state {
296 UnaryState::Starting => unreachable!(), UnaryState::LeaderDropped => {
298 self.m.lock().remove(key);
299 self.work(key, fut).await
301 }
302 UnaryState::Done(val) => (val, false),
303 }
304 }
305 }
306 })
307 }
308}
309
310#[pin_project(PinnedDrop)]
311struct UnaryLeader<T: Clone, F> {
312 #[pin]
313 fut: F,
314 tx: watch::Sender<UnaryState<T>>,
315}
316
317impl<T, F> Future for UnaryLeader<T, F>
318where
319 T: Clone + Send + Sync,
320 F: Future<Output = T>,
321{
322 type Output = T;
323
324 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
325 let this = self.project();
326 let result = this.fut.poll(cx);
327 if let Poll::Ready(val) = &result {
328 let _send = this.tx.send(UnaryState::Done(val.clone()));
329 }
330 result
331 }
332}
333
334#[pinned_drop]
335impl<T, F> PinnedDrop for UnaryLeader<T, F>
336where
337 T: Clone,
338{
339 fn drop(self: Pin<&mut Self>) {
340 let this = self.project();
341 let _ = this.tx.send_if_modified(|s| {
342 if matches!(s, UnaryState::Starting) {
343 *s = UnaryState::LeaderDropped;
344 true
345 } else {
346 false
347 }
348 });
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use std::time::Duration;
355
356 use super::Group;
357
358 const RES: usize = 7;
359
360 async fn return_res() -> Result<usize, ()> {
361 Ok(7)
362 }
363
364 async fn expensive_fn() -> Result<usize, ()> {
365 tokio::time::sleep(Duration::from_millis(500)).await;
366 Ok(RES)
367 }
368
369 #[tokio::test]
370 async fn test_simple() {
371 let g = Group::new();
372 let res = g.work("key", return_res()).await.0;
373 let r = res.unwrap();
374 assert_eq!(r, RES);
375 }
376
377 #[tokio::test]
378 async fn test_multiple_threads() {
379 use std::sync::Arc;
380
381 use futures::future::join_all;
382
383 let g = Arc::new(Group::new());
384 let mut handlers = Vec::new();
385 for _ in 0..10 {
386 let g = g.clone();
387 handlers.push(tokio::spawn(async move {
388 let res = g.work("key", expensive_fn()).await.0;
389 let r = res.unwrap();
390 println!("{}", r);
391 }));
392 }
393
394 join_all(handlers).await;
395 }
396
397 #[tokio::test]
398 async fn test_drop_leader() {
399 use std::time::Duration;
400
401 let g = Group::new();
402 {
403 tokio::time::timeout(Duration::from_millis(50), g.work("key", expensive_fn()))
404 .await
405 .expect_err("owner should be running and cancelled");
406 }
407 assert_eq!(
408 tokio::time::timeout(Duration::from_secs(1), g.work("key", expensive_fn())).await,
409 Ok((Some(RES), None, true)),
410 );
411 }
412}