1use crate::image::{Image, ImageRef};
2use crate::mesh::{Mesh, MeshSer};
3use crate::prelude::IndexSpace;
4use bincode::{Decode, Encode};
5use ron::ser::PrettyConfig;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::Write as _;
9use std::fs::File;
10use std::io::{Read as _, Write as _};
11use std::path::Path;
12use std::str::FromStr;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum CheckpointParseError {
17 #[error("Invalid key")]
18 InvalidKey,
19 #[error("Failed to parse {0}")]
20 ParseFailed(String),
21}
22
23#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct Checkpoint<const N: usize> {
29 meta: HashMap<String, String>,
31 mesh: Option<MeshSer<N>>,
33 embedding: Option<Embedding>,
35 systems: HashMap<String, ImageMeta>,
37 fields: HashMap<String, Vec<f64>>,
39 int_fields: HashMap<String, Vec<i64>>,
41}
42
43impl<const N: usize> Checkpoint<N> {
44 pub fn attach_mesh(&mut self, mesh: &Mesh<N>) {
46 self.mesh.replace(MeshSer {
47 tree: mesh.tree.clone().into(),
48 width: mesh.width,
49 ghost: mesh.ghost,
50 boundary: mesh.boundary,
51 });
52 }
53
54 pub fn set_embedding<const S: usize>(&mut self, positions: &[[f64; S]]) {
56 self.embedding.replace(Embedding {
57 dimension: S,
58 positions: positions.iter().flatten().cloned().collect(),
59 });
60 }
61
62 pub fn read_mesh(&self) -> Mesh<N> {
64 self.mesh.clone().unwrap().into()
65 }
66
67 pub fn save_image(&mut self, name: &str, data: ImageRef) {
69 assert!(!self.systems.contains_key(name));
70
71 let num_channels = data.num_channels();
72 let buffer = data
73 .channels()
74 .flat_map(|label| data.channel(label).iter().cloned())
75 .collect();
76
77 self.systems.insert(
78 name.to_string(),
79 ImageMeta {
80 channels: num_channels,
81 buffer,
82 },
83 );
84 }
85
86 pub fn read_image(&self, name: &str) -> Image {
88 let data = self.systems.get(name).unwrap();
89 Image::from_storage(data.buffer.clone(), data.channels)
92 }
93
94 pub fn save_field(&mut self, name: &str, data: &[f64]) {
128 assert!(!self.fields.contains_key(name));
129 self.fields.insert(name.to_string(), data.to_vec());
130 }
131
132 pub fn load_field(&self, name: &str, data: &mut Vec<f64>) {
134 data.clear();
135 data.extend_from_slice(self.fields.get(name).unwrap());
136 }
137
138 pub fn read_field(&self, name: &str) -> Vec<f64> {
140 let mut result = Vec::new();
141 self.load_field(name, &mut result);
142 result
143 }
144
145 pub fn save_int_field(&mut self, name: &str, data: &[i64]) {
147 assert!(!self.int_fields.contains_key(name));
148 self.int_fields.insert(name.to_string(), data.to_vec());
149 }
150
151 pub fn load_int_field(&self, name: &str, data: &mut Vec<i64>) {
153 data.clear();
154 data.extend_from_slice(self.int_fields.get(name).unwrap());
155 }
156
157 pub fn save_meta(&mut self, name: &str, data: &str) {
159 let _ = self.meta.insert(name.to_string(), data.to_string());
160 }
161
162 pub fn load_meta(&self, name: &str, data: &mut String) {
164 data.clone_from(self.meta.get(name).unwrap())
165 }
166
167 pub fn write_meta<T: ToString>(&mut self, name: &str, data: T) {
169 self.save_meta(name, &data.to_string());
170 }
171
172 pub fn read_meta<T: FromStr>(&self, name: &str) -> Result<T, CheckpointParseError> {
174 let data = self
175 .meta
176 .get(name)
177 .ok_or(CheckpointParseError::InvalidKey)?;
178
179 data.parse()
180 .map_err(|_| CheckpointParseError::ParseFailed(data.clone()))
181 }
182
183 pub fn import_dat(path: impl AsRef<Path>) -> std::io::Result<Self> {
185 let mut contents: String = String::new();
186 let mut file = File::open(path)?;
187 file.read_to_string(&mut contents)?;
188
189 ron::from_str(&contents).map_err(std::io::Error::other)
190 }
191
192 pub fn export_dat(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
194 let data = ron::ser::to_string_pretty::<Checkpoint<N>>(self, PrettyConfig::default())
195 .map_err(std::io::Error::other)?;
196 let mut file = File::create(path)?;
197 file.write_all(data.as_bytes())
198 }
199}
200
201#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
203pub struct ImageMeta {
204 pub channels: usize,
205 pub buffer: Vec<f64>,
206}
207
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct Embedding {
210 dimension: usize,
211 positions: Vec<f64>,
212}
213
214use num_traits::ToPrimitive;
215use vtkio::{
216 IOBuffer, Vtk,
217 model::{
218 Attribute, Attributes, ByteOrder, CellType, Cells, DataArrayBase, DataSet, ElementType,
219 Piece, UnstructuredGridPiece, Version, VertexNumbers,
220 },
221};
222
223#[derive(Clone, Debug)]
224pub struct ExportVtuConfig {
225 pub title: String,
226 pub ghost: bool,
227 pub stride: ExportStride,
228}
229
230impl Default for ExportVtuConfig {
231 fn default() -> Self {
232 Self {
233 title: "Title".to_string(),
234 ghost: false,
235 stride: ExportStride::PerVertex,
236 }
237 }
238}
239
240#[derive(Serialize, Deserialize, Debug, Clone, Copy, Encode, Decode)]
241pub enum ExportStride {
242 #[serde(rename = "per_vertex")]
244 PerVertex,
245 #[serde(rename = "per_cell")]
248 PerCell,
249}
250
251impl<const N: usize> Checkpoint<N> {
252 pub fn export_csv(&self, path: impl AsRef<Path>, stride: ExportStride) -> std::io::Result<()> {
253 let mesh: Mesh<N> = self.mesh.clone().unwrap().into();
254 let stride = match stride {
255 ExportStride::PerVertex => 1,
256 ExportStride::PerCell => mesh.width,
257 };
258
259 let mut wtr = csv::Writer::from_path(path)?;
260 let mut header = (0..N)
261 .into_iter()
262 .map(|i| format!("Coord{i}"))
263 .collect::<Vec<String>>();
264
265 for field in self.fields.keys() {
266 header.push(field.clone());
267 }
268
269 wtr.write_record(header.iter())?;
270
271 let mut buffer = String::new();
272
273 for block in mesh.blocks.indices() {
274 let space = mesh.block_space(block);
275 let nodes = mesh.block_nodes(block);
276 let window = space.inner_window();
277
278 'window: for node in window {
279 for axis in 0..N {
280 if node[axis] % (stride as isize) != 0 {
281 continue 'window;
282 }
283 }
284
285 let index = space.index_from_node(node);
286 let position = space.position(node);
287
288 for i in 0..N {
289 buffer.clear();
290 write!(&mut buffer, "{}", position[i]).unwrap();
291 wtr.write_field(&buffer)?;
292 }
293
294 for (i, (name, data)) in self.fields.iter().enumerate() {
295 debug_assert_eq!(&header[i + N], name);
296
297 let value = data[nodes.clone()][index];
298 buffer.clear();
299 write!(&mut buffer, "{}", value).unwrap();
300 wtr.write_field(&buffer)?;
301 }
302
303 wtr.write_record::<&[String], &String>(&[])?;
304 }
305 }
306
307 Ok(())
308 }
309
310 pub fn export_vtu(
313 &self,
314 path: impl AsRef<Path>,
315 config: ExportVtuConfig,
316 ) -> std::io::Result<()> {
317 const {
318 assert!(N > 0 && N <= 2, "Vtu Output only supported for 0 < N ≤ 2");
319 }
320 assert!(self.mesh.is_some(), "Mesh must be attached to checkpoint");
321
322 let mesh: Mesh<N> = self.mesh.clone().unwrap().into();
324
325 let stride = match config.stride {
326 ExportStride::PerVertex => 1,
327 ExportStride::PerCell => mesh.width,
328 };
329
330 assert!(stride <= mesh.width, "Stride must be <= width");
331 assert!(
332 mesh.width % stride == 0,
333 "Width must be evenly divided by stride"
334 );
335 assert!(
336 !config.ghost || mesh.ghost % stride == 0,
337 "Ghost must be evenly divided by stride"
338 );
339
340 let cells = Self::mesh_cells(&mesh, config.ghost, stride);
342 let points = match self.embedding {
344 Some(ref embedding) => {
345 Self::mesh_points_embedded(&mesh, embedding, config.ghost, stride)
346 }
347 None => Self::mesh_points(&mesh, config.ghost, stride),
348 };
349 let mut attributes = Attributes {
351 point: Vec::new(),
352 cell: Vec::new(),
353 };
354
355 for (name, system) in self.fields.iter() {
371 attributes.point.push(Self::field_attribute(
372 &mesh,
373 format!("Field::{}", name),
374 system,
375 config.ghost,
376 stride,
377 ));
378 }
379
380 for (name, system) in self.int_fields.iter() {
381 attributes.point.push(Self::field_attribute(
382 &mesh,
383 format!("IntField::{}", name),
384 system,
385 config.ghost,
386 stride,
387 ));
388 }
389
390 let mut pieces = Vec::new();
391 pieces.push(Piece::Inline(Box::new(UnstructuredGridPiece {
393 points,
394 cells,
395 data: attributes,
396 })));
397
398 let model = Vtk {
399 version: Version::XML { major: 2, minor: 2 },
400 title: config.title,
401 byte_order: ByteOrder::LittleEndian,
402 data: DataSet::UnstructuredGrid { meta: None, pieces },
403 file_path: None,
404 };
405
406 model.export(path).map_err(|i| match i {
407 vtkio::Error::IO(io) => io,
408 v => {
409 log::error!("Encountered error {:?} while exporting vtu", v);
410 std::io::Error::from(std::io::ErrorKind::Other)
411 }
412 })?;
413
414 Ok(())
415 }
416
417 fn mesh_cells(mesh: &Mesh<N>, ghost: bool, stride: usize) -> Cells {
418 let mut connectivity = Vec::new();
419 let mut offsets = Vec::new();
420
421 let mut vertex_total = 0;
422 let mut cell_total = 0;
423
424 for block in mesh.blocks.indices() {
425 let space = mesh.block_space(block);
426
427 let mut cell_size = space.cell_size();
428 let mut vertex_size = space.vertex_size();
429
430 if ghost {
431 for axis in 0..N {
432 cell_size[axis] += 2 * space.ghost();
433 vertex_size[axis] += 2 * space.ghost();
434 }
435 }
436
437 for axis in 0..N {
438 debug_assert!(cell_size[axis] % stride == 0);
439 debug_assert!((vertex_size[axis] - 1) % stride == 0);
440
441 cell_size[axis] /= stride;
442 vertex_size[axis] = (vertex_size[axis] - 1) / stride + 1;
443 }
444
445 let cell_space = IndexSpace::new(cell_size);
446 let vertex_space = IndexSpace::new(vertex_size);
447
448 for cell in cell_space.iter() {
449 let mut vertex = [0; N];
450
451 if N == 1 {
452 vertex[0] = cell[0];
453 let v1 = vertex_space.linear_from_cartesian(vertex);
454 vertex[0] = cell[0] + 1;
455 let v2 = vertex_space.linear_from_cartesian(vertex);
456
457 connectivity.push(vertex_total + v1 as u64);
458 connectivity.push(vertex_total + v2 as u64);
459 } else if N == 2 {
460 vertex[0] = cell[0];
461 vertex[1] = cell[1];
462 let v1 = vertex_space.linear_from_cartesian(vertex);
463 vertex[0] = cell[0];
464 vertex[1] = cell[1] + 1;
465 let v2 = vertex_space.linear_from_cartesian(vertex);
466 vertex[0] = cell[0] + 1;
467 vertex[1] = cell[1] + 1;
468 let v3 = vertex_space.linear_from_cartesian(vertex);
469 vertex[0] = cell[0] + 1;
470 vertex[1] = cell[1];
471 let v4 = vertex_space.linear_from_cartesian(vertex);
472
473 connectivity.push(vertex_total + v1 as u64);
474 connectivity.push(vertex_total + v2 as u64);
475 connectivity.push(vertex_total + v3 as u64);
476 connectivity.push(vertex_total + v4 as u64);
477 }
478
479 offsets.push(connectivity.len() as u64);
480 }
481
482 cell_total += cell_space.index_count();
483 vertex_total += vertex_space.index_count() as u64;
484 }
485
486 let cell_type = match N {
487 1 => CellType::Line,
488 2 => CellType::Quad,
489 _ => panic!("Unsupported dimension"),
491 };
492
493 Cells {
494 cell_verts: VertexNumbers::XML {
495 connectivity,
496 offsets,
497 },
498 types: vec![cell_type; cell_total],
499 }
500 }
501
502 fn mesh_points(mesh: &Mesh<N>, ghost: bool, stride: usize) -> IOBuffer {
503 let mut vertices = Vec::new();
505
506 for block in mesh.blocks.indices() {
507 let space = mesh.block_space(block);
508 let window = if ghost {
509 space.full_window()
510 } else {
511 space.inner_window()
512 };
513
514 'window: for node in window {
515 for axis in 0..N {
516 if node[axis] % (stride as isize) != 0 {
517 continue 'window;
518 }
519 }
520
521 let position = space.position(node);
522 let mut vertex = [0.0; 3];
523 vertex[..N].copy_from_slice(&position);
524 vertices.extend(vertex);
525 }
526 }
527
528 IOBuffer::new(vertices)
529 }
530
531 fn mesh_points_embedded(
532 mesh: &Mesh<N>,
533 embedding: &Embedding,
534 ghost: bool,
535 stride: usize,
536 ) -> IOBuffer {
537 let dim = embedding.dimension;
538 assert!(dim <= 3);
539
540 let mut vertices = Vec::new();
542
543 for block in mesh.blocks.indices() {
544 let space = mesh.block_space(block);
545 let nodes = mesh.block_nodes(block);
546 let window = if ghost {
547 space.full_window()
548 } else {
549 space.inner_window()
550 };
551
552 let block_positions = &embedding.positions[nodes.start * dim..nodes.end * dim];
553
554 'window: for node in window {
555 let index = space.index_from_node(node);
556 let position = &block_positions[index * dim..(index + 1) * dim];
557
558 for axis in 0..N {
559 if node[axis] % (stride as isize) != 0 {
560 continue 'window;
561 }
562 }
563
564 let mut vertex = [0.0; 3];
565 vertex[..dim].copy_from_slice(&position);
566 vertices.extend(vertex);
567 }
568 }
569
570 IOBuffer::new(vertices)
571 }
572
573 fn field_attribute<T: ToPrimitive + Copy + 'static>(
574 mesh: &Mesh<N>,
575 name: String,
576 data: &[T],
577 ghost: bool,
578 stride: usize,
579 ) -> Attribute {
580 let mut buffer = Vec::new();
581
582 for block in mesh.blocks.indices() {
583 let space = mesh.block_space(block);
584 let nodes = mesh.block_nodes(block);
585 let window = if ghost {
586 space.full_window()
587 } else {
588 space.inner_window()
589 };
590
591 'window: for node in window {
592 for axis in 0..N {
593 if node[axis] % (stride as isize) != 0 {
594 continue 'window;
595 }
596 }
597
598 let index = space.index_from_node(node);
599 let value = data[nodes.clone()][index];
600 buffer.push(value);
601 }
602 }
603
604 Attribute::DataArray(DataArrayBase {
605 name,
606 elem: ElementType::Scalars {
607 num_comp: 1,
608 lookup_table: None,
609 },
610 data: IOBuffer::new(buffer),
611 })
612 }
613}