Skip to main content

neco_gridfield/
field.rs

1use crate::Array2;
2use core::fmt;
3
4/// Störmer-Verlet triple-buffer for `w`, `u`, and `v`.
5pub struct FieldSet {
6    w: [Array2<f64>; 3],
7    u: [Array2<f64>; 3],
8    v: [Array2<f64>; 3],
9    generation: usize,
10}
11
12/// Split borrows: current / previous (read) and next (write).
13pub struct SplitBufs<'a> {
14    pub w_cur: &'a Array2<f64>,
15    pub w_prev: &'a Array2<f64>,
16    pub w_next: &'a mut Array2<f64>,
17    pub u_cur: &'a Array2<f64>,
18    pub u_prev: &'a Array2<f64>,
19    pub u_next: &'a mut Array2<f64>,
20    pub v_cur: &'a Array2<f64>,
21    pub v_prev: &'a Array2<f64>,
22    pub v_next: &'a mut Array2<f64>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum CheckpointError {
27    InvalidBufferShape { field: &'static str },
28}
29
30impl fmt::Display for CheckpointError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            Self::InvalidBufferShape { field } => {
34                write!(f, "checkpoint {field} buffer must match its declared shape")
35            }
36        }
37    }
38}
39
40impl std::error::Error for CheckpointError {}
41
42impl FieldSet {
43    pub fn new(nx: usize, ny: usize) -> Self {
44        Self {
45            w: std::array::from_fn(|_| Array2::zeros((nx, ny))),
46            u: std::array::from_fn(|_| Array2::zeros((nx, ny))),
47            v: std::array::from_fn(|_| Array2::zeros((nx, ny))),
48            generation: 0,
49        }
50    }
51
52    #[inline]
53    pub fn w(&self) -> &Array2<f64> {
54        &self.w[self.generation % 3]
55    }
56
57    #[inline]
58    pub fn w_prev(&self) -> &Array2<f64> {
59        &self.w[(self.generation + 2) % 3]
60    }
61
62    #[inline]
63    pub fn u(&self) -> &Array2<f64> {
64        &self.u[self.generation % 3]
65    }
66
67    #[inline]
68    pub fn u_prev(&self) -> &Array2<f64> {
69        &self.u[(self.generation + 2) % 3]
70    }
71
72    #[inline]
73    pub fn v(&self) -> &Array2<f64> {
74        &self.v[self.generation % 3]
75    }
76
77    #[inline]
78    pub fn v_prev(&self) -> &Array2<f64> {
79        &self.v[(self.generation + 2) % 3]
80    }
81
82    #[inline]
83    pub fn w_mut(&mut self) -> &mut Array2<f64> {
84        let index = self.generation % 3;
85        &mut self.w[index]
86    }
87
88    #[inline]
89    pub fn u_mut(&mut self) -> &mut Array2<f64> {
90        let index = self.generation % 3;
91        &mut self.u[index]
92    }
93
94    #[inline]
95    pub fn v_mut(&mut self) -> &mut Array2<f64> {
96        let index = self.generation % 3;
97        &mut self.v[index]
98    }
99
100    /// Safety: current, previous, and next are distinct indices.
101    #[inline]
102    pub fn split_bufs(&mut self) -> SplitBufs<'_> {
103        let cur = self.generation % 3;
104        let prev = (self.generation + 2) % 3;
105        let next = (self.generation + 1) % 3;
106        unsafe {
107            let w = self.w.as_mut_ptr();
108            let u = self.u.as_mut_ptr();
109            let v = self.v.as_mut_ptr();
110            SplitBufs {
111                w_cur: &*w.add(cur),
112                w_prev: &*w.add(prev),
113                w_next: &mut *w.add(next),
114                u_cur: &*u.add(cur),
115                u_prev: &*u.add(prev),
116                u_next: &mut *u.add(next),
117                v_cur: &*v.add(cur),
118                v_prev: &*v.add(prev),
119                v_next: &mut *v.add(next),
120            }
121        }
122    }
123
124    #[inline]
125    pub fn advance(&mut self) {
126        self.generation += 1;
127    }
128
129    pub fn to_checkpoint(&self) -> FieldSetCheckpoint {
130        let flatten = |buffers: &[Array2<f64>; 3]| -> [Vec<f64>; 3] {
131            std::array::from_fn(|index| buffers[index].as_slice().to_vec())
132        };
133        let shape = (self.w[0].nrows(), self.w[0].ncols());
134        FieldSetCheckpoint {
135            w: flatten(&self.w),
136            u: flatten(&self.u),
137            v: flatten(&self.v),
138            generation: self.generation,
139            shape,
140        }
141    }
142
143    pub fn restore_checkpoint(
144        &mut self,
145        checkpoint: &FieldSetCheckpoint,
146    ) -> Result<(), CheckpointError> {
147        let (nx, ny) = checkpoint.shape;
148        for index in 0..3 {
149            self.w[index] = Array2::from_shape_vec((nx, ny), checkpoint.w[index].clone())
150                .map_err(|_| CheckpointError::InvalidBufferShape { field: "w" })?;
151            self.u[index] = Array2::from_shape_vec((nx, ny), checkpoint.u[index].clone())
152                .map_err(|_| CheckpointError::InvalidBufferShape { field: "u" })?;
153            self.v[index] = Array2::from_shape_vec((nx, ny), checkpoint.v[index].clone())
154                .map_err(|_| CheckpointError::InvalidBufferShape { field: "v" })?;
155        }
156        self.generation = checkpoint.generation;
157        Ok(())
158    }
159}
160
161#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
162pub struct FieldSetCheckpoint {
163    pub w: [Vec<f64>; 3],
164    pub u: [Vec<f64>; 3],
165    pub v: [Vec<f64>; 3],
166    pub generation: usize,
167    pub shape: (usize, usize),
168}