use crate::error::OptimizeError;
use scirs2_core::random::{rngs::StdRng, seq::SliceRandom, SeedableRng};
use scirs2_sparse::csr_array::CsrArray;
use std::collections::{HashMap, HashSet};
#[allow(dead_code)]
fn get_nonzero_cols_in_row(matrix: &CsrArray<f64>, row: usize) -> Vec<usize> {
let row_start = matrix.get_indptr()[row];
let row_end = matrix.get_indptr()[row + 1];
let indices = matrix.get_indices();
let mut cols = Vec::new();
for i in row_start..row_end {
cols.push(indices[i]);
}
cols
}
#[allow(dead_code)]
pub fn determine_column_groups(
sparsity: &CsrArray<f64>,
seed: Option<u64>,
max_group_size: Option<usize>,
) -> Result<Vec<Vec<usize>>, OptimizeError> {
let (m, n) = sparsity.shape();
let mut conflicts: Vec<HashSet<usize>> = vec![HashSet::new(); n];
for row in 0..m {
let cols = get_nonzero_cols_in_row(sparsity, row);
for &col1 in &cols {
for &col2 in &cols {
if col1 != col2 {
conflicts[col1].insert(col2);
conflicts[col2].insert(col1);
}
}
}
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&v| conflicts[v].len());
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
StdRng::seed_from_u64(0)
}
};
let mut i = 0;
while i < order.len() {
let degree = conflicts[order[i]].len();
let mut j = i + 1;
while j < order.len() && conflicts[order[j]].len() == degree {
j += 1;
}
order[i..j].shuffle(&mut rng);
i = j;
}
let mut vertex_colors: HashMap<usize, usize> = HashMap::new();
for &v in &order {
let mut neighbor_colors: HashSet<usize> = HashSet::new();
for &neighbor in &conflicts[v] {
if let Some(&color) = vertex_colors.get(&neighbor) {
neighbor_colors.insert(color);
}
}
let mut color = 0;
while neighbor_colors.contains(&color) {
color += 1;
}
vertex_colors.insert(v, color);
}
let max_color = vertex_colors.values().max().cloned().unwrap_or(0);
let mut color_groups: Vec<Vec<usize>> = vec![Vec::new(); max_color + 1];
for (vertex, &color) in &vertex_colors {
color_groups[color].push(*vertex);
}
let max_size = max_group_size.unwrap_or(usize::MAX);
if max_size < n {
let mut final_groups = Vec::new();
for group in color_groups {
if group.len() <= max_size {
final_groups.push(group);
} else {
for chunk in group.chunks(max_size) {
final_groups.push(chunk.to_vec());
}
}
}
Ok(final_groups)
} else {
Ok(color_groups.into_iter().filter(|g| !g.is_empty()).collect())
}
}