use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::{HashMap, HashSet};
pub struct AdvancedAMRManager<F: IntegrateFloat> {
pub mesh_hierarchy: MeshHierarchy<F>,
pub refinement_criteria: Vec<Box<dyn RefinementCriterion<F>>>,
pub load_balancer: Option<Box<dyn LoadBalancer<F>>>,
pub max_levels: usize,
pub min_cell_size: F,
pub coarsening_tolerance: F,
pub refinement_tolerance: F,
pub adaptation_frequency: usize,
current_step: usize,
}
#[derive(Debug, Clone)]
pub struct MeshHierarchy<F: IntegrateFloat> {
pub levels: Vec<AdaptiveMeshLevel<F>>,
pub hierarchy_map: HashMap<CellId, Vec<CellId>>,
pub ghost_cells: HashMap<usize, Vec<CellId>>, }
#[derive(Debug, Clone)]
pub struct AdaptiveMeshLevel<F: IntegrateFloat> {
pub level: usize,
pub cells: HashMap<CellId, AdaptiveCell<F>>,
pub grid_spacing: F,
pub boundary_cells: HashSet<CellId>,
}
#[derive(Debug, Clone)]
pub struct AdaptiveCell<F: IntegrateFloat> {
pub id: CellId,
pub center: Array1<F>,
pub size: F,
pub solution: Array1<F>,
pub error_estimate: F,
pub refinement_flag: RefinementFlag,
pub is_active: bool,
pub neighbors: Vec<CellId>,
pub parent: Option<CellId>,
pub children: Vec<CellId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CellId {
pub level: usize,
pub index: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RefinementFlag {
None,
Refine,
Coarsen,
Tagged,
}
pub trait RefinementCriterion<F: IntegrateFloat>: Send + Sync {
fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F;
fn name(&self) -> &'static str;
fn weight(&self) -> F {
F::one()
}
}
pub struct GradientRefinementCriterion<F: IntegrateFloat> {
pub component: Option<usize>,
pub threshold: F,
pub relative_tolerance: F,
}
pub struct FeatureDetectionCriterion<F: IntegrateFloat> {
pub threshold: F,
pub feature_types: Vec<FeatureType>,
pub window_size: usize,
}
pub struct CurvatureRefinementCriterion<F: IntegrateFloat> {
pub threshold: F,
pub approximation_order: usize,
}
pub trait LoadBalancer<F: IntegrateFloat>: Send + Sync {
fn balance(&self, hierarchy: &mut MeshHierarchy<F>) -> IntegrateResult<()>;
}
pub struct GeometricLoadBalancer<F: IntegrateFloat> {
pub num_partitions: usize,
pub imbalance_tolerance: F,
pub method: PartitioningMethod,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FeatureType {
SharpGradient,
Discontinuity,
LocalExtremum,
Oscillation,
BoundaryLayer,
}
#[derive(Debug, Clone, Copy)]
pub enum PartitioningMethod {
RCB,
SFC,
Graph,
}
pub struct AMRAdaptationResult<F: IntegrateFloat> {
pub cells_refined: usize,
pub cells_coarsened: usize,
pub total_active_cells: usize,
pub load_balance_quality: F,
pub memory_change: i64,
pub adaptation_time: std::time::Duration,
}
impl<F: IntegrateFloat> AdvancedAMRManager<F> {
pub fn new(_initial_mesh: AdaptiveMeshLevel<F>, max_levels: usize, min_cellsize: F) -> Self {
let mesh_hierarchy = MeshHierarchy {
levels: vec![_initial_mesh],
hierarchy_map: HashMap::new(),
ghost_cells: HashMap::new(),
};
Self {
mesh_hierarchy,
refinement_criteria: Vec::new(),
load_balancer: None,
max_levels,
min_cell_size: min_cellsize,
coarsening_tolerance: F::from(0.1).expect("Failed to convert constant to float"),
refinement_tolerance: F::from(1.0).expect("Failed to convert constant to float"),
adaptation_frequency: 1,
current_step: 0,
}
}
pub fn add_criterion(&mut self, criterion: Box<dyn RefinementCriterion<F>>) {
self.refinement_criteria.push(criterion);
}
pub fn set_load_balancer(&mut self, balancer: Box<dyn LoadBalancer<F>>) {
self.load_balancer = Some(balancer);
}
pub fn adapt_mesh(&mut self, solution: &Array2<F>) -> IntegrateResult<AMRAdaptationResult<F>> {
let start_time = std::time::Instant::now();
let initial_cells = self.count_active_cells();
self.current_step += 1;
if !self.current_step.is_multiple_of(self.adaptation_frequency) {
return Ok(AMRAdaptationResult {
cells_refined: 0,
cells_coarsened: 0,
total_active_cells: initial_cells,
load_balance_quality: F::one(),
memory_change: 0,
adaptation_time: start_time.elapsed(),
});
}
self.update_cell_solutions(solution)?;
self.evaluate_refinement_criteria()?;
let _refine_count_coarsen_count = self.flag_cells_for_adaptation()?;
let cells_refined = self.refine_flagged_cells()?;
let cells_coarsened = self.coarsen_flagged_cells()?;
let load_balance_quality = if let Some(ref balancer) = self.load_balancer {
balancer.balance(&mut self.mesh_hierarchy)?;
self.assess_load_balance()
} else {
F::one()
};
self.update_ghost_cells()?;
let final_cells = self.count_active_cells();
let memory_change = (final_cells as i64 - initial_cells as i64) * 8;
Ok(AMRAdaptationResult {
cells_refined,
cells_coarsened,
total_active_cells: final_cells,
load_balance_quality,
memory_change,
adaptation_time: start_time.elapsed(),
})
}
fn update_cell_solutions(&mut self, solution: &Array2<F>) -> IntegrateResult<()> {
for level in &mut self.mesh_hierarchy.levels {
for cell in level.cells.values_mut() {
if cell.is_active {
let i = (cell.center[0] * F::from(solution.nrows()).expect("Operation failed"))
.to_usize()
.unwrap_or(0)
.min(solution.nrows() - 1);
let j = if solution.ncols() > 1 && cell.center.len() > 1 {
(cell.center[1] * F::from(solution.ncols()).expect("Operation failed"))
.to_usize()
.unwrap_or(0)
.min(solution.ncols() - 1)
} else {
0
};
if cell.solution.len() == 1 {
cell.solution[0] = solution[[i, j]];
}
}
}
}
Ok(())
}
fn evaluate_refinement_criteria(&mut self) -> IntegrateResult<()> {
for level in &mut self.mesh_hierarchy.levels {
let cellids: Vec<CellId> = level.cells.keys().cloned().collect();
for cellid in cellids {
if let Some(cell) = level.cells.get(&cellid) {
if !cell.is_active {
continue;
}
let neighbor_cells: Vec<&AdaptiveCell<F>> = cell
.neighbors
.iter()
.filter_map(|&neighbor_id| level.cells.get(&neighbor_id))
.collect();
let mut total_error = F::zero();
let mut total_weight = F::zero();
for criterion in &self.refinement_criteria {
let error = criterion.evaluate(cell, &neighbor_cells);
let weight = criterion.weight();
total_error += error * weight;
total_weight += weight;
}
let error_estimate = if total_weight > F::zero() {
total_error / total_weight
} else {
F::zero()
};
if let Some(cell) = level.cells.get_mut(&cellid) {
cell.error_estimate = error_estimate;
}
}
}
}
Ok(())
}
fn flag_cells_for_adaptation(&mut self) -> IntegrateResult<(usize, usize)> {
let mut refine_count = 0;
let mut coarsen_count = 0;
let mut cells_to_check: Vec<(usize, CellId, F, usize, F)> = Vec::new();
for level in &self.mesh_hierarchy.levels {
for cell in level.cells.values() {
if cell.is_active {
cells_to_check.push((
level.level,
cell.id,
cell.error_estimate,
level.level,
cell.size,
));
}
}
}
for (level_idx, cellid, error_estimate, level_num, cell_size) in cells_to_check {
if let Some(level) = self.mesh_hierarchy.levels.get_mut(level_idx) {
if let Some(cell) = level.cells.get_mut(&cellid) {
if error_estimate > self.refinement_tolerance
&& level_num < self.max_levels
&& cell_size > self.min_cell_size
{
cell.refinement_flag = RefinementFlag::Refine;
refine_count += 1;
}
else if error_estimate < self.coarsening_tolerance && level_num > 0 {
cell.refinement_flag = RefinementFlag::Coarsen;
coarsen_count += 1;
} else {
cell.refinement_flag = RefinementFlag::None;
}
}
}
}
Ok((refine_count, coarsen_count))
}
fn can_coarsen_cell(&self, cell: &AdaptiveCell<F>) -> bool {
if let Some(parent_id) = cell.parent {
if let Some(parent_children) = self.mesh_hierarchy.hierarchy_map.get(&parent_id) {
for &child_id in parent_children {
if let Some(level) = self.mesh_hierarchy.levels.get(child_id.level) {
if let Some(sibling) = level.cells.get(&child_id) {
if sibling.refinement_flag != RefinementFlag::Coarsen {
return false;
}
}
}
}
return true;
}
}
false
}
fn refine_flagged_cells(&mut self) -> IntegrateResult<usize> {
let mut cells_refined = 0;
for level_idx in 0..self.mesh_hierarchy.levels.len() {
let cells_to_refine: Vec<CellId> = self.mesh_hierarchy.levels[level_idx]
.cells
.values()
.filter(|cell| cell.refinement_flag == RefinementFlag::Refine)
.map(|cell| cell.id)
.collect();
for cellid in cells_to_refine {
self.refine_cell(cellid)?;
cells_refined += 1;
}
}
Ok(cells_refined)
}
fn refine_cell(&mut self, cellid: CellId) -> IntegrateResult<()> {
let parent_cell = if let Some(level) = self.mesh_hierarchy.levels.get(cellid.level) {
level.cells.get(&cellid).cloned()
} else {
return Err(IntegrateError::ValueError("Invalid cell level".to_string()));
};
let parent_cell =
parent_cell.ok_or_else(|| IntegrateError::ValueError("Cell not found".to_string()))?;
let child_level = cellid.level + 1;
while self.mesh_hierarchy.levels.len() <= child_level {
let new_level = AdaptiveMeshLevel {
level: self.mesh_hierarchy.levels.len(),
cells: HashMap::new(),
grid_spacing: if let Some(last_level) = self.mesh_hierarchy.levels.last() {
last_level.grid_spacing
/ F::from(2.0).expect("Failed to convert constant to float")
} else {
F::one()
},
boundary_cells: HashSet::new(),
};
self.mesh_hierarchy.levels.push(new_level);
}
let num_children = 2_usize.pow(parent_cell.center.len() as u32);
let mut child_ids = Vec::new();
let child_size =
parent_cell.size / F::from(2.0).expect("Failed to convert constant to float");
for child_idx in 0..num_children {
let child_id = CellId {
level: child_level,
index: self.mesh_hierarchy.levels[child_level].cells.len(),
};
let mut child_center = parent_cell.center.clone();
let offset = child_size / F::from(2.0).expect("Failed to convert constant to float");
for dim in 0..parent_cell.center.len() {
if (child_idx >> dim) & 1 == 1 {
child_center[dim] += offset;
} else {
child_center[dim] -= offset;
}
}
let child_cell = AdaptiveCell {
id: child_id,
center: child_center,
size: child_size,
solution: parent_cell.solution.clone(), error_estimate: F::zero(),
refinement_flag: RefinementFlag::None,
is_active: true,
neighbors: Vec::new(),
parent: Some(cellid),
children: Vec::new(),
};
self.mesh_hierarchy.levels[child_level]
.cells
.insert(child_id, child_cell);
child_ids.push(child_id);
}
self.mesh_hierarchy
.hierarchy_map
.insert(cellid, child_ids.clone());
if let Some(parent) = self.mesh_hierarchy.levels[cellid.level]
.cells
.get_mut(&cellid)
{
parent.is_active = false;
parent.children = child_ids;
}
self.update_neighbor_relationships(child_level)?;
Ok(())
}
fn coarsen_flagged_cells(&mut self) -> IntegrateResult<usize> {
let mut cells_coarsened = 0;
for level_idx in (1..self.mesh_hierarchy.levels.len()).rev() {
let parent_cells_to_activate: Vec<CellId> = self.mesh_hierarchy.levels[level_idx]
.cells
.values()
.filter(|cell| cell.refinement_flag == RefinementFlag::Coarsen)
.filter_map(|cell| cell.parent)
.collect::<HashSet<_>>()
.into_iter()
.collect();
for parent_id in parent_cells_to_activate {
if self.coarsen_to_parent(parent_id)? {
cells_coarsened += 1;
}
}
}
Ok(cells_coarsened)
}
fn coarsen_to_parent(&mut self, parentid: CellId) -> IntegrateResult<bool> {
let child_ids = if let Some(children) = self.mesh_hierarchy.hierarchy_map.get(&parentid) {
children.clone()
} else {
return Ok(false);
};
for &child_id in &child_ids {
if let Some(level) = self.mesh_hierarchy.levels.get(child_id.level) {
if let Some(child) = level.cells.get(&child_id) {
if child.refinement_flag != RefinementFlag::Coarsen {
return Ok(false);
}
}
}
}
let mut avg_solution = Array1::zeros(child_ids.len());
if !child_ids.is_empty() {
if let Some(first_child_level) = self.mesh_hierarchy.levels.get(child_ids[0].level) {
if let Some(first_child) = first_child_level.cells.get(&child_ids[0]) {
avg_solution = Array1::zeros(first_child.solution.len());
for &child_id in &child_ids {
if let Some(child_level) = self.mesh_hierarchy.levels.get(child_id.level) {
if let Some(child) = child_level.cells.get(&child_id) {
avg_solution = &avg_solution + &child.solution;
}
}
}
avg_solution /= F::from(child_ids.len()).expect("Operation failed");
}
}
}
if let Some(parent_level) = self.mesh_hierarchy.levels.get_mut(parentid.level) {
if let Some(parent) = parent_level.cells.get_mut(&parentid) {
parent.is_active = true;
parent.solution = avg_solution;
parent.children.clear();
parent.refinement_flag = RefinementFlag::None;
}
}
for &child_id in &child_ids {
if let Some(child_level) = self.mesh_hierarchy.levels.get_mut(child_id.level) {
child_level.cells.remove(&child_id);
}
}
self.mesh_hierarchy.hierarchy_map.remove(&parentid);
Ok(true)
}
fn update_neighbor_relationships(&mut self, level: usize) -> IntegrateResult<()> {
let mut all_neighbor_relationships: Vec<(CellId, Vec<CellId>)> = Vec::new();
if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level) {
let cellids: Vec<CellId> = mesh_level.cells.keys().cloned().collect();
let mut spatial_hash: HashMap<(i32, i32, i32), Vec<CellId>> = HashMap::new();
let grid_spacing = mesh_level.grid_spacing;
for cellid in &cellids {
if let Some(cell) = mesh_level.cells.get(cellid) {
if cell.center.len() >= 3 {
let hash_x = (cell.center[0] / grid_spacing)
.floor()
.to_i32()
.unwrap_or(0);
let hash_y = (cell.center[1] / grid_spacing)
.floor()
.to_i32()
.unwrap_or(0);
let hash_z = (cell.center[2] / grid_spacing)
.floor()
.to_i32()
.unwrap_or(0);
spatial_hash
.entry((hash_x, hash_y, hash_z))
.or_default()
.push(*cellid);
}
}
}
for cellid in &cellids {
if let Some(cell) = mesh_level.cells.get(cellid) {
let mut neighbors = Vec::new();
if cell.center.len() >= 3 {
let hash_x = (cell.center[0] / grid_spacing)
.floor()
.to_i32()
.unwrap_or(0);
let hash_y = (cell.center[1] / grid_spacing)
.floor()
.to_i32()
.unwrap_or(0);
let hash_z = (cell.center[2] / grid_spacing)
.floor()
.to_i32()
.unwrap_or(0);
for dx in -1..=1 {
for dy in -1..=1 {
for dz in -1..=1 {
let hash_key = (hash_x + dx, hash_y + dy, hash_z + dz);
if let Some(potential_neighbors) = spatial_hash.get(&hash_key) {
for &neighbor_id in potential_neighbors {
if neighbor_id != *cellid {
if let Some(neighbor_cell) =
mesh_level.cells.get(&neighbor_id)
{
if self.are_cells_neighbors(cell, neighbor_cell)
{
neighbors.push(neighbor_id);
}
}
}
}
}
}
}
}
}
all_neighbor_relationships.push((*cellid, neighbors));
}
}
}
if let Some(mesh_level) = self.mesh_hierarchy.levels.get_mut(level) {
for (cellid, neighbors) in all_neighbor_relationships {
if let Some(cell) = mesh_level.cells.get_mut(&cellid) {
cell.neighbors = neighbors;
}
}
}
let cellids: Vec<CellId> = if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level) {
mesh_level.cells.keys().cloned().collect()
} else {
Vec::new()
};
for cellid in cellids {
self.update_interlevel_neighbors(cellid, level)?;
}
Ok(())
}
fn are_cells_neighbors(&self, cell1: &AdaptiveCell<F>, cell2: &AdaptiveCell<F>) -> bool {
if cell1.center.len() != cell2.center.len() || cell1.center.len() < 3 {
return false;
}
let max_size = cell1.size.max(cell2.size);
let tolerance = max_size * F::from(1.1).expect("Failed to convert constant to float");
let mut distance_squared = F::zero();
for i in 0..cell1.center.len() {
let diff = cell1.center[i] - cell2.center[i];
distance_squared += diff * diff;
}
let distance = distance_squared.sqrt();
let expected_distance =
(cell1.size + cell2.size) / F::from(2.0).expect("Failed to convert constant to float");
distance <= tolerance
&& distance
>= expected_distance * F::from(0.7).expect("Failed to convert constant to float")
}
fn update_interlevel_neighbors(&mut self, cellid: CellId, level: usize) -> IntegrateResult<()> {
let mut coarser_neighbors = Vec::new();
let mut finer_neighbors = Vec::new();
if level > 0 {
if let (Some(current_level), Some(coarser_level)) = (
self.mesh_hierarchy.levels.get(level),
self.mesh_hierarchy.levels.get(level - 1),
) {
if let Some(current_cell) = current_level.cells.get(&cellid) {
for (coarser_cellid, coarser_cell) in &coarser_level.cells {
if self.are_cells_neighbors(current_cell, coarser_cell) {
coarser_neighbors.push(*coarser_cellid);
}
}
}
}
}
if level + 1 < self.mesh_hierarchy.levels.len() {
if let (Some(current_level), Some(finer_level)) = (
self.mesh_hierarchy.levels.get(level),
self.mesh_hierarchy.levels.get(level + 1),
) {
if let Some(current_cell) = current_level.cells.get(&cellid) {
for (finer_cellid, finer_cell) in &finer_level.cells {
if self.are_cells_neighbors(current_cell, finer_cell) {
finer_neighbors.push(*finer_cellid);
}
}
}
}
}
if let Some(current_level) = self.mesh_hierarchy.levels.get_mut(level) {
if let Some(current_cell) = current_level.cells.get_mut(&cellid) {
for coarser_id in coarser_neighbors {
if !current_cell.neighbors.contains(&coarser_id) {
current_cell.neighbors.push(coarser_id);
}
}
for finer_id in finer_neighbors {
if !current_cell.neighbors.contains(&finer_id) {
current_cell.neighbors.push(finer_id);
}
}
}
}
Ok(())
}
fn update_ghost_cells(&mut self) -> IntegrateResult<()> {
self.mesh_hierarchy.ghost_cells.clear();
for level_idx in 0..self.mesh_hierarchy.levels.len() {
let mut ghost_cells_for_level = Vec::new();
let mut boundary_cells = HashSet::new();
let expected_neighbors = self.calculate_expected_neighbors();
if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level_idx) {
for (cellid, cell) in &mesh_level.cells {
if cell.neighbors.len() < expected_neighbors
|| mesh_level.boundary_cells.contains(cellid)
{
boundary_cells.insert(*cellid);
}
}
for boundary_cellid in &boundary_cells {
if let Some(boundary_cell) = mesh_level.cells.get(boundary_cellid) {
let ghost_cells =
self.create_ghost_cells_for_boundary(boundary_cell, level_idx)?;
ghost_cells_for_level.extend(ghost_cells);
}
}
self.create_interlevel_ghost_cells(level_idx, &mut ghost_cells_for_level)?;
}
self.mesh_hierarchy
.ghost_cells
.insert(level_idx, ghost_cells_for_level);
}
Ok(())
}
fn calculate_expected_neighbors(&self) -> usize {
6
}
fn create_ghost_cells_for_boundary(
&self,
boundary_cell: &AdaptiveCell<F>,
level: usize,
) -> IntegrateResult<Vec<CellId>> {
let mut ghost_cells = Vec::new();
if boundary_cell.center.len() >= 3 {
let cell_size = boundary_cell.size;
let directions = [
[F::one(), F::zero(), F::zero()], [-F::one(), F::zero(), F::zero()], [F::zero(), F::one(), F::zero()], [F::zero(), -F::one(), F::zero()], [F::zero(), F::zero(), F::one()], [F::zero(), F::zero(), -F::one()], ];
for (dir_idx, direction) in directions.iter().enumerate() {
let mut ghost_center = boundary_cell.center.clone();
for i in 0..3 {
ghost_center[i] += direction[i] * cell_size;
}
if !self.cell_exists_at_position(&ghost_center, level) {
let ghost_id = CellId {
level,
index: 1_000_000 + boundary_cell.id.index * 10 + dir_idx,
};
ghost_cells.push(ghost_id);
}
}
}
Ok(ghost_cells)
}
fn create_interlevel_ghost_cells(
&self,
level: usize,
ghost_cells: &mut Vec<CellId>,
) -> IntegrateResult<()> {
if level > 0 {
if let Some(current_level) = self.mesh_hierarchy.levels.get(level) {
for (cellid, cell) in ¤t_level.cells {
if cell.parent.is_none() {
let ghost_id = CellId {
level: level - 1,
index: 2_000_000 + cellid.index,
};
ghost_cells.push(ghost_id);
}
}
}
}
if level + 1 < self.mesh_hierarchy.levels.len() {
if let Some(current_level) = self.mesh_hierarchy.levels.get(level) {
for (cellid, cell) in ¤t_level.cells {
if !cell.children.is_empty() {
let ghost_id = CellId {
level: level + 1,
index: 3_000_000 + cellid.index,
};
ghost_cells.push(ghost_id);
}
}
}
}
Ok(())
}
fn cell_exists_at_position(&self, position: &Array1<F>, level: usize) -> bool {
if let Some(mesh_level) = self.mesh_hierarchy.levels.get(level) {
let tolerance = mesh_level.grid_spacing
* F::from(0.1).expect("Failed to convert constant to float");
for cell in mesh_level.cells.values() {
if position.len() == cell.center.len() {
let mut distance_squared = F::zero();
for i in 0..position.len() {
let diff = position[i] - cell.center[i];
distance_squared += diff * diff;
}
if distance_squared.sqrt() < tolerance {
return true;
}
}
}
}
false
}
fn count_active_cells(&self) -> usize {
self.mesh_hierarchy
.levels
.iter()
.map(|level| level.cells.values().filter(|cell| cell.is_active).count())
.sum()
}
fn assess_load_balance(&self) -> F {
let total_cells = self.count_active_cells();
if total_cells == 0 {
return F::one(); }
let cell_distribution_balance = self.calculate_cell_distribution_balance();
let computational_load_balance = self.calculate_computational_load_balance();
let communication_overhead_balance = self.calculate_communication_balance();
let memory_distribution_balance = self.calculate_memory_balance();
let weight_cell = F::from(0.3).expect("Failed to convert constant to float");
let weight_compute = F::from(0.4).expect("Failed to convert constant to float");
let weight_comm = F::from(0.2).expect("Failed to convert constant to float");
let weight_memory = F::from(0.1).expect("Failed to convert constant to float");
let overall_balance = weight_cell * cell_distribution_balance
+ weight_compute * computational_load_balance
+ weight_comm * communication_overhead_balance
+ weight_memory * memory_distribution_balance;
overall_balance.min(F::one()).max(F::zero())
}
fn calculate_cell_distribution_balance(&self) -> F {
if self.mesh_hierarchy.levels.is_empty() {
return F::one();
}
let mut cells_per_level: Vec<usize> = Vec::new();
let mut total_cells = 0;
for level in &self.mesh_hierarchy.levels {
let active_cells = level.cells.values().filter(|c| c.is_active).count();
cells_per_level.push(active_cells);
total_cells += active_cells;
}
if total_cells == 0 {
return F::one();
}
let mean_cells = total_cells as f64 / cells_per_level.len() as f64;
let variance: f64 = cells_per_level
.iter()
.map(|&count| {
let diff = count as f64 - mean_cells;
diff * diff
})
.sum::<f64>()
/ cells_per_level.len() as f64;
let std_dev = variance.sqrt();
let coefficient_of_variation = if mean_cells > 0.0 {
std_dev / mean_cells
} else {
0.0
};
let balance = (1.0 - coefficient_of_variation.min(1.0)).max(0.0);
F::from(balance).unwrap_or(F::zero())
}
fn calculate_computational_load_balance(&self) -> F {
let mut level_computational_loads: Vec<F> = Vec::new();
let mut total_load = F::zero();
for level in &self.mesh_hierarchy.levels {
let mut level_load = F::zero();
for cell in level.cells.values() {
if cell.is_active {
let cell_cost = cell.error_estimate * cell.size * cell.size; level_load += cell_cost;
}
}
level_computational_loads.push(level_load);
total_load += level_load;
}
if total_load <= F::zero() {
return F::one();
}
let mean_load =
total_load / F::from(level_computational_loads.len()).expect("Operation failed");
let mut variance = F::zero();
for &load in &level_computational_loads {
let diff = load - mean_load;
variance += diff * diff;
}
variance /= F::from(level_computational_loads.len()).expect("Operation failed");
let std_dev = variance.sqrt();
let coeff_var = if mean_load > F::zero() {
std_dev / mean_load
} else {
F::zero()
};
let balance = F::one() - coeff_var.min(F::one());
balance.max(F::zero())
}
fn calculate_communication_balance(&self) -> F {
let mut level_comm_costs: Vec<F> = Vec::new();
let mut total_comm_cost = F::zero();
for (level_idx, level) in self.mesh_hierarchy.levels.iter().enumerate() {
let active_cells = level.cells.values().filter(|c| c.is_active).count();
let ghost_cells = self
.mesh_hierarchy
.ghost_cells
.get(&level_idx)
.map(|ghosts| ghosts.len())
.unwrap_or(0);
let comm_cost = if active_cells > 0 {
F::from(ghost_cells as f64 / active_cells as f64).unwrap_or(F::zero())
} else {
F::zero()
};
level_comm_costs.push(comm_cost);
total_comm_cost += comm_cost;
}
if level_comm_costs.is_empty() || total_comm_cost <= F::zero() {
return F::one();
}
let mean_comm =
total_comm_cost / F::from(level_comm_costs.len()).expect("Operation failed");
let mut variance = F::zero();
for &cost in &level_comm_costs {
let diff = cost - mean_comm;
variance += diff * diff;
}
variance /= F::from(level_comm_costs.len()).expect("Operation failed");
let std_dev = variance.sqrt();
let coeff_var = if mean_comm > F::zero() {
std_dev / mean_comm
} else {
F::zero()
};
let balance = F::one() - coeff_var.min(F::one());
balance.max(F::zero())
}
fn calculate_memory_balance(&self) -> F {
let mut level_memory_usage: Vec<F> = Vec::new();
let mut total_memory = F::zero();
for level in &self.mesh_hierarchy.levels {
let cell_count = level.cells.len();
let total_neighbors: usize = level.cells.values().map(|c| c.neighbors.len()).sum();
let solution_size: usize = level.cells.values().map(|c| c.solution.len()).sum();
let memory_estimate = F::from(cell_count + total_neighbors + solution_size)
.expect("Failed to convert to float");
level_memory_usage.push(memory_estimate);
total_memory += memory_estimate;
}
if level_memory_usage.is_empty() || total_memory <= F::zero() {
return F::one();
}
let mean_memory =
total_memory / F::from(level_memory_usage.len()).expect("Operation failed");
let mut variance = F::zero();
for &memory in &level_memory_usage {
let diff = memory - mean_memory;
variance += diff * diff;
}
variance /= F::from(level_memory_usage.len()).expect("Operation failed");
let std_dev = variance.sqrt();
let coeff_var = if mean_memory > F::zero() {
std_dev / mean_memory
} else {
F::zero()
};
let balance = F::one() - coeff_var.min(F::one());
balance.max(F::zero())
}
}
impl<F: IntegrateFloat + Send + Sync> RefinementCriterion<F> for GradientRefinementCriterion<F> {
fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F {
if neighbors.is_empty() {
return F::zero();
}
let mut max_gradient = F::zero();
for neighbor in neighbors {
let gradient = if let Some(comp) = self.component {
if comp < cell.solution.len() && comp < neighbor.solution.len() {
(cell.solution[comp] - neighbor.solution[comp]).abs() / cell.size
} else {
F::zero()
}
} else {
let diff = &cell.solution - &neighbor.solution;
diff.mapv(|x| x.powi(2)).sum().sqrt() / cell.size
};
max_gradient = max_gradient.max(gradient);
}
let solution_magnitude = if let Some(comp) = self.component {
cell.solution
.get(comp)
.map(|&x| x.abs())
.unwrap_or(F::zero())
} else {
cell.solution.mapv(|x| x.abs()).sum()
};
if solution_magnitude > F::zero() {
max_gradient / solution_magnitude
} else {
max_gradient
}
}
fn name(&self) -> &'static str {
"Gradient"
}
}
impl<F: IntegrateFloat + Send + Sync> RefinementCriterion<F> for FeatureDetectionCriterion<F> {
fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F {
let mut feature_score = F::zero();
for &feature_type in &self.feature_types {
match feature_type {
FeatureType::SharpGradient
if neighbors.len() >= 2 => {
let gradients: Vec<F> = neighbors
.iter()
.map(|n| (&cell.solution - &n.solution).mapv(|x| x.abs()).sum())
.collect();
let max_grad = gradients.iter().fold(F::zero(), |acc, &x| acc.max(x));
let avg_grad = gradients.iter().fold(F::zero(), |acc, &x| acc + x)
/ F::from(gradients.len()).expect("Operation failed");
if avg_grad > F::zero() {
feature_score += max_grad / avg_grad;
}
}
FeatureType::LocalExtremum => {
let cell_value = cell.solution.mapv(|x| x.abs()).sum();
let mut is_extremum = true;
for neighbor in neighbors {
let neighbor_value = neighbor.solution.mapv(|x| x.abs()).sum();
if (neighbor_value - cell_value).abs() < self.threshold {
is_extremum = false;
break;
}
}
if is_extremum {
feature_score += F::one();
}
}
_ => {
}
}
}
feature_score
}
fn name(&self) -> &'static str {
"FeatureDetection"
}
}
impl<F: IntegrateFloat + Send + Sync> RefinementCriterion<F> for CurvatureRefinementCriterion<F> {
fn evaluate(&self, cell: &AdaptiveCell<F>, neighbors: &[&AdaptiveCell<F>]) -> F {
if neighbors.len() < 2 {
return F::zero();
}
let mut curvature = F::zero();
for component in 0..cell.solution.len() {
let center_value = cell.solution[component];
let neighbor_values: Vec<F> = neighbors
.iter()
.filter_map(|n| n.solution.get(component).copied())
.collect();
if neighbor_values.len() >= 2 {
let avg_neighbor = neighbor_values.iter().fold(F::zero(), |acc, &x| acc + x)
/ F::from(neighbor_values.len()).expect("Operation failed");
let second_diff = (avg_neighbor - center_value).abs() / (cell.size * cell.size);
curvature += second_diff;
}
}
curvature
}
fn name(&self) -> &'static str {
"Curvature"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amr_manager_creation() {
let initial_level = AdaptiveMeshLevel {
level: 0,
cells: HashMap::new(),
grid_spacing: 1.0,
boundary_cells: HashSet::new(),
};
let amr = AdvancedAMRManager::new(initial_level, 5, 0.01);
assert_eq!(amr.max_levels, 5);
assert_eq!(amr.mesh_hierarchy.levels.len(), 1);
}
#[test]
fn test_gradient_criterion() {
let cell = AdaptiveCell {
id: CellId { level: 0, index: 0 },
center: Array1::from_vec(vec![0.5, 0.5]),
size: 0.1,
solution: Array1::from_vec(vec![1.0]),
error_estimate: 0.0,
refinement_flag: RefinementFlag::None,
is_active: true,
neighbors: vec![],
parent: None,
children: vec![],
};
let neighbor = AdaptiveCell {
id: CellId { level: 0, index: 1 },
center: Array1::from_vec(vec![0.6, 0.5]),
size: 0.1,
solution: Array1::from_vec(vec![2.0]),
error_estimate: 0.0,
refinement_flag: RefinementFlag::None,
is_active: true,
neighbors: vec![],
parent: None,
children: vec![],
};
let criterion = GradientRefinementCriterion {
component: Some(0),
threshold: 1.0,
relative_tolerance: 0.1,
};
let neighbors = vec![&neighbor];
let result = criterion.evaluate(&cell, &neighbors);
assert!(result > 0.0);
}
}