use std::sync::Mutex;
use crate::parallel_solver::{AssemblyTask, CsrMatrix};
#[derive(Debug, Clone)]
pub struct ElementColoringResult {
pub colors: Vec<usize>,
pub n_colors: usize,
pub buckets: Vec<Vec<usize>>,
}
const MAX_COLORS: usize = 64;
pub fn color_elements(n_elements: usize, element_dofs: &[Vec<usize>]) -> ElementColoringResult {
if n_elements == 0 {
return ElementColoringResult {
colors: Vec::new(),
n_colors: 0,
buckets: Vec::new(),
};
}
let max_dof = element_dofs
.iter()
.flat_map(|dofs| dofs.iter())
.copied()
.max()
.unwrap_or(0);
let mut dof_to_elems: Vec<Vec<usize>> = vec![Vec::new(); max_dof + 1];
for (e, dofs) in element_dofs.iter().enumerate() {
for &d in dofs {
dof_to_elems[d].push(e);
}
}
let mut colors = vec![MAX_COLORS; n_elements]; let mut n_colors_used = 0usize;
for e in 0..n_elements {
let mut neighbor_colors = std::collections::HashSet::new();
for &d in &element_dofs[e] {
for &adj_elem in &dof_to_elems[d] {
if adj_elem != e && colors[adj_elem] < MAX_COLORS {
neighbor_colors.insert(colors[adj_elem]);
}
}
}
let mut chosen_color = 0;
while neighbor_colors.contains(&chosen_color) {
chosen_color += 1;
}
if chosen_color >= MAX_COLORS {
eprintln!(
"assembly_coloring: element {e} needs color {chosen_color} >= MAX_COLORS={MAX_COLORS}, \
using fallback serial bucket"
);
chosen_color = MAX_COLORS - 1;
}
colors[e] = chosen_color;
if chosen_color + 1 > n_colors_used {
n_colors_used = chosen_color + 1;
}
}
let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_colors_used];
for (e, &c) in colors.iter().enumerate() {
if c < n_colors_used {
buckets[c].push(e);
}
}
ElementColoringResult {
colors,
n_colors: n_colors_used,
buckets,
}
}
pub fn assemble_colored_csr(
ndofs: usize,
tasks: &[AssemblyTask],
coloring: &ElementColoringResult,
) -> CsrMatrix {
let mut row_cols: Vec<std::collections::BTreeSet<usize>> =
vec![std::collections::BTreeSet::new(); ndofs];
for task in tasks {
for &row_dof in &task.global_dofs {
for &col_dof in &task.global_dofs {
row_cols[row_dof].insert(col_dof);
}
}
}
let mut row_offsets = vec![0usize; ndofs + 1];
let mut col_indices: Vec<usize> = Vec::new();
for (i, cols) in row_cols.iter().enumerate() {
row_offsets[i + 1] = row_offsets[i] + cols.len();
col_indices.extend(cols.iter().copied());
}
let nnz = col_indices.len();
let row_col_to_csr: Vec<std::collections::HashMap<usize, usize>> = row_cols
.iter()
.enumerate()
.map(|(i, cols)| {
let base = row_offsets[i];
cols.iter()
.enumerate()
.map(|(j, &c)| (c, base + j))
.collect()
})
.collect();
let values_locked: Vec<Mutex<f64>> = (0..nnz).map(|_| Mutex::new(0.0f64)).collect();
use rayon::prelude::*;
for bucket in &coloring.buckets {
bucket.par_iter().for_each(|&e| {
let task = &tasks[e];
let ndof = task.ndof();
for (li, &row) in task.global_dofs.iter().enumerate() {
for (lj, &col) in task.global_dofs.iter().enumerate() {
let ke_val = task.ke[li * ndof + lj];
if let Some(&csr_idx) = row_col_to_csr[row].get(&col) {
let mut guard = values_locked[csr_idx]
.lock()
.unwrap_or_else(|p| p.into_inner());
*guard += ke_val;
}
}
}
});
}
let values: Vec<f64> = values_locked
.into_iter()
.map(|m| m.into_inner().unwrap_or_else(|p| p.into_inner()))
.collect();
CsrMatrix {
nrows: ndofs,
ncols: ndofs,
row_offsets,
col_indices,
values,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parallel_solver::ParallelAssembler;
fn make_test_tasks() -> (Vec<AssemblyTask>, Vec<Vec<usize>>) {
let dofs = vec![vec![0, 1], vec![1, 2], vec![2, 3], vec![3, 4]];
let ke = vec![1.0, -1.0, -1.0, 1.0];
let tasks: Vec<AssemblyTask> = dofs
.iter()
.map(|d| AssemblyTask::new(d.clone(), ke.clone()))
.collect();
(tasks, dofs)
}
#[test]
fn test_coloring_valid() {
let (tasks, dofs) = make_test_tasks();
let coloring = color_elements(tasks.len(), &dofs);
for (color, bucket) in coloring.buckets.iter().enumerate() {
for &e1 in bucket {
for &e2 in bucket {
if e1 == e2 {
continue;
}
let dofs1: std::collections::HashSet<usize> =
dofs[e1].iter().copied().collect();
let dofs2: std::collections::HashSet<usize> =
dofs[e2].iter().copied().collect();
let overlap: std::collections::HashSet<usize> =
dofs1.intersection(&dofs2).copied().collect();
assert!(
overlap.is_empty(),
"Color {color}: elements {e1} and {e2} share DOFs {overlap:?}"
);
}
}
}
}
#[test]
fn test_colored_assembly_matches_serial() {
let (tasks, dofs) = make_test_tasks();
let ndofs = 5;
let asm = ParallelAssembler::new(ndofs);
let mat_serial = asm.assemble(&tasks);
let coloring = color_elements(tasks.len(), &dofs);
let mat_colored = assemble_colored_csr(ndofs, &tasks, &coloring);
assert_eq!(mat_serial.nrows, mat_colored.nrows);
assert_eq!(mat_serial.ncols, mat_colored.ncols);
assert_eq!(mat_serial.nnz(), mat_colored.nnz());
let get_val = |mat: &CsrMatrix, row: usize, col: usize| -> f64 {
for k in mat.row_offsets[row]..mat.row_offsets[row + 1] {
if mat.col_indices[k] == col {
return mat.values[k];
}
}
0.0
};
for i in 0..ndofs {
for k in mat_serial.row_offsets[i]..mat_serial.row_offsets[i + 1] {
let j = mat_serial.col_indices[k];
let v_ser = mat_serial.values[k];
let v_col = get_val(&mat_colored, i, j);
assert!(
(v_ser - v_col).abs() < 1e-14,
"Mismatch at ({i},{j}): serial={v_ser}, colored={v_col}"
);
}
}
}
#[test]
fn test_coloring_small_mesh() {
let dofs = vec![vec![0usize, 1], vec![1, 2], vec![2, 3], vec![0, 3]];
let coloring = color_elements(4, &dofs);
assert_eq!(coloring.colors.len(), 4);
assert!(coloring.n_colors >= 2, "Expected at least 2 colors");
for bucket in &coloring.buckets {
for &e1 in bucket {
for &e2 in bucket {
if e1 == e2 {
continue;
}
let dofs1: std::collections::HashSet<usize> =
dofs[e1].iter().copied().collect();
let dofs2: std::collections::HashSet<usize> =
dofs[e2].iter().copied().collect();
let overlap: Vec<usize> = dofs1.intersection(&dofs2).copied().collect();
assert!(
overlap.is_empty(),
"Same-color elements {e1} and {e2} share DOFs {overlap:?}"
);
}
}
}
}
#[test]
fn test_coloring_empty() {
let result = color_elements(0, &[]);
assert_eq!(result.n_colors, 0);
assert_eq!(result.buckets.len(), 0);
}
}