selene-db-graph 1.3.0

In-memory property-graph storage core (ArcSwap + imbl CoW, label/typed indexes, write funnel) for selene-db.
Documentation
use super::*;

#[cfg(not(test))]
const TURBO_QUANT_BATCH_PARALLEL_CHUNK_ENTRIES: usize = 256;
#[cfg(test)]
const TURBO_QUANT_BATCH_PARALLEL_CHUNK_ENTRIES: usize = 4;

struct PreparedTurboQuantQuery {
    byte_lut: Vec<f64>,
    query_bias: f64,
}

impl TurboQuantVectorIndex {
    pub(crate) fn candidates_batch(
        &self,
        queries: &[VectorValue],
        k: usize,
        search_width: usize,
    ) -> CoreResult<Vec<Vec<TurboQuantVectorHit>>> {
        if queries.is_empty() {
            return Ok(Vec::new());
        }
        if k == 0 || self.live_entries == 0 {
            return Ok(vec![Vec::new(); queries.len()]);
        }
        if !self.should_fuse_batch_scan(queries.len()) {
            return queries
                .iter()
                .map(|query| self.candidates(query, k, search_width))
                .collect();
        }
        let candidate_limit = search_width.max(k).min(self.live_entries);
        let candidates = if self.should_scan_by_slot_order() {
            if let Some(prepared) = self.prepare_fast_scan_queries(queries) {
                self.slot_order_candidates_fast_scan_batch(&prepared, candidate_limit)
            } else {
                let prepared = queries
                    .iter()
                    .map(|query| self.prepare_query(query))
                    .collect::<Vec<_>>();
                self.slot_order_candidates_batch(&prepared, candidate_limit)
            }
        } else {
            let prepared = queries
                .iter()
                .map(|query| self.prepare_query(query))
                .collect::<Vec<_>>();
            self.live_map_candidates_batch(&prepared, candidate_limit)
        };

        Ok(candidates
            .into_iter()
            .map(|query_candidates| {
                query_candidates
                    .into_hits()
                    .into_iter()
                    .map(|hit| TurboQuantVectorHit {
                        row: hit.key,
                        distance: hit.distance,
                    })
                    .collect()
            })
            .collect())
    }

    pub(crate) fn should_fuse_batch_scan(&self, query_count: usize) -> bool {
        query_count > 1 && self.supports_fast_scan_accumulator()
    }

    fn prepare_query(&self, query: &VectorValue) -> PreparedTurboQuantQuery {
        let rotated_query = rotated_unit_vector(query, self.dimension);
        PreparedTurboQuantQuery {
            query_bias: query_bias(&rotated_query, &self.shift),
            byte_lut: self.byte_lut(&rotated_query),
        }
    }

    fn slot_order_candidates_batch(
        &self,
        queries: &[PreparedTurboQuantQuery],
        candidate_limit: usize,
    ) -> Vec<TurboQuantCandidateTopK> {
        if self.should_parallelize_slot_scan(candidate_limit) {
            return self.slot_order_candidates_batch_parallel(queries, candidate_limit);
        }
        self.slot_order_candidates_batch_blocks(
            0,
            self.codes.block_count(),
            queries,
            candidate_limit,
        )
    }

    fn slot_order_candidates_batch_parallel(
        &self,
        queries: &[PreparedTurboQuantQuery],
        candidate_limit: usize,
    ) -> Vec<TurboQuantCandidateTopK> {
        let chunk_blocks =
            TURBO_QUANT_BATCH_PARALLEL_CHUNK_ENTRIES.div_ceil(TURBO_QUANT_BLOCK_ROWS);
        (0..self.codes.block_count())
            .into_par_iter()
            .chunks(chunk_blocks.max(1))
            .map(|blocks| {
                let start = blocks.first().copied().unwrap_or_default();
                let end = blocks.last().copied().map_or(start, |block| block + 1);
                self.slot_order_candidates_batch_blocks(start, end, queries, candidate_limit)
            })
            .reduce(
                || candidate_top_k_batch(queries.len(), candidate_limit),
                merge_candidate_top_k_batch,
            )
    }

    fn slot_order_candidates_batch_blocks(
        &self,
        start_block: usize,
        end_block: usize,
        queries: &[PreparedTurboQuantQuery],
        candidate_limit: usize,
    ) -> Vec<TurboQuantCandidateTopK> {
        let mut candidates = candidate_top_k_batch(queries.len(), candidate_limit);
        let mut dots = vec![[0.0; TURBO_QUANT_BLOCK_ROWS]; queries.len()];
        for block in start_block..end_block {
            let block_len = self.codes.block_len(block);
            for (query_dots, query) in dots.iter_mut().zip(queries) {
                query_dots[..block_len].fill(query.query_bias);
            }
            for byte in 0..self.bytes_per_row {
                let lut_base = byte * 256;
                let codes = self.codes.block_byte(block, byte);
                for (query_dots, query) in dots.iter_mut().zip(queries) {
                    for lane in 0..block_len {
                        query_dots[lane] += query.byte_lut[lut_base + usize::from(codes[lane])];
                    }
                }
            }
            let base_slot = block * TURBO_QUANT_BLOCK_ROWS;
            for lane in 0..block_len {
                let slot = base_slot + lane;
                let Some(row) = self.live_row_at_slot(slot) else {
                    continue;
                };
                push_batch_block_distances(&mut candidates, &dots, lane, slot, row, self);
            }
        }
        candidates
    }

    fn live_map_candidates_batch(
        &self,
        queries: &[PreparedTurboQuantQuery],
        candidate_limit: usize,
    ) -> Vec<TurboQuantCandidateTopK> {
        let mut candidates = candidate_top_k_batch(queries.len(), candidate_limit);
        let mut distances = vec![0.0; queries.len()];
        for (&row, &slot_key) in &self.row_to_entry {
            let slot = super::slot_index(slot_key);
            let Some(stored_row) = self.rows.get(slot).copied() else {
                continue;
            };
            if stored_row != row {
                continue;
            }
            self.approx_distances_lut_batch(slot, queries, &mut distances);
            push_batch_distances(&mut candidates, &distances, row);
        }
        candidates
    }

    fn approx_distances_lut_batch(
        &self,
        slot: usize,
        queries: &[PreparedTurboQuantQuery],
        distances: &mut [f64],
    ) {
        debug_assert_eq!(queries.len(), distances.len());
        for (distance, query) in distances.iter_mut().zip(queries) {
            *distance = query.query_bias;
        }

        for byte in 0..self.bytes_per_row {
            let lut_offset = byte * 256
                + usize::from(
                    self.codes
                        .row_byte(slot, byte)
                        .expect("TurboQuant slot byte is in bounds"),
                );
            for (distance, query) in distances.iter_mut().zip(queries) {
                *distance += query.byte_lut[lut_offset];
            }
        }

        let scale = f64::from(self.row_scales[slot]);
        for distance in distances {
            *distance = -(*distance * scale);
        }
    }
}

fn candidate_top_k_batch(
    query_count: usize,
    candidate_limit: usize,
) -> Vec<TurboQuantCandidateTopK> {
    (0..query_count)
        .map(|_| TurboQuantCandidateTopK::new(candidate_limit))
        .collect()
}

fn merge_candidate_top_k_batch(
    mut lhs: Vec<TurboQuantCandidateTopK>,
    rhs: Vec<TurboQuantCandidateTopK>,
) -> Vec<TurboQuantCandidateTopK> {
    for (lhs_query, rhs_query) in lhs.iter_mut().zip(rhs) {
        for hit in rhs_query.into_hits() {
            lhs_query.push_distance(hit.key, hit.distance);
        }
    }
    lhs
}

fn push_batch_distances(candidates: &mut [TurboQuantCandidateTopK], distances: &[f64], row: u32) {
    for (candidate, distance) in candidates.iter_mut().zip(distances.iter().copied()) {
        candidate.push_distance(row, distance);
    }
}

fn push_batch_block_distances(
    candidates: &mut [TurboQuantCandidateTopK],
    dots: &[[f64; TURBO_QUANT_BLOCK_ROWS]],
    lane: usize,
    slot: usize,
    row: u32,
    index: &TurboQuantVectorIndex,
) {
    let scale = f64::from(index.row_scales[slot]);
    for (candidate, query_dots) in candidates.iter_mut().zip(dots) {
        candidate.push_distance(row, -(query_dots[lane] * scale));
    }
}