use std::any::{Any, TypeId};
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::time::Duration;
use futures_channel::oneshot;
use futures_timer::Delay;
use futures_util::lock::Mutex;
use fnv::FnvHashMap;
#[allow(clippy::type_complexity)]
struct ResSender<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>>(
oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
);
struct Requests<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> {
keys: HashSet<K>,
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
}
impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Default for Requests<K, T> {
fn default() -> Self {
Self {
keys: Default::default(),
pending: Default::default(),
}
}
}
impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
async fn load(self, loader: &T) {
let keys = self.keys.into_iter().collect::<Vec<_>>();
match loader.load(&keys).await {
Ok(values) => {
for (keys, tx) in self.pending {
let mut res = HashMap::new();
for key in &keys {
res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
}
tx.0.send(Ok(res)).ok();
}
}
Err(err) => {
for (_, tx) in self.pending {
tx.0.send(Err(err.clone())).ok();
}
}
}
}
}
#[async_trait::async_trait]
pub trait Loader<K: Send + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
type Value: Send + Clone + 'static;
type Error: Send + Clone + 'static;
async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
}
pub struct DataLoader<T> {
requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
delay: Duration,
max_batch_size: usize,
loader: T,
}
impl<T> DataLoader<T> {
pub fn new(loader: T) -> Self {
Self {
requests: Default::default(),
delay: Duration::from_millis(1),
max_batch_size: 1000,
loader,
}
}
pub fn delay(self, delay: Duration) -> Self {
Self { delay, ..self }
}
pub fn max_batch_size(self, max_batch_size: usize) -> Self {
Self {
max_batch_size,
..self
}
}
#[inline]
pub fn loader(&self) -> &T {
&self.loader
}
pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let mut values = self.load_many(std::iter::once(key.clone())).await?;
Ok(values.remove(&key))
}
pub async fn load_many<K>(
&self,
keys: impl Iterator<Item = K>,
) -> Result<HashMap<K, T::Value>, T::Error>
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let (start_fetch, rx) = {
let mut requests = self.requests.lock().await;
let typed_requests = requests
.entry(tid)
.or_insert_with(|| Box::new(Requests::<K, T>::default()))
.downcast_mut::<Requests<K, T>>()
.unwrap();
let prev_count = typed_requests.keys.len();
let keys = keys.collect::<HashSet<_>>();
typed_requests.keys.extend(keys.clone());
let (tx, rx) = oneshot::channel();
typed_requests.pending.push((keys, ResSender(tx)));
if typed_requests.keys.len() >= self.max_batch_size {
let r = std::mem::take(&mut *typed_requests);
drop(requests);
r.load(&self.loader).await;
(false, rx)
} else {
(!typed_requests.keys.is_empty() && prev_count == 0, rx)
}
};
if start_fetch {
Delay::new(self.delay).await;
let mut requests = self.requests.lock().await;
let typed_requests = requests
.get_mut(&tid)
.unwrap()
.downcast_mut::<Requests<K, T>>()
.unwrap();
let typed_requests = std::mem::take(typed_requests);
drop(requests);
if !typed_requests.keys.is_empty() {
typed_requests.load(&self.loader).await;
}
}
rx.await.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct MyLoader;
#[async_trait::async_trait]
impl Loader<i32> for MyLoader {
type Value = i32;
type Error = ();
async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
assert!(keys.len() <= 10);
Ok(keys.iter().copied().map(|k| (k, k)).collect())
}
}
#[async_trait::async_trait]
impl Loader<i64> for MyLoader {
type Value = i64;
type Error = ();
async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
assert!(keys.len() <= 10);
Ok(keys.iter().copied().map(|k| (k, k)).collect())
}
}
#[tokio::test]
async fn test_dataloader() {
let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10));
assert_eq!(
futures_util::future::try_join_all((0..100i32).map({
let loader = loader.clone();
move |n| {
let loader = loader.clone();
async move { loader.load_one(n).await }
}
}))
.await
.unwrap(),
(0..100).map(Option::Some).collect::<Vec<_>>()
);
assert_eq!(
futures_util::future::try_join_all((0..100i64).map({
let loader = loader.clone();
move |n| {
let loader = loader.clone();
async move { loader.load_one(n).await }
}
}))
.await
.unwrap(),
(0..100).map(Option::Some).collect::<Vec<_>>()
);
}
#[tokio::test]
async fn test_duplicate_keys() {
let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10));
assert_eq!(
futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
let loader = loader.clone();
move |n| {
let loader = loader.clone();
async move { loader.load_one(n).await }
}
}))
.await
.unwrap(),
[1, 3, 5, 1, 7, 8, 3, 7]
.iter()
.copied()
.map(Option::Some)
.collect::<Vec<_>>()
);
}
}