Skip to main content

aeon_tk/kernel/
element.rs

1use crate::{
2    element::{ApproxOperator, Monomials, Uniform, Values},
3    geometry::{IndexSpace, Split},
4};
5use faer::{ColRef, Mat};
6use std::array;
7
8/// A reference element defined on [-1, 1]^N. This element implements code
9/// for wavelet transformations and general operator approximation, whereas
10/// `NodeSpace` implements a much stricter subset of operations with fixed
11/// precomputed weights.
12#[derive(Clone, Debug)]
13pub struct Element<const N: usize> {
14    /// Order of basis functions used in element.
15    order: usize,
16    /// Width of element in nodes
17    width: usize,
18    /// Grid of points within element.
19    grid: Vec<f64>,
20    /// Positions of points within element
21    /// after refinement
22    grid_refined: Vec<f64>,
23    /// Interpolation stencils.
24    stencils: Mat<f64>,
25}
26
27impl<const N: usize> Element<N> {
28    /// Constructs a reference element with uniformly placed
29    /// support points with `width + 1` points along each axis.
30    pub fn uniform(width: usize, order: usize) -> Self {
31        assert!(width >= order);
32        let (grid, grid_refined) = Self::uniform_grid(width);
33        debug_assert!(grid.len() == width + 1);
34        debug_assert!(grid_refined.len() == 2 * width + 1);
35
36        let spacing = 2.0 / width as f64;
37        let points = Vec::from_iter(
38            (0..width)
39                .into_iter()
40                .map(|j| [-1.0 + spacing * (j as f64 + 0.5)]),
41        );
42
43        let mut approx = ApproxOperator::default();
44        approx
45            .build(
46                &Uniform::new([width + 1]),
47                &Monomials::new([order + 1]),
48                &Values(points.as_slice()),
49            )
50            .unwrap();
51
52        let stencils = approx.shape().to_owned();
53
54        debug_assert!(stencils.nrows() == width + 1);
55        debug_assert!(stencils.ncols() == width);
56
57        Self {
58            width,
59            order,
60            grid,
61            grid_refined,
62            stencils,
63        }
64    }
65
66    fn uniform_grid(width: usize) -> (Vec<f64>, Vec<f64>) {
67        let spacing = 2.0 / width as f64;
68
69        let grid = (0..=width)
70            .map(|i| i as f64 * spacing - 1.0)
71            .collect::<Vec<_>>();
72
73        let grid_refined = (0..=2 * width)
74            .map(|i| i as f64 * spacing / 2.0 - 1.0)
75            .collect::<Vec<_>>();
76
77        (grid, grid_refined)
78    }
79
80    /// Number of support points in the reference element.
81    pub fn support(&self) -> usize {
82        (self.width + 1).pow(N as u32)
83    }
84
85    /// Number of supports points after the element has been refined.
86    pub fn support_refined(&self) -> usize {
87        (2 * self.width + 1).pow(N as u32)
88    }
89
90    /// Retrieves the order of the element.
91    pub fn order(&self) -> usize {
92        self.order
93    }
94
95    /// Retrieves the width of the element.
96    pub fn width(&self) -> usize {
97        self.width
98    }
99
100    /// Retrieves the grid along one axis.
101    pub fn grid(&self) -> &[f64] {
102        &self.grid
103    }
104
105    pub fn prolong_stencil(&self, target: usize) -> ColRef<'_, f64> {
106        self.stencils.col(target)
107    }
108
109    /// Iterates the position of a support point in the element.
110    pub fn position(&self, index: [usize; N]) -> [f64; N] {
111        array::from_fn(|axis| self.grid[index[axis]])
112    }
113
114    pub fn position_refined(&self, index: [usize; N]) -> [f64; N] {
115        array::from_fn(|axis| self.grid_refined[index[axis]])
116    }
117
118    // *********************
119    // Point Iteration *****
120    // *********************
121
122    pub fn space(&self) -> IndexSpace<N> {
123        IndexSpace::new([self.width + 1; N])
124    }
125
126    pub fn space_refined(&self) -> IndexSpace<N> {
127        IndexSpace::new([2 * self.width + 1; N])
128    }
129
130    /// Iterates over all nodal coefficients in a wavelet representation on this element.
131    pub fn nodal_indices(&self) -> impl Iterator<Item = [usize; N]> {
132        IndexSpace::new([self.width + 1; N])
133            .iter()
134            .map(|v| array::from_fn(|axis| v[axis] * 2))
135    }
136
137    /// Iterates over all diagonal detail coefficients in a wavelet representation on this element.
138    pub fn diagonal_indices(&self) -> impl Iterator<Item = [usize; N]> {
139        IndexSpace::new([self.width; N])
140            .iter()
141            .map(|v| array::from_fn(|axis| v[axis] * 2 + 1))
142    }
143
144    /// Iterates over diagonal detail coefficients in a wavelet representation of this element,
145    /// ignoring elements within a `buffer` around the edge of the refined support.
146    pub fn diagonal_int_indices(&self, buffer: usize) -> impl Iterator<Item = [usize; N]> {
147        debug_assert!(buffer % 2 == 0);
148
149        IndexSpace::new([self.width - buffer; N])
150            .iter()
151            .map(move |v| array::from_fn(|axis| 2 * v[axis] + 1 + buffer))
152    }
153
154    /// Iterates over all detail coefficients in a wavelet representation on this element.
155    pub fn detail_indices(&self) -> impl Iterator<Item = [usize; N]> {
156        let cells = IndexSpace::new([self.width; N]).iter();
157
158        cells.flat_map(|index| {
159            Split::<N>::enumerate().skip(1).map(move |mask| {
160                let mut point = index;
161
162                for axis in 0..N {
163                    point[axis] *= 2;
164
165                    if mask.is_set(axis) {
166                        point[axis] += 1;
167                    }
168                }
169
170                point
171            })
172        })
173    }
174
175    pub fn nodal_points(&self) -> impl Iterator<Item = usize> {
176        let space = IndexSpace::new([2 * self.width + 1; N]);
177        self.nodal_indices()
178            .map(move |index| space.linear_from_cartesian(index))
179    }
180
181    pub fn diagonal_points(&self) -> impl Iterator<Item = usize> {
182        let space = IndexSpace::new([2 * self.width + 1; N]);
183        self.diagonal_indices()
184            .map(move |index| space.linear_from_cartesian(index))
185    }
186
187    pub fn diagonal_int_points(&self, buffer: usize) -> impl Iterator<Item = usize> {
188        let space = IndexSpace::new([2 * self.width + 1; N]);
189        self.diagonal_int_indices(buffer)
190            .map(move |index| space.linear_from_cartesian(index))
191    }
192
193    pub fn detail_points(&self) -> impl Iterator<Item = usize> {
194        let space = IndexSpace::new([2 * self.width + 1; N]);
195        self.detail_indices()
196            .map(move |index| space.linear_from_cartesian(index))
197    }
198
199    // ****************************
200    // Prolongation ***************
201    // ****************************
202
203    /// Prolongs data from the element to a refined version of the element.
204    pub fn prolong(&self, source: &[f64], dest: &mut [f64]) {
205        self.inject(source, dest);
206        self.prolong_in_place(dest);
207    }
208
209    /// Fills in-between points on dest using interpolation, assuming that nodal
210    /// points on dest have been properly filled.
211    pub fn prolong_in_place(&self, dest: &mut [f64]) {
212        debug_assert!(dest.len() == self.support_refined());
213
214        let space = IndexSpace::new([2 * self.width + 1; N]);
215
216        // And now perform interpolation
217        for axis in (0..N).rev() {
218            let mut psize = [0; N];
219
220            for i in 0..axis {
221                psize[i] = self.width + 1;
222            }
223            psize[axis] = self.width;
224            for i in (axis + 1)..N {
225                psize[i] = 2 * self.width + 1;
226            }
227
228            for mut point in IndexSpace::new(psize).iter() {
229                for i in 0..axis {
230                    point[i] *= 2;
231                }
232
233                let stencil = self.stencils.col(point[axis]);
234
235                point[axis] *= 2;
236                point[axis] += 1;
237
238                let center = space.linear_from_cartesian(point);
239                dest[center] = 0.0;
240
241                for i in 0..=self.width {
242                    point[axis] = 2 * i;
243                    dest[center] += stencil[i] * dest[space.linear_from_cartesian(point)];
244                }
245            }
246        }
247    }
248
249    // *******************************
250    // Injection *********************
251    // *******************************
252
253    /// Performs injection from source to a refined dest.
254    pub fn inject(&self, source: &[f64], dest: &mut [f64]) {
255        debug_assert!(source.len() == self.support());
256        debug_assert!(dest.len() == self.support_refined());
257
258        // Perform injection
259        let source_space = IndexSpace::new([self.width + 1; N]);
260        let dest_space = IndexSpace::new([2 * self.width + 1; N]);
261
262        for (pindex, point) in source_space.iter().enumerate() {
263            let refined: [_; N] = array::from_fn(|axis| 2 * point[axis]);
264            let rindex = dest_space.linear_from_cartesian(refined);
265            dest[rindex] = source[pindex];
266        }
267    }
268
269    /// Restricts refined data from source onto dest.
270    pub fn restrict(&self, source: &[f64], dest: &mut [f64]) {
271        debug_assert!(source.len() == self.support_refined());
272        debug_assert!(dest.len() == self.support());
273
274        let source_space = IndexSpace::new([2 * self.width + 1; N]);
275        let dest_space = IndexSpace::new([self.width + 1; N]);
276
277        for (pindex, point) in source_space.iter().enumerate() {
278            let refined: [_; N] = array::from_fn(|axis| point[axis] / 2);
279            let rindex = dest_space.linear_from_cartesian(refined);
280            dest[rindex] = source[pindex];
281        }
282    }
283
284    // *******************************
285    // Wavelet Expansion *************
286    // *******************************
287
288    /// Computes the wavelet coefficients for the given function.
289    pub fn wavelet(&self, source: &[f64], dest: &mut [f64]) {
290        debug_assert!(source.len() == self.support_refined());
291        debug_assert!(dest.len() == self.support_refined());
292
293        // Copies data from the source to noal points on dest.
294        for point in self.nodal_points() {
295            dest[point] = source[point];
296        }
297
298        self.prolong_in_place(dest);
299
300        // Iterates over the detail coefficients.
301        for point in self.detail_points() {
302            dest[point] -= source[point];
303        }
304    }
305
306    /// Computes the relative error between the wavelet's representation
307    /// and a nodal approximation.
308    pub fn wavelet_rel_error(&self, coefs: &[f64]) -> f64 {
309        let scale = self
310            .nodal_points()
311            .map(|v| coefs[v].abs())
312            .max_by(|a, b| a.total_cmp(b))
313            .unwrap();
314
315        self.wavelet_abs_error(coefs) / scale
316    }
317
318    /// Computes the absolute error between the wavelet's representation
319    /// and a nodal approximation.
320    pub fn wavelet_abs_error(&self, coefs: &[f64]) -> f64 {
321        self.diagonal_points()
322            .map(|v| coefs[v].abs())
323            .max_by(|a, b| a.total_cmp(b))
324            .unwrap()
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::Element;
331
332    #[test]
333    fn iteration() {
334        let element = Element::<2>::uniform(2, 1);
335
336        let mut nodal = element.nodal_indices();
337        assert_eq!(nodal.next(), Some([0, 0]));
338        assert_eq!(nodal.next(), Some([2, 0]));
339        assert_eq!(nodal.next(), Some([4, 0]));
340        assert_eq!(nodal.next(), Some([0, 2]));
341        assert_eq!(nodal.next(), Some([2, 2]));
342        assert_eq!(nodal.next(), Some([4, 2]));
343        assert_eq!(nodal.next(), Some([0, 4]));
344        assert_eq!(nodal.next(), Some([2, 4]));
345        assert_eq!(nodal.next(), Some([4, 4]));
346        assert_eq!(nodal.next(), None);
347
348        let mut diagonal = element.diagonal_indices();
349        assert_eq!(diagonal.next(), Some([1, 1]));
350        assert_eq!(diagonal.next(), Some([3, 1]));
351        assert_eq!(diagonal.next(), Some([1, 3]));
352        assert_eq!(diagonal.next(), Some([3, 3]));
353        assert_eq!(diagonal.next(), None);
354
355        let mut diagonal_int = element.diagonal_int_indices(0);
356        assert_eq!(diagonal_int.next(), Some([1, 1]));
357        assert_eq!(diagonal_int.next(), Some([3, 1]));
358        assert_eq!(diagonal_int.next(), Some([1, 3]));
359        assert_eq!(diagonal_int.next(), Some([3, 3]));
360        assert_eq!(diagonal_int.next(), None);
361
362        let mut detail = element.detail_indices();
363        assert_eq!(detail.next(), Some([1, 0]));
364        assert_eq!(detail.next(), Some([0, 1]));
365        assert_eq!(detail.next(), Some([1, 1]));
366
367        assert_eq!(detail.next(), Some([3, 0]));
368        assert_eq!(detail.next(), Some([2, 1]));
369        assert_eq!(detail.next(), Some([3, 1]));
370
371        assert_eq!(detail.next(), Some([1, 2]));
372        assert_eq!(detail.next(), Some([0, 3]));
373        assert_eq!(detail.next(), Some([1, 3]));
374
375        assert_eq!(detail.next(), Some([3, 2]));
376        assert_eq!(detail.next(), Some([2, 3]));
377        assert_eq!(detail.next(), Some([3, 3]));
378        assert_eq!(detail.next(), None);
379    }
380
381    #[test]
382    fn interior_indices() {
383        let element = Element::<1>::uniform(6, 4);
384
385        let mut indices = element.diagonal_int_points(2);
386        assert_eq!(indices.next(), Some(3));
387        assert_eq!(indices.next(), Some(5));
388        assert_eq!(indices.next(), Some(7));
389        assert_eq!(indices.next(), Some(9));
390        assert_eq!(indices.next(), None);
391
392        // Width 4, ghost 3
393        let width = 6;
394        let ghost = 3;
395
396        let buffer = 2 * (ghost / 2); // 2
397        let support = (width + 2 * buffer) / 2; // 5
398
399        let element = Element::<1>::uniform(support, 4);
400
401        let mut indices = element.diagonal_int_points(buffer);
402        assert_eq!(indices.next(), Some(3));
403        assert_eq!(indices.next(), Some(5));
404        assert_eq!(indices.next(), Some(7));
405        assert_eq!(indices.next(), None);
406    }
407
408    fn prolong(h: f64) -> f64 {
409        let element = Element::<2>::uniform(6, 4);
410
411        let space = element.space_refined();
412
413        let mut values = vec![0.0; element.support_refined()];
414        let mut coefs = vec![0.0; element.support_refined()];
415
416        for index in space.iter() {
417            let [x, y] = element.position_refined(index);
418            let point = space.linear_from_cartesian(index);
419            values[point] = (x * h).sin() * (y * h).exp();
420        }
421
422        element.wavelet(&values, &mut coefs);
423        element.wavelet_abs_error(&coefs)
424    }
425
426    #[test]
427    fn convergence() {
428        let error1 = prolong(0.1);
429        let error2 = prolong(0.05);
430        let error4 = prolong(0.025);
431        let error8 = prolong(0.0125);
432
433        dbg!(error1 / error2);
434        dbg!(error2 / error4);
435        dbg!(error4 / error8);
436
437        assert!(error1 / error2 >= 32.);
438        assert!(error2 / error4 >= 32.);
439        assert!(error4 / error8 >= 32.);
440    }
441}