1use crate::Array2;
2use core::fmt;
3
4pub struct FieldSet {
6 w: [Array2<f64>; 3],
7 u: [Array2<f64>; 3],
8 v: [Array2<f64>; 3],
9 generation: usize,
10}
11
12pub struct SplitBufs<'a> {
14 pub w_cur: &'a Array2<f64>,
15 pub w_prev: &'a Array2<f64>,
16 pub w_next: &'a mut Array2<f64>,
17 pub u_cur: &'a Array2<f64>,
18 pub u_prev: &'a Array2<f64>,
19 pub u_next: &'a mut Array2<f64>,
20 pub v_cur: &'a Array2<f64>,
21 pub v_prev: &'a Array2<f64>,
22 pub v_next: &'a mut Array2<f64>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum CheckpointError {
27 InvalidBufferShape { field: &'static str },
28}
29
30impl fmt::Display for CheckpointError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 Self::InvalidBufferShape { field } => {
34 write!(f, "checkpoint {field} buffer must match its declared shape")
35 }
36 }
37 }
38}
39
40impl std::error::Error for CheckpointError {}
41
42impl FieldSet {
43 pub fn new(nx: usize, ny: usize) -> Self {
44 Self {
45 w: std::array::from_fn(|_| Array2::zeros((nx, ny))),
46 u: std::array::from_fn(|_| Array2::zeros((nx, ny))),
47 v: std::array::from_fn(|_| Array2::zeros((nx, ny))),
48 generation: 0,
49 }
50 }
51
52 #[inline]
53 pub fn w(&self) -> &Array2<f64> {
54 &self.w[self.generation % 3]
55 }
56
57 #[inline]
58 pub fn w_prev(&self) -> &Array2<f64> {
59 &self.w[(self.generation + 2) % 3]
60 }
61
62 #[inline]
63 pub fn u(&self) -> &Array2<f64> {
64 &self.u[self.generation % 3]
65 }
66
67 #[inline]
68 pub fn u_prev(&self) -> &Array2<f64> {
69 &self.u[(self.generation + 2) % 3]
70 }
71
72 #[inline]
73 pub fn v(&self) -> &Array2<f64> {
74 &self.v[self.generation % 3]
75 }
76
77 #[inline]
78 pub fn v_prev(&self) -> &Array2<f64> {
79 &self.v[(self.generation + 2) % 3]
80 }
81
82 #[inline]
83 pub fn w_mut(&mut self) -> &mut Array2<f64> {
84 let index = self.generation % 3;
85 &mut self.w[index]
86 }
87
88 #[inline]
89 pub fn u_mut(&mut self) -> &mut Array2<f64> {
90 let index = self.generation % 3;
91 &mut self.u[index]
92 }
93
94 #[inline]
95 pub fn v_mut(&mut self) -> &mut Array2<f64> {
96 let index = self.generation % 3;
97 &mut self.v[index]
98 }
99
100 #[inline]
102 pub fn split_bufs(&mut self) -> SplitBufs<'_> {
103 let cur = self.generation % 3;
104 let prev = (self.generation + 2) % 3;
105 let next = (self.generation + 1) % 3;
106 unsafe {
107 let w = self.w.as_mut_ptr();
108 let u = self.u.as_mut_ptr();
109 let v = self.v.as_mut_ptr();
110 SplitBufs {
111 w_cur: &*w.add(cur),
112 w_prev: &*w.add(prev),
113 w_next: &mut *w.add(next),
114 u_cur: &*u.add(cur),
115 u_prev: &*u.add(prev),
116 u_next: &mut *u.add(next),
117 v_cur: &*v.add(cur),
118 v_prev: &*v.add(prev),
119 v_next: &mut *v.add(next),
120 }
121 }
122 }
123
124 #[inline]
125 pub fn advance(&mut self) {
126 self.generation += 1;
127 }
128
129 pub fn to_checkpoint(&self) -> FieldSetCheckpoint {
130 let flatten = |buffers: &[Array2<f64>; 3]| -> [Vec<f64>; 3] {
131 std::array::from_fn(|index| buffers[index].as_slice().to_vec())
132 };
133 let shape = (self.w[0].nrows(), self.w[0].ncols());
134 FieldSetCheckpoint {
135 w: flatten(&self.w),
136 u: flatten(&self.u),
137 v: flatten(&self.v),
138 generation: self.generation,
139 shape,
140 }
141 }
142
143 pub fn restore_checkpoint(
144 &mut self,
145 checkpoint: &FieldSetCheckpoint,
146 ) -> Result<(), CheckpointError> {
147 let (nx, ny) = checkpoint.shape;
148 for index in 0..3 {
149 self.w[index] = Array2::from_shape_vec((nx, ny), checkpoint.w[index].clone())
150 .map_err(|_| CheckpointError::InvalidBufferShape { field: "w" })?;
151 self.u[index] = Array2::from_shape_vec((nx, ny), checkpoint.u[index].clone())
152 .map_err(|_| CheckpointError::InvalidBufferShape { field: "u" })?;
153 self.v[index] = Array2::from_shape_vec((nx, ny), checkpoint.v[index].clone())
154 .map_err(|_| CheckpointError::InvalidBufferShape { field: "v" })?;
155 }
156 self.generation = checkpoint.generation;
157 Ok(())
158 }
159}
160
161#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
162pub struct FieldSetCheckpoint {
163 pub w: [Vec<f64>; 3],
164 pub u: [Vec<f64>; 3],
165 pub v: [Vec<f64>; 3],
166 pub generation: usize,
167 pub shape: (usize, usize),
168}