use crate::error::IntegrateError;
use scirs2_core::parallel_ops::*;
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct AmrCell {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
pub value: f64,
}
impl AmrCell {
#[inline]
pub fn centre(&self) -> (f64, f64) {
(self.x + 0.5 * self.width, self.y + 0.5 * self.height)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BalanceType {
None,
Face2to1,
}
#[derive(Debug, Clone)]
pub struct AmrConfig {
pub max_level: u32,
pub refine_threshold: f64,
pub coarsen_threshold: f64,
pub balance_type: BalanceType,
}
impl Default for AmrConfig {
fn default() -> Self {
AmrConfig {
max_level: 8,
refine_threshold: 0.5,
coarsen_threshold: 0.1,
balance_type: BalanceType::Face2to1,
}
}
}
#[derive(Debug, Clone)]
pub struct LoadStats {
pub min: usize,
pub max: usize,
pub mean: f64,
pub n_threads: usize,
}
#[derive(Debug, Clone)]
struct TreeNode {
cell: AmrCell,
level: u32,
parent: Option<usize>,
children: Option<[usize; 4]>,
idx: usize,
}
impl TreeNode {
fn is_leaf(&self) -> bool {
self.children.is_none()
}
}
pub struct AmrGrid2D {
nodes: Vec<TreeNode>,
leaf_indices: Vec<usize>,
domain: [f64; 4],
}
impl AmrGrid2D {
pub fn new_uniform(nx: usize, ny: usize, domain: [f64; 4]) -> Self {
assert!(nx > 0 && ny > 0, "nx and ny must be > 0");
let [x_min, x_max, y_min, y_max] = domain;
let cell_w = (x_max - x_min) / nx as f64;
let cell_h = (y_max - y_min) / ny as f64;
let total = nx * ny;
let mut nodes = Vec::with_capacity(total);
let mut leaf_indices = Vec::with_capacity(total);
for j in 0..ny {
for i in 0..nx {
let idx = j * nx + i;
nodes.push(TreeNode {
cell: AmrCell {
x: x_min + i as f64 * cell_w,
y: y_min + j as f64 * cell_h,
width: cell_w,
height: cell_h,
value: 0.0,
},
level: 0,
parent: None,
children: None,
idx,
});
leaf_indices.push(idx);
}
}
AmrGrid2D {
nodes,
leaf_indices,
domain,
}
}
pub fn n_leaves(&self) -> usize {
self.leaf_indices.len()
}
pub fn n_cells(&self) -> usize {
self.nodes.len()
}
pub fn leaves(&self) -> Vec<&AmrCell> {
self.leaf_indices
.iter()
.map(|&i| &self.nodes[i].cell)
.collect()
}
pub fn refine_parallel<G>(&mut self, indicator: G, config: &AmrConfig)
where
G: Fn(&AmrCell) -> f64 + Send + Sync,
{
let to_refine: Vec<usize> = {
let nodes = &self.nodes;
let leaves = &self.leaf_indices;
parallel_map(leaves, |&leaf_idx| {
let node = &nodes[leaf_idx];
let ind_val = indicator(&node.cell);
if ind_val > config.refine_threshold && node.level < config.max_level {
Some(leaf_idx)
} else {
None
}
})
.into_iter()
.flatten()
.collect()
};
for leaf_idx in to_refine {
self.refine_cell(leaf_idx);
}
self.rebuild_leaf_indices();
if config.balance_type == BalanceType::Face2to1 {
self.enforce_balance_loop(config.max_level);
self.rebuild_leaf_indices();
}
}
fn refine_cell(&mut self, leaf_idx: usize) {
if self.nodes[leaf_idx].children.is_some() {
return; }
let parent_level = self.nodes[leaf_idx].level;
let px = self.nodes[leaf_idx].cell.x;
let py = self.nodes[leaf_idx].cell.y;
let pw = self.nodes[leaf_idx].cell.width;
let ph = self.nodes[leaf_idx].cell.height;
let parent_val = self.nodes[leaf_idx].cell.value;
let hw = pw * 0.5;
let hh = ph * 0.5;
let child_level = parent_level + 1;
let offsets: [(f64, f64); 4] = [(0.0, 0.0), (hw, 0.0), (0.0, hh), (hw, hh)];
let first_child = self.nodes.len();
let mut child_ids = [0_usize; 4];
for (k, &(dx, dy)) in offsets.iter().enumerate() {
let child_idx = first_child + k;
child_ids[k] = child_idx;
self.nodes.push(TreeNode {
cell: AmrCell {
x: px + dx,
y: py + dy,
width: hw,
height: hh,
value: parent_val,
},
level: child_level,
parent: Some(leaf_idx),
children: None,
idx: child_idx,
});
}
self.nodes[leaf_idx].children = Some(child_ids);
}
fn rebuild_leaf_indices(&mut self) {
self.leaf_indices.clear();
let mut stack: VecDeque<usize> = VecDeque::new();
for i in 0..self.nodes.len() {
if self.nodes[i].parent.is_none() {
stack.push_back(i);
}
}
while let Some(idx) = stack.pop_back() {
if self.nodes[idx].is_leaf() {
self.leaf_indices.push(idx);
} else if let Some(children) = self.nodes[idx].children {
for &c in &children {
stack.push_back(c);
}
}
}
}
pub fn enforce_balance(&mut self) {
self.enforce_balance_loop(u32::MAX);
self.rebuild_leaf_indices();
}
fn enforce_balance_loop(&mut self, max_level: u32) {
loop {
let mut changed = false;
let leaf_snap: Vec<usize> = self.leaf_indices.clone();
for &li in &leaf_snap {
let level_i = self.nodes[li].level;
let neighbours: Vec<usize> = self
.leaf_indices
.iter()
.copied()
.filter(|&lj| {
lj != li && Self::are_face_neighbours(&self.nodes[li], &self.nodes[lj])
})
.collect();
for nj in neighbours {
let level_j = self.nodes[nj].level;
if level_i > level_j + 1 {
if level_j < max_level {
self.refine_cell(nj);
self.rebuild_leaf_indices();
changed = true;
break; }
}
}
if changed {
break;
}
}
if !changed {
break;
}
}
}
fn are_face_neighbours(a: &TreeNode, b: &TreeNode) -> bool {
let ax_min = a.cell.x;
let ax_max = a.cell.x + a.cell.width;
let ay_min = a.cell.y;
let ay_max = a.cell.y + a.cell.height;
let bx_min = b.cell.x;
let bx_max = b.cell.x + b.cell.width;
let by_min = b.cell.y;
let by_max = b.cell.y + b.cell.height;
let eps = 1e-12;
let share_x_face = ((ax_max - bx_min).abs() < eps || (bx_max - ax_min).abs() < eps)
&& ay_min < by_max - eps
&& by_min < ay_max - eps;
let share_y_face = ((ay_max - by_min).abs() < eps || (by_max - ay_min).abs() < eps)
&& ax_min < bx_max - eps
&& bx_min < ax_max - eps;
share_x_face || share_y_face
}
pub fn load_stats(&self) -> LoadStats {
let n = num_threads().max(1);
let total = self.leaf_indices.len();
if total == 0 {
return LoadStats {
min: 0,
max: 0,
mean: 0.0,
n_threads: n,
};
}
let base = total / n;
let rem = total % n;
let mut sizes: Vec<usize> = (0..n)
.map(|i| if i < rem { base + 1 } else { base })
.filter(|&s| s > 0)
.collect();
if sizes.is_empty() {
sizes.push(total);
}
let min = *sizes.iter().min().unwrap_or(&0);
let max = *sizes.iter().max().unwrap_or(&0);
let mean = sizes.iter().sum::<usize>() as f64 / sizes.len() as f64;
LoadStats {
min,
max,
mean,
n_threads: sizes.len(),
}
}
pub fn domain(&self) -> [f64; 4] {
self.domain
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_grid_has_correct_n_leaves() {
let grid = AmrGrid2D::new_uniform(4, 4, [0.0, 1.0, 0.0, 1.0]);
assert_eq!(grid.n_leaves(), 16);
assert_eq!(grid.n_cells(), 16);
}
#[test]
fn test_uniform_grid_2x3_n_leaves() {
let grid = AmrGrid2D::new_uniform(2, 3, [0.0, 2.0, 0.0, 3.0]);
assert_eq!(grid.n_leaves(), 6);
}
#[test]
fn test_refine_above_threshold_doubles_cells() {
let mut grid = AmrGrid2D::new_uniform(2, 2, [0.0, 1.0, 0.0, 1.0]);
for i in 0..grid.nodes.len() {
grid.nodes[i].cell.value = 1.0;
}
let config = AmrConfig {
max_level: 4,
refine_threshold: 0.5,
coarsen_threshold: 0.0,
balance_type: BalanceType::None,
};
let n_before = grid.n_leaves();
grid.refine_parallel(|cell| cell.value, &config);
let n_after = grid.n_leaves();
assert_eq!(n_before, 4);
assert_eq!(n_after, 16);
}
#[test]
fn test_no_refinement_below_threshold() {
let mut grid = AmrGrid2D::new_uniform(3, 3, [0.0, 1.0, 0.0, 1.0]);
let config = AmrConfig {
max_level: 4,
refine_threshold: 100.0, coarsen_threshold: 0.0,
balance_type: BalanceType::None,
};
let n_before = grid.n_leaves();
grid.refine_parallel(|_| 0.0, &config);
assert_eq!(grid.n_leaves(), n_before);
}
#[test]
fn test_max_level_respected() {
let mut grid = AmrGrid2D::new_uniform(1, 1, [0.0, 1.0, 0.0, 1.0]);
let config = AmrConfig {
max_level: 1,
refine_threshold: 0.0,
coarsen_threshold: 0.0,
balance_type: BalanceType::None,
};
grid.refine_parallel(|_| 1.0, &config);
let n_after_first = grid.n_leaves();
grid.refine_parallel(|_| 1.0, &config);
let n_after_second = grid.n_leaves();
assert_eq!(n_after_first, 4, "First refinement should yield 4 leaves");
assert_eq!(
n_after_second, 4,
"Second refinement at max_level should be a no-op"
);
}
#[test]
fn test_balance_enforces_2to1() {
let mut grid = AmrGrid2D::new_uniform(4, 4, [0.0, 1.0, 0.0, 1.0]);
for i in 0..grid.nodes.len() {
let node = &grid.nodes[i];
if node.cell.x > 0.75 - 1e-9 && node.cell.y > 0.75 - 1e-9 {
grid.nodes[i].cell.value = 10.0;
}
}
let config = AmrConfig {
max_level: 3,
refine_threshold: 5.0,
coarsen_threshold: 0.0,
balance_type: BalanceType::Face2to1,
};
grid.refine_parallel(|cell| cell.value, &config);
let leaves: Vec<(usize, u32, &AmrCell)> = grid
.leaf_indices
.iter()
.map(|&i| (i, grid.nodes[i].level, &grid.nodes[i].cell))
.collect();
for i in 0..leaves.len() {
for j in (i + 1)..leaves.len() {
let (_, li, ci) = &leaves[i];
let (_, lj, cj) = &leaves[j];
let ax_max = ci.x + ci.width;
let ay_max = ci.y + ci.height;
let bx_max = cj.x + cj.width;
let by_max = cj.y + cj.height;
let eps = 1e-9;
let share_x = ((ax_max - cj.x).abs() < eps || (bx_max - ci.x).abs() < eps)
&& ci.y < by_max - eps
&& cj.y < ay_max - eps;
let share_y = ((ay_max - cj.y).abs() < eps || (by_max - ci.y).abs() < eps)
&& ci.x < bx_max - eps
&& cj.x < ax_max - eps;
if share_x || share_y {
let diff = (*li as i64 - *lj as i64).unsigned_abs();
assert!(
diff <= 1,
"2:1 balance violated: leaf {} (level {}) and leaf {} (level {}) are face neighbours with level diff {}",
i, li, j, lj, diff
);
}
}
}
}
#[test]
fn test_load_stats_nonzero_on_refined_grid() {
let mut grid = AmrGrid2D::new_uniform(4, 4, [0.0, 1.0, 0.0, 1.0]);
let config = AmrConfig {
max_level: 2,
refine_threshold: -1.0, coarsen_threshold: -2.0,
balance_type: BalanceType::None,
};
grid.refine_parallel(|_| 0.0, &config);
let stats = grid.load_stats();
assert!(stats.max > 0, "Max load should be > 0 after refinement");
assert!(stats.mean > 0.0, "Mean load should be > 0 after refinement");
assert!(stats.min <= stats.max, "Min must not exceed max");
}
#[test]
fn test_amr_refine_parallel_matches_serial() {
let mut grid1 = AmrGrid2D::new_uniform(3, 3, [0.0, 1.0, 0.0, 1.0]);
let mut grid2 = AmrGrid2D::new_uniform(3, 3, [0.0, 1.0, 0.0, 1.0]);
let config = AmrConfig {
max_level: 2,
refine_threshold: -1.0,
coarsen_threshold: -2.0,
balance_type: BalanceType::None,
};
grid1.refine_parallel(|_| 0.0, &config);
let initial_leaves: Vec<usize> = grid2.leaf_indices.clone();
for li in initial_leaves {
grid2.refine_cell(li);
}
grid2.rebuild_leaf_indices();
assert_eq!(
grid1.n_leaves(),
grid2.n_leaves(),
"Parallel and serial refinement must produce same leaf count"
);
}
#[test]
fn test_leaves_accessor() {
let grid = AmrGrid2D::new_uniform(2, 2, [0.0, 4.0, 0.0, 4.0]);
let leaves = grid.leaves();
assert_eq!(leaves.len(), 4);
for leaf in &leaves {
assert!(leaf.x >= 0.0 && leaf.x + leaf.width <= 4.0);
assert!(leaf.y >= 0.0 && leaf.y + leaf.height <= 4.0);
}
}
}