1use crate::{
2 element::{ApproxOperator, Monomials, Uniform, Values},
3 geometry::{IndexSpace, Split},
4};
5use faer::{ColRef, Mat};
6use std::array;
7
8#[derive(Clone, Debug)]
13pub struct Element<const N: usize> {
14 order: usize,
16 width: usize,
18 grid: Vec<f64>,
20 grid_refined: Vec<f64>,
23 stencils: Mat<f64>,
25}
26
27impl<const N: usize> Element<N> {
28 pub fn uniform(width: usize, order: usize) -> Self {
31 assert!(width >= order);
32 let (grid, grid_refined) = Self::uniform_grid(width);
33 debug_assert!(grid.len() == width + 1);
34 debug_assert!(grid_refined.len() == 2 * width + 1);
35
36 let spacing = 2.0 / width as f64;
37 let points = Vec::from_iter(
38 (0..width)
39 .into_iter()
40 .map(|j| [-1.0 + spacing * (j as f64 + 0.5)]),
41 );
42
43 let mut approx = ApproxOperator::default();
44 approx
45 .build(
46 &Uniform::new([width + 1]),
47 &Monomials::new([order + 1]),
48 &Values(points.as_slice()),
49 )
50 .unwrap();
51
52 let stencils = approx.shape().to_owned();
53
54 debug_assert!(stencils.nrows() == width + 1);
55 debug_assert!(stencils.ncols() == width);
56
57 Self {
58 width,
59 order,
60 grid,
61 grid_refined,
62 stencils,
63 }
64 }
65
66 fn uniform_grid(width: usize) -> (Vec<f64>, Vec<f64>) {
67 let spacing = 2.0 / width as f64;
68
69 let grid = (0..=width)
70 .map(|i| i as f64 * spacing - 1.0)
71 .collect::<Vec<_>>();
72
73 let grid_refined = (0..=2 * width)
74 .map(|i| i as f64 * spacing / 2.0 - 1.0)
75 .collect::<Vec<_>>();
76
77 (grid, grid_refined)
78 }
79
80 pub fn support(&self) -> usize {
82 (self.width + 1).pow(N as u32)
83 }
84
85 pub fn support_refined(&self) -> usize {
87 (2 * self.width + 1).pow(N as u32)
88 }
89
90 pub fn order(&self) -> usize {
92 self.order
93 }
94
95 pub fn width(&self) -> usize {
97 self.width
98 }
99
100 pub fn grid(&self) -> &[f64] {
102 &self.grid
103 }
104
105 pub fn prolong_stencil(&self, target: usize) -> ColRef<'_, f64> {
106 self.stencils.col(target)
107 }
108
109 pub fn position(&self, index: [usize; N]) -> [f64; N] {
111 array::from_fn(|axis| self.grid[index[axis]])
112 }
113
114 pub fn position_refined(&self, index: [usize; N]) -> [f64; N] {
115 array::from_fn(|axis| self.grid_refined[index[axis]])
116 }
117
118 pub fn space(&self) -> IndexSpace<N> {
123 IndexSpace::new([self.width + 1; N])
124 }
125
126 pub fn space_refined(&self) -> IndexSpace<N> {
127 IndexSpace::new([2 * self.width + 1; N])
128 }
129
130 pub fn nodal_indices(&self) -> impl Iterator<Item = [usize; N]> {
132 IndexSpace::new([self.width + 1; N])
133 .iter()
134 .map(|v| array::from_fn(|axis| v[axis] * 2))
135 }
136
137 pub fn diagonal_indices(&self) -> impl Iterator<Item = [usize; N]> {
139 IndexSpace::new([self.width; N])
140 .iter()
141 .map(|v| array::from_fn(|axis| v[axis] * 2 + 1))
142 }
143
144 pub fn diagonal_int_indices(&self, buffer: usize) -> impl Iterator<Item = [usize; N]> {
147 debug_assert!(buffer % 2 == 0);
148
149 IndexSpace::new([self.width - buffer; N])
150 .iter()
151 .map(move |v| array::from_fn(|axis| 2 * v[axis] + 1 + buffer))
152 }
153
154 pub fn detail_indices(&self) -> impl Iterator<Item = [usize; N]> {
156 let cells = IndexSpace::new([self.width; N]).iter();
157
158 cells.flat_map(|index| {
159 Split::<N>::enumerate().skip(1).map(move |mask| {
160 let mut point = index;
161
162 for axis in 0..N {
163 point[axis] *= 2;
164
165 if mask.is_set(axis) {
166 point[axis] += 1;
167 }
168 }
169
170 point
171 })
172 })
173 }
174
175 pub fn nodal_points(&self) -> impl Iterator<Item = usize> {
176 let space = IndexSpace::new([2 * self.width + 1; N]);
177 self.nodal_indices()
178 .map(move |index| space.linear_from_cartesian(index))
179 }
180
181 pub fn diagonal_points(&self) -> impl Iterator<Item = usize> {
182 let space = IndexSpace::new([2 * self.width + 1; N]);
183 self.diagonal_indices()
184 .map(move |index| space.linear_from_cartesian(index))
185 }
186
187 pub fn diagonal_int_points(&self, buffer: usize) -> impl Iterator<Item = usize> {
188 let space = IndexSpace::new([2 * self.width + 1; N]);
189 self.diagonal_int_indices(buffer)
190 .map(move |index| space.linear_from_cartesian(index))
191 }
192
193 pub fn detail_points(&self) -> impl Iterator<Item = usize> {
194 let space = IndexSpace::new([2 * self.width + 1; N]);
195 self.detail_indices()
196 .map(move |index| space.linear_from_cartesian(index))
197 }
198
199 pub fn prolong(&self, source: &[f64], dest: &mut [f64]) {
205 self.inject(source, dest);
206 self.prolong_in_place(dest);
207 }
208
209 pub fn prolong_in_place(&self, dest: &mut [f64]) {
212 debug_assert!(dest.len() == self.support_refined());
213
214 let space = IndexSpace::new([2 * self.width + 1; N]);
215
216 for axis in (0..N).rev() {
218 let mut psize = [0; N];
219
220 for i in 0..axis {
221 psize[i] = self.width + 1;
222 }
223 psize[axis] = self.width;
224 for i in (axis + 1)..N {
225 psize[i] = 2 * self.width + 1;
226 }
227
228 for mut point in IndexSpace::new(psize).iter() {
229 for i in 0..axis {
230 point[i] *= 2;
231 }
232
233 let stencil = self.stencils.col(point[axis]);
234
235 point[axis] *= 2;
236 point[axis] += 1;
237
238 let center = space.linear_from_cartesian(point);
239 dest[center] = 0.0;
240
241 for i in 0..=self.width {
242 point[axis] = 2 * i;
243 dest[center] += stencil[i] * dest[space.linear_from_cartesian(point)];
244 }
245 }
246 }
247 }
248
249 pub fn inject(&self, source: &[f64], dest: &mut [f64]) {
255 debug_assert!(source.len() == self.support());
256 debug_assert!(dest.len() == self.support_refined());
257
258 let source_space = IndexSpace::new([self.width + 1; N]);
260 let dest_space = IndexSpace::new([2 * self.width + 1; N]);
261
262 for (pindex, point) in source_space.iter().enumerate() {
263 let refined: [_; N] = array::from_fn(|axis| 2 * point[axis]);
264 let rindex = dest_space.linear_from_cartesian(refined);
265 dest[rindex] = source[pindex];
266 }
267 }
268
269 pub fn restrict(&self, source: &[f64], dest: &mut [f64]) {
271 debug_assert!(source.len() == self.support_refined());
272 debug_assert!(dest.len() == self.support());
273
274 let source_space = IndexSpace::new([2 * self.width + 1; N]);
275 let dest_space = IndexSpace::new([self.width + 1; N]);
276
277 for (pindex, point) in source_space.iter().enumerate() {
278 let refined: [_; N] = array::from_fn(|axis| point[axis] / 2);
279 let rindex = dest_space.linear_from_cartesian(refined);
280 dest[rindex] = source[pindex];
281 }
282 }
283
284 pub fn wavelet(&self, source: &[f64], dest: &mut [f64]) {
290 debug_assert!(source.len() == self.support_refined());
291 debug_assert!(dest.len() == self.support_refined());
292
293 for point in self.nodal_points() {
295 dest[point] = source[point];
296 }
297
298 self.prolong_in_place(dest);
299
300 for point in self.detail_points() {
302 dest[point] -= source[point];
303 }
304 }
305
306 pub fn wavelet_rel_error(&self, coefs: &[f64]) -> f64 {
309 let scale = self
310 .nodal_points()
311 .map(|v| coefs[v].abs())
312 .max_by(|a, b| a.total_cmp(b))
313 .unwrap();
314
315 self.wavelet_abs_error(coefs) / scale
316 }
317
318 pub fn wavelet_abs_error(&self, coefs: &[f64]) -> f64 {
321 self.diagonal_points()
322 .map(|v| coefs[v].abs())
323 .max_by(|a, b| a.total_cmp(b))
324 .unwrap()
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::Element;
331
332 #[test]
333 fn iteration() {
334 let element = Element::<2>::uniform(2, 1);
335
336 let mut nodal = element.nodal_indices();
337 assert_eq!(nodal.next(), Some([0, 0]));
338 assert_eq!(nodal.next(), Some([2, 0]));
339 assert_eq!(nodal.next(), Some([4, 0]));
340 assert_eq!(nodal.next(), Some([0, 2]));
341 assert_eq!(nodal.next(), Some([2, 2]));
342 assert_eq!(nodal.next(), Some([4, 2]));
343 assert_eq!(nodal.next(), Some([0, 4]));
344 assert_eq!(nodal.next(), Some([2, 4]));
345 assert_eq!(nodal.next(), Some([4, 4]));
346 assert_eq!(nodal.next(), None);
347
348 let mut diagonal = element.diagonal_indices();
349 assert_eq!(diagonal.next(), Some([1, 1]));
350 assert_eq!(diagonal.next(), Some([3, 1]));
351 assert_eq!(diagonal.next(), Some([1, 3]));
352 assert_eq!(diagonal.next(), Some([3, 3]));
353 assert_eq!(diagonal.next(), None);
354
355 let mut diagonal_int = element.diagonal_int_indices(0);
356 assert_eq!(diagonal_int.next(), Some([1, 1]));
357 assert_eq!(diagonal_int.next(), Some([3, 1]));
358 assert_eq!(diagonal_int.next(), Some([1, 3]));
359 assert_eq!(diagonal_int.next(), Some([3, 3]));
360 assert_eq!(diagonal_int.next(), None);
361
362 let mut detail = element.detail_indices();
363 assert_eq!(detail.next(), Some([1, 0]));
364 assert_eq!(detail.next(), Some([0, 1]));
365 assert_eq!(detail.next(), Some([1, 1]));
366
367 assert_eq!(detail.next(), Some([3, 0]));
368 assert_eq!(detail.next(), Some([2, 1]));
369 assert_eq!(detail.next(), Some([3, 1]));
370
371 assert_eq!(detail.next(), Some([1, 2]));
372 assert_eq!(detail.next(), Some([0, 3]));
373 assert_eq!(detail.next(), Some([1, 3]));
374
375 assert_eq!(detail.next(), Some([3, 2]));
376 assert_eq!(detail.next(), Some([2, 3]));
377 assert_eq!(detail.next(), Some([3, 3]));
378 assert_eq!(detail.next(), None);
379 }
380
381 #[test]
382 fn interior_indices() {
383 let element = Element::<1>::uniform(6, 4);
384
385 let mut indices = element.diagonal_int_points(2);
386 assert_eq!(indices.next(), Some(3));
387 assert_eq!(indices.next(), Some(5));
388 assert_eq!(indices.next(), Some(7));
389 assert_eq!(indices.next(), Some(9));
390 assert_eq!(indices.next(), None);
391
392 let width = 6;
394 let ghost = 3;
395
396 let buffer = 2 * (ghost / 2); let support = (width + 2 * buffer) / 2; let element = Element::<1>::uniform(support, 4);
400
401 let mut indices = element.diagonal_int_points(buffer);
402 assert_eq!(indices.next(), Some(3));
403 assert_eq!(indices.next(), Some(5));
404 assert_eq!(indices.next(), Some(7));
405 assert_eq!(indices.next(), None);
406 }
407
408 fn prolong(h: f64) -> f64 {
409 let element = Element::<2>::uniform(6, 4);
410
411 let space = element.space_refined();
412
413 let mut values = vec![0.0; element.support_refined()];
414 let mut coefs = vec![0.0; element.support_refined()];
415
416 for index in space.iter() {
417 let [x, y] = element.position_refined(index);
418 let point = space.linear_from_cartesian(index);
419 values[point] = (x * h).sin() * (y * h).exp();
420 }
421
422 element.wavelet(&values, &mut coefs);
423 element.wavelet_abs_error(&coefs)
424 }
425
426 #[test]
427 fn convergence() {
428 let error1 = prolong(0.1);
429 let error2 = prolong(0.05);
430 let error4 = prolong(0.025);
431 let error8 = prolong(0.0125);
432
433 dbg!(error1 / error2);
434 dbg!(error2 / error4);
435 dbg!(error4 / error8);
436
437 assert!(error1 / error2 >= 32.);
438 assert!(error2 / error4 >= 32.);
439 assert!(error4 / error8 >= 32.);
440 }
441}