atrium_common/resolver/
throttled.rs

1use std::{hash::Hash, sync::Arc};
2
3use dashmap::{DashMap, Entry};
4use tokio::sync::broadcast::{channel, Sender};
5use tokio::sync::Mutex;
6
7use crate::types::throttled::Throttled;
8
9use super::Resolver;
10
11pub type SenderMap<R> =
12    DashMap<<R as Resolver>::Input, Arc<Mutex<Sender<Option<<R as Resolver>::Output>>>>>;
13
14pub type ThrottledResolver<R> = Throttled<R, SenderMap<R>>;
15
16impl<R> Resolver for Throttled<R, SenderMap<R>>
17where
18    R: Resolver + Send + Sync + 'static,
19    R::Input: Clone + Hash + Eq + Send + Sync + 'static,
20    R::Output: Clone + Send + Sync + 'static,
21{
22    type Input = R::Input;
23    type Output = Option<R::Output>;
24    type Error = R::Error;
25
26    async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
27        match self.pending.entry(input.clone()) {
28            Entry::Occupied(occupied) => {
29                let tx = occupied.get().lock().await.clone();
30                drop(occupied);
31                Ok(tx.subscribe().recv().await.expect("recv"))
32            }
33            Entry::Vacant(vacant) => {
34                let (tx, _) = channel(1);
35                vacant.insert(Arc::new(Mutex::new(tx.clone())));
36                let result = self.inner.resolve(input).await;
37                tx.send(result.as_ref().ok().cloned()).ok();
38                self.pending.remove(input);
39                result.map(Some)
40            }
41        }
42    }
43}