swh-graph-stdlib 13.0.0

Library of algorithms and data structures for swh-graph
Documentation
// Copyright (C) 2026  The Software Heritage developers
// See the AUTHORS file at the top-level directory of this distribution
// License: GNU General Public License version 3, or any later version
// See top-level LICENSE file for more information

use std::borrow::{Borrow, BorrowMut};
use std::collections::HashSet;

use rayon::iter::ParallelIterator;
use swh_graph_stdlib::labeling::{
    BoxLabel, Labels, SliceLabel, StridableLabel, StriddenLabels, StriddenLabelsConfig,
};

macro_rules! set {
    ($($v:expr),* $(,)?) => {{
        vec![$($v,)*].into_iter().collect::<HashSet<_>>()
    }};
}

/********************************************************
 * SliceLabels
 ********************************************************/

fn make_label(words: &[u64]) -> BoxLabel<u64> {
    BoxLabel(words.to_vec().into_boxed_slice())
}

fn new_store(num_nodes: usize, num_words: usize) -> StriddenLabels<SliceLabel<u64>> {
    StriddenLabels::new(num_nodes, StriddenLabelsConfig { num_words })
}

#[test]
fn test_insert_and_get() {
    let mut store = new_store(5, 2);
    assert!(store.is_empty());

    assert!(store.insert(0, make_label(&[10, 20])).is_none());
    assert!(!store.is_empty());

    assert!(store.insert(3, make_label(&[30, 40])).is_none());

    assert_eq!(store.get(0).unwrap().0, [10, 20]);
    assert_eq!(store.get(3).unwrap().0, [30, 40]);
    assert_eq!(store.get(1), None);
    assert_eq!(store.get(2), None);
    assert_eq!(store.get(4), None);

    assert!(store.contains_key(0));
    assert!(!store.contains_key(1));
    assert!(!store.contains_key(2));
    assert!(store.contains_key(3));
    assert!(!store.contains_key(4));
}

#[test]
fn test_insert_overwrites() {
    let mut store = new_store(3, 2);

    assert!(store.insert(1, make_label(&[1, 2])).is_none());

    let prev = store.insert(1, make_label(&[3, 4])).unwrap();
    assert_eq!(prev.0.as_ref(), [1, 2]);
    assert_eq!(store.get(1).unwrap().0, [3, 4]);

    let prev = store.insert(1, make_label(&[5, 6])).unwrap();
    assert_eq!(prev.0.as_ref(), [3, 4]);
    assert_eq!(store.get(1).unwrap().0, [5, 6]);

    assert!(!store.is_empty());
}

#[test]
fn test_remove() {
    let mut store = new_store(4, 2);

    store.insert(1, make_label(&[10, 20]));
    store.insert(2, make_label(&[30, 40]));

    assert!(store.remove(0).is_none());

    let removed = store.remove(1).unwrap();
    assert_eq!(removed.0.as_ref(), [10, 20]);
    assert_eq!(store.get(1), None);
    assert!(!store.contains_key(1));
    assert!(store.remove(1).is_none());

    assert_eq!(store.get(2).unwrap().0, [30, 40]);
    assert!(!store.is_empty());

    let removed = store.remove(2).unwrap();
    assert_eq!(removed.0.as_ref(), [30, 40]);
    assert!(store.is_empty());
}

#[test]
fn test_iterators() {
    let mut store = new_store(4, 2);
    store.insert(1, make_label(&[1, 2]));
    store.insert(3, make_label(&[5, 6]));

    let iter_result: Vec<Option<&[u64]>> = store.iter().map(|opt| opt.map(|l| &l.0)).collect();
    assert_eq!(
        iter_result,
        vec![None, Some(&[1u64, 2][..]), None, Some(&[5u64, 6][..])]
    );

    let labeled: HashSet<(usize, &[u64])> = store.iter_labeled().map(|(n, l)| (n, &l.0)).collect();
    assert_eq!(labeled, set![(1, &[1u64, 2][..]), (3, &[5u64, 6][..])]);

    let par_labeled: HashSet<(usize, &[u64])> =
        store.par_iter_labeled().map(|(n, l)| (n, &l.0)).collect();
    assert_eq!(par_labeled, labeled);
}

#[test]
fn test_stride_3_words() {
    let mut store = new_store(3, 3);

    store.insert(0, make_label(&[100, 200, 300]));
    store.insert(2, make_label(&[400, 500, 600]));

    assert_eq!(store.get(0).unwrap().0, [100, 200, 300]);
    assert_eq!(store.get(1), None);
    assert_eq!(store.get(2).unwrap().0, [400, 500, 600]);

    let labeled: HashSet<(usize, &[u64])> = store.iter_labeled().map(|(n, l)| (n, &l.0)).collect();
    assert_eq!(
        labeled,
        set![(0, &[100u64, 200, 300][..]), (2, &[400u64, 500, 600][..]),]
    );
}

/********************************************************
 * custom label
 ********************************************************/

#[derive(Debug, PartialEq, Eq, Hash)]
#[repr(C)]
struct Timestamps {
    count: u32,
    epoch: u32,
    offsets: [i32],
}

impl Timestamps {
    fn from_stride(stride: &[u32]) -> &Self {
        // SAFETY: Timestamps is #[repr(C)] with u32 at offsets 0 and 4, and [i32]
        // (same size/align as [u32]) starting at offset 8. The fat pointer metadata
        // is the element count of the `offsets` field.
        unsafe {
            &*(std::ptr::slice_from_raw_parts(stride.as_ptr(), stride.len() - 2)
                as *const Timestamps)
        }
    }

    fn from_stride_mut(stride: &mut [u32]) -> &mut Self {
        // SAFETY: same as from_stride, with exclusive access guaranteed by &mut
        unsafe {
            &mut *(std::ptr::slice_from_raw_parts_mut(stride.as_mut_ptr(), stride.len() - 2)
                as *mut Timestamps)
        }
    }
}

impl ToOwned for Timestamps {
    type Owned = OwnedTimestamps;

    fn to_owned(&self) -> OwnedTimestamps {
        OwnedTimestamps::new(self.count, self.epoch, &self.offsets)
    }
}

#[derive(Debug, PartialEq, Eq)]
struct OwnedTimestamps(Box<Timestamps>);

impl OwnedTimestamps {
    fn new(count: u32, epoch: u32, offsets: &[i32]) -> Self {
        let mut words = Vec::with_capacity(2 + offsets.len());
        words.push(count);
        words.push(epoch);
        words.extend(offsets.iter().map(|&x| x as u32));
        let leaked = Box::leak(words.into_boxed_slice());
        let ts = Timestamps::from_stride_mut(leaked);
        // SAFETY: ts was obtained from Box::leak, so we can reconstruct the Box.
        // Layout matches because Timestamps is #[repr(C)] with two u32 fields
        // followed by [i32] (same size/align as [u32]).
        OwnedTimestamps(unsafe { Box::from_raw(ts) })
    }
}

impl Borrow<Timestamps> for OwnedTimestamps {
    fn borrow(&self) -> &Timestamps {
        &self.0
    }
}

impl BorrowMut<Timestamps> for OwnedTimestamps {
    fn borrow_mut(&mut self) -> &mut Timestamps {
        &mut self.0
    }
}

impl StridableLabel for Timestamps {
    type Word = u32;

    fn from_stride(stride: &[u32]) -> &Self {
        Timestamps::from_stride(stride)
    }

    fn swap_with_stride(&mut self, stride: &mut [u32]) {
        std::mem::swap(&mut self.count, &mut stride[0]);
        std::mem::swap(&mut self.epoch, &mut stride[1]);
        let stride_offsets: &mut [i32] = bytemuck::cast_slice_mut(&mut stride[2..]);
        self.offsets.swap_with_slice(stride_offsets);
    }
}

fn new_custom_store(num_nodes: usize, num_words: usize) -> StriddenLabels<Timestamps> {
    StriddenLabels::new(num_nodes, StriddenLabelsConfig { num_words })
}

#[test]
fn test_custom_insert_and_get() {
    let mut store = new_custom_store(5, 4);
    assert!(store.is_empty());

    assert!(
        store
            .insert(0, OwnedTimestamps::new(2, 1000, &[100, 200]))
            .is_none()
    );
    assert!(!store.is_empty());

    assert!(
        store
            .insert(3, OwnedTimestamps::new(2, 2000, &[300, 400]))
            .is_none()
    );

    let ts0 = store.get(0).unwrap();
    assert_eq!(
        (ts0.count, ts0.epoch, &ts0.offsets),
        (2, 1000, &[100, 200][..])
    );

    let ts3 = store.get(3).unwrap();
    assert_eq!(
        (ts3.count, ts3.epoch, &ts3.offsets),
        (2, 2000, &[300, 400][..])
    );

    assert_eq!(store.get(1), None);
    assert_eq!(store.get(2), None);
    assert_eq!(store.get(4), None);

    assert!(store.contains_key(0));
    assert!(!store.contains_key(1));
    assert!(!store.contains_key(2));
    assert!(store.contains_key(3));
    assert!(!store.contains_key(4));
}

#[test]
fn test_custom_insert_overwrites() {
    let mut store = new_custom_store(3, 4);

    assert!(
        store
            .insert(1, OwnedTimestamps::new(1, 100, &[10, 20]))
            .is_none()
    );

    let prev = store
        .insert(1, OwnedTimestamps::new(1, 200, &[30, 40]))
        .unwrap();
    assert_eq!(
        (prev.0.count, prev.0.epoch, &prev.0.offsets),
        (1, 100, &[10, 20][..])
    );

    let ts = store.get(1).unwrap();
    assert_eq!((ts.count, ts.epoch, &ts.offsets), (1, 200, &[30, 40][..]));

    let prev = store
        .insert(1, OwnedTimestamps::new(1, 300, &[50, 60]))
        .unwrap();
    assert_eq!(
        (prev.0.count, prev.0.epoch, &prev.0.offsets),
        (1, 200, &[30, 40][..])
    );

    assert!(!store.is_empty());
}

#[test]
fn test_custom_remove() {
    let mut store = new_custom_store(4, 4);

    store.insert(1, OwnedTimestamps::new(2, 100, &[10, 20]));
    store.insert(2, OwnedTimestamps::new(2, 200, &[30, 40]));

    assert!(store.remove(0).is_none());

    let removed = store.remove(1).unwrap();
    assert_eq!(
        (removed.0.count, removed.0.epoch, &removed.0.offsets),
        (2, 100, &[10, 20][..])
    );
    assert_eq!(store.get(1), None);
    assert!(!store.contains_key(1));
    assert!(store.remove(1).is_none());

    let ts2 = store.get(2).unwrap();
    assert_eq!(
        (ts2.count, ts2.epoch, &ts2.offsets),
        (2, 200, &[30, 40][..])
    );
    assert!(!store.is_empty());

    let removed = store.remove(2).unwrap();
    assert_eq!(
        (removed.0.count, removed.0.epoch, &removed.0.offsets),
        (2, 200, &[30, 40][..])
    );
    assert!(store.is_empty());
}

#[test]
fn test_custom_iterators() {
    let mut store = new_custom_store(4, 4);
    store.insert(1, OwnedTimestamps::new(1, 10, &[100, 200]));
    store.insert(3, OwnedTimestamps::new(1, 30, &[500, 600]));

    let iter_result: Vec<Option<(u32, u32, &[i32])>> = store
        .iter()
        .map(|opt| opt.map(|ts| (ts.count, ts.epoch, &ts.offsets)))
        .collect();
    assert_eq!(
        iter_result,
        vec![
            None,
            Some((1, 10, &[100, 200][..])),
            None,
            Some((1, 30, &[500, 600][..])),
        ]
    );

    let labeled: HashSet<(usize, u32, u32)> = store
        .iter_labeled()
        .map(|(n, ts)| (n, ts.count, ts.epoch))
        .collect();
    assert_eq!(labeled, set![(1, 1, 10), (3, 1, 30)]);

    let par_labeled: HashSet<(usize, u32, u32)> = store
        .par_iter_labeled()
        .map(|(n, ts)| (n, ts.count, ts.epoch))
        .collect();
    assert_eq!(par_labeled, labeled);
}