1use std::{
12 any::{Any, TypeId},
13 hash::{Hash, Hasher},
14 ops::Deref,
15};
16
17use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
18use crossbeam_queue::SegQueue;
19use dashmap::SharedValue;
20use rayon::{
21 iter::{ParallelBridge, ParallelIterator},
22 prelude::*,
23};
24use rustc_hash::FxHasher;
25
26use crate::{
27 ColumnId, CounterId, ExecutionState, Offset, SubsetRef, TableId, TaggedRowBuffer, Value,
28 WrappedTable,
29 common::{DashMap, IndexSet, InternTable, SubsetTracker},
30 parallel_heuristics::{parallelize_inter_container_op, parallelize_intra_container_op},
31 table_spec::Rebuilder,
32};
33
34#[cfg(test)]
35mod tests;
36
37define_id!(pub ContainerValueId, u32, "an identifier for containers");
38
39pub trait MergeFn:
40 Fn(&mut ExecutionState, Value, Value) -> Value + dyn_clone::DynClone + Send + Sync
41{
42}
43impl<T: Fn(&mut ExecutionState, Value, Value) -> Value + Clone + Send + Sync> MergeFn for T {}
44
45dyn_clone::clone_trait_object!(MergeFn);
47
48#[derive(Clone, Default)]
49pub struct ContainerValues {
50 subset_tracker: SubsetTracker,
51 container_ids: InternTable<TypeId, ContainerValueId>,
52 data: DenseIdMap<ContainerValueId, Box<dyn DynamicContainerEnv + Send + Sync>>,
53}
54
55impl ContainerValues {
56 pub fn new() -> Self {
57 Default::default()
58 }
59
60 fn get<C: ContainerValue>(&self) -> Option<&ContainerEnv<C>> {
61 let id = self.container_ids.intern(&TypeId::of::<C>());
62 let res = self.data.get(id)?.as_any();
63 Some(res.downcast_ref::<ContainerEnv<C>>().unwrap())
64 }
65
66 pub fn for_each<C: ContainerValue>(&self, mut f: impl FnMut(&C, Value)) {
68 let Some(env) = self.get::<C>() else {
69 return;
70 };
71 for ent in env.to_id.iter() {
72 f(ent.key(), *ent.value());
73 }
74 }
75
76 pub fn get_val<C: ContainerValue>(&self, val: Value) -> Option<impl Deref<Target = C> + '_> {
82 self.get::<C>()?.get_container(val)
83 }
84
85 pub fn register_val<C: ContainerValue>(
86 &self,
87 container: C,
88 exec_state: &mut ExecutionState,
89 ) -> Value {
90 let env = self
91 .get::<C>()
92 .expect("must register container type before registering a value");
93 env.get_or_insert(&container, exec_state)
94 }
95
96 pub fn rebuild_all(
98 &mut self,
99 table_id: TableId,
100 table: &WrappedTable,
101 exec_state: &mut ExecutionState,
102 ) -> bool {
103 let Some(rebuilder) = table.rebuilder(&[]) else {
104 return false;
105 };
106 let to_scan = rebuilder.hint_col().map(|_| {
107 self.subset_tracker.recent_updates(table_id, table)
109 });
110 if parallelize_inter_container_op(self.data.next_id().index()) {
111 self.data
112 .iter_mut()
113 .zip(std::iter::repeat_with(|| exec_state.clone()))
114 .par_bridge()
115 .map(|((_, env), mut exec_state)| {
116 env.apply_rebuild(
117 table,
118 &*rebuilder,
119 to_scan.as_ref().map(|x| x.as_ref()),
120 &mut exec_state,
121 )
122 })
123 .max()
124 .unwrap_or(false)
125 } else {
126 let mut changed = false;
127 for (_, env) in self.data.iter_mut() {
128 changed |= env.apply_rebuild(
129 table,
130 &*rebuilder,
131 to_scan.as_ref().map(|x| x.as_ref()),
132 exec_state,
133 );
134 }
135 changed
136 }
137 }
138
139 pub fn register_type<C: ContainerValue>(
144 &mut self,
145 id_counter: CounterId,
146 merge_fn: impl MergeFn + 'static,
147 ) -> ContainerValueId {
148 let id = self.container_ids.intern(&TypeId::of::<C>());
149 self.data.get_or_insert(id, || {
150 Box::new(ContainerEnv::<C>::new(Box::new(merge_fn), id_counter))
151 });
152 id
153 }
154}
155
156pub trait ContainerValue: Hash + Eq + Clone + Send + Sync + 'static {
162 fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool;
167
168 fn iter(&self) -> impl Iterator<Item = Value> + '_;
175}
176
177pub trait DynamicContainerEnv: Any + dyn_clone::DynClone + Send + Sync {
178 fn as_any(&self) -> &dyn Any;
179 fn apply_rebuild(
180 &mut self,
181 table: &WrappedTable,
182 rebuilder: &dyn Rebuilder,
183 subset: Option<SubsetRef>,
184 exec_state: &mut ExecutionState,
185 ) -> bool;
186}
187
188dyn_clone::clone_trait_object!(DynamicContainerEnv);
190
191fn hash_container(container: &impl ContainerValue) -> u64 {
192 let mut hasher = FxHasher::default();
193 container.hash(&mut hasher);
194 hasher.finish()
195}
196
197#[derive(Clone)]
198struct ContainerEnv<C: Eq + Hash> {
199 merge_fn: Box<dyn MergeFn>,
200 counter: CounterId,
201 to_id: DashMap<C, Value>,
202 to_container: DashMap<Value, (usize , usize )>,
203 val_index: DashMap<Value, IndexSet<Value>>,
205}
206
207impl<C: ContainerValue> DynamicContainerEnv for ContainerEnv<C> {
208 fn as_any(&self) -> &dyn Any {
209 self
210 }
211
212 fn apply_rebuild(
213 &mut self,
214 table: &WrappedTable,
215 rebuilder: &dyn Rebuilder,
216 subset: Option<SubsetRef>,
217 exec_state: &mut ExecutionState,
218 ) -> bool {
219 if let Some(subset) = subset {
220 if incremental_rebuild(
221 subset.size(),
222 self.to_id.len(),
223 parallelize_intra_container_op(self.to_id.len()),
224 ) {
225 return self.apply_rebuild_incremental(
226 table,
227 rebuilder,
228 exec_state,
229 subset,
230 rebuilder.hint_col().unwrap(),
231 );
232 }
233 }
234 self.apply_rebuild_nonincremental(rebuilder, exec_state)
235 }
236}
237
238impl<C: ContainerValue> ContainerEnv<C> {
239 pub fn new(merge_fn: Box<dyn MergeFn>, counter: CounterId) -> Self {
240 Self {
241 merge_fn,
242 counter,
243 to_id: DashMap::default(),
244 to_container: DashMap::default(),
245 val_index: DashMap::default(),
246 }
247 }
248
249 fn get_or_insert(&self, container: &C, exec_state: &mut ExecutionState) -> Value {
250 if let Some(value) = self.to_id.get(container) {
251 return *value;
252 }
253
254 let value = Value::from_usize(exec_state.inc_counter(self.counter));
259 let target_map = self.to_id.determine_map(container);
260 debug_assert_eq!(
264 target_map,
265 self.to_container
266 .determine_shard(hash_container(container) as usize)
267 );
268 self.to_container
269 .insert(value, (hash_container(container) as usize, target_map));
270
271 match self.to_id.entry(container.clone()) {
274 dashmap::Entry::Vacant(vac) => {
275 vac.insert(value);
277 for val in container.iter() {
278 self.val_index.entry(val).or_default().insert(value);
279 }
280 value
281 }
282 dashmap::Entry::Occupied(occ) => {
283 let res = *occ.get();
287 std::mem::drop(occ); self.to_container.remove(&value);
289 res
290 }
291 }
292 }
293
294 fn insert_owned(&self, container: C, value: Value, exec_state: &mut ExecutionState) {
295 let hc = hash_container(&container);
296 let target_map = self.to_id.determine_map(&container);
297 match self.to_id.entry(container) {
298 dashmap::Entry::Occupied(mut occ) => {
299 let result = (self.merge_fn)(exec_state, *occ.get(), value);
300 let old_val = *occ.get();
301 if result != old_val {
302 self.to_container.remove(&old_val);
303 self.to_container.insert(result, (hc as usize, target_map));
304 *occ.get_mut() = result;
305 for val in occ.key().iter() {
306 let mut index = self.val_index.entry(val).or_default();
307 index.swap_remove(&old_val);
308 index.insert(result);
309 }
310 }
311 }
312 dashmap::Entry::Vacant(vacant_entry) => {
313 self.to_container.insert(value, (hc as usize, target_map));
314 for val in vacant_entry.key().iter() {
315 self.val_index.entry(val).or_default().insert(value);
316 }
317 vacant_entry.insert(value);
318 }
319 }
320 }
321 fn apply_rebuild_incremental(
322 &mut self,
323 table: &WrappedTable,
324 rebuilder: &dyn Rebuilder,
325 exec_state: &mut ExecutionState,
326 to_scan: SubsetRef,
327 search_col: ColumnId,
328 ) -> bool {
329 let mut changed = false;
336 let mut buf = TaggedRowBuffer::new(1);
337 table.scan_project(
338 to_scan,
339 &[search_col],
340 Offset::new(0),
341 usize::MAX,
342 &[],
343 &mut buf,
344 );
345 let mut to_rebuild = IndexSet::<Value>::default();
347 for (_, row) in buf.iter() {
348 to_rebuild.insert(row[0]);
349 let Some(ids) = self.val_index.get(&row[0]) else {
350 continue;
351 };
352 to_rebuild.extend(&*ids);
353 }
354 for id in to_rebuild {
355 let Some((hc, target_map)) = self.to_container.get(&id).map(|x| *x) else {
356 continue;
357 };
358 let shard_mut = self.to_id.shards_mut()[target_map].get_mut();
359 let Some((mut container, _)) =
360 shard_mut.remove_entry(hc as u64, |(_, v)| *v.get() == id)
361 else {
362 continue;
363 };
364 changed |= container.rebuild_contents(rebuilder);
365 self.insert_owned(container, id, exec_state);
366 }
367 changed
368 }
369
370 fn apply_rebuild_nonincremental(
371 &mut self,
372 rebuilder: &dyn Rebuilder,
373 exec_state: &mut ExecutionState,
374 ) -> bool {
375 if parallelize_inter_container_op(self.to_id.len()) {
376 return self.apply_rebuild_nonincremental_parallel(rebuilder, exec_state);
377 }
378 let mut changed = false;
379 let mut to_reinsert = Vec::new();
380 let shards = self.to_id.shards_mut();
381 for shard in shards.iter_mut() {
382 let shard = shard.get_mut();
383 for bucket in unsafe { shard.iter() } {
385 let (container, val) = unsafe { bucket.as_mut() };
387 let old_val = *val.get();
388 let new_val = rebuilder.rebuild_val(old_val);
389 let container_changed = container.rebuild_contents(rebuilder);
390 if !container_changed && new_val == old_val {
391 continue;
393 }
394 changed = true;
395 if container_changed {
396 let ((container, _), _) = unsafe { shard.remove(bucket) };
400 self.to_container.remove(&old_val);
401 to_reinsert.push((container, new_val));
402 } else {
403 *val.get_mut() = new_val;
405 let prev = self.to_container.remove(&old_val).unwrap().1;
406 self.to_container.insert(new_val, prev);
407 }
408 }
409 }
410 for (container, val) in to_reinsert {
411 self.insert_owned(container, val, exec_state);
412 }
413 changed
414 }
415
416 fn apply_rebuild_nonincremental_parallel(
417 &mut self,
418 rebuilder: &dyn Rebuilder,
419 exec_state: &mut ExecutionState,
420 ) -> bool {
421 let mut to_reinsert = IdVec::<usize , SegQueue<(C, Value)>>::default();
426 to_reinsert.resize_with(self.to_id.shards().len(), Default::default);
427
428 let shards = self.to_id.shards_mut();
429 let changed = shards
430 .par_iter_mut()
431 .map(|shard| {
432 let mut changed = false;
433 let shard = shard.get_mut();
434 for bucket in unsafe { shard.iter() } {
436 let (container, val) = unsafe { bucket.as_mut() };
438 let old_val = *val.get();
439 let new_val = rebuilder.rebuild_val(old_val);
440 let container_changed = container.rebuild_contents(rebuilder);
441 if !container_changed && new_val == old_val {
442 continue;
444 }
445 changed = true;
446 if container_changed {
447 let ((container, _), _) = unsafe { shard.remove(bucket) };
451 self.to_container.remove(&old_val);
452 let shard = self
457 .to_container
458 .determine_shard(hash_container(&container) as usize);
459 to_reinsert[shard].push((container, new_val));
460 } else {
461 *val.get_mut() = new_val;
463 let prev = self.to_container.remove(&old_val).unwrap().1;
464 self.to_container.insert(new_val, prev);
465 }
466 }
467 changed
468 })
469 .max()
470 .unwrap_or(false);
471
472 shards
473 .iter_mut()
474 .enumerate()
475 .map(|(i, shard)| (i, shard, exec_state.clone()))
476 .par_bridge()
477 .for_each(|(shard_id, shard, mut exec_state)| {
478 let shard = shard.get_mut();
484 let queue = &to_reinsert[shard_id];
485 while let Some((container, val)) = queue.pop() {
486 let hc = hash_container(&container);
487 let target_map = self.to_container.determine_shard(hc as usize);
488 match shard.find_or_find_insert_slot(
489 hc,
490 |(c, _)| c == &container,
491 |(c, _)| hash_container(c),
492 ) {
493 Ok(bucket) => {
494 let (container, val_slot) = unsafe { bucket.as_mut() };
497 let old_val = *val_slot.get();
498 let result = (self.merge_fn)(&mut exec_state, old_val, val);
499 if result != old_val {
500 self.to_container.remove(&old_val);
501 self.to_container.insert(result, (hc as usize, target_map));
502 *val_slot.get_mut() = result;
503 for val in container.iter() {
504 let mut index = self.val_index.entry(val).or_default();
505 index.swap_remove(&old_val);
506 index.insert(result);
507 }
508 }
509 }
510 Err(slot) => {
511 self.to_container.insert(val, (hc as usize, target_map));
512 for v in container.iter() {
513 self.val_index.entry(v).or_default().insert(val);
514 }
515 unsafe {
518 shard.insert_in_slot(hc, slot, (container, SharedValue::new(val)));
519 }
520 }
521 }
522 }
523 });
524 changed
525 }
526
527 fn get_container(&self, value: Value) -> Option<impl Deref<Target = C> + '_> {
528 let (hc, target_map) = *self.to_container.get(&value)?;
529 let shard = &self.to_id.shards()[target_map];
530 let read_guard = shard.read();
531 let val_ptr: *const (C, _) = shard
532 .read()
533 .find(hc as u64, |(_, v)| *v.get() == value)?
534 .as_ptr();
535 struct ValueDeref<'a, T, Guard> {
536 _guard: Guard,
537 data: &'a T,
538 }
539
540 impl<T, Guard> Deref for ValueDeref<'_, T, Guard> {
541 type Target = T;
542
543 fn deref(&self) -> &T {
544 self.data
545 }
546 }
547
548 Some(ValueDeref {
549 _guard: read_guard,
550 data: unsafe {
552 let unwrapped: &(C, _) = &*val_ptr;
553 &unwrapped.0
554 },
555 })
556 }
557}
558
559fn incremental_rebuild(_uf_size: usize, _table_size: usize, _parallel: bool) -> bool {
560 #[cfg(debug_assertions)]
561 {
562 use rand::Rng;
563 rand::rng().random_bool(0.5)
564 }
565 #[cfg(not(debug_assertions))]
566 {
567 if _parallel {
568 _table_size > 1000 && _uf_size * 512 <= _table_size
569 } else {
570 _table_size > 1000 && _uf_size * 8 <= _table_size
571 }
572 }
573}