1use rkyv::{archived_root, ser::Serializer as _};
8
9use crate::{
10 Circuit, Runtime, Stream,
11 circuit::circuit_builder::StreamId,
12 circuit_cache_key,
13 dynamic::{Data, DataTrait, DynPairs, Factory},
14 operator::communication::new_exchange_operators,
15 trace::{
16 Batch, BatchReader, Builder, Serializer, deserialize_indexed_wset, merge_batches,
17 serialize_indexed_wset,
18 },
19};
20
21use std::{ops::Range, panic::Location};
22
23circuit_cache_key!(ShardId<C, D>((StreamId, Range<usize>) => Stream<C, D>));
24circuit_cache_key!(UnshardId<C, D>(StreamId => Stream<C, D>));
25
26fn all_workers() -> Range<usize> {
27 0..Runtime::num_workers()
28}
29
30impl<C, IB> Stream<C, IB>
31where
32 C: Circuit,
33 IB: BatchReader<Time = ()> + Clone,
34{
35 #[track_caller]
37 pub fn dyn_shard(&self, factories: &IB::Factories) -> Stream<C, IB>
38 where
39 IB: Batch + Send,
40 {
41 self.dyn_shard_generic(factories)
46 .unwrap_or_else(|| self.clone())
47 }
48
49 #[track_caller]
51 pub fn dyn_shard_workers(
52 &self,
53 workers: Range<usize>,
54 factories: &IB::Factories,
55 ) -> Stream<C, IB>
56 where
57 IB: Batch + Send,
58 {
59 self.dyn_shard_generic_workers(workers, factories)
65 .unwrap_or_else(|| self.clone())
66 }
67
68 #[track_caller]
74 pub fn dyn_shard_generic<OB>(&self, factories: &OB::Factories) -> Option<Stream<C, OB>>
75 where
76 OB: Batch<Key = IB::Key, Val = IB::Val, Time = (), R = IB::R> + Send,
77 {
78 self.dyn_shard_generic_workers(all_workers(), factories)
79 }
80
81 #[track_caller]
87 pub fn dyn_shard_generic_workers<OB>(
88 &self,
89 workers: Range<usize>,
90 factories: &OB::Factories,
91 ) -> Option<Stream<C, OB>>
92 where
93 OB: Batch<Key = IB::Key, Val = IB::Val, Time = (), R = IB::R> + Send,
94 {
95 if Runtime::num_workers() == 1 {
96 return None;
97 }
98 let location = Location::caller();
99 let output = self
100 .circuit()
101 .cache_get_or_insert_with(
102 ShardId::new((self.stream_id(), workers.clone())),
103 move || {
104 let mut builders = Vec::with_capacity(Runtime::num_workers());
107 let factories_clone2 = factories.clone();
108 let factories_clone3 = factories.clone();
109 let factories_clone4 = factories.clone();
110 let workers_clone = workers.clone();
111 let workers_clone2 = workers.clone();
112
113 let output = self.circuit().region("shard", || {
114 let (sender, receiver) = new_exchange_operators(
115 Some(location),
116 || Vec::new(),
117 move |batch: IB, batches: &mut Vec<OB>| {
118 shard_batch(
119 batch,
120 &workers_clone,
121 &mut builders,
122 batches,
123 &factories_clone3,
124 );
125 },
126 |batch| serialize_indexed_wset(&batch),
127 move |data| deserialize_indexed_wset(&factories_clone4, &data),
128 |batches: &mut Vec<OB>, batch: OB| batches.push(batch),
129 )
130 .unwrap();
131
132 self.circuit()
133 .add_exchange(sender, receiver, self)
134 .apply_owned_named("merge shards", move |batches| {
135 merge_batches(&factories_clone2, batches, &None, &None)
136 })
137 });
138
139 self.circuit().cache_insert(
140 ShardId::new((output.stream_id(), workers_clone2)),
141 output.clone(),
142 );
143
144 self.circuit()
145 .cache_insert(UnshardId::new(output.stream_id()), self.clone());
146
147 output.set_persistent_id(
148 self.get_persistent_id()
149 .map(|name| format!("{name}.shard"))
150 .as_deref(),
151 )
152 },
153 )
154 .clone();
155
156 Some(output)
157 }
158}
159
160impl<C, K, V> Stream<C, Vec<Box<DynPairs<K, V>>>>
161where
162 C: Circuit,
163 K: DataTrait + ?Sized,
164 V: DataTrait + ?Sized,
165{
166 #[track_caller]
167 pub fn dyn_shard_pairs(
168 &self,
169 pairs_factory: &'static dyn Factory<DynPairs<K, V>>,
170 ) -> Stream<C, Vec<Box<DynPairs<K, V>>>> {
171 if self.is_sharded() {
172 return self.clone();
173 }
174
175 let location = Location::caller();
176
177 let (sender, receiver) = new_exchange_operators(
178 Some(location),
179 Vec::new,
180 move |input_pairs: Vec<Box<DynPairs<K, V>>>,
181 output_pairs: &mut Vec<Box<DynPairs<K, V>>>| {
182 shard_pairs(input_pairs, &all_workers(), output_pairs, pairs_factory);
183 },
184 |batch| {
185 let mut s = Serializer::default();
186 let offset = batch.serialize(&mut s).unwrap();
187 s.serialize_value(&offset).unwrap();
188 s.into_serializer().into_inner().into_vec()
189 },
190 move |data| {
191 let offset = unsafe { archived_root::<usize>(&data) };
192 let mut output = pairs_factory.default_box();
193
194 unsafe { output.deserialize_from_bytes(&data, *offset as usize) };
195 output
196 },
197 |output_pairs: &mut Vec<Box<DynPairs<K, V>>>, batch: Box<DynPairs<K, V>>| {
198 output_pairs.push(batch);
199 },
200 )
201 .unwrap();
202
203 let output = self.circuit().add_exchange(sender, receiver, self);
204
205 output.set_persistent_id(
206 self.get_persistent_id()
207 .map(|name| format!("{name}.shard"))
208 .as_deref(),
209 );
210 output
211 }
212}
213
214pub fn shard_batch<IB, OB>(
217 mut batch: IB,
218 workers: &Range<usize>,
219 builders: &mut Vec<OB::Builder>,
220 outputs: &mut Vec<OB>,
221 factories: &OB::Factories,
222) where
223 IB: BatchReader<Time = ()>,
224 OB: Batch<Key = IB::Key, Val = IB::Val, Time = (), R = IB::R>,
225{
226 builders.clear();
227
228 let shards = workers.len();
231 for _ in 0..shards {
232 builders.push(OB::Builder::with_capacity(
236 factories,
237 batch.key_count() / shards,
238 batch.len() / shards,
239 ));
240 }
241
242 let mut cursor = batch.consuming_cursor(None, None);
243 if cursor.has_mut() {
244 while cursor.key_valid() {
245 let b = &mut builders[cursor.key().default_hash() as usize % shards];
246 while cursor.val_valid() {
247 b.push_diff_mut(cursor.weight_mut());
248 b.push_val_mut(cursor.val_mut());
249 cursor.step_val();
250 }
251 b.push_key_mut(cursor.key_mut());
252 cursor.step_key();
253 }
254 } else {
255 while cursor.key_valid() {
256 let b = &mut builders[cursor.key().default_hash() as usize % shards];
257 while cursor.val_valid() {
258 b.push_diff(cursor.weight());
259 b.push_val(cursor.val());
260 cursor.step_val();
261 }
262 b.push_key(cursor.key());
263 cursor.step_key();
264 }
265 }
266 for _ in 0..workers.start {
267 outputs.push(OB::dyn_empty(factories));
268 }
269 for builder in builders.drain(..) {
270 outputs.push(builder.done());
271 }
272 for _ in workers.end..Runtime::num_workers() {
273 outputs.push(OB::dyn_empty(factories));
274 }
275}
276
277pub fn shard_pairs<K, V>(
280 input_pairs: Vec<Box<DynPairs<K, V>>>,
281 workers: &Range<usize>,
282 output_pairs: &mut Vec<Box<DynPairs<K, V>>>,
283 pairs_factory: &'static dyn Factory<DynPairs<K, V>>,
284) where
285 K: DataTrait + ?Sized,
286 V: DataTrait + ?Sized,
287{
288 output_pairs.clear();
289 output_pairs.resize(workers.len(), pairs_factory.default_box());
290
291 for mut pairs in input_pairs {
292 for pair in pairs.dyn_iter_mut() {
293 let k = pair.fst();
294 let shard_index = k.default_hash() as usize % workers.len();
295 output_pairs[shard_index].push_val(pair);
296 }
297 }
298}
299
300impl<C, T> Stream<C, T>
301where
302 C: Circuit,
303 T: 'static,
304{
305 pub fn mark_sharded(&self) -> Self {
312 self.circuit().cache_insert(
313 ShardId::new((self.stream_id(), all_workers())),
314 self.clone(),
315 );
316 self.clone()
317 }
318
319 pub fn has_sharded_version(&self) -> bool {
321 self.circuit()
322 .cache_contains(&ShardId::<C, T>::new((self.stream_id(), all_workers())))
323 }
324
325 pub fn try_sharded_version(&self) -> Self {
329 self.circuit()
330 .cache_get(&ShardId::new((self.stream_id(), all_workers())))
331 .unwrap_or_else(|| self.clone())
332 }
333
334 pub fn try_unsharded_version(&self) -> Self {
337 self.circuit()
338 .cache_get(&UnshardId::new(self.stream_id()))
339 .unwrap_or_else(|| self.clone())
340 }
341
342 pub fn is_sharded(&self) -> bool {
344 if Runtime::num_workers() == 1 {
345 return true;
346 }
347
348 self.circuit()
349 .cache_get(&ShardId::<C, T>::new((self.stream_id(), all_workers())))
350 .is_some_and(|sharded| sharded.ptr_eq(self))
351 }
352
353 pub fn mark_sharded_if<C2, U>(&self, input: &Stream<C2, U>)
355 where
356 C2: Circuit,
357 U: 'static,
358 {
359 if input.has_sharded_version() {
360 self.mark_sharded();
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use crate::{
368 Circuit, RootCircuit, Runtime, operator::Generator, trace::BatchReader,
369 typed_batch::OrdIndexedZSet, utils::Tup2,
370 };
371
372 #[test]
373 fn test_shard() {
374 do_test_shard(2);
375 do_test_shard(4);
376 do_test_shard(16);
377 }
378
379 fn test_data(worker_index: usize, num_workers: usize) -> OrdIndexedZSet<u64, u64> {
380 let tuples: Vec<_> = (0..1000)
381 .filter(|n| n % num_workers == worker_index)
382 .flat_map(|n| {
383 vec![
384 Tup2(Tup2(n as u64, n as u64), 1i64),
385 Tup2(Tup2(n as u64, 1000 * n as u64), 1),
386 ]
387 })
388 .collect();
389 <OrdIndexedZSet<u64, u64>>::from_tuples((), tuples)
390 }
391
392 fn do_test_shard(workers: usize) {
393 let hruntime = Runtime::run(workers, |_parker| {
394 let circuit = RootCircuit::build(move |circuit| {
395 let input = circuit.add_source(Generator::new(|| {
396 let worker_index = Runtime::worker_index();
397 let num_workers = Runtime::num_workers();
398 test_data(worker_index, num_workers)
399 }));
400 input
401 .shard()
402 .gather(0)
403 .inspect(|batch: &OrdIndexedZSet<u64, u64>| {
404 if Runtime::worker_index() == 0 {
405 assert_eq!(batch, &test_data(0, 1))
406 } else {
407 assert_eq!(batch.len(), 0);
408 }
409 });
410 Ok(())
411 })
412 .unwrap()
413 .0;
414
415 for _ in 0..3 {
416 circuit.transaction().unwrap();
417 }
418 })
419 .expect("failed to run runtime");
420
421 hruntime.join().unwrap();
422 }
423}