dataload_rs/
loader.rs

1use std::ops::Drop;
2use std::{collections::HashMap, fmt::Debug};
3
4use tokio::sync::{mpsc, oneshot};
5
6use crate::{
7    batch_function::BatchFunction,
8    loader_op::{LoadRequest, LoaderOp},
9    loader_worker::LoaderWorker,
10};
11
12/// Batch loads values from some expensive resource, primarily intended for mitigating GraphQL's
13/// N+1 problem.
14///
15/// Users can call [`Loader::load`] and [`Loader::load_many`] to fetch values from the underlying resource or
16/// cache. The cache can be cleared with calls to [`Loader::clear`] and [`Loader::clear_many`], and values can be
17/// added to the cache out-of-band through the use of [`Loader::prime`] and [`Loader::prime_many`].
18///
19/// The `Loader` struct acts as an intermediary between the async domain in which `load` calls are
20/// invoked and the pseudo-single-threaded domain of the `LoaderWorker`. Callers can invoke the
21/// `Loader` from multiple parallel tasks, and the loader will enqueue the requested operations on
22/// the request queue for processing by its `LoaderWorker`. The worker processes the requests
23/// sequentially and provides results via response oneshot channels back to the Loader.
24pub struct Loader<K, V>
25where
26    K: 'static + Eq + Debug + Copy + Send,
27    V: 'static + Send + Debug + Clone,
28{
29    request_tx: mpsc::UnboundedSender<LoaderOp<K, V>>,
30    load_task_handle: tokio::task::JoinHandle<()>,
31}
32
33impl<K, V> Drop for Loader<K, V>
34where
35    K: 'static + Eq + Debug + Copy + Send,
36    V: 'static + Send + Debug + Clone,
37{
38    fn drop(&mut self) {
39        self.load_task_handle.abort();
40    }
41}
42
43impl<K, V> Loader<K, V>
44where
45    K: 'static + Eq + Debug + Ord + Copy + std::hash::Hash + Send + Sync,
46    V: 'static + Send + Debug + Clone,
47{
48    /// Creates a new Loader for the provided BatchFunction and Context type.
49    ///
50    /// Note: the batch function is passed in as a marker for type inference.
51    pub fn new<F, ContextT>(_: F, context: ContextT) -> Self
52    where
53        ContextT: Send + Sync + 'static,
54        F: 'static + BatchFunction<K, V, Context = ContextT> + Send,
55    {
56        let (tx, rx) = mpsc::unbounded_channel();
57        Self {
58            request_tx: tx,
59            load_task_handle: tokio::task::spawn(
60                LoaderWorker::<K, V, F, HashMap<K, V>, ContextT>::new(HashMap::new(), rx, context)
61                    .start(),
62            ),
63        }
64    }
65}
66
67impl<K, V> Loader<K, V>
68where
69    K: 'static + Eq + Debug + Ord + Copy + Send + Sync,
70    V: 'static + Send + Debug + Clone,
71{
72    /// Loads a value from the underlying resource.
73    ///
74    /// Returns None if the value could not be loaded by the BatchFunction.
75    ///
76    /// If the value is already in the loader cache, it is returned as soon as it is processed.
77    /// Otherwise, the requested key is enqueued for batch loading in the next loader execution
78    /// frame.
79    pub async fn load(&self, key: K) -> Option<V> {
80        let (response_tx, response_rx) = oneshot::channel();
81        self.request_tx
82            .send(LoaderOp::Load(LoadRequest::One(key, response_tx)))
83            .unwrap();
84        response_rx.await.unwrap()
85    }
86
87    /// Loads many values at once.
88    ///
89    /// Returns None for values that could not be loaded by the BatchFunction.
90    ///
91    /// If all the values are already present in the laoder cache, they are returned as soon as the
92    /// request is processed by the worker. Otherwise, the keys is enqueue for batch loading in the
93    /// next loader execution frame.
94    pub async fn load_many(&self, keys: Vec<K>) -> Vec<Option<V>> {
95        let (response_tx, response_rx) = oneshot::channel();
96        self.request_tx
97            .send(LoaderOp::Load(LoadRequest::Many(keys, response_tx)))
98            .unwrap();
99        response_rx.await.unwrap()
100    }
101
102    /// Adds a value to the cache.
103    pub async fn prime(&self, key: K, value: V) {
104        self.request_tx.send(LoaderOp::Prime(key, value)).unwrap();
105    }
106
107    /// Adds many values to the cache at once.
108    pub async fn prime_many(&self, key_vals: Vec<(K, V)>) {
109        self.request_tx.send(LoaderOp::PrimeMany(key_vals)).unwrap();
110    }
111
112    /// Removes a value from the cache.
113    ///
114    /// This key will be reloaded when it is next requested.
115    pub async fn clear(&self, key: K) {
116        self.request_tx.send(LoaderOp::Clear(key)).unwrap();
117    }
118
119    /// Removes multiple values from the cache at once.
120    ///
121    /// These keys will be reloaded when requested.
122    pub async fn clear_many(&self, keys: Vec<K>) {
123        self.request_tx.send(LoaderOp::ClearMany(keys)).unwrap();
124    }
125}