Skip to main content

aeon_tk/geometry/tree/
mod.rs

1#![allow(clippy::needless_range_loop)]
2
3use crate::{
4    geometry::{Region, Side, Split, regions},
5    prelude::{Face, HyperBox, IndexSpace},
6};
7use bitvec::{order::Lsb0, slice::BitSlice, vec::BitVec};
8use datasize::DataSize;
9use std::{array, ops::Range, slice};
10
11mod blocks;
12mod interfaces;
13mod neighbors;
14
15pub use blocks::{BlockId, TreeBlocks};
16pub use interfaces::{TreeInterface, TreeInterfaces};
17pub use neighbors::{NeighborId, TreeBlockNeighbor, TreeCellNeighbor, TreeNeighbors};
18
19/// Null index, used internally to make storage of `Option<usize>`` more efficent
20const NULL: usize = usize::MAX;
21
22/// Index into active cells in tree.
23///
24/// This is the primary representation of cells in a `Tree`, as degrees
25/// of freedom are only assigned to active cells. Can be converted to generic `CellIndex` via
26/// `tree.cell_from_active_index(`
27#[derive(
28    Clone,
29    Copy,
30    PartialEq,
31    Eq,
32    PartialOrd,
33    Ord,
34    Hash,
35    Debug,
36    serde::Serialize,
37    serde::Deserialize,
38    DataSize,
39)]
40pub struct ActiveCellId(pub usize);
41
42/// Index into cells in a tree.
43///
44/// A tree stores non-active cells to facilitate O(log n) point -> cell and cell -> neighbor
45/// searches. These cells are generated after refinement/coarsening and are therefore not
46/// the "source of truth" for the dataset.
47#[derive(
48    Clone,
49    Copy,
50    PartialEq,
51    Eq,
52    PartialOrd,
53    Ord,
54    Hash,
55    Debug,
56    serde::Serialize,
57    serde::Deserialize,
58    DataSize,
59)]
60pub struct CellId(pub usize);
61
62impl CellId {
63    /// The root cell in a tree is also stored at index 0.
64    pub const ROOT: CellId = CellId(0);
65
66    pub fn child<const N: usize>(offset: Self, split: Split<N>) -> Self {
67        Self(offset.0 + split.to_linear())
68    }
69}
70
71#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
72struct Cell<const N: usize> {
73    /// Physical bounds of this node
74    bounds: HyperBox<N>,
75    /// Parent Node
76    parent: usize,
77    /// Child nodes
78    children: usize,
79    /// Which active cells are children of this cell?
80    active_offset: usize,
81    /// Number of active cells which are children of this cell.
82    active_count: usize,
83    /// Level of cell
84    level: usize,
85}
86
87impl<const N: usize> DataSize for Cell<N> {
88    const IS_DYNAMIC: bool = false;
89    const STATIC_HEAP_SIZE: usize = 0;
90
91    fn estimate_heap_size(&self) -> usize {
92        0
93    }
94}
95
96/// An `N`-dimensional hypertree, which subdives each axis in two in
97/// each refinement step.
98///
99/// Used as a basis for axes aligned adaptive finite difference
100/// meshes. The tree is
101#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
102#[serde(from = "TreeSer<N>", into = "TreeSer<N>")]
103pub struct Tree<const N: usize> {
104    domain: HyperBox<N>,
105    periodic: [bool; N],
106    // *********************
107    // Active Cells
108    //
109    /// Stores structure of the quadtree using `zindex` ordering.
110    active_values: BitVec<usize, Lsb0>,
111    /// Offsets into `active_indices` (stride of `N`).
112    active_offsets: Vec<usize>,
113    /// Map from active cell index to general cells.
114    active_to_cell: Vec<usize>,
115    // *********************
116    // All cells
117    //
118    /// Map from level to cells.
119    level_offsets: Vec<usize>,
120    /// Bounds of each individual cell.
121    cells: Vec<Cell<N>>,
122}
123
124impl<const N: usize> Tree<N> {
125    /// Constructs a new tree consisting of a single root cell, covering the given
126    /// domain.
127    pub fn new(domain: HyperBox<N>) -> Self {
128        let mut result = Self {
129            domain,
130            periodic: [false; N],
131            active_values: BitVec::new(),
132            active_offsets: vec![0, 0],
133            active_to_cell: Vec::new(),
134            level_offsets: Vec::new(),
135            cells: Vec::new(),
136        };
137        result.build();
138        result
139    }
140
141    pub fn set_periodic(&mut self, axis: usize, periodic: bool) {
142        self.periodic[axis] = periodic;
143    }
144
145    pub fn domain(&self) -> HyperBox<N> {
146        self.domain
147    }
148
149    /// The number of active (leaf) cells in this tree.
150    pub fn num_active_cells(&self) -> usize {
151        self.active_offsets.len() - 1
152    }
153
154    /// The total number of cells in this tree (including )
155    pub fn num_cells(&self) -> usize {
156        self.cells.len()
157    }
158
159    /// The maximum depth of this tree.
160    pub fn num_levels(&self) -> usize {
161        self.level_offsets.len() - 1
162    }
163
164    pub fn level_cells(&self, level: usize) -> impl Iterator<Item = CellId> + ExactSizeIterator {
165        (self.level_offsets[level]..self.level_offsets[level + 1]).map(CellId)
166    }
167
168    pub fn cell_indices(&self) -> impl Iterator<Item = CellId> {
169        (0..self.num_cells()).map(CellId)
170    }
171
172    pub fn active_cell_indices(&self) -> impl Iterator<Item = ActiveCellId> {
173        (0..self.num_active_cells()).map(ActiveCellId)
174    }
175
176    /// Returns the numerical bounds of a given cell.
177    pub fn bounds(&self, cell: CellId) -> HyperBox<N> {
178        self.cells[cell.0].bounds
179    }
180
181    pub fn active_bounds(&self, active: ActiveCellId) -> HyperBox<N> {
182        self.bounds(self.cell_from_active_index(active))
183    }
184
185    /// Returns the level of a given cell.
186    pub fn level(&self, cell: CellId) -> usize {
187        self.cells[cell.0].level
188    }
189
190    pub fn active_level(&self, cell: ActiveCellId) -> usize {
191        self.active_offsets[cell.0 + 1] - self.active_offsets[cell.0]
192    }
193
194    /// Returns the children of a given node. Node must not be leaf.
195    pub fn children(&self, cell: CellId) -> Option<CellId> {
196        if self.cells[cell.0].children == NULL {
197            return None;
198        }
199        Some(CellId(self.cells[cell.0].children))
200    }
201
202    /// Returns a child of a give node.
203    pub fn child(&self, cell: CellId, child: Split<N>) -> Option<CellId> {
204        if self.cells[cell.0].children == NULL {
205            return None;
206        }
207        Some(CellId(self.cells[cell.0].children + child.to_linear()))
208    }
209
210    /// The parent node of a given node.
211    pub fn parent(&self, cell: CellId) -> Option<CellId> {
212        if self.cells[cell.0].parent == NULL {
213            return None;
214        }
215
216        Some(CellId(self.cells[cell.0].parent))
217    }
218
219    /// Returns the zvalue of the given active cell.
220    pub fn active_zvalue(&self, active: ActiveCellId) -> &BitSlice<usize, Lsb0> {
221        &self.active_values
222            [N * self.active_offsets[active.0]..N * self.active_offsets[active.0 + 1]]
223    }
224
225    pub fn active_split(&self, active: ActiveCellId, level: usize) -> Split<N> {
226        Split::pack(array::from_fn(|axis| {
227            self.active_zvalue(active)[N * level + axis]
228        }))
229    }
230
231    pub fn most_recent_active_split(&self, active: ActiveCellId) -> Option<Split<N>> {
232        if self.num_cells() == 1 {
233            return None;
234        }
235
236        Some(self.active_split(active, self.active_level(active) - 1))
237    }
238
239    /// Checks whether the given refinement flags are balanced.
240    pub fn check_refine_flags(&self, flags: &[bool]) -> bool {
241        assert!(flags.len() == self.num_active_cells());
242
243        for cell in self.active_cell_indices() {
244            if !flags[cell.0] {
245                continue;
246            }
247
248            for coarse in self.active_coarse_neighborhood(cell) {
249                if !flags[coarse.0] {
250                    return false;
251                }
252            }
253        }
254
255        true
256    }
257
258    /// Balances the given refinement flags, flagging additional cells
259    /// for refinement to preserve the 2:1 fine coarse ratio between every
260    /// two neighbors.
261    pub fn balance_refine_flags(&self, flags: &mut [bool]) {
262        assert!(flags.len() == self.num_active_cells());
263
264        loop {
265            let mut is_balanced = true;
266
267            for cell in self.active_cell_indices() {
268                if !flags[cell.0] {
269                    continue;
270                }
271
272                for coarse in self.active_coarse_neighborhood(cell) {
273                    if !flags[coarse.0] {
274                        is_balanced = false;
275                        flags[coarse.0] = true;
276                    }
277                }
278            }
279
280            if is_balanced {
281                break;
282            }
283        }
284    }
285
286    /// Fills the map with updated indices after refinement is performed.
287    /// If a cell is refined, this will point to the base cell in that new subdivision.
288    pub fn refine_active_index_map(&self, flags: &[bool], map: &mut [ActiveCellId]) {
289        assert!(flags.len() == self.num_active_cells());
290        assert!(map.len() == self.num_active_cells());
291
292        let mut cursor = 0;
293
294        for cell in 0..self.num_active_cells() {
295            map[cell] = ActiveCellId(cursor);
296
297            if flags[cell] {
298                cursor += Split::<N>::COUNT;
299            } else {
300                cursor += 1;
301            }
302        }
303    }
304
305    pub fn refine(&mut self, flags: &[bool]) {
306        assert!(self.num_active_cells() == flags.len());
307
308        let num_flags = flags.iter().copied().filter(|&p| p).count();
309        let total_active_cells = self.num_active_cells() + (Split::<N>::COUNT - 1) * num_flags;
310
311        let mut active_values = BitVec::with_capacity(total_active_cells * N);
312        let mut active_offsets = Vec::with_capacity(total_active_cells);
313        active_offsets.push(0);
314
315        for active in 0..self.num_active_cells() {
316            if flags[active] {
317                for split in Split::<N>::enumerate() {
318                    active_values.extend_from_bitslice(self.active_zvalue(ActiveCellId(active)));
319                    for axis in 0..N {
320                        active_values.push(split.is_set(axis));
321                    }
322                    active_offsets.push(active_values.len() / N);
323                }
324            } else {
325                active_values.extend_from_bitslice(self.active_zvalue(ActiveCellId(active)));
326                active_offsets.push(active_values.len() / N);
327            }
328        }
329
330        self.active_values.clone_from(&active_values);
331        self.active_offsets.clone_from(&active_offsets);
332
333        self.build();
334    }
335
336    /// Checks that the given coarsening flags are balanced and valid.
337    pub fn check_coarsen_flags(&self, flags: &[bool]) -> bool {
338        assert!(flags.len() == self.num_active_cells());
339
340        if flags.len() == 1 {
341            return true;
342        }
343
344        // Short circuit if this mesh only has two levels.
345        if flags.len() == Split::<N>::COUNT {
346            return flags.iter().all(|&b| !b);
347        }
348
349        // First if any flagging would break 2:1 border, unmark it
350        for cell in self.active_cell_indices() {
351            if !flags[cell.0] {
352                for neighbor in self.active_coarse_neighborhood(cell) {
353                    // Set any coarser cells to not be coarsened further.
354                    if flags[neighbor.0] {
355                        return false;
356                    }
357                }
358            }
359        }
360
361        // Make sure only cells that can be coarsened are coarsened. And that every single child of such a cell
362        // is flagged.
363        let mut cell = 0;
364
365        while cell < self.num_active_cells() {
366            if !flags[cell] {
367                cell += 1;
368                continue;
369            }
370
371            // if flags[cell] {
372            let level = self.active_level(ActiveCellId(cell));
373            let split = self.most_recent_active_split(ActiveCellId(cell)).unwrap();
374
375            if split != Split::<N>::empty() {
376                return false;
377            }
378
379            for offset in 0..Split::<N>::COUNT {
380                if self.active_level(ActiveCellId(cell + offset)) != level {
381                    return false;
382                }
383            }
384
385            if !flags[cell..cell + Split::<N>::COUNT].iter().all(|&b| b) {
386                return false;
387            }
388            // Skip forwards. We have considered all cases.
389            cell += Split::<N>::COUNT;
390        }
391
392        true
393    }
394
395    /// Balances the given coarsening flags
396    pub fn balance_coarsen_flags(&self, flags: &mut [bool]) {
397        assert!(flags.len() == self.num_active_cells());
398
399        if flags.len() == 1 {
400            return;
401        }
402
403        // Short circuit if this mesh only has two levels.
404        if flags.len() == Split::<N>::COUNT {
405            flags.fill(false);
406        }
407
408        loop {
409            let mut is_balanced = true;
410
411            // First if any flagging would break 2:1 border, unmark it
412            for cell in self.active_cell_indices() {
413                if !flags[cell.0] {
414                    for neighbor in self.active_coarse_neighborhood(cell) {
415                        // Set any coarser cells to not be coarsened further.
416                        if flags[neighbor.0] {
417                            is_balanced = false;
418                        }
419                        flags[neighbor.0] = false;
420                    }
421                }
422            }
423
424            // Make sure only cells that can be coarsened are coarsened. And that every single child of such a cell
425            // is flagged.
426            let mut cell = 0;
427
428            while cell < self.num_active_cells() {
429                if !flags[cell] {
430                    cell += 1;
431                    continue;
432                }
433
434                // if flags[cell] {
435                let level = self.active_level(ActiveCellId(cell));
436                let split = self.most_recent_active_split(ActiveCellId(cell)).unwrap();
437
438                if split != Split::<N>::empty() {
439                    flags[cell] = false;
440                    is_balanced = false;
441                    cell += 1;
442                    continue;
443                }
444
445                for offset in 0..Split::<N>::COUNT {
446                    if self.active_level(ActiveCellId(cell + offset)) != level {
447                        flags[cell] = false;
448                        is_balanced = false;
449                        cell += 1;
450                        continue;
451                    }
452                }
453
454                if !flags[cell..cell + Split::<N>::COUNT].iter().all(|&b| b) {
455                    flags[cell..cell + Split::<N>::COUNT].fill(false);
456                    is_balanced = false;
457                }
458                // Skip forwards. We have considered all cases.
459                cell += Split::<N>::COUNT;
460            }
461
462            if is_balanced {
463                break;
464            }
465        }
466    }
467
468    /// Maps current cells to indices after coarsening is performed.
469    pub fn coarsen_active_index_map(&self, flags: &[bool], map: &mut [ActiveCellId]) {
470        assert!(flags.len() == self.num_active_cells());
471        assert!(map.len() == self.num_active_cells());
472
473        let mut cursor = 0;
474        let mut cell = 0;
475
476        while cell < self.num_active_cells() {
477            if flags[cell] {
478                map[cell..cell + Split::<N>::COUNT].fill(ActiveCellId(cursor));
479                cell += Split::<N>::COUNT;
480            } else {
481                map[cell] = ActiveCellId(cursor);
482                cell += 1;
483            }
484
485            cursor += 1;
486        }
487    }
488
489    pub fn coarsen(&mut self, flags: &[bool]) {
490        assert!(flags.len() == self.num_active_cells());
491
492        // Compute number of cells after coarsening
493        let num_flags = flags.iter().copied().filter(|&p| p).count();
494        debug_assert!(num_flags % Split::<N>::COUNT == 0);
495        let total_active = self.num_active_cells() - num_flags / Split::<N>::COUNT;
496
497        let mut active_values = BitVec::with_capacity(total_active * N);
498        let mut active_offsets = Vec::new();
499        active_offsets.push(0);
500
501        // Loop over cells
502        let mut cursor = 0;
503
504        while cursor < self.num_active_cells() {
505            // Retrieve zvalue of cursor
506            let zvalue = self.active_zvalue(ActiveCellId(cursor));
507
508            if flags[cursor] {
509                #[cfg(debug_assertions)]
510                for split in Split::<N>::enumerate() {
511                    assert!(flags[cursor + split.to_linear()])
512                }
513
514                active_values.extend_from_bitslice(&zvalue[0..zvalue.len().saturating_sub(N)]);
515                // Skip next `Count` cells
516                cursor += Split::<N>::COUNT;
517            } else {
518                active_values.extend_from_bitslice(zvalue);
519                cursor += 1;
520            }
521
522            active_offsets.push(active_values.len() / N);
523        }
524
525        self.active_values.clone_from(&active_values);
526        self.active_offsets.clone_from(&active_offsets);
527
528        self.build();
529    }
530
531    pub fn build(&mut self) {
532        // Reset tree
533        self.active_to_cell.resize(self.num_active_cells(), 0);
534        self.level_offsets.clear();
535        self.cells.clear();
536
537        // Add root cell
538        self.cells.push(Cell {
539            bounds: self.domain,
540            parent: NULL,
541            children: NULL,
542            active_offset: 0,
543            active_count: self.num_active_cells(),
544            level: 0,
545        });
546        self.level_offsets.push(0);
547        self.level_offsets.push(1);
548
549        // Recursively subdivide existing nodes using `active_indices`.
550        loop {
551            let level = self.level_offsets.len() - 2;
552            let level_cells = self.level_offsets[level]..self.level_offsets[level + 1];
553
554            // First node on current level
555            let next_level_start = self.cells.len();
556            // Loop over nodes on the current level
557            for parent in level_cells {
558                if self.cells[parent].active_count == 1 {
559                    debug_assert!(
560                        self.active_level(ActiveCellId(self.cells[parent].active_offset)) == level
561                    );
562                    self.active_to_cell[self.cells[parent].active_offset] = parent;
563                    continue;
564                }
565
566                // Update parent's children
567                self.cells[parent].children = self.cells.len();
568                // Iterate over constituent active cells
569                let active_start = self.cells[parent].active_offset;
570                let active_end = active_start + self.cells[parent].active_count;
571
572                let mut cursor = active_start;
573
574                debug_assert!(self.active_level(ActiveCellId(cursor)) > level);
575
576                let bounds = self.cells[parent].bounds;
577
578                for mask in Split::<N>::enumerate() {
579                    let child_cell_start = cursor;
580
581                    while cursor < active_end
582                        && mask == self.active_split(ActiveCellId(cursor), level)
583                    {
584                        cursor += 1;
585                    }
586
587                    let child_cell_end = cursor;
588
589                    self.cells.push(Cell {
590                        bounds: bounds.subdivide(mask),
591                        parent,
592                        children: NULL,
593                        active_offset: child_cell_start,
594                        active_count: child_cell_end - child_cell_start,
595                        level: level + 1,
596                    });
597                }
598            }
599
600            let next_level_end = self.cells.len();
601
602            if next_level_start >= next_level_end {
603                break;
604            }
605
606            self.level_offsets.push(next_level_end);
607        }
608
609        #[cfg(debug_assertions)]
610        for cell in self.cell_indices() {
611            let active = ActiveCellId(self.cells[cell.0].active_offset);
612            assert!(self.active_level(active) >= self.level(cell));
613        }
614    }
615
616    /// Computes the cell index corresponding to an active cell.
617    pub fn cell_from_active_index(&self, active: ActiveCellId) -> CellId {
618        debug_assert!(
619            active.0 < self.num_active_cells(),
620            "Active cell index is expected to be less that the number of active cells."
621        );
622        CellId(self.active_to_cell[active.0])
623    }
624
625    /// Computes active cell index from a cell, returning None if `cell` is
626    /// not active.
627    pub fn active_index_from_cell(&self, cell: CellId) -> Option<ActiveCellId> {
628        debug_assert!(
629            cell.0 < self.num_cells(),
630            "Cell index is expected to be less that the number of cells."
631        );
632
633        if self.cells[cell.0].active_count != 1 {
634            return None;
635        }
636
637        Some(ActiveCellId(self.cells[cell.0].active_offset))
638    }
639
640    /// Returns an iterator over active cells that are children of the given cell.
641    /// If `is_active(cell) = true` then this iterator will be a singleton
642    /// returning the same value as `tree.active_index_from_cell(cell)`.
643    pub fn active_children(
644        &self,
645        cell: CellId,
646    ) -> impl Iterator<Item = ActiveCellId> + ExactSizeIterator {
647        let (offset, count) = (
648            self.cells[cell.0].active_offset,
649            self.cells[cell.0].active_count,
650        );
651
652        (offset..offset + count).map(ActiveCellId)
653    }
654
655    /// True if a node has no children.
656    pub fn is_active(&self, node: CellId) -> bool {
657        let result = self.cells[node.0].children == NULL;
658        debug_assert!(!result || self.cells[node.0].active_count == 1);
659        result
660    }
661
662    /// Returns the cell which owns the given point.
663    /// Performs in O(log N).
664    pub fn cell_from_point(&self, point: [f64; N]) -> CellId {
665        debug_assert!(self.domain.contains(point));
666
667        let mut node = CellId(0);
668
669        while let Some(children) = self.children(node) {
670            let bounds = self.bounds(node);
671            let center = bounds.center();
672            node = CellId::child(
673                children,
674                Split::<N>::pack(array::from_fn(|axis| point[axis] > center[axis])),
675            );
676        }
677
678        node
679    }
680
681    /// Returns the node which owns the given point, shortening this search
682    /// with an initial guess. Rather than operating in O(log N) time, this approaches
683    /// O(1) if the guess is sufficiently close.
684    pub fn cell_from_point_cached(&self, point: [f64; N], mut cache: CellId) -> CellId {
685        debug_assert!(self.domain.contains(point));
686
687        while !self.bounds(cache).contains(point) {
688            cache = self.parent(cache).unwrap();
689        }
690
691        let mut node = cache;
692
693        while let Some(children) = self.children(node) {
694            let bounds = self.bounds(node);
695            let center = bounds.center();
696            node = CellId::child(
697                children,
698                Split::<N>::pack(array::from_fn(|axis| point[axis] > center[axis])),
699            )
700        }
701
702        node
703    }
704
705    /// Returns the neighboring cell along the given face. If the neighboring cell is more refined, this
706    /// returns the cell index of the adjacent cell with `tree.level(neighbor) == tree.level(cell)`.
707    /// If this passes over a nonperiodic boundary then it returns `None`.
708    pub fn neighbor(&self, cell: CellId, face: Face<N>) -> Option<CellId> {
709        let mut region = Region::CENTRAL;
710        region.set_side(face.axis, if face.side { Side::Right } else { Side::Left });
711        self.neighbor_region(cell, region)
712    }
713
714    /// Returns the neighboring cell in the given region. If the neighboring cell is more refined, this
715    /// returns the cell index of the adjacent cell with `tree.level(neighbor) == tree.level(cell)`.
716    /// If this passes over a nonperiodic boundary then it returns `None`.
717    pub fn neighbor_region(&self, cell: CellId, region: Region<N>) -> Option<CellId> {
718        let active_indices = ActiveCellId(self.cells[cell.0].active_offset);
719        debug_assert!(self.active_level(active_indices) >= self.level(cell));
720
721        let is_periodic = (0..N)
722            .map(|axis| region.side(axis) == Side::Middle || self.periodic[axis])
723            .all(|b| b);
724
725        if cell == CellId::ROOT && is_periodic {
726            return Some(CellId::ROOT);
727        }
728
729        let parent = self.parent(cell)?;
730        debug_assert!(self.level(cell) > 0 && self.level(cell) == self.level(parent) + 1);
731        let split = self.active_split(active_indices, self.level(parent));
732        if split.is_inner_region(region) {
733            let children = self.children(parent).unwrap();
734            return Some(CellId::child(children, split.as_outer_region(region)));
735        }
736
737        let mut parent_region = region;
738
739        for axis in 0..N {
740            // If on inside, set parent region to middle
741            match (region.side(axis), split.is_set(axis)) {
742                (Side::Left, true) | (Side::Right, false) => {
743                    parent_region.set_side(axis, Side::Middle);
744                }
745                _ => {}
746            }
747        }
748
749        let parent_neighbor = self.neighbor_region(parent, parent_region)?;
750
751        let Some(parent_neighbor_children) = self.children(parent_neighbor) else {
752            return Some(parent_neighbor);
753        };
754
755        let mut neighbor_split = split;
756
757        for axis in 0..N {
758            match (region.side(axis), split.is_set(axis)) {
759                (Side::Left, false) | (Side::Right, true) => {
760                    neighbor_split = neighbor_split.toggled(axis);
761                }
762                (Side::Left, true) | (Side::Right, false) => {
763                    neighbor_split = neighbor_split.toggled(axis);
764                }
765                _ => {}
766            }
767        }
768
769        Some(CellId::child(parent_neighbor_children, neighbor_split))
770    }
771
772    /// Returns the neighboring cell in the given region. If the neighboring cell is more refined, this
773    /// returns the cell index of the adjacent cell with `tree.level(neighbor) == tree.level(cell)`.
774    /// If this passes over a nonperiodic boundary then it returns `None`.
775    pub fn _neighbor_region2(&self, cell: CellId, region: Region<N>) -> Option<CellId> {
776        let is_periodic = (0..N)
777            .map(|axis| region.side(axis) == Side::Middle || self.periodic[axis])
778            .all(|b| b);
779
780        if cell == CellId::ROOT && is_periodic {
781            return Some(CellId::ROOT);
782        }
783
784        // Retrieve first active cell owned by `cell`.
785        let active_index = ActiveCellId(self.cells[cell.0].active_offset);
786        // Start at this cell
787        let mut cursor = cell;
788        // While this cell has a parent, recurse downwards and check whether the region is compatible.
789        // If so, break.
790        while let Some(parent) = self.parent(cursor) {
791            cursor = parent;
792            if self
793                .active_split(active_index, self.level(cursor))
794                .is_inner_region(region)
795            {
796                break;
797            }
798        }
799
800        if self.children(cursor).is_some() {
801            let split = self.active_split(active_index, self.level(cursor));
802
803            if split.is_inner_region(region) {
804                cursor = CellId::child(
805                    self.children(cursor).unwrap(),
806                    split.as_outer_region(region),
807                )
808            }
809        }
810
811        // If we are at root, we can proceed to do silliness (i.e. recurse back upwards)
812        if cursor == CellId::ROOT {
813            if !is_periodic {
814                return None;
815            }
816
817            debug_assert!(self.level(cell) > 0);
818
819            let split = self.active_split(active_index, self.level(cursor));
820            cursor = CellId::child(
821                self.children(cursor).unwrap(),
822                split.as_inner_region(region),
823            );
824        }
825
826        // Recurse back upwards
827        while self.level(cursor) < self.level(cell) {
828            let Some(children) = self.children(cursor) else {
829                break;
830            };
831
832            let split = self
833                .active_split(active_index, self.level(cursor))
834                .as_inner_region(region);
835            cursor = CellId::child(children, split);
836        }
837
838        // Algorithm complete
839        Some(cursor)
840    }
841
842    /// Iterates over
843    pub fn active_neighbors_in_region(
844        &self,
845        cell: CellId,
846        region: Region<N>,
847    ) -> impl Iterator<Item = ActiveCellId> + '_ {
848        let level = self.level(cell);
849
850        self.neighbor_region(cell, region)
851            .into_iter()
852            .flat_map(move |neighbor| {
853                self.active_children(neighbor).filter(move |&active| {
854                    for l in level..self.active_level(active) {
855                        if !region
856                            .reverse()
857                            .is_split_adjacent(self.active_split(active, l))
858                        {
859                            return false;
860                        }
861                    }
862
863                    true
864                })
865            })
866    }
867
868    pub fn active_neighborhood(
869        &self,
870        cell: ActiveCellId,
871    ) -> impl Iterator<Item = ActiveCellId> + '_ {
872        regions().flat_map(move |region| {
873            self.active_neighbors_in_region(self.cell_from_active_index(cell), region)
874        })
875    }
876
877    pub fn active_coarse_neighborhood(
878        &self,
879        cell: ActiveCellId,
880    ) -> impl Iterator<Item = ActiveCellId> + '_ {
881        regions().flat_map(move |region| {
882            let neighbor = self.neighbor_region(self.cell_from_active_index(cell), region)?;
883            if self.level(neighbor) < self.active_level(cell) {
884                return self.active_index_from_cell(neighbor);
885            }
886            None
887        })
888    }
889
890    /// Returns true if a face lies on a boundary.
891    pub fn is_boundary_face(&self, cell: CellId, face: Face<N>) -> bool {
892        let mut region = Region::CENTRAL;
893        region.set_side(face.axis, if face.side { Side::Right } else { Side::Left });
894        self.boundary_region(cell, region) != Region::CENTRAL
895    }
896
897    /// Given a neighboring region to a cell, determines which global region that
898    /// belongs to (usually)
899    pub fn boundary_region(&self, cell: CellId, region: Region<N>) -> Region<N> {
900        // Get the active cell owned by this cell.
901        let Some(active) = self.active_index_from_cell(cell) else {
902            return region;
903        };
904
905        let mut result = region;
906        let mut level = self.level(cell);
907
908        while level > 0 && result != Region::CENTRAL {
909            let split = self.active_split(active, level - 1);
910
911            // Mask region by
912            for axis in 0..N {
913                match (result.side(axis), split.is_set(axis)) {
914                    (Side::Left, true) => result.set_side(axis, Side::Middle),
915                    (Side::Right, false) => result.set_side(axis, Side::Middle),
916                    _ => {}
917                }
918            }
919
920            level -= 1;
921        }
922
923        result
924    }
925}
926
927impl<const N: usize> DataSize for Tree<N> {
928    const IS_DYNAMIC: bool = true;
929    const STATIC_HEAP_SIZE: usize = 0;
930
931    fn estimate_heap_size(&self) -> usize {
932        self.active_offsets.estimate_heap_size()
933            + self.active_values.capacity() / size_of::<usize>()
934            + self.active_to_cell.estimate_heap_size()
935            + self.level_offsets.estimate_heap_size()
936            + self.cells.estimate_heap_size()
937    }
938}
939
940/// Helper struct for serializing a tree while avoiding saving redundent data.
941#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
942pub struct TreeSer<const N: usize> {
943    domain: HyperBox<N>,
944    #[serde(with = "crate::array")]
945    periodic: [bool; N],
946    active_values: BitVec<usize, Lsb0>,
947    active_offsets: Vec<usize>,
948}
949
950impl<const N: usize> From<TreeSer<N>> for Tree<N> {
951    fn from(value: TreeSer<N>) -> Self {
952        let mut result = Tree {
953            domain: value.domain,
954            periodic: value.periodic,
955            active_values: value.active_values,
956            active_offsets: value.active_offsets,
957            active_to_cell: Vec::default(),
958            level_offsets: Vec::default(),
959            cells: Vec::default(),
960        };
961        result.build();
962        result
963    }
964}
965
966impl<const N: usize> From<Tree<N>> for TreeSer<N> {
967    fn from(value: Tree<N>) -> Self {
968        Self {
969            domain: value.domain,
970            periodic: value.periodic,
971            active_values: value.active_values,
972            active_offsets: value.active_offsets,
973        }
974    }
975}
976
977impl<const N: usize> Default for TreeSer<N> {
978    fn default() -> Self {
979        Self {
980            domain: HyperBox::UNIT,
981            periodic: [false; N],
982            active_values: BitVec::default(),
983            active_offsets: Vec::default(),
984        }
985    }
986}
987
988#[cfg(test)]
989mod tests {
990    use super::*;
991
992    #[test]
993    fn neighbors() {
994        let mut tree = Tree::<2>::new(HyperBox::UNIT);
995
996        assert_eq!(tree.bounds(CellId::ROOT), HyperBox::UNIT);
997        assert_eq!(tree.num_cells(), 1);
998        assert_eq!(tree.num_active_cells(), 1);
999        assert_eq!(tree.num_levels(), 1);
1000
1001        assert_eq!(tree.neighbor(CellId::ROOT, Face::negative(0)), None);
1002
1003        tree.refine(&[true]);
1004        tree.build();
1005
1006        assert_eq!(tree.num_cells(), 5);
1007        assert_eq!(tree.num_active_cells(), 4);
1008        assert_eq!(tree.num_levels(), 2);
1009        for split in Split::enumerate() {
1010            assert_eq!(tree.active_split(ActiveCellId(split.to_linear()), 0), split);
1011        }
1012        for i in 0..4 {
1013            assert_eq!(tree.cell_from_active_index(ActiveCellId(i)), CellId(i + 1));
1014        }
1015
1016        tree.refine(&[true, false, false, false]);
1017        tree.build();
1018
1019        assert_eq!(tree.cell_from_active_index(ActiveCellId(0)), CellId(5));
1020
1021        assert!(tree.is_boundary_face(CellId(5), Face::negative(0)));
1022        assert!(tree.is_boundary_face(CellId(5), Face::negative(1)));
1023        assert_eq!(
1024            tree.boundary_region(CellId(5), Region::new([Side::Left, Side::Right])),
1025            Region::new([Side::Left, Side::Middle])
1026        );
1027
1028        assert_eq!(
1029            tree.neighbor_region(CellId(5), Region::new([Side::Right, Side::Right])),
1030            Some(CellId(8))
1031        );
1032
1033        assert_eq!(
1034            tree.neighbor_region(CellId(4), Region::new([Side::Left, Side::Left])),
1035            Some(CellId(1))
1036        );
1037    }
1038
1039    #[test]
1040    fn periodic_neighbors() {
1041        let mut tree = Tree::<2>::new(HyperBox::UNIT);
1042        tree.set_periodic(0, true);
1043        tree.set_periodic(1, true);
1044        assert_eq!(
1045            tree.neighbor(CellId::ROOT, Face::negative(0)),
1046            Some(CellId::ROOT)
1047        );
1048
1049        // Refine tree
1050        tree.refine(&[true]);
1051        tree.refine(&[true, false, false, false]);
1052        tree.build();
1053
1054        assert_eq!(tree.neighbor(CellId(5), Face::negative(0)), Some(CellId(2)));
1055        assert_eq!(
1056            tree.neighbor_region(CellId(5), Region::new([Side::Left, Side::Left])),
1057            Some(CellId(4))
1058        );
1059    }
1060
1061    #[test]
1062    fn active_neighbors_in_region() {
1063        let mut tree = Tree::<2>::new(HyperBox::UNIT);
1064        // Refine tree
1065        tree.refine(&[true]);
1066        tree.refine(&[true, false, false, false]);
1067        tree.build();
1068
1069        assert!(
1070            tree.active_neighbors_in_region(CellId(2), Region::new([Side::Left, Side::Middle]))
1071                .eq([ActiveCellId(1), ActiveCellId(3)].into_iter())
1072        );
1073
1074        assert!(
1075            tree.active_neighbors_in_region(CellId(3), Region::new([Side::Middle, Side::Left]))
1076                .eq([ActiveCellId(2), ActiveCellId(3)].into_iter())
1077        );
1078
1079        assert!(
1080            tree.active_neighbors_in_region(CellId(4), Region::new([Side::Left, Side::Left]))
1081                .eq([ActiveCellId(3)].into_iter())
1082        );
1083
1084        assert!(
1085            tree.active_neighbors_in_region(CellId(6), Region::new([Side::Right, Side::Right]))
1086                .eq([ActiveCellId(4)].into_iter())
1087        );
1088    }
1089
1090    #[test]
1091    fn refinement_and_coarsening() {
1092        let mut tree = Tree::<2>::new(HyperBox::UNIT);
1093        tree.refine(&[true]);
1094        // Make initially asymmetric.
1095        tree.refine(&[true, false, false, false]);
1096
1097        for _ in 0..1 {
1098            let mut flags: Vec<bool> = vec![true; tree.num_active_cells()];
1099            tree.balance_refine_flags(&mut flags);
1100            tree.refine(&flags);
1101        }
1102
1103        for _ in 0..2 {
1104            let mut flags = vec![true; tree.num_active_cells()];
1105            tree.balance_coarsen_flags(&mut flags);
1106            let mut coarsen_map = vec![ActiveCellId(0); tree.num_active_cells()];
1107            tree.coarsen_active_index_map(&flags, &mut coarsen_map);
1108            tree.coarsen(&flags);
1109        }
1110
1111        let mut other_tree = Tree::<2>::new(HyperBox::UNIT);
1112        other_tree.refine(&[true]);
1113
1114        assert_eq!(tree, other_tree);
1115    }
1116
1117    use rand::Rng;
1118
1119    #[test]
1120    fn fuzz_serialize() -> eyre::Result<()> {
1121        let mut tree = Tree::<2>::new(HyperBox::UNIT);
1122
1123        // Randomly refine tree
1124        let mut rng = rand::rng();
1125        for _ in 0..4 {
1126            let mut flags = vec![false; tree.num_active_cells()];
1127            rng.fill(flags.as_mut_slice());
1128
1129            tree.balance_coarsen_flags(&mut flags);
1130            tree.refine(&mut flags);
1131        }
1132
1133        // Serialize tree
1134        let data = ron::to_string(&tree)?;
1135        let tree2: Tree<2> = ron::from_str(data.as_str())?;
1136
1137        assert_eq!(tree, tree2);
1138
1139        Ok(())
1140    }
1141
1142    #[test]
1143    fn cell_from_point() -> eyre::Result<()> {
1144        let mut tree = Tree::<2>::new(HyperBox::UNIT);
1145        tree.refine(&[true]);
1146        tree.refine(&[true, false, false, false]);
1147
1148        assert_eq!(tree.cell_from_point([0.0, 0.0]), CellId(5));
1149        assert_eq!(
1150            tree.active_index_from_cell(CellId(5)),
1151            Some(ActiveCellId(0))
1152        );
1153
1154        assert_eq!(tree.cell_from_point([0.51, 0.67]), CellId(4));
1155        assert_eq!(
1156            tree.active_index_from_cell(CellId(4)),
1157            Some(ActiveCellId(6))
1158        );
1159
1160        let mut rng = rand::rng();
1161        for _ in 0..50 {
1162            let x: f64 = rng.random_range(0.0..1.0);
1163            let y: f64 = rng.random_range(0.0..1.0);
1164
1165            let cache: usize = rng.random_range(..tree.num_cells());
1166
1167            assert_eq!(
1168                tree.cell_from_point_cached([x, y], CellId(cache)),
1169                tree.cell_from_point([x, y])
1170            );
1171        }
1172
1173        Ok(())
1174    }
1175}