egglog_core_relations/containers/
mod.rs

1//! Support for containers
2//!
3//! Containers behave a lot like base values. They are implemented differently because
4//! their ids share a space with other Ids in the egraph and as a result, their ids need to be
5//! sparse.
6//!
7//! This is a relatively "eagler" implementation of containers, reflecting egglog's current
8//! semantics. One could imagine a variant of containers in which they behave more like egglog
9//! functions than base values.
10
11use std::{
12    any::{Any, TypeId},
13    hash::{Hash, Hasher},
14    ops::Deref,
15};
16
17use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
18use crossbeam_queue::SegQueue;
19use dashmap::SharedValue;
20use rayon::{
21    iter::{ParallelBridge, ParallelIterator},
22    prelude::*,
23};
24use rustc_hash::FxHasher;
25
26use crate::{
27    ColumnId, CounterId, ExecutionState, Offset, SubsetRef, TableId, TaggedRowBuffer, Value,
28    WrappedTable,
29    common::{DashMap, IndexSet, InternTable, SubsetTracker},
30    parallel_heuristics::{parallelize_inter_container_op, parallelize_intra_container_op},
31    table_spec::Rebuilder,
32};
33
34#[cfg(test)]
35mod tests;
36
37define_id!(pub ContainerValueId, u32, "an identifier for containers");
38
39pub trait MergeFn:
40    Fn(&mut ExecutionState, Value, Value) -> Value + dyn_clone::DynClone + Send + Sync
41{
42}
43impl<T: Fn(&mut ExecutionState, Value, Value) -> Value + Clone + Send + Sync> MergeFn for T {}
44
45// Implements `Clone` for `Box<dyn MergeFn>`.
46dyn_clone::clone_trait_object!(MergeFn);
47
48#[derive(Clone, Default)]
49pub struct ContainerValues {
50    subset_tracker: SubsetTracker,
51    container_ids: InternTable<TypeId, ContainerValueId>,
52    data: DenseIdMap<ContainerValueId, Box<dyn DynamicContainerEnv + Send + Sync>>,
53}
54
55impl ContainerValues {
56    pub fn new() -> Self {
57        Default::default()
58    }
59
60    fn get<C: ContainerValue>(&self) -> Option<&ContainerEnv<C>> {
61        let id = self.container_ids.intern(&TypeId::of::<C>());
62        let res = self.data.get(id)?.as_any();
63        Some(res.downcast_ref::<ContainerEnv<C>>().unwrap())
64    }
65
66    /// Iterate over the containers of the given type.
67    pub fn for_each<C: ContainerValue>(&self, mut f: impl FnMut(&C, Value)) {
68        let Some(env) = self.get::<C>() else {
69            return;
70        };
71        for ent in env.to_id.iter() {
72            f(ent.key(), *ent.value());
73        }
74    }
75
76    /// Get the container associated with the value `val` in the database. The caller must know the
77    /// type of the container.
78    ///
79    /// The return type of this function may contain lock guards. Attempts to modify the contents
80    /// of the containers database may deadlock if the given guard has not been dropped.
81    pub fn get_val<C: ContainerValue>(&self, val: Value) -> Option<impl Deref<Target = C> + '_> {
82        self.get::<C>()?.get_container(val)
83    }
84
85    pub fn register_val<C: ContainerValue>(
86        &self,
87        container: C,
88        exec_state: &mut ExecutionState,
89    ) -> Value {
90        let env = self
91            .get::<C>()
92            .expect("must register container type before registering a value");
93        env.get_or_insert(&container, exec_state)
94    }
95
96    /// Apply the given rebuild to the contents of each container.
97    pub fn rebuild_all(
98        &mut self,
99        table_id: TableId,
100        table: &WrappedTable,
101        exec_state: &mut ExecutionState,
102    ) -> bool {
103        let Some(rebuilder) = table.rebuilder(&[]) else {
104            return false;
105        };
106        let to_scan = rebuilder.hint_col().map(|_| {
107            // We may attempt an incremental rebuild.
108            self.subset_tracker.recent_updates(table_id, table)
109        });
110        if parallelize_inter_container_op(self.data.next_id().index()) {
111            self.data
112                .iter_mut()
113                .zip(std::iter::repeat_with(|| exec_state.clone()))
114                .par_bridge()
115                .map(|((_, env), mut exec_state)| {
116                    env.apply_rebuild(
117                        table,
118                        &*rebuilder,
119                        to_scan.as_ref().map(|x| x.as_ref()),
120                        &mut exec_state,
121                    )
122                })
123                .max()
124                .unwrap_or(false)
125        } else {
126            let mut changed = false;
127            for (_, env) in self.data.iter_mut() {
128                changed |= env.apply_rebuild(
129                    table,
130                    &*rebuilder,
131                    to_scan.as_ref().map(|x| x.as_ref()),
132                    exec_state,
133                );
134            }
135            changed
136        }
137    }
138
139    /// Add a new container type to the given [`Containers`] instance.
140    ///
141    /// Container types need a meaans of generating fresh ids (`id_counter`) along with a means of
142    /// merging conflicting ids (`merge_fn`).
143    pub fn register_type<C: ContainerValue>(
144        &mut self,
145        id_counter: CounterId,
146        merge_fn: impl MergeFn + 'static,
147    ) -> ContainerValueId {
148        let id = self.container_ids.intern(&TypeId::of::<C>());
149        self.data.get_or_insert(id, || {
150            Box::new(ContainerEnv::<C>::new(Box::new(merge_fn), id_counter))
151        });
152        id
153    }
154}
155
156/// A trait implemented by container types.
157///
158/// Containers behave a lot like base values, but they include extra trait methods to support
159/// rebuilding of container contents and merging containers that become equal after a rebuild pass
160/// has taken place.
161pub trait ContainerValue: Hash + Eq + Clone + Send + Sync + 'static {
162    /// Rebuild an additional container in place according the the given [`Rebuilder`].
163    ///
164    /// If this method returns `false` then the container must not have been modified (i.e. it must
165    /// hash to the same value, and compare equal to a copy of itself before the call).
166    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool;
167
168    /// Iterate over the contents of the container.
169    ///
170    /// Note that containers can be more structured than just a sequence of values. This iterator
171    /// is used to populate an index that in turn is used to speed up rebuilds. If a value in the
172    /// container is eligible for a rebuild and it is not mentioned by this iterator, the outer
173    /// [`Containers`] registry may skip rebuilding this container.
174    fn iter(&self) -> impl Iterator<Item = Value> + '_;
175}
176
177pub trait DynamicContainerEnv: Any + dyn_clone::DynClone + Send + Sync {
178    fn as_any(&self) -> &dyn Any;
179    fn apply_rebuild(
180        &mut self,
181        table: &WrappedTable,
182        rebuilder: &dyn Rebuilder,
183        subset: Option<SubsetRef>,
184        exec_state: &mut ExecutionState,
185    ) -> bool;
186}
187
188// Implements `Clone` for `Box<dyn DynamicContainerEnv>`.
189dyn_clone::clone_trait_object!(DynamicContainerEnv);
190
191fn hash_container(container: &impl ContainerValue) -> u64 {
192    let mut hasher = FxHasher::default();
193    container.hash(&mut hasher);
194    hasher.finish()
195}
196
197#[derive(Clone)]
198struct ContainerEnv<C: Eq + Hash> {
199    merge_fn: Box<dyn MergeFn>,
200    counter: CounterId,
201    to_id: DashMap<C, Value>,
202    to_container: DashMap<Value, (usize /* hash code */, usize /* map */)>,
203    /// Map from a Value to the set of ids of containers that contain that value.
204    val_index: DashMap<Value, IndexSet<Value>>,
205}
206
207impl<C: ContainerValue> DynamicContainerEnv for ContainerEnv<C> {
208    fn as_any(&self) -> &dyn Any {
209        self
210    }
211
212    fn apply_rebuild(
213        &mut self,
214        table: &WrappedTable,
215        rebuilder: &dyn Rebuilder,
216        subset: Option<SubsetRef>,
217        exec_state: &mut ExecutionState,
218    ) -> bool {
219        if let Some(subset) = subset {
220            if incremental_rebuild(
221                subset.size(),
222                self.to_id.len(),
223                parallelize_intra_container_op(self.to_id.len()),
224            ) {
225                return self.apply_rebuild_incremental(
226                    table,
227                    rebuilder,
228                    exec_state,
229                    subset,
230                    rebuilder.hint_col().unwrap(),
231                );
232            }
233        }
234        self.apply_rebuild_nonincremental(rebuilder, exec_state)
235    }
236}
237
238impl<C: ContainerValue> ContainerEnv<C> {
239    pub fn new(merge_fn: Box<dyn MergeFn>, counter: CounterId) -> Self {
240        Self {
241            merge_fn,
242            counter,
243            to_id: DashMap::default(),
244            to_container: DashMap::default(),
245            val_index: DashMap::default(),
246        }
247    }
248
249    fn get_or_insert(&self, container: &C, exec_state: &mut ExecutionState) -> Value {
250        if let Some(value) = self.to_id.get(container) {
251            return *value;
252        }
253
254        // Time to insert a new mapping. First, insert into `to_container`: the moment that we
255        // insert a new value into `to_id`, someone else can return it from another call to
256        // `get_or_insert` and then feed that value to `get_container`.
257
258        let value = Value::from_usize(exec_state.inc_counter(self.counter));
259        let target_map = self.to_id.determine_map(container);
260        // This assertion is here because in parallel rebuilding we use `to_container` to
261        // compute the intended shard for to_id, because we have a mutable borrow of
262        // `to_container` that means we cannot call `determine_map` on `to_id`.
263        debug_assert_eq!(
264            target_map,
265            self.to_container
266                .determine_shard(hash_container(container) as usize)
267        );
268        self.to_container
269            .insert(value, (hash_container(container) as usize, target_map));
270
271        // Now insert into `to_id`, handling the case where a different thread is doing the same
272        // thing.
273        match self.to_id.entry(container.clone()) {
274            dashmap::Entry::Vacant(vac) => {
275                // Common case: insert the mapping in to_id and update the index.
276                vac.insert(value);
277                for val in container.iter() {
278                    self.val_index.entry(val).or_default().insert(value);
279                }
280                value
281            }
282            dashmap::Entry::Occupied(occ) => {
283                // Someone inserted `container` into the mapping since we looked it up. Remove the
284                // mapping that we inserted into `to_container` (we won't use it), and instead
285                // return the "winning" value.
286                let res = *occ.get();
287                std::mem::drop(occ); // drop the lock.
288                self.to_container.remove(&value);
289                res
290            }
291        }
292    }
293
294    fn insert_owned(&self, container: C, value: Value, exec_state: &mut ExecutionState) {
295        let hc = hash_container(&container);
296        let target_map = self.to_id.determine_map(&container);
297        match self.to_id.entry(container) {
298            dashmap::Entry::Occupied(mut occ) => {
299                let result = (self.merge_fn)(exec_state, *occ.get(), value);
300                let old_val = *occ.get();
301                if result != old_val {
302                    self.to_container.remove(&old_val);
303                    self.to_container.insert(result, (hc as usize, target_map));
304                    *occ.get_mut() = result;
305                    for val in occ.key().iter() {
306                        let mut index = self.val_index.entry(val).or_default();
307                        index.swap_remove(&old_val);
308                        index.insert(result);
309                    }
310                }
311            }
312            dashmap::Entry::Vacant(vacant_entry) => {
313                self.to_container.insert(value, (hc as usize, target_map));
314                for val in vacant_entry.key().iter() {
315                    self.val_index.entry(val).or_default().insert(value);
316                }
317                vacant_entry.insert(value);
318            }
319        }
320    }
321    fn apply_rebuild_incremental(
322        &mut self,
323        table: &WrappedTable,
324        rebuilder: &dyn Rebuilder,
325        exec_state: &mut ExecutionState,
326        to_scan: SubsetRef,
327        search_col: ColumnId,
328    ) -> bool {
329        // NB: there is no parallel implementation as of now.
330        //
331        // Implementing one should be straightforward, but we should wait for a real benchmark that
332        // requires it. It's possible that incremental rebuilding will only be profitable when the
333        // total number of ids to rebuild is small, in which case the overhead of parallelism may
334        // not be worth it in the first place.
335        let mut changed = false;
336        let mut buf = TaggedRowBuffer::new(1);
337        table.scan_project(
338            to_scan,
339            &[search_col],
340            Offset::new(0),
341            usize::MAX,
342            &[],
343            &mut buf,
344        );
345        // For each value in the buffer, rebuild all containers that mention it.
346        let mut to_rebuild = IndexSet::<Value>::default();
347        for (_, row) in buf.iter() {
348            to_rebuild.insert(row[0]);
349            let Some(ids) = self.val_index.get(&row[0]) else {
350                continue;
351            };
352            to_rebuild.extend(&*ids);
353        }
354        for id in to_rebuild {
355            let Some((hc, target_map)) = self.to_container.get(&id).map(|x| *x) else {
356                continue;
357            };
358            let shard_mut = self.to_id.shards_mut()[target_map].get_mut();
359            let Some((mut container, _)) =
360                shard_mut.remove_entry(hc as u64, |(_, v)| *v.get() == id)
361            else {
362                continue;
363            };
364            changed |= container.rebuild_contents(rebuilder);
365            self.insert_owned(container, id, exec_state);
366        }
367        changed
368    }
369
370    fn apply_rebuild_nonincremental(
371        &mut self,
372        rebuilder: &dyn Rebuilder,
373        exec_state: &mut ExecutionState,
374    ) -> bool {
375        if parallelize_inter_container_op(self.to_id.len()) {
376            return self.apply_rebuild_nonincremental_parallel(rebuilder, exec_state);
377        }
378        let mut changed = false;
379        let mut to_reinsert = Vec::new();
380        let shards = self.to_id.shards_mut();
381        for shard in shards.iter_mut() {
382            let shard = shard.get_mut();
383            // SAFETY: the iterator does not outlive `shard`.
384            for bucket in unsafe { shard.iter() } {
385                // SAFETY: the bucket is valid; we just got it from the iterator.
386                let (container, val) = unsafe { bucket.as_mut() };
387                let old_val = *val.get();
388                let new_val = rebuilder.rebuild_val(old_val);
389                let container_changed = container.rebuild_contents(rebuilder);
390                if !container_changed && new_val == old_val {
391                    // Nothing changed about this entry. Leave it in place.
392                    continue;
393                }
394                changed = true;
395                if container_changed {
396                    // The container changed. Remove both map entries then reinsert.
397                    // SAFETY: This is a valid bucket. Furthermore, iterators remain valid if
398                    // buckets they have already yielded have been removed.
399                    let ((container, _), _) = unsafe { shard.remove(bucket) };
400                    self.to_container.remove(&old_val);
401                    to_reinsert.push((container, new_val));
402                } else {
403                    // Just the value changed. Leave the container in place.
404                    *val.get_mut() = new_val;
405                    let prev = self.to_container.remove(&old_val).unwrap().1;
406                    self.to_container.insert(new_val, prev);
407                }
408            }
409        }
410        for (container, val) in to_reinsert {
411            self.insert_owned(container, val, exec_state);
412        }
413        changed
414    }
415
416    fn apply_rebuild_nonincremental_parallel(
417        &mut self,
418        rebuilder: &dyn Rebuilder,
419        exec_state: &mut ExecutionState,
420    ) -> bool {
421        // This is very similar to the serial variant. The main difference is that
422        // `to_reinsert` isn't a flat vector. It's instead a vector of queues - one per
423        // destination map shard. This lets us do a bulk insertion in parallel without having
424        // to grab a lock per container.
425        let mut to_reinsert = IdVec::<usize /* to_id shard */, SegQueue<(C, Value)>>::default();
426        to_reinsert.resize_with(self.to_id.shards().len(), Default::default);
427
428        let shards = self.to_id.shards_mut();
429        let changed = shards
430            .par_iter_mut()
431            .map(|shard| {
432                let mut changed = false;
433                let shard = shard.get_mut();
434                // SAFETY: the iterator does not outlive `shard`.
435                for bucket in unsafe { shard.iter() } {
436                    // SAFETY: the bucket is valid; we just got it from the iterator.
437                    let (container, val) = unsafe { bucket.as_mut() };
438                    let old_val = *val.get();
439                    let new_val = rebuilder.rebuild_val(old_val);
440                    let container_changed = container.rebuild_contents(rebuilder);
441                    if !container_changed && new_val == old_val {
442                        // Nothing changed about this entry. Leave it in place.
443                        continue;
444                    }
445                    changed = true;
446                    if container_changed {
447                        // The container changed. Remove both map entries then reinsert.
448                        // SAFETY: This is a valid bucket. Furthermore, iterators remain valid if
449                        // buckets they have already yielded have been removed.
450                        let ((container, _), _) = unsafe { shard.remove(bucket) };
451                        self.to_container.remove(&old_val);
452                        // Spooky: we're using `to_container` to determine the shard for
453                        // `to_id`. We are assuming that the # shards determination is
454                        // deterministic here. There is a debug assertion in `get_or_insert`
455                        // that attempts to verify this.
456                        let shard = self
457                            .to_container
458                            .determine_shard(hash_container(&container) as usize);
459                        to_reinsert[shard].push((container, new_val));
460                    } else {
461                        // Just the value changed. Leave the container in place.
462                        *val.get_mut() = new_val;
463                        let prev = self.to_container.remove(&old_val).unwrap().1;
464                        self.to_container.insert(new_val, prev);
465                    }
466                }
467                changed
468            })
469            .max()
470            .unwrap_or(false);
471
472        shards
473            .iter_mut()
474            .enumerate()
475            .map(|(i, shard)| (i, shard, exec_state.clone()))
476            .par_bridge()
477            .for_each(|(shard_id, shard, mut exec_state)| {
478                // This bit is a real slog. Once Dashmap updates from RawTable to HashTable for
479                // the underlying shard, this will get a little better.
480                //
481                // NB: We are probably leaving some paralellism on the floor with these calls
482                // to `to_container` and `val_index`.
483                let shard = shard.get_mut();
484                let queue = &to_reinsert[shard_id];
485                while let Some((container, val)) = queue.pop() {
486                    let hc = hash_container(&container);
487                    let target_map = self.to_container.determine_shard(hc as usize);
488                    match shard.find_or_find_insert_slot(
489                        hc,
490                        |(c, _)| c == &container,
491                        |(c, _)| hash_container(c),
492                    ) {
493                        Ok(bucket) => {
494                            // SAFETY: the bucket is valid; we just got it from the shard and
495                            // we have not done any operations that can invalidate the bucket.
496                            let (container, val_slot) = unsafe { bucket.as_mut() };
497                            let old_val = *val_slot.get();
498                            let result = (self.merge_fn)(&mut exec_state, old_val, val);
499                            if result != old_val {
500                                self.to_container.remove(&old_val);
501                                self.to_container.insert(result, (hc as usize, target_map));
502                                *val_slot.get_mut() = result;
503                                for val in container.iter() {
504                                    let mut index = self.val_index.entry(val).or_default();
505                                    index.swap_remove(&old_val);
506                                    index.insert(result);
507                                }
508                            }
509                        }
510                        Err(slot) => {
511                            self.to_container.insert(val, (hc as usize, target_map));
512                            for v in container.iter() {
513                                self.val_index.entry(v).or_default().insert(val);
514                            }
515                            // SAFETY: We just got this slot from `find_or_find_insert_slot`
516                            // and we have not mutated the map at all since then.
517                            unsafe {
518                                shard.insert_in_slot(hc, slot, (container, SharedValue::new(val)));
519                            }
520                        }
521                    }
522                }
523            });
524        changed
525    }
526
527    fn get_container(&self, value: Value) -> Option<impl Deref<Target = C> + '_> {
528        let (hc, target_map) = *self.to_container.get(&value)?;
529        let shard = &self.to_id.shards()[target_map];
530        let read_guard = shard.read();
531        let val_ptr: *const (C, _) = shard
532            .read()
533            .find(hc as u64, |(_, v)| *v.get() == value)?
534            .as_ptr();
535        struct ValueDeref<'a, T, Guard> {
536            _guard: Guard,
537            data: &'a T,
538        }
539
540        impl<T, Guard> Deref for ValueDeref<'_, T, Guard> {
541            type Target = T;
542
543            fn deref(&self) -> &T {
544                self.data
545            }
546        }
547
548        Some(ValueDeref {
549            _guard: read_guard,
550            // SAFETY: the value will remain valid for as long as `read_guard` is in scope.
551            data: unsafe {
552                let unwrapped: &(C, _) = &*val_ptr;
553                &unwrapped.0
554            },
555        })
556    }
557}
558
559fn incremental_rebuild(_uf_size: usize, _table_size: usize, _parallel: bool) -> bool {
560    #[cfg(debug_assertions)]
561    {
562        use rand::Rng;
563        rand::rng().random_bool(0.5)
564    }
565    #[cfg(not(debug_assertions))]
566    {
567        if _parallel {
568            _table_size > 1000 && _uf_size * 512 <= _table_size
569        } else {
570            _table_size > 1000 && _uf_size * 8 <= _table_size
571        }
572    }
573}