Skip to main content

dbsp/operator/dynamic/communication/
shard.rs

1//! Operators to shard batches across multiple worker threads based on keys
2//! and to gather sharded batches in one worker.
3
4// TODOs:
5// - different sharding modes.
6
7use rkyv::{archived_root, ser::Serializer as _};
8
9use crate::{
10    Circuit, Runtime, Stream,
11    circuit::circuit_builder::StreamId,
12    circuit_cache_key,
13    dynamic::{Data, DataTrait, DynPairs, Factory},
14    operator::communication::new_exchange_operators,
15    trace::{
16        Batch, BatchReader, Builder, Serializer, deserialize_indexed_wset, merge_batches,
17        serialize_indexed_wset,
18    },
19};
20
21use std::{ops::Range, panic::Location};
22
23circuit_cache_key!(ShardId<C, D>((StreamId, Range<usize>) => Stream<C, D>));
24circuit_cache_key!(UnshardId<C, D>(StreamId => Stream<C, D>));
25
26fn all_workers() -> Range<usize> {
27    0..Runtime::num_workers()
28}
29
30impl<C, IB> Stream<C, IB>
31where
32    C: Circuit,
33    IB: BatchReader<Time = ()> + Clone,
34{
35    /// See [`Stream::shard`].
36    #[track_caller]
37    pub fn dyn_shard(&self, factories: &IB::Factories) -> Stream<C, IB>
38    where
39        IB: Batch + Send,
40    {
41        // `shard_generic` returns `None` if there is only one worker thread
42        // and hence sharding is a no-op.  In this case, we simply return the
43        // input stream.  This allows us to use `shard` unconditionally without
44        // incurring any overhead in the single-threaded case.
45        self.dyn_shard_generic(factories)
46            .unwrap_or_else(|| self.clone())
47    }
48
49    /// See [`Stream::shard_workers`].
50    #[track_caller]
51    pub fn dyn_shard_workers(
52        &self,
53        workers: Range<usize>,
54        factories: &IB::Factories,
55    ) -> Stream<C, IB>
56    where
57        IB: Batch + Send,
58    {
59        // `shard_generic_workers` returns `None` if there is only one worker
60        // thread and hence sharding is a no-op.  In this case, we simply return
61        // the input stream.  This allows us to use `shard_workers`
62        // unconditionally without incurring any overhead in the single-threaded
63        // case.
64        self.dyn_shard_generic_workers(workers, factories)
65            .unwrap_or_else(|| self.clone())
66    }
67
68    /// Like [`Self::dyn_shard`], but can assemble the results into any output batch
69    /// type `OB`.
70    ///
71    /// Returns `None` when the circuit is not running inside a multithreaded
72    /// runtime or is running in a runtime with a single worker thread.
73    #[track_caller]
74    pub fn dyn_shard_generic<OB>(&self, factories: &OB::Factories) -> Option<Stream<C, OB>>
75    where
76        OB: Batch<Key = IB::Key, Val = IB::Val, Time = (), R = IB::R> + Send,
77    {
78        self.dyn_shard_generic_workers(all_workers(), factories)
79    }
80
81    /// Like [`Self::dyn_shard`], but can assemble the results into any output batch
82    /// type `OB`.
83    ///
84    /// Returns `None` when the circuit is not running inside a multithreaded
85    /// runtime or is running in a runtime with a single worker thread.
86    #[track_caller]
87    pub fn dyn_shard_generic_workers<OB>(
88        &self,
89        workers: Range<usize>,
90        factories: &OB::Factories,
91    ) -> Option<Stream<C, OB>>
92    where
93        OB: Batch<Key = IB::Key, Val = IB::Val, Time = (), R = IB::R> + Send,
94    {
95        if Runtime::num_workers() == 1 {
96            return None;
97        }
98        let location = Location::caller();
99        let output = self
100            .circuit()
101            .cache_get_or_insert_with(
102                ShardId::new((self.stream_id(), workers.clone())),
103                move || {
104                    // As a minor optimization, we reuse this array across all invocations
105                    // of the sharding operator.
106                    let mut builders = Vec::with_capacity(Runtime::num_workers());
107                    let factories_clone2 = factories.clone();
108                    let factories_clone3 = factories.clone();
109                    let factories_clone4 = factories.clone();
110                    let workers_clone = workers.clone();
111                    let workers_clone2 = workers.clone();
112
113                    let output = self.circuit().region("shard", || {
114                        let (sender, receiver) = new_exchange_operators(
115                            Some(location),
116                            || Vec::new(),
117                            move |batch: IB, batches: &mut Vec<OB>| {
118                                shard_batch(
119                                    batch,
120                                    &workers_clone,
121                                    &mut builders,
122                                    batches,
123                                    &factories_clone3,
124                                );
125                            },
126                            |batch| serialize_indexed_wset(&batch),
127                            move |data| deserialize_indexed_wset(&factories_clone4, &data),
128                            |batches: &mut Vec<OB>, batch: OB| batches.push(batch),
129                        )
130                        .unwrap();
131
132                        self.circuit()
133                            .add_exchange(sender, receiver, self)
134                            .apply_owned_named("merge shards", move |batches| {
135                                merge_batches(&factories_clone2, batches, &None, &None)
136                            })
137                    });
138
139                    self.circuit().cache_insert(
140                        ShardId::new((output.stream_id(), workers_clone2)),
141                        output.clone(),
142                    );
143
144                    self.circuit()
145                        .cache_insert(UnshardId::new(output.stream_id()), self.clone());
146
147                    output.set_persistent_id(
148                        self.get_persistent_id()
149                            .map(|name| format!("{name}.shard"))
150                            .as_deref(),
151                    )
152                },
153            )
154            .clone();
155
156        Some(output)
157    }
158}
159
160impl<C, K, V> Stream<C, Vec<Box<DynPairs<K, V>>>>
161where
162    C: Circuit,
163    K: DataTrait + ?Sized,
164    V: DataTrait + ?Sized,
165{
166    #[track_caller]
167    pub fn dyn_shard_pairs(
168        &self,
169        pairs_factory: &'static dyn Factory<DynPairs<K, V>>,
170    ) -> Stream<C, Vec<Box<DynPairs<K, V>>>> {
171        if self.is_sharded() {
172            return self.clone();
173        }
174
175        let location = Location::caller();
176
177        let (sender, receiver) = new_exchange_operators(
178            Some(location),
179            Vec::new,
180            move |input_pairs: Vec<Box<DynPairs<K, V>>>,
181                  output_pairs: &mut Vec<Box<DynPairs<K, V>>>| {
182                shard_pairs(input_pairs, &all_workers(), output_pairs, pairs_factory);
183            },
184            |batch| {
185                let mut s = Serializer::default();
186                let offset = batch.serialize(&mut s).unwrap();
187                s.serialize_value(&offset).unwrap();
188                s.into_serializer().into_inner().into_vec()
189            },
190            move |data| {
191                let offset = unsafe { archived_root::<usize>(&data) };
192                let mut output = pairs_factory.default_box();
193
194                unsafe { output.deserialize_from_bytes(&data, *offset as usize) };
195                output
196            },
197            |output_pairs: &mut Vec<Box<DynPairs<K, V>>>, batch: Box<DynPairs<K, V>>| {
198                output_pairs.push(batch);
199            },
200        )
201        .unwrap();
202
203        let output = self.circuit().add_exchange(sender, receiver, self);
204
205        output.set_persistent_id(
206            self.get_persistent_id()
207                .map(|name| format!("{name}.shard"))
208                .as_deref(),
209        );
210        output
211    }
212}
213
214// Partitions the batch into shards covering `workers` (out of
215// `all_workers()`), based on the hash of the key.
216pub fn shard_batch<IB, OB>(
217    mut batch: IB,
218    workers: &Range<usize>,
219    builders: &mut Vec<OB::Builder>,
220    outputs: &mut Vec<OB>,
221    factories: &OB::Factories,
222) where
223    IB: BatchReader<Time = ()>,
224    OB: Batch<Key = IB::Key, Val = IB::Val, Time = (), R = IB::R>,
225{
226    builders.clear();
227
228    // XXX If `shards == 1` and `OB` and `IB` are the same, then we could
229    // implement this more efficiently, without copying.
230    let shards = workers.len();
231    for _ in 0..shards {
232        // We iterate over tuples in the batch in order; hence tuples added
233        // to each shard are also ordered, so we can use the more efficient
234        // `Builder` API (instead of `Batcher`) to construct output batches.
235        builders.push(OB::Builder::with_capacity(
236            factories,
237            batch.key_count() / shards,
238            batch.len() / shards,
239        ));
240    }
241
242    let mut cursor = batch.consuming_cursor(None, None);
243    if cursor.has_mut() {
244        while cursor.key_valid() {
245            let b = &mut builders[cursor.key().default_hash() as usize % shards];
246            while cursor.val_valid() {
247                b.push_diff_mut(cursor.weight_mut());
248                b.push_val_mut(cursor.val_mut());
249                cursor.step_val();
250            }
251            b.push_key_mut(cursor.key_mut());
252            cursor.step_key();
253        }
254    } else {
255        while cursor.key_valid() {
256            let b = &mut builders[cursor.key().default_hash() as usize % shards];
257            while cursor.val_valid() {
258                b.push_diff(cursor.weight());
259                b.push_val(cursor.val());
260                cursor.step_val();
261            }
262            b.push_key(cursor.key());
263            cursor.step_key();
264        }
265    }
266    for _ in 0..workers.start {
267        outputs.push(OB::dyn_empty(factories));
268    }
269    for builder in builders.drain(..) {
270        outputs.push(builder.done());
271    }
272    for _ in workers.end..Runtime::num_workers() {
273        outputs.push(OB::dyn_empty(factories));
274    }
275}
276
277// Partitions the batch into shards covering `workers` (out of
278// `all_workers()`), based on the hash of the key.
279pub fn shard_pairs<K, V>(
280    input_pairs: Vec<Box<DynPairs<K, V>>>,
281    workers: &Range<usize>,
282    output_pairs: &mut Vec<Box<DynPairs<K, V>>>,
283    pairs_factory: &'static dyn Factory<DynPairs<K, V>>,
284) where
285    K: DataTrait + ?Sized,
286    V: DataTrait + ?Sized,
287{
288    output_pairs.clear();
289    output_pairs.resize(workers.len(), pairs_factory.default_box());
290
291    for mut pairs in input_pairs {
292        for pair in pairs.dyn_iter_mut() {
293            let k = pair.fst();
294            let shard_index = k.default_hash() as usize % workers.len();
295            output_pairs[shard_index].push_val(pair);
296        }
297    }
298}
299
300impl<C, T> Stream<C, T>
301where
302    C: Circuit,
303    T: 'static,
304{
305    /// Marks the data within the current stream as sharded, meaning that all
306    /// further calls to `.shard()` will have no effect.
307    ///
308    /// This must only be used on streams of values that are properly sharded
309    /// across workers, otherwise this will cause the dataflow to yield
310    /// incorrect results
311    pub fn mark_sharded(&self) -> Self {
312        self.circuit().cache_insert(
313            ShardId::new((self.stream_id(), all_workers())),
314            self.clone(),
315        );
316        self.clone()
317    }
318
319    /// Returns `true` if a sharded version of the current stream exists
320    pub fn has_sharded_version(&self) -> bool {
321        self.circuit()
322            .cache_contains(&ShardId::<C, T>::new((self.stream_id(), all_workers())))
323    }
324
325    /// Returns the sharded version of the stream if it exists
326    /// (which may be the stream itself or the result of applying
327    /// the `shard` operator to it).  Otherwise, returns `self`.
328    pub fn try_sharded_version(&self) -> Self {
329        self.circuit()
330            .cache_get(&ShardId::new((self.stream_id(), all_workers())))
331            .unwrap_or_else(|| self.clone())
332    }
333
334    /// Returns the unsharded version of the stream if it exists, and otherwise
335    /// `self`.
336    pub fn try_unsharded_version(&self) -> Self {
337        self.circuit()
338            .cache_get(&UnshardId::new(self.stream_id()))
339            .unwrap_or_else(|| self.clone())
340    }
341
342    /// Returns `true` if this stream is sharded.
343    pub fn is_sharded(&self) -> bool {
344        if Runtime::num_workers() == 1 {
345            return true;
346        }
347
348        self.circuit()
349            .cache_get(&ShardId::<C, T>::new((self.stream_id(), all_workers())))
350            .is_some_and(|sharded| sharded.ptr_eq(self))
351    }
352
353    /// Marks `self` as sharded if `input` has a sharded version of itself
354    pub fn mark_sharded_if<C2, U>(&self, input: &Stream<C2, U>)
355    where
356        C2: Circuit,
357        U: 'static,
358    {
359        if input.has_sharded_version() {
360            self.mark_sharded();
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use crate::{
368        Circuit, RootCircuit, Runtime, operator::Generator, trace::BatchReader,
369        typed_batch::OrdIndexedZSet, utils::Tup2,
370    };
371
372    #[test]
373    fn test_shard() {
374        do_test_shard(2);
375        do_test_shard(4);
376        do_test_shard(16);
377    }
378
379    fn test_data(worker_index: usize, num_workers: usize) -> OrdIndexedZSet<u64, u64> {
380        let tuples: Vec<_> = (0..1000)
381            .filter(|n| n % num_workers == worker_index)
382            .flat_map(|n| {
383                vec![
384                    Tup2(Tup2(n as u64, n as u64), 1i64),
385                    Tup2(Tup2(n as u64, 1000 * n as u64), 1),
386                ]
387            })
388            .collect();
389        <OrdIndexedZSet<u64, u64>>::from_tuples((), tuples)
390    }
391
392    fn do_test_shard(workers: usize) {
393        let hruntime = Runtime::run(workers, |_parker| {
394            let circuit = RootCircuit::build(move |circuit| {
395                let input = circuit.add_source(Generator::new(|| {
396                    let worker_index = Runtime::worker_index();
397                    let num_workers = Runtime::num_workers();
398                    test_data(worker_index, num_workers)
399                }));
400                input
401                    .shard()
402                    .gather(0)
403                    .inspect(|batch: &OrdIndexedZSet<u64, u64>| {
404                        if Runtime::worker_index() == 0 {
405                            assert_eq!(batch, &test_data(0, 1))
406                        } else {
407                            assert_eq!(batch.len(), 0);
408                        }
409                    });
410                Ok(())
411            })
412            .unwrap()
413            .0;
414
415            for _ in 0..3 {
416                circuit.transaction().unwrap();
417            }
418        })
419        .expect("failed to run runtime");
420
421        hruntime.join().unwrap();
422    }
423}