Skip to main content

aeon_tk/solver/
hyper.rs

1use std::convert::Infallible;
2
3use crate::IRef;
4use crate::geometry::{Face, IndexSpace};
5use crate::image::{ImageMut, ImageRef};
6use crate::kernel::{BoundaryKind, DirichletParams, RadiativeParams, SystemBoundaryConds};
7use crate::mesh::FunctionBorrowMut;
8use datasize::DataSize;
9use reborrow::{Reborrow, ReborrowMut};
10use thiserror::Error;
11
12use crate::{
13    mesh::{Engine, Function, Mesh},
14    solver::{Integrator, Method},
15};
16
17use super::SolverCallback;
18
19/// Error which may be thrown during hyperbolic relaxation.
20#[derive(Error, Debug)]
21pub enum HyperRelaxError<A, B> {
22    #[error("failed to relax below tolerance in allotted number of steps")]
23    ReachedMaxSteps,
24    #[error("norm diverged to NaN")]
25    NormDiverged,
26    #[error("function error")]
27    FunctionFailed(A),
28    #[error("callback error")]
29    CallbackFailed(B),
30}
31
32/// A solver which implements the algorithm described in NRPyElliptic. This transforms the elliptic equation
33/// 𝓛{u} = p, into the hyperbolic equation ∂ₜ²u + η∂ₜu = c² (𝓛{u} - p), where c is the speed of the wave, and η is
34/// a dampening term that speeds up convergence.
35#[derive(Clone, Debug, DataSize)]
36pub struct HyperRelaxSolver {
37    /// Error tolerance (relaxation stops once error goes below this value).
38    pub tolerance: f64,
39    /// Maximum number of relaxation steps to perform
40    pub max_steps: usize,
41    /// Dampening term η.
42    pub dampening: f64,
43    /// CFL factor for ficticuous time step.
44    pub cfl: f64,
45    /// If set, the relax solver uses larger time steps for
46    /// vertices in less refined regions (subject to the CFL condition
47    /// of course).
48    pub adaptive: bool,
49
50    integrator: Integrator,
51}
52
53impl Default for HyperRelaxSolver {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl HyperRelaxSolver {
60    /// Constructs a new `HyperRelaxSolver` with default settings.
61    pub fn new() -> Self {
62        Self {
63            tolerance: 1e-5,
64            max_steps: 100000,
65            dampening: 1.0,
66            cfl: 0.1,
67            adaptive: false,
68            // visualize: None,
69            integrator: Integrator::new(Method::RK4),
70        }
71    }
72
73    /// Solves a given elliptic system
74    pub fn solve<const N: usize, C: SystemBoundaryConds<N> + Sync, F: Function<N> + Sync>(
75        &mut self,
76        mesh: &mut Mesh<N>,
77        order: usize,
78        conditions: C,
79        deriv: F,
80        result: ImageMut,
81    ) -> Result<(), HyperRelaxError<F::Error, Infallible>>
82    where
83        F::Error: Send,
84    {
85        self.solve_with_callback(mesh, order, conditions, (), deriv, result)
86    }
87
88    pub fn solve_with_callback<
89        const N: usize,
90        C: SystemBoundaryConds<N> + Sync,
91        F: Function<N> + Sync,
92        Call: SolverCallback<N>,
93    >(
94        &mut self,
95        mesh: &mut Mesh<N>,
96        order: usize,
97        conditions: C,
98        mut callback: Call,
99        mut deriv: F,
100        mut result: ImageMut,
101    ) -> Result<(), HyperRelaxError<F::Error, Call::Error>>
102    where
103        F::Error: Send,
104        Call::Error: Send,
105    {
106        assert_eq!(result.num_nodes(), mesh.num_nodes());
107        // Total number of degreees of freedom in the whole system
108        let dimension = result.num_channels() * mesh.num_nodes();
109        let num_channels = result.num_channels();
110
111        // assert!(result.len() == dimension);
112
113        // Allocate storage
114        let mut data = vec![0.0; 2 * dimension].into_boxed_slice();
115        // Compute minimum spacing and spacing per vertex.
116        let min_spacing = mesh.min_spacing();
117
118        let mut spacing_per_vertex = vec![min_spacing; mesh.num_nodes()];
119        if self.adaptive {
120            mesh.spacing_per_vertex(&mut spacing_per_vertex);
121        }
122
123        // Use CFL factor to compute time_step
124        let time_step = self.cfl * min_spacing;
125
126        // Fill initial guess
127        {
128            let (u, v) = data.split_at_mut(dimension);
129            // u is initial guess
130            mesh.copy_from_slice(ImageMut::from_storage(u, num_channels), result.rb());
131            // Let us assume that du/dt is initially zero
132            mesh.copy_from_slice(ImageMut::from_storage(v, num_channels), result.rb());
133            for value in v.iter_mut() {
134                *value *= self.dampening;
135            }
136        }
137
138        for index in 0..self.max_steps {
139            mesh.fill_boundary(
140                order,
141                FicticuousBoundaryConds {
142                    dampening: self.dampening,
143                    conditions: conditions.clone(),
144                    channels: num_channels,
145                },
146                ImageMut::from_storage(&mut data, 2 * num_channels),
147            );
148
149            {
150                let u = ImageRef::from_storage(&data[..dimension], num_channels);
151                mesh.copy_from_slice(result.rb_mut(), u.rb());
152                mesh.apply(
153                    order,
154                    conditions.clone(),
155                    FunctionBorrowMut(&mut deriv),
156                    result.rb_mut(),
157                )
158                .map_err(|err| HyperRelaxError::FunctionFailed(err))?;
159                callback
160                    .callback(mesh, u.rb(), result.rb(), index)
161                    .map_err(|err| HyperRelaxError::CallbackFailed(err))?;
162            }
163
164            let norm = mesh.l2_norm_system(result.rb());
165
166            if !norm.is_finite() || norm >= 1e60 {
167                return Err(HyperRelaxError::NormDiverged);
168            }
169
170            if index % 100 == 0 {
171                log::trace!("Relaxed {}k steps, norm: {:.5e}", index / 100, norm);
172            }
173
174            if norm <= self.tolerance {
175                log::trace!("Converged in {} steps.", index);
176
177                // Copy solution back to system vector
178                mesh.copy_from_slice(
179                    result.rb_mut(),
180                    ImageRef::from_storage(&data[..dimension], num_channels),
181                );
182                mesh.fill_boundary(order, conditions, result.rb_mut());
183
184                return Ok(());
185            }
186
187            self.integrator
188                .step(
189                    mesh,
190                    order,
191                    FicticuousBoundaryConds {
192                        dampening: self.dampening,
193                        conditions: conditions.clone(),
194                        channels: num_channels,
195                    },
196                    FicticuousDerivs {
197                        dampening: self.dampening,
198                        function: &deriv,
199                        spacing_per_vertex: &spacing_per_vertex,
200                        min_spacing,
201                    },
202                    time_step,
203                    ImageMut::from_storage(&mut data, 2 * num_channels),
204                )
205                .map_err(|err| HyperRelaxError::FunctionFailed(err))?;
206
207            if index == self.max_steps - 1 {
208                log::error!(
209                    "Hyperbolic relaxation failed to converge in {} steps.",
210                    self.max_steps
211                );
212            }
213        }
214
215        // Copy solution back to system vector
216        mesh.copy_from_slice(
217            result.rb_mut(),
218            ImageRef::from_storage(&data[..dimension], num_channels),
219        );
220        mesh.fill_boundary(order, conditions, result.rb_mut());
221
222        Err(HyperRelaxError::ReachedMaxSteps)
223    }
224}
225
226#[derive(Clone)]
227struct FicticuousBoundaryConds<C> {
228    dampening: f64,
229    conditions: C,
230    channels: usize,
231}
232
233impl<const N: usize, C: SystemBoundaryConds<N>> SystemBoundaryConds<N>
234    for FicticuousBoundaryConds<C>
235{
236    fn kind(&self, channel: usize, face: Face<N>) -> BoundaryKind {
237        let boundary_kind: BoundaryKind = self.conditions.kind(channel % self.channels, face);
238
239        match boundary_kind {
240            BoundaryKind::Symmetric => BoundaryKind::Symmetric,
241            BoundaryKind::AntiSymmetric => BoundaryKind::AntiSymmetric,
242            BoundaryKind::Custom => BoundaryKind::Custom,
243            BoundaryKind::Radiative => BoundaryKind::Radiative,
244            BoundaryKind::Free => BoundaryKind::Free,
245            BoundaryKind::StrongDirichlet | BoundaryKind::WeakDirichlet => {
246                BoundaryKind::WeakDirichlet
247            }
248        }
249    }
250
251    fn radiative(&self, channel: usize, position: [f64; N]) -> RadiativeParams {
252        let mut result = self.conditions.radiative(channel % self.channels, position);
253        if channel >= self.channels {
254            result.target *= self.dampening;
255        }
256        result
257    }
258
259    fn dirichlet(&self, channel: usize, position: [f64; N]) -> DirichletParams {
260        let mut result = self.conditions.dirichlet(channel % self.channels, position);
261        if channel >= self.channels {
262            result.target *= self.dampening;
263        }
264        result
265    }
266}
267
268#[derive(Clone)]
269struct FicticuousDerivs<'a, const N: usize, F> {
270    dampening: f64,
271    function: &'a F,
272    spacing_per_vertex: &'a [f64],
273    min_spacing: f64,
274}
275
276impl<const N: usize, F: Function<N>> Function<N> for FicticuousDerivs<'_, N, F> {
277    type Error = F::Error;
278
279    fn evaluate(
280        &self,
281        engine: impl Engine<N>,
282        input: ImageRef,
283        mut output: ImageMut,
284    ) -> Result<(), F::Error> {
285        assert_eq!(input.num_channels(), output.num_channels());
286        let num_channels = output.num_channels();
287        let (uin, vin) = input.split_channels(num_channels / 2);
288        let (mut uout, mut vout) = output.rb_mut().split_channels(num_channels / 2);
289
290        // Find du/dt from the definition v = du/dt + η u
291        for field in uin.channels() {
292            let u = uin.channel(field);
293            let v = vin.channel(field);
294
295            let udest = uout.channel_mut(field);
296
297            for vertex in IndexSpace::new(engine.vertex_size()).iter() {
298                let index = engine.index_from_vertex(vertex);
299                udest[index] = v[index] - u[index] * self.dampening;
300            }
301        }
302
303        // dv/dt = c^2 Lu
304        // TODO speed
305        self.function.evaluate(IRef(&engine), uin, vout.rb_mut())?;
306
307        // Use adaptive timestep
308        let block_spacing = &self.spacing_per_vertex[engine.node_range()];
309
310        for field in uout.channels() {
311            let uout = uout.channel_mut(field);
312            for vertex in IndexSpace::new(engine.vertex_size()).iter() {
313                let index = engine.index_from_vertex(vertex);
314                uout[index] *= block_spacing[index] / self.min_spacing;
315            }
316        }
317
318        for field in vout.channels() {
319            let vout = vout.channel_mut(field);
320            for vertex in IndexSpace::new(engine.vertex_size()).iter() {
321                let index = engine.index_from_vertex(vertex);
322                vout[index] *= block_spacing[index] / self.min_spacing;
323            }
324        }
325
326        Ok(())
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use crate::{
333        geometry::{FaceArray, HyperBox},
334        kernel::{BoundaryClass, DirichletParams},
335    };
336
337    use super::*;
338    use crate::{kernel::BoundaryKind, mesh::Projection};
339    use std::{convert::Infallible, f64::consts};
340
341    #[derive(Clone)]
342    struct _PoissonConditions;
343
344    impl SystemBoundaryConds<2> for _PoissonConditions {
345        fn kind(&self, _channel: usize, _face: Face<2>) -> BoundaryKind {
346            BoundaryKind::StrongDirichlet
347        }
348
349        fn dirichlet(&self, _channel: usize, _position: [f64; 2]) -> DirichletParams {
350            DirichletParams {
351                target: 0.0,
352                strength: 1.0,
353            }
354        }
355    }
356
357    #[derive(Clone)]
358    pub struct PoissonSolution;
359
360    impl Projection<2> for PoissonSolution {
361        fn project(&self, [x, y]: [f64; 2]) -> f64 {
362            (2.0 * consts::PI * x).sin() * (2.0 * consts::PI * y).sin()
363        }
364    }
365
366    #[derive(Clone)]
367    pub struct _PoissonEquation;
368
369    impl Function<2> for _PoissonEquation {
370        type Error = Infallible;
371
372        fn evaluate(
373            &self,
374            engine: impl Engine<2>,
375            input: ImageRef,
376            mut output: ImageMut,
377        ) -> Result<(), Infallible> {
378            let input = input.channel(0);
379            let output = output.channel_mut(0);
380
381            for vertex in IndexSpace::new(engine.vertex_size()).iter() {
382                let index = engine.index_from_vertex(vertex);
383                let [x, y] = engine.position(vertex);
384
385                let laplacian = engine.second_derivative(input, 0, vertex)
386                    + engine.second_derivative(input, 1, vertex);
387                let source = -8.0
388                    * consts::PI
389                    * consts::PI
390                    * (2.0 * consts::PI * x).sin()
391                    * (2.0 * consts::PI * y).sin();
392
393                output[index] = laplacian - source;
394            }
395
396            Ok(())
397        }
398    }
399
400    #[test]
401    fn poisson() {
402        let mut mesh = Mesh::new(
403            HyperBox::from_aabb([0.0, 0.0], [1.0, 1.0]),
404            4,
405            2,
406            FaceArray::splat(BoundaryClass::Ghost),
407        );
408        // Perform refinement
409        mesh.refine_global();
410        mesh.refine_global();
411
412        // Write solution vector
413        let mut solution = vec![0.0; mesh.num_nodes()];
414        mesh.project(4, PoissonSolution, &mut solution);
415
416        let mut solver = HyperRelaxSolver::new();
417        solver.adaptive = true;
418        solver.cfl = 0.5;
419        solver.dampening = 0.4;
420        solver.max_steps = 1_000_000;
421        solver.tolerance = 1e-4;
422
423        // loop {
424        // if mesh.max_level() > 11 {
425        //     panic!("Poisson mesh solver exceeded max levels");
426        // }
427
428        // let mut result = vec![1.0; mesh.num_nodes()];
429
430        // solver
431        //     .solve(
432        //         &mut mesh,
433        //         Order::<4>,
434        //         PoissonConditions,
435        //         PoissonEquation,
436        //         (&mut result).into(),
437        //     )
438        //     .unwrap();
439
440        // mesh.flag_wavelets::<Scalar>(4, 0.0, 1e-4, result.as_slice().into());
441        // mesh.balance_flags();
442
443        // if mesh.requires_regridding() {
444        //     mesh.regrid();
445        // } else {
446        //     return;
447        // }
448        // }
449    }
450}