Skip to main content

aeon_tk/mesh/
checkpoint.rs

1use crate::image::{Image, ImageRef};
2use crate::mesh::{Mesh, MeshSer};
3use crate::prelude::IndexSpace;
4use bincode::{Decode, Encode};
5use ron::ser::PrettyConfig;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::Write as _;
9use std::fs::File;
10use std::io::{Read as _, Write as _};
11use std::path::Path;
12use std::str::FromStr;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum CheckpointParseError {
17    #[error("Invalid key")]
18    InvalidKey,
19    #[error("Failed to parse {0}")]
20    ParseFailed(String),
21}
22
23/// Represents a snapshot of a `Mesh` along with relavent field data.
24///
25/// The `Checkpoint<N>` can then be serialized and deserialized from disk,
26/// allowing data to be loaded/reused across runs.
27#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct Checkpoint<const N: usize> {
29    /// Meta data to be stored in checkpoint (useful for storing time, number of steps, ect.)
30    meta: HashMap<String, String>,
31    /// Mesh data to be attached to checkpoint,
32    mesh: Option<MeshSer<N>>,
33    /// Is this mesh embedded in a higher dimensional mesh?
34    embedding: Option<Embedding>,
35    /// Systems which are stored in the checkpoint.
36    systems: HashMap<String, ImageMeta>,
37    /// Fields which are stored in the checkpoint
38    fields: HashMap<String, Vec<f64>>,
39    /// Int fields (useful for debugging) which are stored in the checkpoint.
40    int_fields: HashMap<String, Vec<i64>>,
41}
42
43impl<const N: usize> Checkpoint<N> {
44    /// Attaches a mesh to the checkpoint
45    pub fn attach_mesh(&mut self, mesh: &Mesh<N>) {
46        self.mesh.replace(MeshSer {
47            tree: mesh.tree.clone().into(),
48            width: mesh.width,
49            ghost: mesh.ghost,
50            boundary: mesh.boundary,
51        });
52    }
53
54    /// Sets the mesh as embedded in a higher dimensional space.
55    pub fn set_embedding<const S: usize>(&mut self, positions: &[[f64; S]]) {
56        self.embedding.replace(Embedding {
57            dimension: S,
58            positions: positions.iter().flatten().cloned().collect(),
59        });
60    }
61
62    /// Clones the mesh attached to the checkpoint.
63    pub fn read_mesh(&self) -> Mesh<N> {
64        self.mesh.clone().unwrap().into()
65    }
66
67    /// Saves image data with the given name key.
68    pub fn save_image(&mut self, name: &str, data: ImageRef) {
69        assert!(!self.systems.contains_key(name));
70
71        let num_channels = data.num_channels();
72        let buffer = data
73            .channels()
74            .flat_map(|label| data.channel(label).iter().cloned())
75            .collect();
76
77        self.systems.insert(
78            name.to_string(),
79            ImageMeta {
80                channels: num_channels,
81                buffer,
82            },
83        );
84    }
85
86    /// Reads image data associated to the given name.
87    pub fn read_image(&self, name: &str) -> Image {
88        let data = self.systems.get(name).unwrap();
89        // let system = ron::de::from_str::<S>(&data.meta).unwrap();
90
91        Image::from_storage(data.buffer.clone(), data.channels)
92    }
93
94    // pub fn save_system_default<S: System + Default>(&mut self, data: SystemSlice<S>) {
95    //     assert!(!self.systems.contains_key(S::NAME));
96
97    //     let count = data.len();
98    //     let buffer = data
99    //         .system()
100    //         .enumerate()
101    //         .flat_map(|label| data.field(label).iter().cloned())
102    //         .collect();
103
104    //     let fields = data
105    //         .system()
106    //         .enumerate()
107    //         .map(|label| data.system().label_name(label))
108    //         .collect();
109
110    //     self.systems.insert(
111    //         S::NAME.to_string(),
112    //         SystemMeta {
113    //             meta: String::new(),
114    //             count,
115    //             buffer,
116    //             fields,
117    //         },
118    //     );
119    // }
120
121    // pub fn read_system_default<S: System + Default>(&mut self) -> SystemVec<S> {
122    //     let data = self.systems.get(S::NAME).unwrap();
123    //     SystemVec::from_contiguous(data.buffer.clone(), S::default())
124    // }
125
126    /// Attaches a field for serialization in the model.
127    pub fn save_field(&mut self, name: &str, data: &[f64]) {
128        assert!(!self.fields.contains_key(name));
129        self.fields.insert(name.to_string(), data.to_vec());
130    }
131
132    /// Reads a field from the model.
133    pub fn load_field(&self, name: &str, data: &mut Vec<f64>) {
134        data.clear();
135        data.extend_from_slice(self.fields.get(name).unwrap());
136    }
137
138    /// Reads a field from the model.
139    pub fn read_field(&self, name: &str) -> Vec<f64> {
140        let mut result = Vec::new();
141        self.load_field(name, &mut result);
142        result
143    }
144
145    /// Attaches an integer field for serialization in the checkpoint.
146    pub fn save_int_field(&mut self, name: &str, data: &[i64]) {
147        assert!(!self.int_fields.contains_key(name));
148        self.int_fields.insert(name.to_string(), data.to_vec());
149    }
150
151    /// Reads an integer field from the checkpoint.
152    pub fn load_int_field(&self, name: &str, data: &mut Vec<i64>) {
153        data.clear();
154        data.extend_from_slice(self.int_fields.get(name).unwrap());
155    }
156
157    /// Saves meta data associated with a given string.
158    pub fn save_meta(&mut self, name: &str, data: &str) {
159        let _ = self.meta.insert(name.to_string(), data.to_string());
160    }
161
162    /// Loads meta data associated with a given string into `data`
163    pub fn load_meta(&self, name: &str, data: &mut String) {
164        data.clone_from(self.meta.get(name).unwrap())
165    }
166
167    /// Writes meta data associated with a given type to the checkpoint.
168    pub fn write_meta<T: ToString>(&mut self, name: &str, data: T) {
169        self.save_meta(name, &data.to_string());
170    }
171
172    /// Reads meta data associated with a given type from the checkpoint.
173    pub fn read_meta<T: FromStr>(&self, name: &str) -> Result<T, CheckpointParseError> {
174        let data = self
175            .meta
176            .get(name)
177            .ok_or(CheckpointParseError::InvalidKey)?;
178
179        data.parse()
180            .map_err(|_| CheckpointParseError::ParseFailed(data.clone()))
181    }
182
183    /// Loads the mesh and any additional data from disk.
184    pub fn import_dat(path: impl AsRef<Path>) -> std::io::Result<Self> {
185        let mut contents: String = String::new();
186        let mut file = File::open(path)?;
187        file.read_to_string(&mut contents)?;
188
189        ron::from_str(&contents).map_err(std::io::Error::other)
190    }
191
192    /// Saves the checkpoint as a dat file at the given path.
193    pub fn export_dat(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
194        let data = ron::ser::to_string_pretty::<Checkpoint<N>>(self, PrettyConfig::default())
195            .map_err(std::io::Error::other)?;
196        let mut file = File::create(path)?;
197        file.write_all(data.as_bytes())
198    }
199}
200
201/// Metadata required for storing a system.
202#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
203pub struct ImageMeta {
204    pub channels: usize,
205    pub buffer: Vec<f64>,
206}
207
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct Embedding {
210    dimension: usize,
211    positions: Vec<f64>,
212}
213
214use num_traits::ToPrimitive;
215use vtkio::{
216    IOBuffer, Vtk,
217    model::{
218        Attribute, Attributes, ByteOrder, CellType, Cells, DataArrayBase, DataSet, ElementType,
219        Piece, UnstructuredGridPiece, Version, VertexNumbers,
220    },
221};
222
223#[derive(Clone, Debug)]
224pub struct ExportVtuConfig {
225    pub title: String,
226    pub ghost: bool,
227    pub stride: ExportStride,
228}
229
230impl Default for ExportVtuConfig {
231    fn default() -> Self {
232        Self {
233            title: "Title".to_string(),
234            ghost: false,
235            stride: ExportStride::PerVertex,
236        }
237    }
238}
239
240#[derive(Serialize, Deserialize, Debug, Clone, Copy, Encode, Decode)]
241pub enum ExportStride {
242    /// Output data for every vertex in the simulation
243    #[serde(rename = "per_vertex")]
244    PerVertex,
245    /// Output data for each corner of a cell in the simulation
246    /// This is significantly more compressed
247    #[serde(rename = "per_cell")]
248    PerCell,
249}
250
251impl<const N: usize> Checkpoint<N> {
252    pub fn export_csv(&self, path: impl AsRef<Path>, stride: ExportStride) -> std::io::Result<()> {
253        let mesh: Mesh<N> = self.mesh.clone().unwrap().into();
254        let stride = match stride {
255            ExportStride::PerVertex => 1,
256            ExportStride::PerCell => mesh.width,
257        };
258
259        let mut wtr = csv::Writer::from_path(path)?;
260        let mut header = (0..N)
261            .into_iter()
262            .map(|i| format!("Coord{i}"))
263            .collect::<Vec<String>>();
264
265        for field in self.fields.keys() {
266            header.push(field.clone());
267        }
268
269        wtr.write_record(header.iter())?;
270
271        let mut buffer = String::new();
272
273        for block in mesh.blocks.indices() {
274            let space = mesh.block_space(block);
275            let nodes = mesh.block_nodes(block);
276            let window = space.inner_window();
277
278            'window: for node in window {
279                for axis in 0..N {
280                    if node[axis] % (stride as isize) != 0 {
281                        continue 'window;
282                    }
283                }
284
285                let index = space.index_from_node(node);
286                let position = space.position(node);
287
288                for i in 0..N {
289                    buffer.clear();
290                    write!(&mut buffer, "{}", position[i]).unwrap();
291                    wtr.write_field(&buffer)?;
292                }
293
294                for (i, (name, data)) in self.fields.iter().enumerate() {
295                    debug_assert_eq!(&header[i + N], name);
296
297                    let value = data[nodes.clone()][index];
298                    buffer.clear();
299                    write!(&mut buffer, "{}", value).unwrap();
300                    wtr.write_field(&buffer)?;
301                }
302
303                wtr.write_record::<&[String], &String>(&[])?;
304            }
305        }
306
307        Ok(())
308    }
309
310    /// Checkpoint and additional field data to a .vtu file, for visualisation in applications like
311    /// Paraview. This requires a mesh be attached to the checkpoint.
312    pub fn export_vtu(
313        &self,
314        path: impl AsRef<Path>,
315        config: ExportVtuConfig,
316    ) -> std::io::Result<()> {
317        const {
318            assert!(N > 0 && N <= 2, "Vtu Output only supported for 0 < N ≤ 2");
319        }
320        assert!(self.mesh.is_some(), "Mesh must be attached to checkpoint");
321
322        // Uncompress mesh.
323        let mesh: Mesh<N> = self.mesh.clone().unwrap().into();
324
325        let stride = match config.stride {
326            ExportStride::PerVertex => 1,
327            ExportStride::PerCell => mesh.width,
328        };
329
330        assert!(stride <= mesh.width, "Stride must be <= width");
331        assert!(
332            mesh.width % stride == 0,
333            "Width must be evenly divided by stride"
334        );
335        assert!(
336            !config.ghost || mesh.ghost % stride == 0,
337            "Ghost must be evenly divided by stride"
338        );
339
340        // Generate Cells
341        let cells = Self::mesh_cells(&mesh, config.ghost, stride);
342        // Generate Point Data
343        let points = match self.embedding {
344            Some(ref embedding) => {
345                Self::mesh_points_embedded(&mesh, embedding, config.ghost, stride)
346            }
347            None => Self::mesh_points(&mesh, config.ghost, stride),
348        };
349        // Attributes
350        let mut attributes = Attributes {
351            point: Vec::new(),
352            cell: Vec::new(),
353        };
354
355        // for (name, system) in self.systems.iter() {
356        //     for (idx, field) in system.fields.iter().enumerate() {
357        //         let start = idx * system.count;
358        //         let end = idx * system.count + system.count;
359
360        //         attributes.point.push(Self::field_attribute(
361        //             &mesh,
362        //             format!("{}::{}", name, field),
363        //             &system.buffer[start..end],
364        //             config.ghost,
365        //             stride,
366        //         ));
367        //     }
368        // }
369
370        for (name, system) in self.fields.iter() {
371            attributes.point.push(Self::field_attribute(
372                &mesh,
373                format!("Field::{}", name),
374                system,
375                config.ghost,
376                stride,
377            ));
378        }
379
380        for (name, system) in self.int_fields.iter() {
381            attributes.point.push(Self::field_attribute(
382                &mesh,
383                format!("IntField::{}", name),
384                system,
385                config.ghost,
386                stride,
387            ));
388        }
389
390        let mut pieces = Vec::new();
391        // Primary piece
392        pieces.push(Piece::Inline(Box::new(UnstructuredGridPiece {
393            points,
394            cells,
395            data: attributes,
396        })));
397
398        let model = Vtk {
399            version: Version::XML { major: 2, minor: 2 },
400            title: config.title,
401            byte_order: ByteOrder::LittleEndian,
402            data: DataSet::UnstructuredGrid { meta: None, pieces },
403            file_path: None,
404        };
405
406        model.export(path).map_err(|i| match i {
407            vtkio::Error::IO(io) => io,
408            v => {
409                log::error!("Encountered error {:?} while exporting vtu", v);
410                std::io::Error::from(std::io::ErrorKind::Other)
411            }
412        })?;
413
414        Ok(())
415    }
416
417    fn mesh_cells(mesh: &Mesh<N>, ghost: bool, stride: usize) -> Cells {
418        let mut connectivity = Vec::new();
419        let mut offsets = Vec::new();
420
421        let mut vertex_total = 0;
422        let mut cell_total = 0;
423
424        for block in mesh.blocks.indices() {
425            let space = mesh.block_space(block);
426
427            let mut cell_size = space.cell_size();
428            let mut vertex_size = space.vertex_size();
429
430            if ghost {
431                for axis in 0..N {
432                    cell_size[axis] += 2 * space.ghost();
433                    vertex_size[axis] += 2 * space.ghost();
434                }
435            }
436
437            for axis in 0..N {
438                debug_assert!(cell_size[axis] % stride == 0);
439                debug_assert!((vertex_size[axis] - 1) % stride == 0);
440
441                cell_size[axis] /= stride;
442                vertex_size[axis] = (vertex_size[axis] - 1) / stride + 1;
443            }
444
445            let cell_space = IndexSpace::new(cell_size);
446            let vertex_space = IndexSpace::new(vertex_size);
447
448            for cell in cell_space.iter() {
449                let mut vertex = [0; N];
450
451                if N == 1 {
452                    vertex[0] = cell[0];
453                    let v1 = vertex_space.linear_from_cartesian(vertex);
454                    vertex[0] = cell[0] + 1;
455                    let v2 = vertex_space.linear_from_cartesian(vertex);
456
457                    connectivity.push(vertex_total + v1 as u64);
458                    connectivity.push(vertex_total + v2 as u64);
459                } else if N == 2 {
460                    vertex[0] = cell[0];
461                    vertex[1] = cell[1];
462                    let v1 = vertex_space.linear_from_cartesian(vertex);
463                    vertex[0] = cell[0];
464                    vertex[1] = cell[1] + 1;
465                    let v2 = vertex_space.linear_from_cartesian(vertex);
466                    vertex[0] = cell[0] + 1;
467                    vertex[1] = cell[1] + 1;
468                    let v3 = vertex_space.linear_from_cartesian(vertex);
469                    vertex[0] = cell[0] + 1;
470                    vertex[1] = cell[1];
471                    let v4 = vertex_space.linear_from_cartesian(vertex);
472
473                    connectivity.push(vertex_total + v1 as u64);
474                    connectivity.push(vertex_total + v2 as u64);
475                    connectivity.push(vertex_total + v3 as u64);
476                    connectivity.push(vertex_total + v4 as u64);
477                }
478
479                offsets.push(connectivity.len() as u64);
480            }
481
482            cell_total += cell_space.index_count();
483            vertex_total += vertex_space.index_count() as u64;
484        }
485
486        let cell_type = match N {
487            1 => CellType::Line,
488            2 => CellType::Quad,
489            // 3 => CellType::Hexahedron,
490            _ => panic!("Unsupported dimension"),
491        };
492
493        Cells {
494            cell_verts: VertexNumbers::XML {
495                connectivity,
496                offsets,
497            },
498            types: vec![cell_type; cell_total],
499        }
500    }
501
502    fn mesh_points(mesh: &Mesh<N>, ghost: bool, stride: usize) -> IOBuffer {
503        // Generate point data
504        let mut vertices = Vec::new();
505
506        for block in mesh.blocks.indices() {
507            let space = mesh.block_space(block);
508            let window = if ghost {
509                space.full_window()
510            } else {
511                space.inner_window()
512            };
513
514            'window: for node in window {
515                for axis in 0..N {
516                    if node[axis] % (stride as isize) != 0 {
517                        continue 'window;
518                    }
519                }
520
521                let position = space.position(node);
522                let mut vertex = [0.0; 3];
523                vertex[..N].copy_from_slice(&position);
524                vertices.extend(vertex);
525            }
526        }
527
528        IOBuffer::new(vertices)
529    }
530
531    fn mesh_points_embedded(
532        mesh: &Mesh<N>,
533        embedding: &Embedding,
534        ghost: bool,
535        stride: usize,
536    ) -> IOBuffer {
537        let dim = embedding.dimension;
538        assert!(dim <= 3);
539
540        // Generate point data
541        let mut vertices = Vec::new();
542
543        for block in mesh.blocks.indices() {
544            let space = mesh.block_space(block);
545            let nodes = mesh.block_nodes(block);
546            let window = if ghost {
547                space.full_window()
548            } else {
549                space.inner_window()
550            };
551
552            let block_positions = &embedding.positions[nodes.start * dim..nodes.end * dim];
553
554            'window: for node in window {
555                let index = space.index_from_node(node);
556                let position = &block_positions[index * dim..(index + 1) * dim];
557
558                for axis in 0..N {
559                    if node[axis] % (stride as isize) != 0 {
560                        continue 'window;
561                    }
562                }
563
564                let mut vertex = [0.0; 3];
565                vertex[..dim].copy_from_slice(&position);
566                vertices.extend(vertex);
567            }
568        }
569
570        IOBuffer::new(vertices)
571    }
572
573    fn field_attribute<T: ToPrimitive + Copy + 'static>(
574        mesh: &Mesh<N>,
575        name: String,
576        data: &[T],
577        ghost: bool,
578        stride: usize,
579    ) -> Attribute {
580        let mut buffer = Vec::new();
581
582        for block in mesh.blocks.indices() {
583            let space = mesh.block_space(block);
584            let nodes = mesh.block_nodes(block);
585            let window = if ghost {
586                space.full_window()
587            } else {
588                space.inner_window()
589            };
590
591            'window: for node in window {
592                for axis in 0..N {
593                    if node[axis] % (stride as isize) != 0 {
594                        continue 'window;
595                    }
596                }
597
598                let index = space.index_from_node(node);
599                let value = data[nodes.clone()][index];
600                buffer.push(value);
601            }
602        }
603
604        Attribute::DataArray(DataArrayBase {
605            name,
606            elem: ElementType::Scalars {
607                num_comp: 1,
608                lookup_table: None,
609            },
610            data: IOBuffer::new(buffer),
611        })
612    }
613}