Skip to main content

dbsp/operator/communication/
shard.rs

1use std::ops::Range;
2
3use crate::{
4    Circuit, Stream,
5    trace::BatchReaderFactories,
6    typed_batch::{Batch, Spine},
7};
8
9impl<C, IB> Stream<C, IB>
10where
11    C: Circuit,
12    IB: Batch<Time = ()>,
13    IB::InnerBatch: Send,
14{
15    /// Shard batches across multiple worker threads based on keys.
16    ///
17    /// # Theory
18    ///
19    /// We parallelize processing across `N` worker threads by creating a
20    /// replica of the same circuit per thread and sharding data across
21    /// replicas.  To ensure correctness (i.e., that the sum of outputs
22    /// produced by individual workers is equal to the output produced
23    /// by processing the entire dataset by one worker), sharding must satisfy
24    /// certain requirements determined by each operator.  In particular,
25    /// for `distinct`, and `aggregate` all tuples that share the same key
26    /// must be processed by the same worker.  For `join`, tuples from both
27    /// input streams with the same key must be processed by the same worker.
28    ///
29    /// Other operators, e.g., `filter` and `flat_map`, impose no restrictions
30    /// on the sharding scheme: as long as each tuple in a batch is
31    /// processed by some worker, the correct result will be produced.  This
32    /// is true for all linear operators.
33    ///
34    /// The `shard` operator shards input batches based on the hash of the key,
35    /// making sure that tuples with the same key always end up at the same
36    /// worker.  More precisely, the operator **re-shards** its input by
37    /// partitioning batches in the input stream of each worker based on the
38    /// hash of the key, distributing resulting fragments among peers
39    /// and re-assembling fragments at each peer:
40    ///
41    /// ```text
42    ///         ┌──────────────────┐
43    /// worker1 │                  │
44    /// ───────►├─────┬───────────►├──────►
45    ///         │     │            │
46    /// ───────►├─────┴───────────►├──────►
47    /// worker2 │                  │
48    ///         └──────────────────┘
49    /// ```
50    ///
51    /// # Usage
52    ///
53    /// Most users do not need to invoke `shard` directly (and doing so is
54    /// likely to lead to incorrect results unless you know exactly what you
55    /// are doing).  Instead, each operator re-shards its inputs as
56    /// necessary, e.g., `join` applies `shard` to both of its
57    /// input streams, while `filter` consumes its input directly without
58    /// re-sharding.
59    ///
60    /// # Performance considerations
61    ///
62    /// In the current implementation, the `shard` operator introduces a
63    /// synchronization barrier across all workers: its output at any worker
64    /// is only produced once input batches have been collected from all
65    /// workers.  This limits the scalability since a slow worker (e.g., running
66    /// on a busy CPU core or sharing the core with other workers) or uneven
67    /// sharding can slow down the whole system and reduce gains from
68    /// parallelization.
69    pub fn shard(&self) -> Stream<C, IB> {
70        let factories = BatchReaderFactories::new::<IB::Key, IB::Val, IB::R>();
71        self.inner().dyn_shard(&factories).typed()
72    }
73
74    /// Shard batch across just the specified range of `workers`.
75    ///
76    /// If `workers` contains just one worker, then [Stream::gather] is more
77    /// efficient.
78    pub fn shard_workers(&self, workers: Range<usize>) -> Stream<C, IB> {
79        let factories = BatchReaderFactories::new::<IB::Key, IB::Val, IB::R>();
80        self.inner().dyn_shard_workers(workers, &factories).typed()
81    }
82}
83
84impl<C, B> Stream<C, B>
85where
86    C: Circuit,
87    B: Batch<Time = ()>,
88{
89    #[track_caller]
90    pub fn shard_accumulate(&self) -> Stream<C, Option<Spine<B>>> {
91        let factories = BatchReaderFactories::new::<B::Key, B::Val, B::R>();
92
93        let result = self.inner().dyn_shard_accumulate(&factories);
94
95        unsafe { result.transmute_payload() }
96    }
97}