1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
use crate::{
  loader::{DataLoader, LocalLoader},
  task::{CompletionReceipt, PendingAssignment, Task, TaskAssignment, TaskHandler},
  Key,
};
use std::{collections::HashMap, sync::Arc};

/// Simplified TaskHandler interface
#[async_trait::async_trait]
pub trait BatchLoader: Sized + Send + Sync + 'static {
  type Key: Key;
  type Value: Send + Sync + Clone + 'static;
  type Error: Send + Sync + Clone + 'static;
  const CORES_PER_WORKER_GROUP: usize = 4;
  async fn load(keys: Vec<Self::Key>) -> Result<HashMap<Self::Key, Arc<Self::Value>>, Self::Error>;
}

pub struct BatchHandler<T: BatchLoader>(T);

#[async_trait::async_trait]
impl<T> TaskHandler for BatchHandler<T>
where
  T: BatchLoader,
{
  type Key = T::Key;
  type Value = T::Value;
  type Error = T::Error;
  const CORES_PER_WORKER_GROUP: usize = T::CORES_PER_WORKER_GROUP;

  async fn handle_task(task: Task<PendingAssignment<Self>>) -> Task<CompletionReceipt> {
    match task.get_assignment() {
      TaskAssignment::LoadBatch(task) => {
        let keys = task.keys();
        let result = T::load(keys).await;
        task.resolve(result)
      }
      TaskAssignment::NoAssignment(receipt) => receipt,
    }
  }
}

impl<Loader> LocalLoader for BatchHandler<Loader>
where
  Loader: BatchLoader + LocalLoader,
{
  type Handler = <Loader as LocalLoader>::Handler;
  fn loader() -> &'static std::thread::LocalKey<DataLoader<Self::Handler>> {
    Loader::loader()
  }
}

#[cfg(test)]
mod tests {
  use super::*;
  use crate::LoadBy;
  use deque_loader_derive::{Loadable, Loader};
  use std::{collections::HashMap, iter};
  use tokio::try_join;

  #[derive(Loader)]
  #[data_loader(handler = "BatchHandler<BatchSizeLoader>", internal = true)]
  pub struct BatchSizeLoader {}

  #[derive(Clone, Debug, PartialEq, Eq, Loadable)]
  #[data_loader(handler = "BatchHandler<BatchSizeLoader>", internal = true)]
  pub struct BatchSize(usize);

  #[async_trait::async_trait]
  impl BatchLoader for BatchSizeLoader {
    type Key = i32;
    type Value = BatchSize;
    type Error = ();
    async fn load(
      keys: Vec<Self::Key>,
    ) -> Result<HashMap<Self::Key, Arc<Self::Value>>, Self::Error> {
      let mut data: HashMap<i32, Arc<BatchSize>> = HashMap::new();
      let len = keys.len();
      data.extend(keys.into_iter().zip(iter::repeat(Arc::new(BatchSize(len)))));

      Ok(data)
    }
  }

  #[tokio::test]
  async fn it_loads() -> Result<(), ()> {
    let data = BatchSize::load_by(1_i32).await?;

    assert!(data.is_some());

    Ok(())
  }

  #[tokio::test]
  async fn it_auto_batches() -> Result<(), ()> {
    let a = BatchSize::load_by(2_i32);

    let b = BatchSize::load_by(3_i32);

    let (a, b) = try_join!(a, b)?;

    assert_eq!(a, b);
    assert!(a.unwrap().0.ge(&2));

    Ok(())
  }
}