Skip to main content

oxirs_physics/
mesh_refinement.rs

1//! Adaptive mesh refinement for FEM (Finite Element Method)
2//!
3//! This module implements 2D triangular mesh refinement using the
4//! 4-triangle subdivision (midpoint refinement) approach. Elements that
5//! exceed quality thresholds or have large error estimates are subdivided
6//! into 4 congruent sub-triangles by inserting midpoints on each edge.
7//!
8//! # Example
9//!
10//! ```rust
11//! use oxirs_physics::mesh_refinement::{Mesh2D, RefinementCriteria};
12//! use std::collections::HashMap;
13//!
14//! let mut mesh = Mesh2D::new();
15//! let n0 = mesh.add_node(0.0, 0.0);
16//! let n1 = mesh.add_node(1.0, 0.0);
17//! let n2 = mesh.add_node(0.5, 1.0);
18//! let _e0 = mesh.add_element(n0, n1, n2);
19//!
20//! let criteria = RefinementCriteria {
21//!     max_edge_length: 0.6,
22//!     max_aspect_ratio: 10.0,
23//!     min_angle_deg: 5.0,
24//!     error_threshold: 0.1,
25//! };
26//! let error_map: HashMap<u32, f64> = HashMap::new();
27//! let stats = mesh.refine_mesh(&criteria, &error_map);
28//! assert!(stats.elements_refined > 0);
29//! ```
30
31use std::collections::HashMap;
32
33/// Maximum refinement level to prevent infinite subdivision
34pub const MAX_REFINEMENT_LEVEL: u8 = 8;
35
36/// A 2D mesh node with coordinates
37#[derive(Debug, Clone, PartialEq)]
38pub struct MeshNode2D {
39    pub id: u32,
40    pub x: f64,
41    pub y: f64,
42}
43
44impl MeshNode2D {
45    /// Create a new mesh node
46    pub fn new(id: u32, x: f64, y: f64) -> Self {
47        Self { id, x, y }
48    }
49}
50
51/// A triangular mesh element referencing three node IDs
52#[derive(Debug, Clone, PartialEq)]
53pub struct MeshElement {
54    pub id: u32,
55    /// Node IDs forming the triangle (counter-clockwise ordering preferred)
56    pub nodes: [u32; 3],
57    /// Current subdivision level (0 = original)
58    pub refinement_level: u8,
59}
60
61impl MeshElement {
62    /// Create a new triangular element
63    pub fn new(id: u32, n0: u32, n1: u32, n2: u32) -> Self {
64        Self {
65            id,
66            nodes: [n0, n1, n2],
67            refinement_level: 0,
68        }
69    }
70
71    /// Create a refined child element with incremented level
72    pub fn child(id: u32, n0: u32, n1: u32, n2: u32, parent_level: u8) -> Self {
73        Self {
74            id,
75            nodes: [n0, n1, n2],
76            refinement_level: parent_level.saturating_add(1),
77        }
78    }
79}
80
81/// Criteria controlling when elements should be refined
82#[derive(Debug, Clone)]
83pub struct RefinementCriteria {
84    /// Maximum allowed edge length before refinement is triggered
85    pub max_edge_length: f64,
86    /// Maximum allowed aspect ratio (longest/shortest edge)
87    pub max_aspect_ratio: f64,
88    /// Minimum allowed interior angle in degrees
89    pub min_angle_deg: f64,
90    /// Error threshold — elements with error > this value are refined
91    pub error_threshold: f64,
92}
93
94impl RefinementCriteria {
95    /// Default criteria suitable for general FEM meshes
96    pub fn default_criteria() -> Self {
97        Self {
98            max_edge_length: 1.0,
99            max_aspect_ratio: 5.0,
100            min_angle_deg: 10.0,
101            error_threshold: 0.05,
102        }
103    }
104}
105
106/// Statistics from a single refinement pass
107#[derive(Debug, Clone, Default)]
108pub struct RefinementStats {
109    /// Number of elements that were subdivided
110    pub elements_refined: u32,
111    /// Number of new nodes inserted during refinement
112    pub nodes_added: u32,
113    /// Maximum refinement level reached in the mesh after this pass
114    pub max_level: u8,
115}
116
117/// 2D triangular mesh supporting adaptive refinement
118#[derive(Debug, Clone)]
119pub struct Mesh2D {
120    /// All nodes indexed by their ID
121    pub nodes: HashMap<u32, MeshNode2D>,
122    /// All elements indexed by their ID
123    pub elements: HashMap<u32, MeshElement>,
124    /// Next available node ID
125    pub next_node_id: u32,
126    /// Next available element ID
127    pub next_elem_id: u32,
128}
129
130impl Mesh2D {
131    /// Create an empty mesh
132    pub fn new() -> Self {
133        Self {
134            nodes: HashMap::new(),
135            elements: HashMap::new(),
136            next_node_id: 0,
137            next_elem_id: 0,
138        }
139    }
140
141    /// Add a node at (x, y) and return its assigned ID
142    pub fn add_node(&mut self, x: f64, y: f64) -> u32 {
143        let id = self.next_node_id;
144        self.nodes.insert(id, MeshNode2D::new(id, x, y));
145        self.next_node_id += 1;
146        id
147    }
148
149    /// Add a triangular element referencing three existing node IDs
150    /// Returns the new element ID
151    pub fn add_element(&mut self, n0: u32, n1: u32, n2: u32) -> u32 {
152        let id = self.next_elem_id;
153        self.elements.insert(id, MeshElement::new(id, n0, n1, n2));
154        self.next_elem_id += 1;
155        id
156    }
157
158    /// Compute the signed area of a triangle element
159    /// Returns 0.0 if any referenced node is missing
160    pub fn element_area(&self, elem: &MeshElement) -> f64 {
161        let (n0, n1, n2) = match (
162            self.nodes.get(&elem.nodes[0]),
163            self.nodes.get(&elem.nodes[1]),
164            self.nodes.get(&elem.nodes[2]),
165        ) {
166            (Some(a), Some(b), Some(c)) => (a, b, c),
167            _ => return 0.0,
168        };
169        let area = 0.5 * ((n1.x - n0.x) * (n2.y - n0.y) - (n2.x - n0.x) * (n1.y - n0.y));
170        area.abs()
171    }
172
173    /// Compute the three edge lengths [|e01|, |e12|, |e20|]
174    /// Returns `[0,0,0]` if any node is missing
175    pub fn element_edge_lengths(&self, elem: &MeshElement) -> [f64; 3] {
176        let (n0, n1, n2) = match (
177            self.nodes.get(&elem.nodes[0]),
178            self.nodes.get(&elem.nodes[1]),
179            self.nodes.get(&elem.nodes[2]),
180        ) {
181            (Some(a), Some(b), Some(c)) => (a, b, c),
182            _ => return [0.0, 0.0, 0.0],
183        };
184        let e01 = ((n1.x - n0.x).powi(2) + (n1.y - n0.y).powi(2)).sqrt();
185        let e12 = ((n2.x - n1.x).powi(2) + (n2.y - n1.y).powi(2)).sqrt();
186        let e20 = ((n0.x - n2.x).powi(2) + (n0.y - n2.y).powi(2)).sqrt();
187        [e01, e12, e20]
188    }
189
190    /// Compute the aspect ratio as longest_edge / shortest_edge
191    /// Returns f64::INFINITY for degenerate elements
192    pub fn element_aspect_ratio(&self, elem: &MeshElement) -> f64 {
193        let edges = self.element_edge_lengths(elem);
194        let min_e = edges.iter().cloned().fold(f64::INFINITY, f64::min);
195        let max_e = edges.iter().cloned().fold(0.0_f64, f64::max);
196        if min_e <= 0.0 {
197            f64::INFINITY
198        } else {
199            max_e / min_e
200        }
201    }
202
203    /// Compute the minimum interior angle in degrees
204    /// Uses the law of cosines; returns 0.0 for degenerate elements
205    pub fn element_min_angle_deg(&self, elem: &MeshElement) -> f64 {
206        let (n0, n1, n2) = match (
207            self.nodes.get(&elem.nodes[0]),
208            self.nodes.get(&elem.nodes[1]),
209            self.nodes.get(&elem.nodes[2]),
210        ) {
211            (Some(a), Some(b), Some(c)) => (a, b, c),
212            _ => return 0.0,
213        };
214
215        let edges = self.element_edge_lengths(elem);
216        let a = edges[0]; // opposite n2
217        let b = edges[1]; // opposite n0
218        let c = edges[2]; // opposite n1
219
220        // Avoid division by zero for degenerate triangles
221        if a <= 0.0 || b <= 0.0 || c <= 0.0 {
222            return 0.0;
223        }
224
225        // Angle at each vertex using law of cosines
226        let cos_n0 = ((b * b + c * c - a * a) / (2.0 * b * c)).clamp(-1.0, 1.0);
227        let cos_n1 = ((a * a + c * c - b * b) / (2.0 * a * c)).clamp(-1.0, 1.0);
228        let cos_n2 = ((a * a + b * b - c * c) / (2.0 * a * b)).clamp(-1.0, 1.0);
229
230        // Suppress unused variable warning — n0/n1/n2 only used for existence check above
231        let _ = (n0, n1, n2);
232
233        let angle_n0 = cos_n0.acos().to_degrees();
234        let angle_n1 = cos_n1.acos().to_degrees();
235        let angle_n2 = cos_n2.acos().to_degrees();
236
237        angle_n0.min(angle_n1).min(angle_n2)
238    }
239
240    /// Compute the midpoint coordinates between two nodes
241    /// Returns (0,0) if either node is missing
242    pub fn midpoint(&self, n1: u32, n2: u32) -> (f64, f64) {
243        match (self.nodes.get(&n1), self.nodes.get(&n2)) {
244            (Some(a), Some(b)) => ((a.x + b.x) / 2.0, (a.y + b.y) / 2.0),
245            _ => (0.0, 0.0),
246        }
247    }
248
249    /// Refine a single element by bisecting all three edges.
250    ///
251    /// Inserts midpoint nodes on each edge and creates 4 congruent sub-triangles:
252    /// ```text
253    ///        n2
254    ///       /  \
255    ///      m20--m12
256    ///     / \  / \
257    ///   n0--m01--n1
258    /// ```
259    /// Returns the IDs of the 4 new child elements.
260    /// Returns an empty array if the element is not found or already at MAX_REFINEMENT_LEVEL.
261    pub fn refine_element(&mut self, elem_id: u32) -> [u32; 4] {
262        // Clone what we need to avoid simultaneous borrow
263        let elem = match self.elements.get(&elem_id).cloned() {
264            Some(e) => e,
265            None => return [0; 4],
266        };
267
268        if elem.refinement_level >= MAX_REFINEMENT_LEVEL {
269            return [0; 4];
270        }
271
272        let [n0, n1, n2] = elem.nodes;
273
274        // Insert midpoint nodes
275        let (mx01, my01) = self.midpoint(n0, n1);
276        let (mx12, my12) = self.midpoint(n1, n2);
277        let (mx20, my20) = self.midpoint(n2, n0);
278
279        let m01 = self.add_node(mx01, my01);
280        let m12 = self.add_node(mx12, my12);
281        let m20 = self.add_node(mx20, my20);
282
283        let lvl = elem.refinement_level;
284
285        // Create 4 child elements
286        let c0_id = self.next_elem_id;
287        self.elements
288            .insert(c0_id, MeshElement::child(c0_id, n0, m01, m20, lvl));
289        self.next_elem_id += 1;
290
291        let c1_id = self.next_elem_id;
292        self.elements
293            .insert(c1_id, MeshElement::child(c1_id, m01, n1, m12, lvl));
294        self.next_elem_id += 1;
295
296        let c2_id = self.next_elem_id;
297        self.elements
298            .insert(c2_id, MeshElement::child(c2_id, m20, m12, n2, lvl));
299        self.next_elem_id += 1;
300
301        let c3_id = self.next_elem_id;
302        self.elements
303            .insert(c3_id, MeshElement::child(c3_id, m01, m12, m20, lvl));
304        self.next_elem_id += 1;
305
306        // Remove the parent element (replaced by 4 children)
307        self.elements.remove(&elem_id);
308
309        [c0_id, c1_id, c2_id, c3_id]
310    }
311
312    /// Determine whether an element should be refined given criteria and an error map
313    pub fn should_refine(
314        &self,
315        elem: &MeshElement,
316        criteria: &RefinementCriteria,
317        error_map: &HashMap<u32, f64>,
318    ) -> bool {
319        if elem.refinement_level >= MAX_REFINEMENT_LEVEL {
320            return false;
321        }
322
323        // Check user-supplied error estimate
324        if let Some(&err) = error_map.get(&elem.id) {
325            if err > criteria.error_threshold {
326                return true;
327            }
328        }
329
330        // Check geometric quality metrics
331        let edges = self.element_edge_lengths(elem);
332        let max_edge = edges.iter().cloned().fold(0.0_f64, f64::max);
333        if max_edge > criteria.max_edge_length {
334            return true;
335        }
336
337        let aspect = self.element_aspect_ratio(elem);
338        if aspect > criteria.max_aspect_ratio {
339            return true;
340        }
341
342        let min_angle = self.element_min_angle_deg(elem);
343        if min_angle < criteria.min_angle_deg && min_angle > 0.0 {
344            return true;
345        }
346
347        false
348    }
349
350    /// Perform one global refinement pass: refine all elements meeting the criteria.
351    ///
352    /// Returns statistics about this refinement pass.
353    pub fn refine_mesh(
354        &mut self,
355        criteria: &RefinementCriteria,
356        error_map: &HashMap<u32, f64>,
357    ) -> RefinementStats {
358        // Collect IDs of elements to refine (snapshot to avoid mutation during iteration)
359        let to_refine: Vec<u32> = self
360            .elements
361            .values()
362            .filter(|e| self.should_refine(e, criteria, error_map))
363            .map(|e| e.id)
364            .collect();
365
366        let elements_refined = to_refine.len() as u32;
367        let nodes_before = self.nodes.len() as u32;
368
369        for id in to_refine {
370            self.refine_element(id);
371        }
372
373        let nodes_added = self.nodes.len() as u32 - nodes_before;
374        let max_level = self
375            .elements
376            .values()
377            .map(|e| e.refinement_level)
378            .max()
379            .unwrap_or(0);
380
381        RefinementStats {
382            elements_refined,
383            nodes_added,
384            max_level,
385        }
386    }
387
388    /// Return the current number of nodes
389    pub fn node_count(&self) -> usize {
390        self.nodes.len()
391    }
392
393    /// Return the current number of elements
394    pub fn element_count(&self) -> usize {
395        self.elements.len()
396    }
397
398    /// Compute the total mesh area (sum of all element areas)
399    pub fn total_area(&self) -> f64 {
400        self.elements.values().map(|e| self.element_area(e)).sum()
401    }
402}
403
404impl Default for Mesh2D {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    /// Build a simple right triangle with vertices at (0,0), (1,0), (0,1)
415    fn unit_triangle() -> (Mesh2D, u32, u32, u32, u32) {
416        let mut mesh = Mesh2D::new();
417        let n0 = mesh.add_node(0.0, 0.0);
418        let n1 = mesh.add_node(1.0, 0.0);
419        let n2 = mesh.add_node(0.0, 1.0);
420        let e0 = mesh.add_element(n0, n1, n2);
421        (mesh, n0, n1, n2, e0)
422    }
423
424    /// Build an equilateral triangle with side length 1
425    fn equilateral() -> (Mesh2D, u32) {
426        let mut mesh = Mesh2D::new();
427        let n0 = mesh.add_node(0.0, 0.0);
428        let n1 = mesh.add_node(1.0, 0.0);
429        let n2 = mesh.add_node(0.5, 3_f64.sqrt() / 2.0);
430        let e0 = mesh.add_element(n0, n1, n2);
431        (mesh, e0)
432    }
433
434    // --- Node management ---
435
436    #[test]
437    fn test_add_node_returns_sequential_ids() {
438        let mut mesh = Mesh2D::new();
439        assert_eq!(mesh.add_node(0.0, 0.0), 0);
440        assert_eq!(mesh.add_node(1.0, 0.0), 1);
441        assert_eq!(mesh.add_node(0.5, 1.0), 2);
442    }
443
444    #[test]
445    fn test_node_count_matches_insertions() {
446        let mut mesh = Mesh2D::new();
447        for i in 0..10 {
448            mesh.add_node(i as f64, 0.0);
449        }
450        assert_eq!(mesh.node_count(), 10);
451    }
452
453    #[test]
454    fn test_node_coordinates_stored_correctly() {
455        let mut mesh = Mesh2D::new();
456        let id = mesh.add_node(1.23, 2.72);
457        let node = mesh.nodes.get(&id).expect("node should exist");
458        assert!((node.x - 1.23).abs() < 1e-10);
459        assert!((node.y - 2.72).abs() < 1e-10);
460    }
461
462    // --- Element management ---
463
464    #[test]
465    fn test_add_element_returns_sequential_ids() {
466        let (mut mesh, _n0, n1, n2, _e0) = unit_triangle();
467        let n3 = mesh.add_node(1.0, 1.0);
468        let e1 = mesh.add_element(n1, n3, n2);
469        assert_eq!(e1, 1);
470    }
471
472    #[test]
473    fn test_element_count_matches_insertions() {
474        let mut mesh = Mesh2D::new();
475        let n0 = mesh.add_node(0.0, 0.0);
476        let n1 = mesh.add_node(1.0, 0.0);
477        let n2 = mesh.add_node(0.5, 1.0);
478        let n3 = mesh.add_node(1.5, 1.0);
479        mesh.add_element(n0, n1, n2);
480        mesh.add_element(n1, n3, n2);
481        assert_eq!(mesh.element_count(), 2);
482    }
483
484    #[test]
485    fn test_element_nodes_stored_correctly() {
486        let (mesh, n0, n1, n2, e0) = unit_triangle();
487        let elem = mesh.elements.get(&e0).expect("element should exist");
488        assert_eq!(elem.nodes, [n0, n1, n2]);
489    }
490
491    #[test]
492    fn test_element_initial_level_is_zero() {
493        let (mesh, _, _, _, e0) = unit_triangle();
494        let elem = mesh.elements.get(&e0).expect("element should exist");
495        assert_eq!(elem.refinement_level, 0);
496    }
497
498    // --- Area calculations ---
499
500    #[test]
501    fn test_unit_triangle_area() {
502        let (mesh, _, _, _, e0) = unit_triangle();
503        let elem = mesh.elements.get(&e0).expect("element should exist");
504        let area = mesh.element_area(elem);
505        assert!((area - 0.5).abs() < 1e-10, "Expected 0.5, got {}", area);
506    }
507
508    #[test]
509    fn test_equilateral_area() {
510        let (mesh, e0) = equilateral();
511        let elem = mesh.elements.get(&e0).expect("element should exist");
512        let expected = 3_f64.sqrt() / 4.0;
513        let area = mesh.element_area(elem);
514        assert!(
515            (area - expected).abs() < 1e-10,
516            "Expected {}, got {}",
517            expected,
518            area
519        );
520    }
521
522    #[test]
523    fn test_total_area_single_element() {
524        let (mesh, _, _, _, _) = unit_triangle();
525        assert!((mesh.total_area() - 0.5).abs() < 1e-10);
526    }
527
528    #[test]
529    fn test_total_area_two_elements() {
530        let mut mesh = Mesh2D::new();
531        let n0 = mesh.add_node(0.0, 0.0);
532        let n1 = mesh.add_node(1.0, 0.0);
533        let n2 = mesh.add_node(1.0, 1.0);
534        let n3 = mesh.add_node(0.0, 1.0);
535        mesh.add_element(n0, n1, n2);
536        mesh.add_element(n0, n2, n3);
537        assert!((mesh.total_area() - 1.0).abs() < 1e-10);
538    }
539
540    // --- Edge lengths ---
541
542    #[test]
543    fn test_unit_triangle_edge_lengths() {
544        let (mesh, _, _, _, e0) = unit_triangle();
545        let elem = mesh.elements.get(&e0).expect("element should exist");
546        let edges = mesh.element_edge_lengths(elem);
547        // Edges: (0,0)-(1,0)=1, (1,0)-(0,1)=sqrt(2), (0,1)-(0,0)=1
548        assert!((edges[0] - 1.0).abs() < 1e-10);
549        assert!((edges[1] - 2_f64.sqrt()).abs() < 1e-10);
550        assert!((edges[2] - 1.0).abs() < 1e-10);
551    }
552
553    #[test]
554    fn test_equilateral_edge_lengths_equal() {
555        let (mesh, e0) = equilateral();
556        let elem = mesh.elements.get(&e0).expect("element should exist");
557        let edges = mesh.element_edge_lengths(elem);
558        assert!((edges[0] - edges[1]).abs() < 1e-10);
559        assert!((edges[1] - edges[2]).abs() < 1e-10);
560    }
561
562    // --- Aspect ratio ---
563
564    #[test]
565    fn test_equilateral_aspect_ratio_is_one() {
566        let (mesh, e0) = equilateral();
567        let elem = mesh.elements.get(&e0).expect("element should exist");
568        let ar = mesh.element_aspect_ratio(elem);
569        assert!((ar - 1.0).abs() < 1e-10, "Expected 1.0, got {}", ar);
570    }
571
572    #[test]
573    fn test_right_triangle_aspect_ratio_gt_one() {
574        let (mesh, _, _, _, e0) = unit_triangle();
575        let elem = mesh.elements.get(&e0).expect("element should exist");
576        let ar = mesh.element_aspect_ratio(elem);
577        assert!(ar > 1.0, "Expected AR > 1, got {}", ar);
578    }
579
580    // --- Angle calculations ---
581
582    #[test]
583    fn test_equilateral_min_angle_60_degrees() {
584        let (mesh, e0) = equilateral();
585        let elem = mesh.elements.get(&e0).expect("element should exist");
586        let min_angle = mesh.element_min_angle_deg(elem);
587        assert!(
588            (min_angle - 60.0).abs() < 1e-6,
589            "Expected 60°, got {}",
590            min_angle
591        );
592    }
593
594    #[test]
595    fn test_right_triangle_min_angle_45_degrees() {
596        let (mesh, _, _, _, e0) = unit_triangle();
597        let elem = mesh.elements.get(&e0).expect("element should exist");
598        let min_angle = mesh.element_min_angle_deg(elem);
599        assert!(
600            (min_angle - 45.0).abs() < 1e-6,
601            "Expected 45°, got {}",
602            min_angle
603        );
604    }
605
606    // --- Midpoint ---
607
608    #[test]
609    fn test_midpoint_basic() {
610        let mut mesh = Mesh2D::new();
611        let n0 = mesh.add_node(0.0, 0.0);
612        let n1 = mesh.add_node(2.0, 4.0);
613        let (mx, my) = mesh.midpoint(n0, n1);
614        assert!((mx - 1.0).abs() < 1e-10);
615        assert!((my - 2.0).abs() < 1e-10);
616    }
617
618    #[test]
619    fn test_midpoint_missing_node_returns_origin() {
620        let mesh = Mesh2D::new();
621        let (mx, my) = mesh.midpoint(99, 100);
622        assert_eq!(mx, 0.0);
623        assert_eq!(my, 0.0);
624    }
625
626    // --- Single element refinement ---
627
628    #[test]
629    fn test_refine_element_produces_four_children() {
630        let (mut mesh, _, _, _, e0) = unit_triangle();
631        let children = mesh.refine_element(e0);
632        // All 4 child IDs should be non-zero or valid IDs
633        assert_ne!(children, [0; 4]);
634        // Parent element should be removed
635        assert!(!mesh.elements.contains_key(&e0));
636        // Mesh should now have 4 elements
637        assert_eq!(mesh.element_count(), 4);
638    }
639
640    #[test]
641    fn test_refine_element_adds_three_midpoint_nodes() {
642        let (mut mesh, _, _, _, e0) = unit_triangle();
643        let initial_nodes = mesh.node_count();
644        mesh.refine_element(e0);
645        assert_eq!(mesh.node_count(), initial_nodes + 3);
646    }
647
648    #[test]
649    fn test_refine_element_child_levels_incremented() {
650        let (mut mesh, _, _, _, e0) = unit_triangle();
651        let children = mesh.refine_element(e0);
652        for &child_id in &children {
653            if child_id != 0 {
654                let child = mesh.elements.get(&child_id).expect("child should exist");
655                assert_eq!(child.refinement_level, 1);
656            }
657        }
658    }
659
660    #[test]
661    fn test_refine_element_preserves_total_area() {
662        let (mut mesh, _, _, _, e0) = unit_triangle();
663        let area_before = mesh.total_area();
664        mesh.refine_element(e0);
665        let area_after = mesh.total_area();
666        assert!(
667            (area_before - area_after).abs() < 1e-10,
668            "Area should be preserved: {} vs {}",
669            area_before,
670            area_after
671        );
672    }
673
674    #[test]
675    fn test_refine_missing_element_returns_empty() {
676        let mut mesh = Mesh2D::new();
677        let result = mesh.refine_element(999);
678        assert_eq!(result, [0; 4]);
679    }
680
681    #[test]
682    fn test_refine_at_max_level_returns_empty() {
683        let mut mesh = Mesh2D::new();
684        let n0 = mesh.add_node(0.0, 0.0);
685        let n1 = mesh.add_node(1.0, 0.0);
686        let n2 = mesh.add_node(0.5, 1.0);
687        let e_id = mesh.next_elem_id;
688        mesh.elements.insert(
689            e_id,
690            MeshElement {
691                id: e_id,
692                nodes: [n0, n1, n2],
693                refinement_level: MAX_REFINEMENT_LEVEL,
694            },
695        );
696        mesh.next_elem_id += 1;
697        let result = mesh.refine_element(e_id);
698        assert_eq!(result, [0; 4]);
699    }
700
701    // --- should_refine ---
702
703    #[test]
704    fn test_should_refine_by_edge_length() {
705        let (mesh, _, _, _, e0) = unit_triangle();
706        let elem = mesh.elements.get(&e0).expect("element should exist");
707        let criteria = RefinementCriteria {
708            max_edge_length: 0.5, // all edges > 0.5
709            max_aspect_ratio: 100.0,
710            min_angle_deg: 0.0,
711            error_threshold: f64::INFINITY,
712        };
713        let empty_map = HashMap::new();
714        assert!(mesh.should_refine(elem, &criteria, &empty_map));
715    }
716
717    #[test]
718    fn test_should_not_refine_below_thresholds() {
719        let (mesh, _, _, _, e0) = unit_triangle();
720        let elem = mesh.elements.get(&e0).expect("element should exist");
721        let criteria = RefinementCriteria {
722            max_edge_length: 10.0,
723            max_aspect_ratio: 100.0,
724            min_angle_deg: 0.0,
725            error_threshold: f64::INFINITY,
726        };
727        let empty_map = HashMap::new();
728        assert!(!mesh.should_refine(elem, &criteria, &empty_map));
729    }
730
731    #[test]
732    fn test_should_refine_by_error_map() {
733        let (mesh, _, _, _, e0) = unit_triangle();
734        let elem = mesh.elements.get(&e0).expect("element should exist");
735        let criteria = RefinementCriteria {
736            max_edge_length: 100.0,
737            max_aspect_ratio: 100.0,
738            min_angle_deg: 0.0,
739            error_threshold: 0.01,
740        };
741        let mut error_map = HashMap::new();
742        error_map.insert(e0, 0.5); // exceeds threshold
743        assert!(mesh.should_refine(elem, &criteria, &error_map));
744    }
745
746    #[test]
747    fn test_should_refine_by_aspect_ratio() {
748        // n0=(0,0), n1=(100,0), n2=(0.1, 0.0001)
749        // e01=100, e12≈99.9, e20≈0.1  → AR≈100/0.1=1000 >> 2.0
750        let mut mesh = Mesh2D::new();
751        let n0 = mesh.add_node(0.0, 0.0);
752        let n1 = mesh.add_node(100.0, 0.0);
753        let n2 = mesh.add_node(0.1, 0.0001);
754        let e0 = mesh.add_element(n0, n1, n2);
755        let elem = mesh.elements.get(&e0).expect("element should exist");
756        let criteria = RefinementCriteria {
757            max_edge_length: 200.0,
758            max_aspect_ratio: 2.0,
759            min_angle_deg: 0.0,
760            error_threshold: f64::INFINITY,
761        };
762        let empty_map = HashMap::new();
763        assert!(
764            mesh.should_refine(elem, &criteria, &empty_map),
765            "AR of very thin triangle should trigger refinement"
766        );
767    }
768
769    #[test]
770    fn test_should_not_refine_at_max_level() {
771        let mut mesh = Mesh2D::new();
772        let n0 = mesh.add_node(0.0, 0.0);
773        let n1 = mesh.add_node(0.001, 0.0);
774        let n2 = mesh.add_node(0.0005, 0.001);
775        let e_id = mesh.next_elem_id;
776        mesh.elements.insert(
777            e_id,
778            MeshElement {
779                id: e_id,
780                nodes: [n0, n1, n2],
781                refinement_level: MAX_REFINEMENT_LEVEL,
782            },
783        );
784        mesh.next_elem_id += 1;
785        let elem = mesh.elements.get(&e_id).expect("element should exist");
786        let criteria = RefinementCriteria {
787            max_edge_length: 0.00001,
788            max_aspect_ratio: 1.0,
789            min_angle_deg: 89.0,
790            error_threshold: 0.0,
791        };
792        let mut error_map = HashMap::new();
793        error_map.insert(e_id, 999.0);
794        assert!(!mesh.should_refine(elem, &criteria, &empty_map_helper()));
795    }
796
797    // helper to avoid borrowing issue
798    fn empty_map_helper() -> HashMap<u32, f64> {
799        HashMap::new()
800    }
801
802    // --- Adaptive global refinement ---
803
804    #[test]
805    fn test_refine_mesh_refines_eligible_elements() {
806        let (mut mesh, _, _, _, _) = unit_triangle();
807        let criteria = RefinementCriteria {
808            max_edge_length: 0.5,
809            max_aspect_ratio: 100.0,
810            min_angle_deg: 0.0,
811            error_threshold: f64::INFINITY,
812        };
813        let empty_map = HashMap::new();
814        let stats = mesh.refine_mesh(&criteria, &empty_map);
815        assert_eq!(stats.elements_refined, 1);
816        assert_eq!(stats.nodes_added, 3);
817        assert_eq!(stats.max_level, 1);
818        assert_eq!(mesh.element_count(), 4);
819    }
820
821    #[test]
822    fn test_refine_mesh_stats_max_level() {
823        let (mut mesh, _, _, _, _) = unit_triangle();
824        let criteria = RefinementCriteria {
825            max_edge_length: 0.01,
826            max_aspect_ratio: 100.0,
827            min_angle_deg: 0.0,
828            error_threshold: f64::INFINITY,
829        };
830        let empty_map = HashMap::new();
831        // Two passes
832        let s1 = mesh.refine_mesh(&criteria, &empty_map);
833        let s2 = mesh.refine_mesh(&criteria, &empty_map);
834        assert!(s1.max_level <= s2.max_level);
835    }
836
837    #[test]
838    fn test_refine_mesh_preserves_total_area() {
839        let (mut mesh, _, _, _, _) = unit_triangle();
840        let area_before = mesh.total_area();
841        let criteria = RefinementCriteria {
842            max_edge_length: 0.5,
843            max_aspect_ratio: 100.0,
844            min_angle_deg: 0.0,
845            error_threshold: f64::INFINITY,
846        };
847        let empty_map = HashMap::new();
848        mesh.refine_mesh(&criteria, &empty_map);
849        let area_after = mesh.total_area();
850        assert!(
851            (area_before - area_after).abs() < 1e-9,
852            "Area changed: {} -> {}",
853            area_before,
854            area_after
855        );
856    }
857
858    #[test]
859    fn test_refine_mesh_error_driven() {
860        let mut mesh = Mesh2D::new();
861        // 4-element unit square mesh
862        let n0 = mesh.add_node(0.0, 0.0);
863        let n1 = mesh.add_node(1.0, 0.0);
864        let n2 = mesh.add_node(1.0, 1.0);
865        let n3 = mesh.add_node(0.0, 1.0);
866        let e0 = mesh.add_element(n0, n1, n2);
867        let e1 = mesh.add_element(n0, n2, n3);
868
869        let criteria = RefinementCriteria {
870            max_edge_length: 100.0,
871            max_aspect_ratio: 100.0,
872            min_angle_deg: 0.0,
873            error_threshold: 0.1,
874        };
875        let mut error_map = HashMap::new();
876        error_map.insert(e0, 0.5); // only e0 has high error
877        error_map.insert(e1, 0.01);
878
879        let stats = mesh.refine_mesh(&criteria, &error_map);
880        assert_eq!(stats.elements_refined, 1);
881        assert_eq!(mesh.element_count(), 5); // 1 original + 4 children
882    }
883
884    #[test]
885    fn test_refine_mesh_no_eligible_elements() {
886        let (mut mesh, _, _, _, _) = unit_triangle();
887        let criteria = RefinementCriteria {
888            max_edge_length: 100.0,
889            max_aspect_ratio: 100.0,
890            min_angle_deg: 0.0,
891            error_threshold: f64::INFINITY,
892        };
893        let empty_map = HashMap::new();
894        let stats = mesh.refine_mesh(&criteria, &empty_map);
895        assert_eq!(stats.elements_refined, 0);
896        assert_eq!(stats.nodes_added, 0);
897        assert_eq!(mesh.element_count(), 1);
898    }
899
900    #[test]
901    fn test_double_refinement_grows_correctly() {
902        let (mut mesh, _, _, _, _) = unit_triangle();
903        let criteria = RefinementCriteria {
904            max_edge_length: 0.5,
905            max_aspect_ratio: 100.0,
906            min_angle_deg: 0.0,
907            error_threshold: f64::INFINITY,
908        };
909        let empty_map = HashMap::new();
910        mesh.refine_mesh(&criteria, &empty_map);
911        let area_after_1 = mesh.total_area();
912        mesh.refine_mesh(&criteria, &empty_map);
913        let area_after_2 = mesh.total_area();
914        assert!((area_after_1 - area_after_2).abs() < 1e-9);
915        assert!(mesh.element_count() > 4);
916    }
917
918    #[test]
919    fn test_refinement_criteria_default() {
920        let c = RefinementCriteria::default_criteria();
921        assert!(c.max_edge_length > 0.0);
922        assert!(c.max_aspect_ratio > 1.0);
923        assert!(c.min_angle_deg > 0.0);
924        assert!(c.error_threshold > 0.0);
925    }
926
927    #[test]
928    fn test_mesh_default_is_empty() {
929        let mesh = Mesh2D::default();
930        assert_eq!(mesh.node_count(), 0);
931        assert_eq!(mesh.element_count(), 0);
932        assert_eq!(mesh.total_area(), 0.0);
933    }
934
935    #[test]
936    fn test_multiple_sequential_refinements_bounded_by_max_level() {
937        let (mut mesh, _, _, _, _) = unit_triangle();
938        let criteria = RefinementCriteria {
939            max_edge_length: 0.0001, // extremely small — will always trigger until max level
940            max_aspect_ratio: 1.0001,
941            min_angle_deg: 89.0,
942            error_threshold: 0.0,
943        };
944        let empty_map = HashMap::new();
945        for _ in 0..20 {
946            mesh.refine_mesh(&criteria, &empty_map);
947        }
948        let max_level = mesh
949            .elements
950            .values()
951            .map(|e| e.refinement_level)
952            .max()
953            .unwrap_or(0);
954        assert!(max_level <= MAX_REFINEMENT_LEVEL);
955    }
956}