1use std::convert::Infallible;
2
3use crate::IRef;
4use crate::geometry::{Face, IndexSpace};
5use crate::image::{ImageMut, ImageRef};
6use crate::kernel::{BoundaryKind, DirichletParams, RadiativeParams, SystemBoundaryConds};
7use crate::mesh::FunctionBorrowMut;
8use datasize::DataSize;
9use reborrow::{Reborrow, ReborrowMut};
10use thiserror::Error;
11
12use crate::{
13 mesh::{Engine, Function, Mesh},
14 solver::{Integrator, Method},
15};
16
17use super::SolverCallback;
18
19#[derive(Error, Debug)]
21pub enum HyperRelaxError<A, B> {
22 #[error("failed to relax below tolerance in allotted number of steps")]
23 ReachedMaxSteps,
24 #[error("norm diverged to NaN")]
25 NormDiverged,
26 #[error("function error")]
27 FunctionFailed(A),
28 #[error("callback error")]
29 CallbackFailed(B),
30}
31
32#[derive(Clone, Debug, DataSize)]
36pub struct HyperRelaxSolver {
37 pub tolerance: f64,
39 pub max_steps: usize,
41 pub dampening: f64,
43 pub cfl: f64,
45 pub adaptive: bool,
49
50 integrator: Integrator,
51}
52
53impl Default for HyperRelaxSolver {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl HyperRelaxSolver {
60 pub fn new() -> Self {
62 Self {
63 tolerance: 1e-5,
64 max_steps: 100000,
65 dampening: 1.0,
66 cfl: 0.1,
67 adaptive: false,
68 integrator: Integrator::new(Method::RK4),
70 }
71 }
72
73 pub fn solve<const N: usize, C: SystemBoundaryConds<N> + Sync, F: Function<N> + Sync>(
75 &mut self,
76 mesh: &mut Mesh<N>,
77 order: usize,
78 conditions: C,
79 deriv: F,
80 result: ImageMut,
81 ) -> Result<(), HyperRelaxError<F::Error, Infallible>>
82 where
83 F::Error: Send,
84 {
85 self.solve_with_callback(mesh, order, conditions, (), deriv, result)
86 }
87
88 pub fn solve_with_callback<
89 const N: usize,
90 C: SystemBoundaryConds<N> + Sync,
91 F: Function<N> + Sync,
92 Call: SolverCallback<N>,
93 >(
94 &mut self,
95 mesh: &mut Mesh<N>,
96 order: usize,
97 conditions: C,
98 mut callback: Call,
99 mut deriv: F,
100 mut result: ImageMut,
101 ) -> Result<(), HyperRelaxError<F::Error, Call::Error>>
102 where
103 F::Error: Send,
104 Call::Error: Send,
105 {
106 assert_eq!(result.num_nodes(), mesh.num_nodes());
107 let dimension = result.num_channels() * mesh.num_nodes();
109 let num_channels = result.num_channels();
110
111 let mut data = vec![0.0; 2 * dimension].into_boxed_slice();
115 let min_spacing = mesh.min_spacing();
117
118 let mut spacing_per_vertex = vec![min_spacing; mesh.num_nodes()];
119 if self.adaptive {
120 mesh.spacing_per_vertex(&mut spacing_per_vertex);
121 }
122
123 let time_step = self.cfl * min_spacing;
125
126 {
128 let (u, v) = data.split_at_mut(dimension);
129 mesh.copy_from_slice(ImageMut::from_storage(u, num_channels), result.rb());
131 mesh.copy_from_slice(ImageMut::from_storage(v, num_channels), result.rb());
133 for value in v.iter_mut() {
134 *value *= self.dampening;
135 }
136 }
137
138 for index in 0..self.max_steps {
139 mesh.fill_boundary(
140 order,
141 FicticuousBoundaryConds {
142 dampening: self.dampening,
143 conditions: conditions.clone(),
144 channels: num_channels,
145 },
146 ImageMut::from_storage(&mut data, 2 * num_channels),
147 );
148
149 {
150 let u = ImageRef::from_storage(&data[..dimension], num_channels);
151 mesh.copy_from_slice(result.rb_mut(), u.rb());
152 mesh.apply(
153 order,
154 conditions.clone(),
155 FunctionBorrowMut(&mut deriv),
156 result.rb_mut(),
157 )
158 .map_err(|err| HyperRelaxError::FunctionFailed(err))?;
159 callback
160 .callback(mesh, u.rb(), result.rb(), index)
161 .map_err(|err| HyperRelaxError::CallbackFailed(err))?;
162 }
163
164 let norm = mesh.l2_norm_system(result.rb());
165
166 if !norm.is_finite() || norm >= 1e60 {
167 return Err(HyperRelaxError::NormDiverged);
168 }
169
170 if index % 100 == 0 {
171 log::trace!("Relaxed {}k steps, norm: {:.5e}", index / 100, norm);
172 }
173
174 if norm <= self.tolerance {
175 log::trace!("Converged in {} steps.", index);
176
177 mesh.copy_from_slice(
179 result.rb_mut(),
180 ImageRef::from_storage(&data[..dimension], num_channels),
181 );
182 mesh.fill_boundary(order, conditions, result.rb_mut());
183
184 return Ok(());
185 }
186
187 self.integrator
188 .step(
189 mesh,
190 order,
191 FicticuousBoundaryConds {
192 dampening: self.dampening,
193 conditions: conditions.clone(),
194 channels: num_channels,
195 },
196 FicticuousDerivs {
197 dampening: self.dampening,
198 function: &deriv,
199 spacing_per_vertex: &spacing_per_vertex,
200 min_spacing,
201 },
202 time_step,
203 ImageMut::from_storage(&mut data, 2 * num_channels),
204 )
205 .map_err(|err| HyperRelaxError::FunctionFailed(err))?;
206
207 if index == self.max_steps - 1 {
208 log::error!(
209 "Hyperbolic relaxation failed to converge in {} steps.",
210 self.max_steps
211 );
212 }
213 }
214
215 mesh.copy_from_slice(
217 result.rb_mut(),
218 ImageRef::from_storage(&data[..dimension], num_channels),
219 );
220 mesh.fill_boundary(order, conditions, result.rb_mut());
221
222 Err(HyperRelaxError::ReachedMaxSteps)
223 }
224}
225
226#[derive(Clone)]
227struct FicticuousBoundaryConds<C> {
228 dampening: f64,
229 conditions: C,
230 channels: usize,
231}
232
233impl<const N: usize, C: SystemBoundaryConds<N>> SystemBoundaryConds<N>
234 for FicticuousBoundaryConds<C>
235{
236 fn kind(&self, channel: usize, face: Face<N>) -> BoundaryKind {
237 let boundary_kind: BoundaryKind = self.conditions.kind(channel % self.channels, face);
238
239 match boundary_kind {
240 BoundaryKind::Symmetric => BoundaryKind::Symmetric,
241 BoundaryKind::AntiSymmetric => BoundaryKind::AntiSymmetric,
242 BoundaryKind::Custom => BoundaryKind::Custom,
243 BoundaryKind::Radiative => BoundaryKind::Radiative,
244 BoundaryKind::Free => BoundaryKind::Free,
245 BoundaryKind::StrongDirichlet | BoundaryKind::WeakDirichlet => {
246 BoundaryKind::WeakDirichlet
247 }
248 }
249 }
250
251 fn radiative(&self, channel: usize, position: [f64; N]) -> RadiativeParams {
252 let mut result = self.conditions.radiative(channel % self.channels, position);
253 if channel >= self.channels {
254 result.target *= self.dampening;
255 }
256 result
257 }
258
259 fn dirichlet(&self, channel: usize, position: [f64; N]) -> DirichletParams {
260 let mut result = self.conditions.dirichlet(channel % self.channels, position);
261 if channel >= self.channels {
262 result.target *= self.dampening;
263 }
264 result
265 }
266}
267
268#[derive(Clone)]
269struct FicticuousDerivs<'a, const N: usize, F> {
270 dampening: f64,
271 function: &'a F,
272 spacing_per_vertex: &'a [f64],
273 min_spacing: f64,
274}
275
276impl<const N: usize, F: Function<N>> Function<N> for FicticuousDerivs<'_, N, F> {
277 type Error = F::Error;
278
279 fn evaluate(
280 &self,
281 engine: impl Engine<N>,
282 input: ImageRef,
283 mut output: ImageMut,
284 ) -> Result<(), F::Error> {
285 assert_eq!(input.num_channels(), output.num_channels());
286 let num_channels = output.num_channels();
287 let (uin, vin) = input.split_channels(num_channels / 2);
288 let (mut uout, mut vout) = output.rb_mut().split_channels(num_channels / 2);
289
290 for field in uin.channels() {
292 let u = uin.channel(field);
293 let v = vin.channel(field);
294
295 let udest = uout.channel_mut(field);
296
297 for vertex in IndexSpace::new(engine.vertex_size()).iter() {
298 let index = engine.index_from_vertex(vertex);
299 udest[index] = v[index] - u[index] * self.dampening;
300 }
301 }
302
303 self.function.evaluate(IRef(&engine), uin, vout.rb_mut())?;
306
307 let block_spacing = &self.spacing_per_vertex[engine.node_range()];
309
310 for field in uout.channels() {
311 let uout = uout.channel_mut(field);
312 for vertex in IndexSpace::new(engine.vertex_size()).iter() {
313 let index = engine.index_from_vertex(vertex);
314 uout[index] *= block_spacing[index] / self.min_spacing;
315 }
316 }
317
318 for field in vout.channels() {
319 let vout = vout.channel_mut(field);
320 for vertex in IndexSpace::new(engine.vertex_size()).iter() {
321 let index = engine.index_from_vertex(vertex);
322 vout[index] *= block_spacing[index] / self.min_spacing;
323 }
324 }
325
326 Ok(())
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use crate::{
333 geometry::{FaceArray, HyperBox},
334 kernel::{BoundaryClass, DirichletParams},
335 };
336
337 use super::*;
338 use crate::{kernel::BoundaryKind, mesh::Projection};
339 use std::{convert::Infallible, f64::consts};
340
341 #[derive(Clone)]
342 struct _PoissonConditions;
343
344 impl SystemBoundaryConds<2> for _PoissonConditions {
345 fn kind(&self, _channel: usize, _face: Face<2>) -> BoundaryKind {
346 BoundaryKind::StrongDirichlet
347 }
348
349 fn dirichlet(&self, _channel: usize, _position: [f64; 2]) -> DirichletParams {
350 DirichletParams {
351 target: 0.0,
352 strength: 1.0,
353 }
354 }
355 }
356
357 #[derive(Clone)]
358 pub struct PoissonSolution;
359
360 impl Projection<2> for PoissonSolution {
361 fn project(&self, [x, y]: [f64; 2]) -> f64 {
362 (2.0 * consts::PI * x).sin() * (2.0 * consts::PI * y).sin()
363 }
364 }
365
366 #[derive(Clone)]
367 pub struct _PoissonEquation;
368
369 impl Function<2> for _PoissonEquation {
370 type Error = Infallible;
371
372 fn evaluate(
373 &self,
374 engine: impl Engine<2>,
375 input: ImageRef,
376 mut output: ImageMut,
377 ) -> Result<(), Infallible> {
378 let input = input.channel(0);
379 let output = output.channel_mut(0);
380
381 for vertex in IndexSpace::new(engine.vertex_size()).iter() {
382 let index = engine.index_from_vertex(vertex);
383 let [x, y] = engine.position(vertex);
384
385 let laplacian = engine.second_derivative(input, 0, vertex)
386 + engine.second_derivative(input, 1, vertex);
387 let source = -8.0
388 * consts::PI
389 * consts::PI
390 * (2.0 * consts::PI * x).sin()
391 * (2.0 * consts::PI * y).sin();
392
393 output[index] = laplacian - source;
394 }
395
396 Ok(())
397 }
398 }
399
400 #[test]
401 fn poisson() {
402 let mut mesh = Mesh::new(
403 HyperBox::from_aabb([0.0, 0.0], [1.0, 1.0]),
404 4,
405 2,
406 FaceArray::splat(BoundaryClass::Ghost),
407 );
408 mesh.refine_global();
410 mesh.refine_global();
411
412 let mut solution = vec![0.0; mesh.num_nodes()];
414 mesh.project(4, PoissonSolution, &mut solution);
415
416 let mut solver = HyperRelaxSolver::new();
417 solver.adaptive = true;
418 solver.cfl = 0.5;
419 solver.dampening = 0.4;
420 solver.max_steps = 1_000_000;
421 solver.tolerance = 1e-4;
422
423 }
450}