Skip to main content

aeon_tk/mesh/
evaluate.rs

1use std::convert::Infallible;
2use std::{array, ops::Range};
3
4use crate::geometry::{BlockId, Face, FaceMask, IndexSpace};
5use crate::image::ImageShared;
6use crate::kernel::{is_boundary_compatible, Derivative, Dissipation, Kernel, SecondDerivative, SystemBoundaryConds};
7use crate::{
8    kernel::{
9        BoundaryConds as _, BoundaryKind, Hessian, NodeSpace, node_from_vertex, vertex_from_node,
10    },
11};
12use reborrow::{Reborrow, ReborrowMut as _};
13
14use crate::{
15    mesh::{Engine, Function, Projection},
16    shared::SharedSlice,
17    image::{ImageRef, ImageMut},
18};
19
20use super::{Mesh, MeshStore};
21
22/// A finite difference engine of a given order, but potentially bordering a free boundary.
23struct FdEngine<'store, const N: usize, const ORDER: usize> {
24    space: NodeSpace<N>,
25    store: &'store MeshStore,
26    range: Range<usize>,
27}
28
29impl<'store, const N: usize, const ORDER: usize> FdEngine<'store, N, ORDER> {
30    fn evaluate_axis(
31        &self,
32        field: &[f64],
33        axis: usize,
34        kernel: impl Kernel,
35        vertex: [usize; N],
36    ) -> f64 {
37        self.space
38            .evaluate_axis(kernel, node_from_vertex(vertex), field, axis)
39    }
40}
41
42impl<'store, const N: usize, const ORDER: usize> Engine<N> for FdEngine<'store, N, ORDER> {
43    fn space(&self) -> &NodeSpace<N> {
44        &self.space
45    }
46
47    fn node_range(&self) -> Range<usize> {
48        self.range.clone()
49    }
50
51    fn alloc<T: Default>(&self, len: usize) -> &mut [T] {
52        self.store.scratch(len)
53    }
54
55    fn value(&self, field: &[f64], vertex: [usize; N]) -> f64 {
56        let index = self.space.index_from_vertex(vertex);
57        field[index]
58    }
59
60    fn derivative(&self, field: &[f64], axis: usize, vertex: [usize; N]) -> f64 {
61        self.evaluate_axis(field, axis, Derivative::<ORDER>, vertex)
62    }
63
64    fn second_derivative(&self, field: &[f64], axis: usize, vertex: [usize; N]) -> f64 {
65        self.evaluate_axis(field, axis, SecondDerivative::<ORDER>, vertex)
66    }
67
68    fn mixed_derivative(&self, field: &[f64], i: usize, j: usize, vertex: [usize; N]) -> f64 {
69        self.space
70            .evaluate(Hessian::<ORDER>::new(i, j), node_from_vertex(vertex), field)
71    }
72
73    fn dissipation(&self, field: &[f64], axis: usize, vertex: [usize; N]) -> f64 {
74        self.evaluate_axis(field, axis, Dissipation::<ORDER>, vertex)
75    }
76}
77
78/// A finite difference engine that only every relies on interior support (and can thus use better optimized stencils).
79struct FdIntEngine<'store, const N: usize, const ORDER: usize> {
80    space: NodeSpace<N>,
81    store: &'store MeshStore,
82    range: Range<usize>,
83}
84
85impl<'store, const N: usize, const ORDER: usize> FdIntEngine<'store, N, ORDER> {
86    fn evaluate(&self, field: &[f64], axis: usize, kernel: impl Kernel, vertex: [usize; N]) -> f64 {
87        self.space
88            .evaluate_axis_interior(kernel, node_from_vertex(vertex), field, axis)
89    }
90}
91
92impl<'store, const N: usize, const ORDER: usize> Engine<N> for FdIntEngine<'store, N, ORDER> {
93    fn space(&self) -> &NodeSpace<N> {
94        &self.space
95    }
96
97    fn node_range(&self) -> Range<usize> {
98        self.range.clone()
99    }
100
101    fn alloc<T: Default>(&self, len: usize) -> &mut [T] {
102        self.store.scratch(len)
103    }
104
105    fn value(&self, field: &[f64], vertex: [usize; N]) -> f64 {
106        let index = self.space.index_from_vertex(vertex);
107        field[index]
108    }
109
110    fn derivative(&self, field: &[f64], axis: usize, vertex: [usize; N]) -> f64 {
111        self.evaluate(field, axis, Derivative::<ORDER>, vertex)
112    }
113
114    fn second_derivative(&self, field: &[f64], axis: usize, vertex: [usize; N]) -> f64 {
115        self.evaluate(field, axis, SecondDerivative::<ORDER>, vertex)
116    }
117
118    fn mixed_derivative(&self, field: &[f64], i: usize, j: usize, vertex: [usize; N]) -> f64 {
119        self.space
120            .evaluate_interior(Hessian::<ORDER>::new(i, j), node_from_vertex(vertex), field)
121    }
122
123    fn dissipation(&self, field: &[f64], axis: usize, vertex: [usize; N]) -> f64 {
124        self.evaluate(field, axis, Dissipation::<ORDER>, vertex)
125    }
126}
127
128/// Transforms a projection into a function.
129#[derive(Clone)]
130struct ProjectionAsFunction<P>(P);
131
132impl<const N: usize, P: Projection<N>> Function<N> for ProjectionAsFunction<P> {
133    type Error = Infallible;
134
135    fn evaluate(
136        &self,
137        engine: impl Engine<N>,
138        _input: ImageRef,
139        mut output: ImageMut,
140    ) -> Result<(), Infallible> {
141        let dest = output.channel_mut(0);
142
143        for vertex in IndexSpace::new(engine.vertex_size()).iter() {
144            let index = engine.index_from_vertex(vertex);
145            dest[index] = self.0.project(engine.position(vertex))
146        }
147
148        Ok(())
149    }
150}
151
152impl<const N: usize> Mesh<N> {
153    /// Applies the projection to `source`, and stores the result in `dest`.
154    pub fn evaluate<P: Function<N> + Sync>(
155        &mut self,
156        order: usize,
157        function: P,
158        source: ImageRef,
159        dest: ImageMut,
160    ) -> Result<(), P::Error>
161    where
162        P::Error: Send,
163    {
164        assert!(dest.num_nodes() == source.num_nodes() || source.num_channels() == 0);
165        assert_eq!(dest.num_nodes(), self.num_nodes());
166
167        // Make sure order is valid.
168        assert!(matches!(order, 2 | 4 | 6));
169
170        let dest = ImageShared::from(dest);
171
172        self.try_block_compute(|mesh, store, block| {
173            let space = mesh.block_space(block);
174            let nodes = mesh.block_nodes(block);
175
176            let block_source = source.slice(nodes.clone());
177            let block_dest = unsafe { dest.slice_mut(nodes.clone()) };
178
179            if mesh.is_block_in_interior(block) {
180                macro_rules! evaluate_int {
181                    ($order:literal) => {
182                        function.evaluate(
183                            FdIntEngine::<N, $order> {
184                                space: space.clone(),
185                                store,
186                                range: nodes.clone(),
187                            },
188                            block_source,
189                            block_dest,
190                        )
191                    };
192                }
193
194                match order {
195                    2 => evaluate_int!(2),
196                    4 => evaluate_int!(4),
197                    6 => evaluate_int!(6),
198                    _ => unreachable!(),
199                }
200            } else {
201                macro_rules! evaluate {
202                    ($order:literal) => {
203                        function.evaluate(
204                            FdEngine::<N, $order> {
205                                space: space.clone(),
206                                store,
207                                range: nodes.clone(),
208                            },
209                            block_source,
210                            block_dest,
211                        )
212                    };
213                }
214
215                match order {
216                    2 => evaluate!(2),
217                    4 => evaluate!(4),
218                    6 => evaluate!(6),
219                    _ => unreachable!(),
220                }
221            }
222        })
223    }
224
225    /// Checks if the all neighbors of the block have strongly enforced boundary conditions 
226    /// (and thus can use centered stencils).
227    pub fn is_block_in_interior(&self, block: BlockId) -> bool {
228        let boundary = self.block_boundary_classes(block);
229
230        let mut result = true;
231
232        for axis in 0..N {
233            result &= boundary[Face::negative(axis)].has_ghost();
234            result &= boundary[Face::positive(axis)].has_ghost();
235        }
236
237        result
238    }
239
240    /// Evaluates the given function on a system in place.
241    fn evaluate_mut<
242        const ORDER: usize,
243        P: Function<N> + Sync,
244    >(
245        &mut self,
246        function: P,
247        dest: ImageMut,
248    ) -> Result<(), P::Error>
249    where
250        P::Error: Send,
251    {
252        let dest = ImageShared::from(dest);
253
254        self.try_block_compute(|mesh, store, block| {
255            let space = mesh.block_space(block);
256            let nodes = mesh.block_nodes(block);
257
258            let block_dest = unsafe { dest.slice_mut(nodes.clone()) };
259            let mut block_source =
260                ImageMut::from_storage(store.scratch(block_dest.num_nodes() * block_dest.num_channels()), block_dest.num_channels());
261
262            for field in dest.channels() {
263                block_source
264                    .channel_mut(field)
265                    .copy_from_slice(block_dest.channel(field));
266            }
267
268            if mesh.is_block_in_interior(block) {
269                let engine = FdIntEngine::<N, ORDER> {
270                    space: space.clone(),
271                    store,
272                    range: nodes.clone(),
273                };
274
275                function.evaluate(engine, block_source.rb(), block_dest)
276            } else {
277                let engine = FdEngine::<N, ORDER> {
278                    space: space.clone(),
279                    store,
280                    range: nodes.clone(),
281                };
282
283                function.evaluate(engine, block_source.rb(), block_dest)
284            }
285        })
286    }
287
288    /// Applies an operator to a system in place, enforcing both strong and weak boundary conditions
289    /// and running necessary preprocessing.
290    pub fn apply<
291        C: SystemBoundaryConds<N> + Sync,
292        P: Function<N> + Sync,
293    >(
294        &mut self,
295        order: usize,
296        bcs: C,
297        mut op: P,
298        mut f: ImageMut<'_>,
299    ) -> Result<(), P::Error>
300    where
301        P::Error: Send,
302    {
303        assert_eq!(f.num_nodes(), self.num_nodes());
304
305        for field in f.channels() {
306            assert!(
307                is_boundary_compatible(&self.boundary, &bcs.field(field)),
308                "Boundary Conditions incompatible with set boundary classes"
309            )
310        }
311
312        // Strong boundary condition
313        self.fill_boundary(order, bcs.clone(), f.rb_mut());
314        // Preprocess data
315        op.preprocess(self, f.rb_mut())?;
316
317        let f: ImageShared = f.into();
318
319        self.try_block_compute(|mesh, store, block| {
320            let space = mesh.block_space(block);
321            let nodes = mesh.block_nodes(block);
322            let bcs = mesh.block_bcs(block, bcs.clone());
323
324            let mut block_dest = unsafe { f.slice_mut(nodes.clone()) };
325
326            let mut block_source =
327                ImageMut::from_storage(store.scratch(block_dest.num_nodes() * block_dest.num_channels()), block_dest.num_channels());
328
329            for field in f.channels() {
330                block_source
331                    .channel_mut(field)
332                    .copy_from_slice(block_dest.channel(field));
333            }
334
335            if mesh.is_block_in_interior(block) {
336                macro_rules! evaluate_int {
337                    ($order:literal) => {
338                        op.evaluate(
339                            FdIntEngine::<N, $order> {
340                                space: space.clone(),
341                                store,
342                                range: nodes.clone(),
343                            },
344                            block_source.rb(),
345                            block_dest.rb_mut(),
346                        )
347                    };
348                }
349
350                match order {
351                    2 => evaluate_int!(2),
352                    4 => evaluate_int!(4),
353                    6 => evaluate_int!(6),
354                    _ => unreachable!(),
355                }
356            } else {
357                macro_rules! evaluate {
358                    ($order:literal) => {
359                        op.evaluate(
360                            FdEngine::<N, $order> {
361                                space: space.clone(),
362                                store,
363                                range: nodes.clone(),
364                            },
365                            block_source.rb(),
366                            block_dest.rb_mut(),
367                        )?
368                    };
369                }
370
371                match order {
372                    2 => evaluate!(2),
373                    4 => evaluate!(4),
374                    6 => evaluate!(6),
375                    _ => unreachable!(),
376                }
377
378                // Weak boundary conditions.
379                for face in Face::<N>::iterate() {
380                    for field in f.channels() {
381                        let boundary = bcs.field(field);
382                        let source = block_source.channel(field);
383                        let dest = block_dest.channel_mut(field);
384
385                        // Apply weak dirichlet boundary conditions
386                        if boundary.kind(face) == BoundaryKind::WeakDirichlet {
387                            for node in space.face_window_disjoint(face) {
388                                let index = space.index_from_node(node);
389                                let position = space.position(node);
390                                let dirichlet = boundary.dirichlet(position);
391                                dest[index] =
392                                    dirichlet.strength * (dirichlet.target - source[index])
393                            }
394                        }
395
396                        // Apply radiative condition
397                        if boundary.kind(face) != BoundaryKind::Radiative {
398                            continue;
399                        }
400
401                        // Sommerfeld radiative boundary conditions.
402                        for node in space.face_window(face) {
403                            let vertex = vertex_from_node(node);
404                            // *************************
405                            // At vertex
406
407                            let position: [f64; N] = space.position(node);
408                            let r = position.iter().map(|&v| v * v).sum::<f64>().sqrt();
409                            let index = space.index_from_node(node);
410
411                            macro_rules! inner {
412                                ($order:literal) => {
413                                    // *************************
414                                    // Inner
415
416                                    let engine = FdEngine::<N, $order> {
417                                        space: space.clone(),
418                                        store,
419                                        range: nodes.clone(),
420                                    };
421
422                                    let mut inner = vertex;
423
424                                    // Find innter vertex for approximating higher order r dependence
425                                    for axis in 0..N {
426                                        if boundary.kind(Face::negative(axis)) == BoundaryKind::Radiative
427                                            && vertex[axis] == 0
428                                        {
429                                            inner[axis] += 1;
430                                        }
431                                    
432                                        if boundary.kind(Face::positive(axis)) == BoundaryKind::Radiative
433                                            && vertex[axis] == engine.vertex_size()[axis] - 1
434                                        {
435                                            inner[axis] -= 1;
436                                        }
437                                    }
438                                
439                                    let inner_position = engine.position(inner);
440                                    let inner_r = inner_position.iter().map(|&v| v * v).sum::<f64>().sqrt();
441                                    let inner_index = engine.index_from_vertex(inner);
442                                
443                                    // Get condition parameters.
444                                    let params = boundary.radiative(position);
445                                    // Inner R dependence.
446                                    let mut inner_advection = source[inner_index] - params.target;
447                                
448                                    for axis in 0..N {
449                                        let derivative = engine.derivative(source, axis, inner);
450                                        inner_advection += inner_position[axis] * derivative;
451                                    }
452                                
453                                    inner_advection *= params.speed;
454                                
455                                    let k = inner_r
456                                        * inner_r
457                                        * inner_r
458                                        * (dest[inner_index] + inner_advection / inner_r);
459                                
460                                    // Vertex
461                                    let mut advection = source[index] - params.target;
462                                
463                                    for axis in 0..N {
464                                        let derivative = engine.derivative(source, axis, vertex);
465                                        advection += position[axis] * derivative;
466                                    }
467                                
468                                    advection *= params.speed;
469                                    dest[index] = -advection / r + k / (r * r * r);
470                                };
471                            }
472
473                            match order {
474                                2 => { inner!(2); },
475                                4 => { inner!(4); },
476                                6 => { inner!(6); },
477                                _ => unimplemented!("Order unimplemented")
478                            }
479                        }
480                    }
481                }
482
483                Ok(())
484            }
485        })
486    }
487
488    /// Copies an immutable src slice into a mutable dest slice.
489    pub fn copy_from_slice(&mut self, mut dest: ImageMut, src: ImageRef) {
490        assert_eq!(dest.num_nodes(), src.num_nodes());
491
492        for label in dest.channels() {
493            dest.channel_mut(label).copy_from_slice(src.channel(label));
494        }
495    }
496
497    /// Applies a projection and stores the result in the destination vector.
498    pub fn project<P: Projection<N> + Sync>(
499        &mut self,
500        order: usize,
501        projection: P,
502        dest: &mut [f64],
503    ) {
504        assert_eq!(dest.len(), self.num_nodes());
505        self.evaluate(
506            order,
507            ProjectionAsFunction(projection),
508            ImageRef::empty(),
509            ImageMut::from(dest),
510        )
511        .unwrap();
512    }
513
514    /// Applies the projection to `source`, and stores the result in `dest`.
515    pub fn dissipation<const ORDER: usize>(
516        &mut self,
517        amplitude: f64,
518        mut dest: ImageMut,
519    ) {
520        assert_eq!(dest.num_nodes(), self.num_nodes());
521
522        #[derive(Clone)]
523        struct Dissipation(f64);
524
525        impl<const N: usize> Function<N> for Dissipation {
526            type Error = Infallible;
527
528            fn evaluate(
529                &self,
530                engine: impl Engine<N>,
531                input: ImageRef,
532                mut output: ImageMut,
533            ) -> Result<(), Infallible> {
534                let input = input.channel(0);
535                let output = output.channel_mut(0);
536
537                for vertex in IndexSpace::new(engine.vertex_size()).iter() {
538                    let index = engine.index_from_vertex(vertex);
539
540                    for axis in 0..N {
541                        output[index] += self.0 * engine.dissipation(input, axis, vertex);
542                    }
543                }
544
545                Ok(())
546            }
547        }
548
549        for field in dest.channels() {
550            self.evaluate_mut::<ORDER, _>(
551                Dissipation(amplitude),
552                ImageMut::from(dest.channel_mut(field)),
553            )
554            .unwrap();
555        }
556    }
557
558    /// This function computes the distance between each vertex and its nearest
559    /// neighbor.
560    pub fn spacing_per_vertex(&mut self, dest: &mut [f64]) {
561        assert!(dest.len() == self.num_nodes());
562
563        let dest = SharedSlice::new(dest);
564
565        self.block_compute(|mesh, _, block| {
566            let nodes = mesh.block_nodes(block);
567            let space = mesh.block_space(block);
568
569            let spacing = space.spacing();
570            let min_spacing = spacing
571                .iter()
572                .min_by(|a, b| a.total_cmp(&b))
573                .cloned()
574                .unwrap_or(1.0);
575
576            let vertex_size = space.vertex_size();
577
578            let block_dest = unsafe { dest.slice_mut(nodes) };
579
580            for &cell in mesh.blocks.active_cells(block) {
581                let node_size = mesh.cell_node_size(cell);
582                let node_origin = mesh.active_node_origin(cell);
583
584                let mut flags = FaceMask::empty();
585
586                for face in Face::iterate() {
587                    let Some(neighbor) = mesh
588                        .tree
589                        .neighbor(mesh.tree.cell_from_active_index(cell), face)
590                    else {
591                        continue;
592                    };
593                    // If neighbors have larger refinement than us
594                    if !mesh.tree.is_active(neighbor) {
595                        flags.set(face);
596                    }
597                }
598
599                for offset in IndexSpace::new(node_size).iter() {
600                    let vertex: [_; N] = array::from_fn(|axis| node_origin[axis] + offset[axis]);
601                    let index = space.index_from_vertex(vertex);
602
603                    let mut refined = false;
604
605                    for axis in 0..N {
606                        refined |= vertex[axis] == 0 && flags.is_set(Face::negative(axis));
607                        refined |= vertex[axis] == vertex_size[axis] - 1
608                            && flags.is_set(Face::positive(axis));
609                    }
610
611                    if refined {
612                        block_dest[index] = min_spacing / 2.0;
613                    } else {
614                        block_dest[index] = min_spacing;
615                    }
616                }
617            }
618        });
619    }
620
621    pub fn adaptive_cfl(
622        &mut self,
623        spacing_per_vertex: &[f64],
624        dest: ImageMut,
625    ) {
626        let dest = ImageShared::from(dest);
627        let min_spacing = self.min_spacing();
628
629        self.block_compute(|mesh, _, block| {
630            let block_space = mesh.block_space(block);
631            let block_nodes = mesh.block_nodes(block);
632
633            let block_spacings = &spacing_per_vertex[block_nodes.clone()];
634            let mut block_dest = unsafe { dest.slice_mut(block_nodes.clone()) };
635
636            for field in block_dest.channels() {
637                let block_dest = block_dest.channel_mut(field);
638
639                for vertex in IndexSpace::new(block_space.vertex_size()).iter() {
640                    let index = block_space.index_from_vertex(vertex);
641                    block_dest[index] *= block_spacings[index] / min_spacing;
642                }
643            }
644        });
645    }
646}