batch_loader/
lib.rs

1//! Query batching utility
2
3extern crate batch_recv;
4extern crate crossbeam_channel as chan;
5extern crate futures;
6extern crate itertools;
7extern crate worker_sentinel;
8
9use std::fmt::Debug;
10use futures::sync::oneshot;
11use itertools::Itertools;
12use worker_sentinel::{Work, WorkFactory};
13use batch_recv::BatchRecv;
14
15/// Trait for values which is identifiable by unique `Key`
16///
17/// Values must be cloneable because a single value will be cloned to the respective
18/// multiple callers if some callers request by the same key.
19pub trait Value: Debug + Clone + Send {
20    /// Key is used to route the values to the caller.
21    type Key: Ord + Clone + Send + 'static;
22    /// Returns a `Key`
23    fn key(&self) -> &Self::Key;
24}
25
26/// Trait for querier backend
27pub trait Backend: Send + 'static {
28    type Value: Value;
29    type Error: Debug + Clone + Send;
30    /// This function provides the actual data fetching logic.
31    fn batch_load<'a, I>(&self, keys: I) -> Result<Vec<Self::Value>, Self::Error>
32    where
33        I: Iterator<Item = &'a <Self::Value as Value>::Key> + 'a;
34}
35
36pub trait NewBackend: Send + Sync + 'static {
37    type Backend: Backend;
38    fn new_backend(&self) -> Self::Backend;
39}
40impl<F, B> NewBackend for F
41where
42    B: Backend,
43    F: Fn() -> B + Send + Sync + 'static,
44{
45    type Backend = B;
46    fn new_backend(&self) -> Self::Backend {
47        self()
48    }
49}
50
51type LoadResult<B> = Result<Option<<B as Backend>::Value>, <B as Backend>::Error>;
52type Message<B> = (
53    <<B as Backend>::Value as Value>::Key,
54    oneshot::Sender<LoadResult<B>>,
55);
56type QueueTx<B> = chan::Sender<Message<B>>;
57type QueueRx<B> = chan::Receiver<Message<B>>;
58
59/// Batched data loader interface
60///
61/// Loader is composed of the queue which associated to the backend.
62#[derive(Clone)]
63pub struct Loader<B>
64where
65    B: Backend,
66{
67    queue_tx: QueueTx<B>,
68}
69
70impl<B> Loader<B>
71where
72    B: Backend,
73{
74    /// Create new loader
75    ///
76    /// `concurrent` sets the number of threads which runs the backend.
77    /// `new_backend` will be called in spawning the new thread.
78    pub fn new<N>(new_backend: N, batch_size: usize, concurrent: usize) -> Loader<B>
79    where
80        N: NewBackend<Backend = B> + 'static,
81    {
82        let (queue_tx, queue_rx) = chan::unbounded();
83        let work_factory = BackendWorkFactory {
84            queue_rx,
85            new_backend,
86            batch_size,
87        };
88        worker_sentinel::spawn(concurrent, work_factory);
89        Loader { queue_tx }
90    }
91
92    /// Load value by key
93    ///
94    /// This function writes the key to the queue and returns a Future to wait the result.
95    pub fn load(
96        &self,
97        key: <B::Value as Value>::Key,
98    ) -> Result<oneshot::Receiver<LoadResult<B>>, chan::SendError<<B::Value as Value>::Key>> {
99        let (cb_tx, cb_rx) = oneshot::channel();
100        self.queue_tx.send((key, cb_tx)).map_err(|err| {
101            let (key, _) = err.into_inner();
102            chan::SendError(key)
103        })?;
104        Ok(cb_rx)
105    }
106}
107
108struct BackendWork<B>
109where
110    B: Backend,
111{
112    queue_rx: QueueRx<B>,
113    backend: B,
114    batch_size: usize,
115}
116impl<B> Work for BackendWork<B>
117where
118    B: Backend,
119{
120    fn work(self) -> Option<Self> {
121        let mut requests: Vec<_> = self.queue_rx.batch_recv(self.batch_size).ok()?.collect();
122        requests.sort_by(|&(ref left, _), &(ref right, _)| left.cmp(&right));
123        let req_groups_by_key = requests.into_iter().group_by(|&(ref key, _)| key.clone());
124        let req_groups_by_key_vec: Vec<_> = req_groups_by_key.into_iter().collect();
125
126        let ret = {
127            let keys_iter = req_groups_by_key_vec.iter().map(|&(ref key, _)| key);
128            self.backend.batch_load(keys_iter)
129        };
130        let mut values = match ret {
131            Ok(values) => values,
132            Err(err) => {
133                for (_, req_group) in req_groups_by_key_vec {
134                    for (_, cb) in req_group {
135                        cb.send(Err(err.clone())).expect("return error as result");
136                    }
137                }
138                return Some(self);
139            }
140        };
141        values.sort_by(|ref left, ref right| left.key().cmp(right.key()));
142        let joined = req_groups_by_key_vec
143            .into_iter()
144            .merge_join_by(values.into_iter(), |&(ref key, _), value| {
145                key.cmp(value.key())
146            });
147        for pair in joined {
148            use itertools::EitherOrBoth::{Both, Left};
149            match pair {
150                Left((_, req_group)) => for (_, cb) in req_group {
151                    cb.send(Ok(None)).expect("respond to caller");
152                },
153                Both((_, req_group), value) => for (_, cb) in req_group {
154                    cb.send(Ok(Some(value.clone()))).expect("respond to caller");
155                },
156                _ => unreachable!(),
157            }
158        }
159        Some(self)
160    }
161}
162
163struct BackendWorkFactory<N>
164where
165    N: NewBackend,
166{
167    queue_rx: QueueRx<N::Backend>,
168    new_backend: N,
169    batch_size: usize,
170}
171impl<N> WorkFactory for BackendWorkFactory<N>
172where
173    N: NewBackend,
174{
175    type Work = BackendWork<N::Backend>;
176    fn build(&self) -> Self::Work {
177        let backend = self.new_backend.new_backend();
178        let queue_rx = self.queue_rx.clone();
179        let batch_size = self.batch_size;
180        BackendWork {
181            backend,
182            queue_rx,
183            batch_size,
184        }
185    }
186}
187
188#[cfg(test)]
189mod teet {
190    use futures::{future, Future};
191    use super::{Backend, Loader, Value};
192    #[derive(Debug, Clone, PartialEq)]
193    struct HalfValue {
194        key: u32,
195        half: u32,
196    }
197    impl Value for HalfValue {
198        type Key = u32;
199        fn key(&self) -> &u32 {
200            &self.key
201        }
202    }
203    struct HalfBackend;
204    impl Backend for HalfBackend {
205        type Value = HalfValue;
206        type Error = ();
207        fn batch_load<'a, I>(&self, keys: I) -> Result<Vec<Self::Value>, Self::Error>
208        where
209            I: Iterator<Item = &'a <Self::Value as Value>::Key> + 'a,
210        {
211            let ret = keys.filter_map(|&key| {
212                if key % 2 == 0 {
213                    Some(HalfValue { key, half: key / 2 })
214                } else {
215                    None
216                }
217            }).collect();
218            Ok(ret)
219        }
220    }
221
222    #[test]
223    fn test_loader() {
224        let loader = Loader::new(|| HalfBackend, 10, 1);
225
226        let f1 = loader
227            .load(1)
228            .unwrap()
229            .map(|v| assert!(v.unwrap().is_none()));
230        let f3 = loader
231            .load(3)
232            .unwrap()
233            .map(|v| assert!(v.unwrap().is_none()));
234        let f2 = loader.load(2).unwrap().map(|v| {
235            assert_eq!(v.unwrap().unwrap(), HalfValue { key: 2, half: 1 })
236        });
237        let f4 = loader.load(4).unwrap().map(|v| {
238            assert_eq!(v.unwrap().unwrap(), HalfValue { key: 4, half: 2 })
239        });
240        future::join_all(vec![
241            Box::new(f1) as Box<Future<Item = _, Error = _>>,
242            Box::new(f2) as Box<_>,
243            Box::new(f3) as Box<_>,
244            Box::new(f4) as Box<_>,
245        ]).wait()
246            .unwrap();
247    }
248}