use crate::complexity::{Complexity, ComplexityClass};
use crate::matrix::Matrix;
use alloc::vec::Vec;
use bit_set::BitSet;
#[derive(Debug, Clone, Copy, Default)]
pub struct ClosureIndicesOp;
impl Complexity for ClosureIndicesOp {
const CLASS: ComplexityClass = ComplexityClass::SubLinear;
}
pub fn closure_indices(matrix: &dyn Matrix, seeds: &[usize], depth: usize) -> Vec<usize> {
let n = matrix.rows();
if n == 0 || seeds.is_empty() {
return Vec::new();
}
let mut visited = BitSet::with_capacity(n);
let mut frontier: Vec<usize> = Vec::with_capacity(seeds.len());
for &s in seeds {
if s < n && visited.insert(s) {
frontier.push(s);
}
}
let mut next: Vec<usize> = Vec::new();
for _ in 0..depth {
if frontier.is_empty() {
break;
}
next.clear();
for &row in &frontier {
for (col, _value) in matrix.row_iter(row) {
let c = col as usize;
if c < n && visited.insert(c) {
next.push(c);
}
}
}
core::mem::swap(&mut frontier, &mut next);
}
let mut out: Vec<usize> = Vec::with_capacity(visited.len());
for idx in visited.iter() {
out.push(idx);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::SparseMatrix;
#[test]
fn closure_on_diagonal_is_seeds_only() {
let n = 8;
let triplets: Vec<_> = (0..n).map(|i| (i, i, 1.0)).collect();
let a = SparseMatrix::from_triplets(triplets, n, n).unwrap();
let result = closure_indices(&a, &[2, 5], 3);
assert_eq!(result, vec![2, 5], "diagonal matrix should not expand");
}
#[test]
fn closure_grows_with_depth_on_bidiagonal() {
let n = 10;
let mut triplets = Vec::new();
for i in 0..n {
triplets.push((i, i, 2.0));
if i + 1 < n {
triplets.push((i, i + 1, -1.0));
}
}
let a = SparseMatrix::from_triplets(triplets, n, n).unwrap();
assert_eq!(closure_indices(&a, &[0], 0), vec![0]);
assert_eq!(closure_indices(&a, &[0], 1), vec![0, 1]);
assert_eq!(closure_indices(&a, &[0], 2), vec![0, 1, 2]);
assert_eq!(closure_indices(&a, &[0], 5), vec![0, 1, 2, 3, 4, 5]);
}
#[test]
fn closure_drops_out_of_bound_seeds() {
let n = 4;
let triplets: Vec<_> = (0..n).map(|i| (i, i, 1.0)).collect();
let a = SparseMatrix::from_triplets(triplets, n, n).unwrap();
let result = closure_indices(&a, &[1, 99, 3], 5);
assert_eq!(result, vec![1, 3]);
}
#[test]
fn closure_on_empty_seeds_is_empty() {
let a = SparseMatrix::from_triplets(Vec::new(), 4, 4).unwrap();
assert!(closure_indices(&a, &[], 5).is_empty());
}
#[test]
fn closure_on_empty_matrix_is_empty() {
let a = SparseMatrix::from_triplets(Vec::new(), 0, 0).unwrap();
assert!(closure_indices(&a, &[0, 1, 2], 3).is_empty());
}
#[test]
fn closure_op_complexity_class() {
assert_eq!(
<ClosureIndicesOp as Complexity>::CLASS,
ComplexityClass::SubLinear
);
}
#[test]
fn closure_op_compile_time_bound() {
const _: () = assert!(matches!(
<ClosureIndicesOp as Complexity>::CLASS,
ComplexityClass::SubLinear
));
}
}