selene-db-graph 1.2.0

In-memory property-graph storage core (ArcSwap + imbl CoW, label/typed indexes, write funnel) for selene-db.
Documentation
use rayon::prelude::*;
use selene_core::{CancellationChecker, DbString};

use crate::graph::SeleneGraph;
use crate::parallel_scan::should_parallelize_scan;

use super::{VectorCandidateSet, VectorNeighborDirection, VectorSearchError};

#[cfg(not(test))]
const VECTOR_EXPANDED_BATCH_PARALLEL_MIN_SETS: usize = 16;
#[cfg(test)]
const VECTOR_EXPANDED_BATCH_PARALLEL_MIN_SETS: usize = 2;

#[cfg(not(test))]
const VECTOR_EXPANDED_BATCH_PARALLEL_MIN_CANDIDATES: usize = 8192;
#[cfg(test)]
const VECTOR_EXPANDED_BATCH_PARALLEL_MIN_CANDIDATES: usize = 8;

const VECTOR_EXPANDED_BATCH_PARALLEL_ESTIMATE_SETS: usize = 4;
const VECTOR_EXPANDED_BATCH_GROUP_MAX_SETS: usize = 128;

impl SeleneGraph {
    pub(super) fn expand_vector_candidate_sets_batch(
        &self,
        root_sets: &[VectorCandidateSet],
        edge_label: &DbString,
        direction: VectorNeighborDirection,
        k: usize,
        checker: CancellationChecker<'_>,
    ) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
        if let Some(first_roots) = root_sets.first()
            && root_sets
                .iter()
                .skip(1)
                .all(|roots| candidate_sets_match(first_roots, roots))
        {
            checker.check()?;
            let expanded = self.expand_vector_candidate_set_checked(
                first_roots,
                edge_label,
                direction,
                checker,
            )?;
            return Ok(vec![expanded; root_sets.len()]);
        }

        let groups = repeated_root_set_groups(root_sets);
        if !groups.is_empty() {
            return self.expand_vector_candidate_sets_batch_grouped(
                root_sets, edge_label, direction, k, checker, groups,
            );
        }

        if self.should_parallelize_expanded_candidate_batch(root_sets, edge_label, direction, k) {
            return root_sets
                .par_iter()
                .map(|roots| {
                    checker.check()?;
                    self.expand_vector_candidate_set_checked(roots, edge_label, direction, checker)
                })
                .collect();
        }

        let mut expanded_sets = Vec::with_capacity(root_sets.len());
        for roots in root_sets {
            checker.check()?;
            expanded_sets.push(
                self.expand_vector_candidate_set_checked(roots, edge_label, direction, checker)?,
            );
        }
        Ok(expanded_sets)
    }

    fn expand_vector_candidate_sets_batch_grouped(
        &self,
        root_sets: &[VectorCandidateSet],
        edge_label: &DbString,
        direction: VectorNeighborDirection,
        k: usize,
        checker: CancellationChecker<'_>,
        groups: Vec<Vec<usize>>,
    ) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
        let mut expanded_sets = vec![None; root_sets.len()];
        let mut grouped = vec![false; root_sets.len()];
        for group in groups {
            checker.check()?;
            let expanded = self.expand_vector_candidate_set_checked(
                &root_sets[group[0]],
                edge_label,
                direction,
                checker,
            )?;
            for index in group {
                grouped[index] = true;
                expanded_sets[index] = Some(expanded.clone());
            }
        }

        let ungrouped = grouped
            .iter()
            .enumerate()
            .filter_map(|(index, is_grouped)| (!is_grouped).then_some(index))
            .collect::<Vec<_>>();
        let expanded_ungrouped = if self
            .should_parallelize_expanded_candidate_batch(root_sets, edge_label, direction, k)
        {
            ungrouped
                .par_iter()
                .map(|&index| {
                    checker.check()?;
                    self.expand_vector_candidate_set_checked(
                        &root_sets[index],
                        edge_label,
                        direction,
                        checker,
                    )
                    .map(|expanded| (index, expanded))
                })
                .collect::<Result<Vec<_>, _>>()?
        } else {
            let mut expanded = Vec::with_capacity(ungrouped.len());
            for index in ungrouped {
                checker.check()?;
                expanded.push((
                    index,
                    self.expand_vector_candidate_set_checked(
                        &root_sets[index],
                        edge_label,
                        direction,
                        checker,
                    )?,
                ));
            }
            expanded
        };
        for (index, expanded) in expanded_ungrouped {
            expanded_sets[index] = Some(expanded);
        }

        Ok(expanded_sets
            .into_iter()
            .map(|expanded| expanded.expect("batched expansion fills every root slot"))
            .collect())
    }

    fn should_parallelize_expanded_candidate_batch(
        &self,
        root_sets: &[VectorCandidateSet],
        edge_label: &DbString,
        direction: VectorNeighborDirection,
        k: usize,
    ) -> bool {
        if !should_parallelize_scan(
            root_sets.len() as u64,
            k,
            VECTOR_EXPANDED_BATCH_PARALLEL_MIN_SETS as u64,
        ) {
            return false;
        }
        let sample_count = root_sets
            .len()
            .min(VECTOR_EXPANDED_BATCH_PARALLEL_ESTIMATE_SETS);
        let sampled_candidates = root_sets
            .iter()
            .take(sample_count)
            .map(|roots| self.expanded_candidate_work_estimate(roots, edge_label, direction))
            .sum::<usize>();
        let estimated_candidates = sampled_candidates
            .saturating_mul(root_sets.len())
            .div_ceil(sample_count);

        estimated_candidates >= VECTOR_EXPANDED_BATCH_PARALLEL_MIN_CANDIDATES
    }

    fn expanded_candidate_work_estimate(
        &self,
        roots: &VectorCandidateSet,
        edge_label: &DbString,
        direction: VectorNeighborDirection,
    ) -> usize {
        let mut candidate_count = roots.len();
        for root in roots.as_nodes().iter().copied() {
            if matches!(
                direction,
                VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
            ) && let Some(entry) = self.outgoing_edges(root)
            {
                candidate_count += entry.iter_label(edge_label).count();
            }
            if matches!(
                direction,
                VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
            ) && let Some(entry) = self.incoming_edges(root)
            {
                candidate_count += entry.iter_label(edge_label).count();
            }
        }
        candidate_count
    }
}

fn candidate_sets_match(lhs: &VectorCandidateSet, rhs: &VectorCandidateSet) -> bool {
    let lhs = lhs.as_nodes();
    let rhs = rhs.as_nodes();
    lhs.len() == rhs.len() && lhs.first() == rhs.first() && lhs.last() == rhs.last() && lhs == rhs
}

fn repeated_root_set_groups(root_sets: &[VectorCandidateSet]) -> Vec<Vec<usize>> {
    if root_sets.len() <= 2 || root_sets.len() > VECTOR_EXPANDED_BATCH_GROUP_MAX_SETS {
        return Vec::new();
    }
    let mut assigned = vec![false; root_sets.len()];
    let mut groups = Vec::new();
    for index in 0..root_sets.len() {
        if assigned[index] {
            continue;
        }
        let mut group = Vec::new();
        for next in index + 1..root_sets.len() {
            if !assigned[next] && candidate_sets_match(&root_sets[index], &root_sets[next]) {
                if group.is_empty() {
                    group.push(index);
                    assigned[index] = true;
                }
                group.push(next);
                assigned[next] = true;
            }
        }
        if group.len() > 1 {
            groups.push(group);
        }
    }
    groups
}