1#![allow(dead_code)]
12#![allow(clippy::too_many_arguments)]
13
14const BINARY_MAGIC: [u8; 4] = [0x4F, 0x58, 0x52, 0x53];
18
19const FORMAT_VERSION: u32 = 1;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum RestartFormat {
27 Binary,
29 Ascii,
31 Hdf5Like,
33 Json,
35 MessagePack,
37}
38
39#[derive(Debug, Clone, PartialEq)]
43pub struct RestartMetadata {
44 pub version: String,
46 pub timestamp: u64,
48 pub step: u64,
50 pub time: f64,
52 pub crate_name: String,
54 pub description: String,
56}
57
58impl RestartMetadata {
59 pub fn new(
61 version: impl Into<String>,
62 timestamp: u64,
63 step: u64,
64 time: f64,
65 crate_name: impl Into<String>,
66 description: impl Into<String>,
67 ) -> Self {
68 Self {
69 version: version.into(),
70 timestamp,
71 step,
72 time,
73 crate_name: crate_name.into(),
74 description: description.into(),
75 }
76 }
77
78 pub fn default_test() -> Self {
80 Self::new("1.0", 0, 0, 0.0, "oxiphysics", "test checkpoint")
81 }
82}
83
84#[derive(Debug, Clone, PartialEq)]
88pub struct RestartData {
89 pub metadata: RestartMetadata,
91 pub positions: Vec<[f64; 3]>,
93 pub velocities: Vec<[f64; 3]>,
95 pub forces: Vec<[f64; 3]>,
97 pub masses: Vec<f64>,
99 pub types: Vec<u32>,
101 pub box_matrix: [[f64; 3]; 3],
103 pub extra_scalars: Vec<(String, Vec<f64>)>,
105 pub extra_vectors: Vec<(String, Vec<[f64; 3]>)>,
107}
108
109impl RestartData {
110 pub fn n_particles(&self) -> usize {
112 self.positions.len()
113 }
114
115 pub fn empty(metadata: RestartMetadata) -> Self {
117 Self {
118 metadata,
119 positions: Vec::new(),
120 velocities: Vec::new(),
121 forces: Vec::new(),
122 masses: Vec::new(),
123 types: Vec::new(),
124 box_matrix: [[0.0; 3]; 3],
125 extra_scalars: Vec::new(),
126 extra_vectors: Vec::new(),
127 }
128 }
129
130 pub fn single_particle_test() -> Self {
132 let meta = RestartMetadata::default_test();
133 let mut d = Self::empty(meta);
134 d.positions = vec![[1.0, 2.0, 3.0]];
135 d.velocities = vec![[0.1, 0.2, 0.3]];
136 d.forces = vec![[0.0, -9.81, 0.0]];
137 d.masses = vec![1.0];
138 d.types = vec![0];
139 d.box_matrix = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]];
140 d
141 }
142}
143
144fn encode_u32(buf: &mut Vec<u8>, v: u32) {
147 buf.extend_from_slice(&v.to_le_bytes());
148}
149
150fn encode_u64(buf: &mut Vec<u8>, v: u64) {
151 buf.extend_from_slice(&v.to_le_bytes());
152}
153
154fn encode_f64(buf: &mut Vec<u8>, v: f64) {
155 buf.extend_from_slice(&v.to_bits().to_le_bytes());
156}
157
158fn encode_str(buf: &mut Vec<u8>, s: &str) {
159 let bytes = s.as_bytes();
160 encode_u32(buf, bytes.len() as u32);
161 buf.extend_from_slice(bytes);
162}
163
164fn encode_vec3(buf: &mut Vec<u8>, v: &[f64; 3]) {
165 encode_f64(buf, v[0]);
166 encode_f64(buf, v[1]);
167 encode_f64(buf, v[2]);
168}
169
170fn decode_u32(data: &[u8], offset: &mut usize) -> Result<u32, String> {
171 if *offset + 4 > data.len() {
172 return Err(format!("unexpected EOF at offset {}", *offset));
173 }
174 let v = u32::from_le_bytes(
175 data[*offset..*offset + 4]
176 .try_into()
177 .expect("slice length must match"),
178 );
179 *offset += 4;
180 Ok(v)
181}
182
183fn decode_u64(data: &[u8], offset: &mut usize) -> Result<u64, String> {
184 if *offset + 8 > data.len() {
185 return Err(format!("unexpected EOF at offset {}", *offset));
186 }
187 let v = u64::from_le_bytes(
188 data[*offset..*offset + 8]
189 .try_into()
190 .expect("slice length must match"),
191 );
192 *offset += 8;
193 Ok(v)
194}
195
196fn decode_f64(data: &[u8], offset: &mut usize) -> Result<f64, String> {
197 if *offset + 8 > data.len() {
198 return Err(format!("unexpected EOF at offset {}", *offset));
199 }
200 let bits = u64::from_le_bytes(
201 data[*offset..*offset + 8]
202 .try_into()
203 .expect("slice length must match"),
204 );
205 *offset += 8;
206 Ok(f64::from_bits(bits))
207}
208
209fn decode_str(data: &[u8], offset: &mut usize) -> Result<String, String> {
210 let len = decode_u32(data, offset)? as usize;
211 if *offset + len > data.len() {
212 return Err(format!("string extends past EOF at offset {}", *offset));
213 }
214 let s = std::str::from_utf8(&data[*offset..*offset + len])
215 .map_err(|e| format!("UTF-8 error: {e}"))?
216 .to_string();
217 *offset += len;
218 Ok(s)
219}
220
221fn decode_vec3(data: &[u8], offset: &mut usize) -> Result<[f64; 3], String> {
222 let x = decode_f64(data, offset)?;
223 let y = decode_f64(data, offset)?;
224 let z = decode_f64(data, offset)?;
225 Ok([x, y, z])
226}
227
228#[derive(Debug, Clone)]
232pub struct RestartWriter {
233 pub path: String,
235 pub format: RestartFormat,
237}
238
239impl RestartWriter {
240 pub fn new(path: &str, format: RestartFormat) -> Self {
242 Self {
243 path: path.to_string(),
244 format,
245 }
246 }
247
248 pub fn write(&self, data: &RestartData) -> Result<(), String> {
252 let bytes: Vec<u8> = match &self.format {
253 RestartFormat::Binary => Self::write_binary(data),
254 RestartFormat::Ascii => Self::write_ascii(data).into_bytes(),
255 RestartFormat::Json => Self::write_json(data).into_bytes(),
256 RestartFormat::Hdf5Like => Self::write_hdf5like(data),
257 RestartFormat::MessagePack => Self::write_msgpack(data),
258 };
259 std::fs::write(&self.path, &bytes)
260 .map_err(|e| format!("failed to write '{}': {e}", self.path))
261 }
262
263 pub fn write_binary(data: &RestartData) -> Vec<u8> {
269 let mut buf = Vec::new();
270 buf.extend_from_slice(&BINARY_MAGIC);
271 encode_u32(&mut buf, FORMAT_VERSION);
272 encode_str(&mut buf, &data.metadata.version);
274 encode_u64(&mut buf, data.metadata.timestamp);
275 encode_u64(&mut buf, data.metadata.step);
276 encode_f64(&mut buf, data.metadata.time);
277 encode_str(&mut buf, &data.metadata.crate_name);
278 encode_str(&mut buf, &data.metadata.description);
279 let n = data.n_particles() as u64;
281 encode_u64(&mut buf, n);
282 for p in &data.positions {
283 encode_vec3(&mut buf, p);
284 }
285 for v in &data.velocities {
286 encode_vec3(&mut buf, v);
287 }
288 for f in &data.forces {
289 encode_vec3(&mut buf, f);
290 }
291 for m in &data.masses {
292 encode_f64(&mut buf, *m);
293 }
294 for t in &data.types {
295 encode_u32(&mut buf, *t);
296 }
297 for row in &data.box_matrix {
299 for &c in row {
300 encode_f64(&mut buf, c);
301 }
302 }
303 encode_u32(&mut buf, data.extra_scalars.len() as u32);
305 for (name, vals) in &data.extra_scalars {
306 encode_str(&mut buf, name);
307 encode_u64(&mut buf, vals.len() as u64);
308 for &v in vals {
309 encode_f64(&mut buf, v);
310 }
311 }
312 encode_u32(&mut buf, data.extra_vectors.len() as u32);
314 for (name, vecs) in &data.extra_vectors {
315 encode_str(&mut buf, name);
316 encode_u64(&mut buf, vecs.len() as u64);
317 for v in vecs {
318 encode_vec3(&mut buf, v);
319 }
320 }
321 buf
322 }
323
324 pub fn write_ascii(data: &RestartData) -> String {
326 let mut s = String::new();
327 s.push_str("# OxiPhysics restart file\n");
328 s.push_str(&format!("VERSION {}\n", data.metadata.version));
329 s.push_str(&format!("TIMESTAMP {}\n", data.metadata.timestamp));
330 s.push_str(&format!("STEP {}\n", data.metadata.step));
331 s.push_str(&format!("TIME {:.6}\n", data.metadata.time));
332 s.push_str(&format!("CRATE {}\n", data.metadata.crate_name));
333 s.push_str(&format!("DESC {}\n", data.metadata.description));
334 let n = data.n_particles();
335 s.push_str(&format!("N_PARTICLES {n}\n"));
336 s.push_str("BEGIN_POSITIONS\n");
337 for p in &data.positions {
338 s.push_str(&format!("{:.6} {:.6} {:.6}\n", p[0], p[1], p[2]));
339 }
340 s.push_str("END_POSITIONS\n");
341 s.push_str("BEGIN_VELOCITIES\n");
342 for v in &data.velocities {
343 s.push_str(&format!("{:.6} {:.6} {:.6}\n", v[0], v[1], v[2]));
344 }
345 s.push_str("END_VELOCITIES\n");
346 s.push_str("BEGIN_FORCES\n");
347 for f in &data.forces {
348 s.push_str(&format!("{:.6} {:.6} {:.6}\n", f[0], f[1], f[2]));
349 }
350 s.push_str("END_FORCES\n");
351 s.push_str("BEGIN_MASSES\n");
352 for m in &data.masses {
353 s.push_str(&format!("{:.6}\n", m));
354 }
355 s.push_str("END_MASSES\n");
356 s.push_str("BEGIN_TYPES\n");
357 for t in &data.types {
358 s.push_str(&format!("{t}\n"));
359 }
360 s.push_str("END_TYPES\n");
361 s.push_str("BEGIN_BOX\n");
362 for row in &data.box_matrix {
363 s.push_str(&format!("{:.6} {:.6} {:.6}\n", row[0], row[1], row[2]));
364 }
365 s.push_str("END_BOX\n");
366 s
367 }
368
369 pub fn write_json(data: &RestartData) -> String {
371 let mut s = String::new();
372 s.push_str("{\n");
373 s.push_str(&format!(
374 " \"version\": \"{}\",\n \"timestamp\": {},\n \"step\": {},\n \"time\": {:.6},\n \"crate\": \"{}\",\n \"description\": \"{}\",\n",
375 data.metadata.version, data.metadata.timestamp, data.metadata.step,
376 data.metadata.time, data.metadata.crate_name, data.metadata.description
377 ));
378 s.push_str(&format!(" \"n_particles\": {},\n", data.n_particles()));
379 s.push_str(" \"positions\": [");
381 for (i, p) in data.positions.iter().enumerate() {
382 if i > 0 {
383 s.push(',');
384 }
385 s.push_str(&format!("[{:.6},{:.6},{:.6}]", p[0], p[1], p[2]));
386 }
387 s.push_str("],\n");
388 s.push_str(" \"velocities\": [");
390 for (i, v) in data.velocities.iter().enumerate() {
391 if i > 0 {
392 s.push(',');
393 }
394 s.push_str(&format!("[{:.6},{:.6},{:.6}]", v[0], v[1], v[2]));
395 }
396 s.push_str("],\n");
397 s.push_str(" \"masses\": [");
399 for (i, m) in data.masses.iter().enumerate() {
400 if i > 0 {
401 s.push(',');
402 }
403 s.push_str(&format!("{:.6}", m));
404 }
405 s.push_str("],\n");
406 s.push_str(" \"types\": [");
408 for (i, t) in data.types.iter().enumerate() {
409 if i > 0 {
410 s.push(',');
411 }
412 s.push_str(&format!("{t}"));
413 }
414 s.push_str("]\n}\n");
415 s
416 }
417
418 pub fn write_hdf5like(data: &RestartData) -> Vec<u8> {
420 let mut buf = Self::write_binary(data);
422 buf[0] = 0x4F; buf[1] = 0x58; buf[2] = 0x48; buf[3] = 0x35; buf
428 }
429
430 pub fn write_msgpack(data: &RestartData) -> Vec<u8> {
432 let mut buf = Self::write_binary(data);
434 buf[0] = 0x4D; buf[1] = 0x50; buf[2] = 0x4B; buf[3] = 0x31; buf
439 }
440}
441
442#[derive(Debug, Clone)]
446pub struct RestartReader {
447 pub path: String,
449}
450
451impl RestartReader {
452 pub fn new(path: &str) -> Self {
454 Self {
455 path: path.to_string(),
456 }
457 }
458
459 pub fn read(&self) -> Result<RestartData, String> {
461 let bytes = std::fs::read(&self.path)
462 .map_err(|e| format!("failed to read '{}': {e}", self.path))?;
463 let fmt = Self::detect_format(&bytes);
464 match fmt {
465 RestartFormat::Ascii => {
466 let text = std::str::from_utf8(&bytes).map_err(|e| format!("UTF-8 error: {e}"))?;
467 Self::read_ascii(text)
468 }
469 _ => Self::read_binary(&bytes),
470 }
471 }
472
473 pub fn detect_format(bytes: &[u8]) -> RestartFormat {
475 if bytes.len() < 4 {
476 return RestartFormat::Ascii;
477 }
478 match &bytes[0..4] {
479 b"OXRS" => RestartFormat::Binary,
480 b"OXHS" | [0x4F, 0x58, 0x48, 0x35] => RestartFormat::Hdf5Like,
481 [0x4D, 0x50, 0x4B, 0x31] => RestartFormat::MessagePack,
482 b"# Ox" | b"VERS" => RestartFormat::Ascii,
483 _ if bytes.starts_with(b"{") => RestartFormat::Json,
484 _ => RestartFormat::Ascii,
485 }
486 }
487
488 pub fn read_binary(bytes: &[u8]) -> Result<RestartData, String> {
490 let mut off = 0usize;
491 if bytes.len() < 8 {
493 return Err("binary too short".into());
494 }
495 off += 4; let _file_version = decode_u32(bytes, &mut off)?;
497 let version = decode_str(bytes, &mut off)?;
499 let timestamp = decode_u64(bytes, &mut off)?;
500 let step = decode_u64(bytes, &mut off)?;
501 let time = decode_f64(bytes, &mut off)?;
502 let crate_name = decode_str(bytes, &mut off)?;
503 let description = decode_str(bytes, &mut off)?;
504 let metadata = RestartMetadata {
505 version,
506 timestamp,
507 step,
508 time,
509 crate_name,
510 description,
511 };
512 let n = decode_u64(bytes, &mut off)? as usize;
514 let mut positions = Vec::with_capacity(n);
515 for _ in 0..n {
516 positions.push(decode_vec3(bytes, &mut off)?);
517 }
518 let mut velocities = Vec::with_capacity(n);
519 for _ in 0..n {
520 velocities.push(decode_vec3(bytes, &mut off)?);
521 }
522 let mut forces = Vec::with_capacity(n);
523 for _ in 0..n {
524 forces.push(decode_vec3(bytes, &mut off)?);
525 }
526 let mut masses = Vec::with_capacity(n);
527 for _ in 0..n {
528 masses.push(decode_f64(bytes, &mut off)?);
529 }
530 let mut types = Vec::with_capacity(n);
531 for _ in 0..n {
532 types.push(decode_u32(bytes, &mut off)?);
533 }
534 let mut box_matrix = [[0.0f64; 3]; 3];
536 for row in &mut box_matrix {
537 for c in row.iter_mut() {
538 *c = decode_f64(bytes, &mut off)?;
539 }
540 }
541 let n_es = decode_u32(bytes, &mut off)? as usize;
543 let mut extra_scalars = Vec::with_capacity(n_es);
544 for _ in 0..n_es {
545 let name = decode_str(bytes, &mut off)?;
546 let count = decode_u64(bytes, &mut off)? as usize;
547 let mut vals = Vec::with_capacity(count);
548 for _ in 0..count {
549 vals.push(decode_f64(bytes, &mut off)?);
550 }
551 extra_scalars.push((name, vals));
552 }
553 let n_ev = decode_u32(bytes, &mut off)? as usize;
555 let mut extra_vectors = Vec::with_capacity(n_ev);
556 for _ in 0..n_ev {
557 let name = decode_str(bytes, &mut off)?;
558 let count = decode_u64(bytes, &mut off)? as usize;
559 let mut vecs = Vec::with_capacity(count);
560 for _ in 0..count {
561 vecs.push(decode_vec3(bytes, &mut off)?);
562 }
563 extra_vectors.push((name, vecs));
564 }
565 Ok(RestartData {
566 metadata,
567 positions,
568 velocities,
569 forces,
570 masses,
571 types,
572 box_matrix,
573 extra_scalars,
574 extra_vectors,
575 })
576 }
577
578 pub fn read_ascii(text: &str) -> Result<RestartData, String> {
580 let mut version = String::new();
581 let mut timestamp = 0u64;
582 let mut step = 0u64;
583 let mut time = 0.0f64;
584 let mut crate_name = String::new();
585 let mut description = String::new();
586 let mut positions: Vec<[f64; 3]> = Vec::new();
587 let mut velocities: Vec<[f64; 3]> = Vec::new();
588 let mut forces: Vec<[f64; 3]> = Vec::new();
589 let mut masses: Vec<f64> = Vec::new();
590 let mut types: Vec<u32> = Vec::new();
591 let mut box_matrix = [[0.0f64; 3]; 3];
592
593 #[derive(PartialEq)]
594 enum Section {
595 None,
596 Positions,
597 Velocities,
598 Forces,
599 Masses,
600 Types,
601 Box,
602 }
603 let mut section = Section::None;
604 let mut box_row = 0usize;
605
606 for line in text.lines() {
607 let line = line.trim();
608 if line.is_empty() || line.starts_with('#') {
609 continue;
610 }
611 match line {
612 "BEGIN_POSITIONS" => {
613 section = Section::Positions;
614 continue;
615 }
616 "END_POSITIONS" => {
617 section = Section::None;
618 continue;
619 }
620 "BEGIN_VELOCITIES" => {
621 section = Section::Velocities;
622 continue;
623 }
624 "END_VELOCITIES" => {
625 section = Section::None;
626 continue;
627 }
628 "BEGIN_FORCES" => {
629 section = Section::Forces;
630 continue;
631 }
632 "END_FORCES" => {
633 section = Section::None;
634 continue;
635 }
636 "BEGIN_MASSES" => {
637 section = Section::Masses;
638 continue;
639 }
640 "END_MASSES" => {
641 section = Section::None;
642 continue;
643 }
644 "BEGIN_TYPES" => {
645 section = Section::Types;
646 continue;
647 }
648 "END_TYPES" => {
649 section = Section::None;
650 continue;
651 }
652 "BEGIN_BOX" => {
653 section = Section::Box;
654 box_row = 0;
655 continue;
656 }
657 "END_BOX" => {
658 section = Section::None;
659 continue;
660 }
661 _ => {}
662 }
663 match section {
664 Section::Positions | Section::Velocities | Section::Forces => {
665 let nums: Vec<f64> = line
666 .split_whitespace()
667 .map(|s| s.parse::<f64>().unwrap_or(0.0))
668 .collect();
669 if nums.len() >= 3 {
670 let arr = [nums[0], nums[1], nums[2]];
671 match section {
672 Section::Positions => positions.push(arr),
673 Section::Velocities => velocities.push(arr),
674 Section::Forces => forces.push(arr),
675 _ => {}
676 }
677 }
678 }
679 Section::Masses => {
680 if let Ok(m) = line.parse::<f64>() {
681 masses.push(m);
682 }
683 }
684 Section::Types => {
685 if let Ok(t) = line.parse::<u32>() {
686 types.push(t);
687 }
688 }
689 Section::Box => {
690 if box_row < 3 {
691 let nums: Vec<f64> = line
692 .split_whitespace()
693 .map(|s| s.parse::<f64>().unwrap_or(0.0))
694 .collect();
695 if nums.len() >= 3 {
696 box_matrix[box_row] = [nums[0], nums[1], nums[2]];
697 box_row += 1;
698 }
699 }
700 }
701 Section::None => {
702 if let Some(rest) = line.strip_prefix("VERSION ") {
704 version = rest.trim().to_string();
705 } else if let Some(rest) = line.strip_prefix("TIMESTAMP ") {
706 timestamp = rest.trim().parse().unwrap_or(0);
707 } else if let Some(rest) = line.strip_prefix("STEP ") {
708 step = rest.trim().parse().unwrap_or(0);
709 } else if let Some(rest) = line.strip_prefix("TIME ") {
710 time = rest.trim().parse().unwrap_or(0.0);
711 } else if let Some(rest) = line.strip_prefix("CRATE ") {
712 crate_name = rest.trim().to_string();
713 } else if let Some(rest) = line.strip_prefix("DESC ") {
714 description = rest.trim().to_string();
715 }
716 }
717 }
718 }
719 let metadata = RestartMetadata {
720 version,
721 timestamp,
722 step,
723 time,
724 crate_name,
725 description,
726 };
727 Ok(RestartData {
728 metadata,
729 positions,
730 velocities,
731 forces,
732 masses,
733 types,
734 box_matrix,
735 extra_scalars: Vec::new(),
736 extra_vectors: Vec::new(),
737 })
738 }
739}
740
741#[derive(Debug, Clone)]
745pub struct CheckpointManager {
746 pub base_dir: String,
748 pub max_checkpoints: usize,
750 checkpoints: Vec<(u64, String)>,
752}
753
754impl CheckpointManager {
755 pub fn new(base_dir: &str, max_checkpoints: usize) -> Self {
757 Self {
758 base_dir: base_dir.to_string(),
759 max_checkpoints: max_checkpoints.max(1),
760 checkpoints: Vec::new(),
761 }
762 }
763
764 pub fn save_checkpoint(&mut self, data: &RestartData, step: u64) -> String {
768 let filename = format!("{}/checkpoint_{step:010}.bin", self.base_dir);
769 let writer = RestartWriter::new(&filename, RestartFormat::Binary);
770 let _ = std::fs::create_dir_all(&self.base_dir);
771 let _ = writer.write(data);
772 self.checkpoints.push((step, filename.clone()));
773 self.prune_old_checkpoints();
774 filename
775 }
776
777 pub fn load_latest(&self) -> Option<RestartData> {
779 let (_, path) = self.checkpoints.last()?;
780 let reader = RestartReader::new(path);
781 reader.read().ok()
782 }
783
784 pub fn list_checkpoints(&self) -> Vec<(u64, String)> {
786 self.checkpoints.clone()
787 }
788
789 pub fn prune_old_checkpoints(&mut self) {
791 while self.checkpoints.len() > self.max_checkpoints {
792 let (_, path) = self.checkpoints.remove(0);
793 let _ = std::fs::remove_file(&path);
794 }
795 }
796}
797
798#[derive(Debug, Clone, Default)]
805pub struct IncrementalRestart {
806 pub changed: Vec<usize>,
808}
809
810impl IncrementalRestart {
811 pub fn new() -> Self {
813 Self::default()
814 }
815
816 pub fn mark_changed(&mut self, idx: usize) {
818 if !self.changed.contains(&idx) {
819 self.changed.push(idx);
820 }
821 }
822
823 pub fn reset(&mut self) {
825 self.changed.clear();
826 }
827
828 pub fn extract_delta(&self, full: &RestartData) -> RestartData {
832 let indices = &self.changed;
833 let positions = indices
834 .iter()
835 .filter_map(|&i| full.positions.get(i).copied())
836 .collect();
837 let velocities = indices
838 .iter()
839 .filter_map(|&i| full.velocities.get(i).copied())
840 .collect();
841 let forces = indices
842 .iter()
843 .filter_map(|&i| full.forces.get(i).copied())
844 .collect();
845 let masses = indices
846 .iter()
847 .filter_map(|&i| full.masses.get(i).copied())
848 .collect();
849 let types = indices
850 .iter()
851 .filter_map(|&i| full.types.get(i).copied())
852 .collect();
853 RestartData {
854 metadata: full.metadata.clone(),
855 positions,
856 velocities,
857 forces,
858 masses,
859 types,
860 box_matrix: full.box_matrix,
861 extra_scalars: Vec::new(),
862 extra_vectors: Vec::new(),
863 }
864 }
865}
866
867#[derive(Debug, Clone, Default)]
871pub struct RestartValidator;
872
873impl RestartValidator {
874 pub fn new() -> Self {
876 Self
877 }
878
879 pub fn checksum_sum(bytes: &[u8]) -> u64 {
881 bytes.iter().map(|&b| b as u64).sum()
882 }
883
884 pub fn checksum_xor(bytes: &[u8]) -> u8 {
886 bytes.iter().fold(0u8, |acc, &b| acc ^ b)
887 }
888
889 pub fn verify_sum(bytes: &[u8], expected_sum: u64) -> bool {
891 Self::checksum_sum(bytes) == expected_sum
892 }
893
894 pub fn verify_xor(bytes: &[u8], expected_xor: u8) -> bool {
896 Self::checksum_xor(bytes) == expected_xor
897 }
898}
899
900pub fn restart_to_xyz(data: &RestartData) -> String {
906 let n = data.n_particles();
907 let mut s = String::new();
908 s.push_str(&format!("{n}\n"));
909 s.push_str(&format!(
910 "Restart step={} time={:.6}\n",
911 data.metadata.step, data.metadata.time
912 ));
913 for (i, p) in data.positions.iter().enumerate() {
914 let sym = if let Some(&t) = data.types.get(i) {
915 match t {
916 0 => "H",
917 1 => "C",
918 2 => "N",
919 3 => "O",
920 _ => "X",
921 }
922 } else {
923 "X"
924 };
925 s.push_str(&format!("{sym} {:.6} {:.6} {:.6}\n", p[0], p[1], p[2]));
926 }
927 s
928}
929
930pub fn restart_to_lammps_dump(data: &RestartData) -> String {
934 let n = data.n_particles();
935 let step = data.metadata.step;
936 let mut s = String::new();
937 s.push_str(&format!("ITEM: TIMESTEP\n{step}\n"));
938 s.push_str(&format!("ITEM: NUMBER OF ATOMS\n{n}\n"));
939 let bm = &data.box_matrix;
940 s.push_str("ITEM: BOX BOUNDS pp pp pp\n");
941 s.push_str(&format!(
942 "0.0 {:.6}\n0.0 {:.6}\n0.0 {:.6}\n",
943 bm[0][0], bm[1][1], bm[2][2]
944 ));
945 s.push_str("ITEM: ATOMS id type x y z vx vy vz fx fy fz mass\n");
946 for i in 0..n {
947 let p = data.positions.get(i).copied().unwrap_or([0.0; 3]);
948 let v = data.velocities.get(i).copied().unwrap_or([0.0; 3]);
949 let f = data.forces.get(i).copied().unwrap_or([0.0; 3]);
950 let m = data.masses.get(i).copied().unwrap_or(1.0);
951 let tp = data.types.get(i).copied().unwrap_or(0);
952 s.push_str(&format!(
953 "{} {} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6}\n",
954 i + 1,
955 tp,
956 p[0],
957 p[1],
958 p[2],
959 v[0],
960 v[1],
961 v[2],
962 f[0],
963 f[1],
964 f[2],
965 m
966 ));
967 }
968 s
969}
970
971#[cfg(test)]
974mod tests {
975 use super::*;
976
977 fn make_data(n: usize) -> RestartData {
978 let meta = RestartMetadata::new("1.0", 1234567890, 42, 4.2, "oxiphysics-md", "unit test");
979 let positions = (0..n).map(|i| [i as f64, i as f64 * 0.5, 0.0]).collect();
980 let velocities = (0..n).map(|i| [0.1 * i as f64, 0.0, 0.0]).collect();
981 let forces = (0..n).map(|_| [0.0, -9.81, 0.0]).collect();
982 let masses = (0..n).map(|_| 1.0).collect();
983 let types = (0..n).map(|i| (i % 3) as u32).collect();
984 let box_matrix = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]];
985 RestartData {
986 metadata: meta,
987 positions,
988 velocities,
989 forces,
990 masses,
991 types,
992 box_matrix,
993 extra_scalars: Vec::new(),
994 extra_vectors: Vec::new(),
995 }
996 }
997
998 #[test]
1001 fn test_metadata_new_fields() {
1002 let m = RestartMetadata::new("2.0", 100, 5, 0.5, "crate_x", "desc");
1003 assert_eq!(m.version, "2.0");
1004 assert_eq!(m.timestamp, 100);
1005 assert_eq!(m.step, 5);
1006 assert!((m.time - 0.5).abs() < 1e-12);
1007 assert_eq!(m.crate_name, "crate_x");
1008 assert_eq!(m.description, "desc");
1009 }
1010
1011 #[test]
1012 fn test_metadata_default_test() {
1013 let m = RestartMetadata::default_test();
1014 assert_eq!(m.version, "1.0");
1015 assert_eq!(m.step, 0);
1016 }
1017
1018 #[test]
1021 fn test_restart_data_n_particles() {
1022 let d = make_data(7);
1023 assert_eq!(d.n_particles(), 7);
1024 }
1025
1026 #[test]
1027 fn test_restart_data_empty() {
1028 let d = RestartData::empty(RestartMetadata::default_test());
1029 assert_eq!(d.n_particles(), 0);
1030 assert!(d.positions.is_empty());
1031 }
1032
1033 #[test]
1034 fn test_restart_data_single_particle_test() {
1035 let d = RestartData::single_particle_test();
1036 assert_eq!(d.n_particles(), 1);
1037 assert!((d.positions[0][0] - 1.0).abs() < 1e-12);
1038 }
1039
1040 #[test]
1043 fn test_binary_roundtrip_zero_particles() {
1044 let d = make_data(0);
1045 let bytes = RestartWriter::write_binary(&d);
1046 let recovered = RestartReader::read_binary(&bytes).unwrap();
1047 assert_eq!(recovered.n_particles(), 0);
1048 assert_eq!(recovered.metadata.step, 42);
1049 }
1050
1051 #[test]
1052 fn test_binary_roundtrip_positions() {
1053 let d = make_data(5);
1054 let bytes = RestartWriter::write_binary(&d);
1055 let r = RestartReader::read_binary(&bytes).unwrap();
1056 for i in 0..5 {
1057 assert!((r.positions[i][0] - d.positions[i][0]).abs() < 1e-12);
1058 assert!((r.positions[i][1] - d.positions[i][1]).abs() < 1e-12);
1059 }
1060 }
1061
1062 #[test]
1063 fn test_binary_roundtrip_metadata() {
1064 let d = make_data(3);
1065 let bytes = RestartWriter::write_binary(&d);
1066 let r = RestartReader::read_binary(&bytes).unwrap();
1067 assert_eq!(r.metadata.version, "1.0");
1068 assert_eq!(r.metadata.timestamp, 1234567890);
1069 assert_eq!(r.metadata.step, 42);
1070 assert!((r.metadata.time - 4.2).abs() < 1e-10);
1071 assert_eq!(r.metadata.crate_name, "oxiphysics-md");
1072 }
1073
1074 #[test]
1075 fn test_binary_roundtrip_box_matrix() {
1076 let d = make_data(2);
1077 let bytes = RestartWriter::write_binary(&d);
1078 let r = RestartReader::read_binary(&bytes).unwrap();
1079 assert!((r.box_matrix[0][0] - 10.0).abs() < 1e-12);
1080 assert!((r.box_matrix[1][1] - 10.0).abs() < 1e-12);
1081 assert!((r.box_matrix[2][2] - 10.0).abs() < 1e-12);
1082 assert!(r.box_matrix[0][1].abs() < 1e-12);
1083 }
1084
1085 #[test]
1086 fn test_binary_roundtrip_types() {
1087 let d = make_data(6);
1088 let bytes = RestartWriter::write_binary(&d);
1089 let r = RestartReader::read_binary(&bytes).unwrap();
1090 for i in 0..6 {
1091 assert_eq!(r.types[i], (i % 3) as u32);
1092 }
1093 }
1094
1095 #[test]
1096 fn test_binary_roundtrip_extra_scalars() {
1097 let mut d = make_data(3);
1098 d.extra_scalars
1099 .push(("charge".into(), vec![0.1, -0.2, 0.3]));
1100 let bytes = RestartWriter::write_binary(&d);
1101 let r = RestartReader::read_binary(&bytes).unwrap();
1102 assert_eq!(r.extra_scalars.len(), 1);
1103 assert_eq!(r.extra_scalars[0].0, "charge");
1104 assert!((r.extra_scalars[0].1[1] - (-0.2)).abs() < 1e-12);
1105 }
1106
1107 #[test]
1108 fn test_binary_roundtrip_extra_vectors() {
1109 let mut d = make_data(2);
1110 d.extra_vectors
1111 .push(("spin".into(), vec![[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]));
1112 let bytes = RestartWriter::write_binary(&d);
1113 let r = RestartReader::read_binary(&bytes).unwrap();
1114 assert_eq!(r.extra_vectors.len(), 1);
1115 assert_eq!(r.extra_vectors[0].0, "spin");
1116 assert!((r.extra_vectors[0].1[0][2] - 1.0).abs() < 1e-12);
1117 }
1118
1119 #[test]
1120 fn test_binary_magic_bytes() {
1121 let d = make_data(1);
1122 let bytes = RestartWriter::write_binary(&d);
1123 assert_eq!(&bytes[0..4], b"OXRS");
1124 }
1125
1126 #[test]
1129 fn test_ascii_roundtrip_basic() {
1130 let d = make_data(4);
1131 let text = RestartWriter::write_ascii(&d);
1132 let r = RestartReader::read_ascii(&text).unwrap();
1133 assert_eq!(r.n_particles(), 4);
1134 assert_eq!(r.metadata.step, 42);
1135 }
1136
1137 #[test]
1138 fn test_ascii_roundtrip_positions() {
1139 let d = make_data(3);
1140 let text = RestartWriter::write_ascii(&d);
1141 let r = RestartReader::read_ascii(&text).unwrap();
1142 for i in 0..3 {
1143 assert!((r.positions[i][0] - d.positions[i][0]).abs() < 1e-4);
1144 }
1145 }
1146
1147 #[test]
1148 fn test_ascii_roundtrip_box() {
1149 let d = make_data(1);
1150 let text = RestartWriter::write_ascii(&d);
1151 let r = RestartReader::read_ascii(&text).unwrap();
1152 assert!((r.box_matrix[0][0] - 10.0).abs() < 1e-4);
1153 }
1154
1155 #[test]
1156 fn test_ascii_contains_keywords() {
1157 let d = make_data(2);
1158 let text = RestartWriter::write_ascii(&d);
1159 assert!(text.contains("BEGIN_POSITIONS"));
1160 assert!(text.contains("END_POSITIONS"));
1161 assert!(text.contains("VERSION"));
1162 assert!(text.contains("STEP"));
1163 }
1164
1165 #[test]
1168 fn test_json_contains_particles_key() {
1169 let d = make_data(3);
1170 let j = RestartWriter::write_json(&d);
1171 assert!(j.contains("n_particles"));
1172 assert!(j.contains("positions"));
1173 assert!(j.contains("velocities"));
1174 }
1175
1176 #[test]
1177 fn test_json_step_present() {
1178 let d = make_data(1);
1179 let j = RestartWriter::write_json(&d);
1180 assert!(j.contains("\"step\": 42"));
1181 }
1182
1183 #[test]
1186 fn test_detect_format_binary() {
1187 let d = make_data(1);
1188 let bytes = RestartWriter::write_binary(&d);
1189 assert_eq!(RestartReader::detect_format(&bytes), RestartFormat::Binary);
1190 }
1191
1192 #[test]
1193 fn test_detect_format_ascii() {
1194 let text = "# OxiPhysics restart\nVERSION 1.0\n";
1195 assert_eq!(
1196 RestartReader::detect_format(text.as_bytes()),
1197 RestartFormat::Ascii
1198 );
1199 }
1200
1201 #[test]
1202 fn test_detect_format_json() {
1203 let j = "{\"n_particles\": 0}";
1204 assert_eq!(
1205 RestartReader::detect_format(j.as_bytes()),
1206 RestartFormat::Json
1207 );
1208 }
1209
1210 #[test]
1213 fn test_validator_sum_consistent() {
1214 let bytes = b"hello world";
1215 let sum = RestartValidator::checksum_sum(bytes);
1216 assert!(RestartValidator::verify_sum(bytes, sum));
1217 }
1218
1219 #[test]
1220 fn test_validator_xor_consistent() {
1221 let bytes = b"test data";
1222 let xor = RestartValidator::checksum_xor(bytes);
1223 assert!(RestartValidator::verify_xor(bytes, xor));
1224 }
1225
1226 #[test]
1227 fn test_validator_sum_detects_corruption() {
1228 let bytes = b"original";
1229 let sum = RestartValidator::checksum_sum(bytes);
1230 let corrupt = b"0riginal";
1231 assert!(!RestartValidator::verify_sum(corrupt, sum));
1232 }
1233
1234 #[test]
1235 fn test_validator_xor_empty() {
1236 let xor = RestartValidator::checksum_xor(b"");
1237 assert_eq!(xor, 0);
1238 }
1239
1240 #[test]
1243 fn test_incremental_mark_and_extract() {
1244 let full = make_data(5);
1245 let mut inc = IncrementalRestart::new();
1246 inc.mark_changed(1);
1247 inc.mark_changed(3);
1248 let delta = inc.extract_delta(&full);
1249 assert_eq!(delta.n_particles(), 2);
1250 assert!((delta.positions[0][0] - full.positions[1][0]).abs() < 1e-12);
1251 }
1252
1253 #[test]
1254 fn test_incremental_reset_clears_changes() {
1255 let mut inc = IncrementalRestart::new();
1256 inc.mark_changed(0);
1257 inc.mark_changed(2);
1258 inc.reset();
1259 assert!(inc.changed.is_empty());
1260 }
1261
1262 #[test]
1263 fn test_incremental_no_duplicates() {
1264 let mut inc = IncrementalRestart::new();
1265 inc.mark_changed(0);
1266 inc.mark_changed(0);
1267 inc.mark_changed(0);
1268 assert_eq!(inc.changed.len(), 1);
1269 }
1270
1271 #[test]
1274 fn test_restart_to_xyz_header_line_count() {
1275 let d = make_data(4);
1276 let xyz = restart_to_xyz(&d);
1277 let lines: Vec<&str> = xyz.lines().collect();
1278 assert!(lines.len() >= 6); assert_eq!(lines[0], "4");
1280 }
1281
1282 #[test]
1283 fn test_restart_to_lammps_dump_has_timestep() {
1284 let d = make_data(2);
1285 let dump = restart_to_lammps_dump(&d);
1286 assert!(dump.contains("ITEM: TIMESTEP"));
1287 assert!(dump.contains("ITEM: NUMBER OF ATOMS"));
1288 assert!(dump.contains("ITEM: ATOMS"));
1289 }
1290
1291 #[test]
1292 fn test_restart_to_lammps_dump_atom_count() {
1293 let d = make_data(3);
1294 let dump = restart_to_lammps_dump(&d);
1295 let atom_lines: Vec<&str> = dump
1296 .lines()
1297 .skip_while(|l| !l.starts_with("ITEM: ATOMS"))
1298 .skip(1)
1299 .collect();
1300 assert_eq!(atom_lines.len(), 3);
1301 }
1302
1303 #[test]
1306 fn test_checkpoint_manager_list_empty() {
1307 let mgr = CheckpointManager::new("/tmp/oxi_test_ckpt_empty", 3);
1308 assert!(mgr.list_checkpoints().is_empty());
1309 }
1310
1311 #[test]
1312 fn test_checkpoint_manager_prune_keeps_max() {
1313 let dir = "/tmp/oxi_test_ckpt_prune";
1314 let _ = std::fs::remove_dir_all(dir);
1315 let mut mgr = CheckpointManager::new(dir, 2);
1316 let d = make_data(1);
1317 mgr.save_checkpoint(&d, 1);
1318 mgr.save_checkpoint(&d, 2);
1319 mgr.save_checkpoint(&d, 3);
1320 assert_eq!(mgr.list_checkpoints().len(), 2);
1321 let latest = mgr.list_checkpoints().last().unwrap().0;
1323 assert_eq!(latest, 3);
1324 let _ = std::fs::remove_dir_all(dir);
1325 }
1326}