1#![allow(missing_docs)]
11#![allow(dead_code)]
12
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct PyVtkWriter {
22 pub filename: String,
24 pub point_data: Vec<(String, Vec<f64>)>,
26 pub cell_data: Vec<(String, Vec<f64>)>,
28}
29
30impl PyVtkWriter {
31 pub fn new(filename: impl Into<String>) -> Self {
33 Self {
34 filename: filename.into(),
35 point_data: Vec::new(),
36 cell_data: Vec::new(),
37 }
38 }
39
40 pub fn add_point_data(&mut self, name: impl Into<String>, data: Vec<f64>) {
42 self.point_data.push((name.into(), data));
43 }
44
45 pub fn add_cell_data(&mut self, name: impl Into<String>, data: Vec<f64>) {
47 self.cell_data.push((name.into(), data));
48 }
49
50 pub fn write_ascii(&self) -> String {
52 let mut out = format!(
53 "# vtk DataFile Version 3.0\nOxiPhysics output\nASCII\nDATASET UNSTRUCTURED_GRID\nfile={}\n",
54 self.filename
55 );
56 for (name, data) in &self.point_data {
57 out.push_str(&format!("POINT_DATA {} len={}\n", name, data.len()));
58 }
59 for (name, data) in &self.cell_data {
60 out.push_str(&format!("CELL_DATA {} len={}\n", name, data.len()));
61 }
62 out
63 }
64
65 pub fn write_binary(&self) -> usize {
67 let base = self.filename.len() + 64;
68 let pd: usize = self.point_data.iter().map(|(_, v)| v.len() * 8).sum();
69 let cd: usize = self.cell_data.iter().map(|(_, v)| v.len() * 8).sum();
70 base + pd + cd
71 }
72
73 pub fn n_point_arrays(&self) -> usize {
75 self.point_data.len()
76 }
77
78 pub fn n_cell_arrays(&self) -> usize {
80 self.cell_data.len()
81 }
82}
83
84impl Default for PyVtkWriter {
85 fn default() -> Self {
86 Self::new("output.vtk")
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct PyCsvReader {
97 pub filename: String,
99 pub headers: Vec<String>,
101 pub rows: Vec<Vec<f64>>,
103}
104
105impl PyCsvReader {
106 pub fn new(filename: impl Into<String>) -> Self {
108 Self {
109 filename: filename.into(),
110 headers: Vec::new(),
111 rows: Vec::new(),
112 }
113 }
114
115 pub fn load_data(&mut self, headers: Vec<String>, rows: Vec<Vec<f64>>) {
117 self.headers = headers;
118 self.rows = rows;
119 }
120
121 pub fn read_column(&self, col: usize) -> Vec<f64> {
123 self.rows
124 .iter()
125 .filter_map(|r| r.get(col).copied())
126 .collect()
127 }
128
129 pub fn read_all_f64(&self) -> Vec<f64> {
131 self.rows.iter().flat_map(|r| r.iter().copied()).collect()
132 }
133
134 pub fn header_names(&self) -> &[String] {
136 &self.headers
137 }
138
139 pub fn n_rows(&self) -> usize {
141 self.rows.len()
142 }
143
144 pub fn n_cols(&self) -> usize {
146 self.rows.first().map_or(0, |r| r.len())
147 }
148}
149
150impl Default for PyCsvReader {
151 fn default() -> Self {
152 Self::new("input.csv")
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct PyCsvWriter {
163 pub filename: String,
165 pub buffer: Vec<Vec<f64>>,
167}
168
169impl PyCsvWriter {
170 pub fn new(filename: impl Into<String>) -> Self {
172 Self {
173 filename: filename.into(),
174 buffer: Vec::new(),
175 }
176 }
177
178 pub fn write_row(&mut self, data: Vec<f64>) {
180 self.buffer.push(data);
181 }
182
183 pub fn flush(&mut self) -> String {
185 let csv = self
186 .buffer
187 .iter()
188 .map(|row| {
189 row.iter()
190 .map(|v| v.to_string())
191 .collect::<Vec<_>>()
192 .join(",")
193 })
194 .collect::<Vec<_>>()
195 .join("\n");
196 self.buffer.clear();
197 csv
198 }
199
200 pub fn buffered_rows(&self) -> usize {
202 self.buffer.len()
203 }
204}
205
206impl Default for PyCsvWriter {
207 fn default() -> Self {
208 Self::new("output.csv")
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct PyXyzReader {
219 pub filename: String,
221 pub species: Vec<String>,
223 pub pos_flat: Vec<f64>,
225}
226
227impl PyXyzReader {
228 pub fn new(filename: impl Into<String>) -> Self {
230 Self {
231 filename: filename.into(),
232 species: Vec::new(),
233 pos_flat: Vec::new(),
234 }
235 }
236
237 pub fn load_data(&mut self, species: Vec<String>, pos_flat: Vec<f64>) {
239 self.species = species;
240 self.pos_flat = pos_flat;
241 }
242
243 pub fn n_atoms(&self) -> usize {
245 self.species.len()
246 }
247
248 pub fn positions(&self) -> &[f64] {
250 &self.pos_flat
251 }
252
253 pub fn species(&self) -> &[String] {
255 &self.species
256 }
257
258 pub fn position_of(&self, i: usize) -> Option<[f64; 3]> {
260 let base = i * 3;
261 if base + 2 < self.pos_flat.len() {
262 Some([
263 self.pos_flat[base],
264 self.pos_flat[base + 1],
265 self.pos_flat[base + 2],
266 ])
267 } else {
268 None
269 }
270 }
271}
272
273impl Default for PyXyzReader {
274 fn default() -> Self {
275 Self::new("atoms.xyz")
276 }
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct PyXyzWriter {
286 pub filename: String,
288 pub frames: Vec<String>,
290}
291
292impl PyXyzWriter {
293 pub fn new(filename: impl Into<String>) -> Self {
295 Self {
296 filename: filename.into(),
297 frames: Vec::new(),
298 }
299 }
300
301 pub fn write_frame(
306 &mut self,
307 positions: &[f64],
308 species: &[String],
309 comment: impl Into<String>,
310 ) {
311 let n = species.len();
312 let mut frame = format!("{}\n{}\n", n, comment.into());
313 for (i, sp) in species.iter().enumerate() {
314 let base = i * 3;
315 let (x, y, z) = if base + 2 < positions.len() {
316 (positions[base], positions[base + 1], positions[base + 2])
317 } else {
318 (0.0, 0.0, 0.0)
319 };
320 frame.push_str(&format!("{} {} {} {}\n", sp, x, y, z));
321 }
322 self.frames.push(frame);
323 }
324
325 pub fn n_frames(&self) -> usize {
327 self.frames.len()
328 }
329
330 pub fn as_string(&self) -> String {
332 self.frames.concat()
333 }
334}
335
336impl Default for PyXyzWriter {
337 fn default() -> Self {
338 Self::new("output.xyz")
339 }
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct PyLammpsReader {
349 pub filename: String,
351 pub atom_data: Vec<Vec<f64>>,
353 pub box_bounds: [[f64; 2]; 3],
355}
356
357impl PyLammpsReader {
358 pub fn new(filename: impl Into<String>) -> Self {
360 Self {
361 filename: filename.into(),
362 atom_data: Vec::new(),
363 box_bounds: [[0.0, 1.0]; 3],
364 }
365 }
366
367 pub fn load_data(&mut self, atom_data: Vec<Vec<f64>>, box_bounds: [[f64; 2]; 3]) {
369 self.atom_data = atom_data;
370 self.box_bounds = box_bounds;
371 }
372
373 pub fn read_atoms(&self) -> &[Vec<f64>] {
375 &self.atom_data
376 }
377
378 pub fn n_atoms(&self) -> usize {
380 self.atom_data.len()
381 }
382
383 pub fn box_bounds(&self) -> [[f64; 2]; 3] {
385 self.box_bounds
386 }
387
388 pub fn box_lengths(&self) -> [f64; 3] {
390 [
391 self.box_bounds[0][1] - self.box_bounds[0][0],
392 self.box_bounds[1][1] - self.box_bounds[1][0],
393 self.box_bounds[2][1] - self.box_bounds[2][0],
394 ]
395 }
396}
397
398impl Default for PyLammpsReader {
399 fn default() -> Self {
400 Self::new("dump.lammpstrj")
401 }
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct PyHdf5Writer {
411 pub filename: String,
413 pub datasets: Vec<(String, Vec<f64>)>,
415 pub attributes: Vec<(String, f64)>,
417}
418
419impl PyHdf5Writer {
420 pub fn new(filename: impl Into<String>) -> Self {
422 Self {
423 filename: filename.into(),
424 datasets: Vec::new(),
425 attributes: Vec::new(),
426 }
427 }
428
429 pub fn write_dataset(&mut self, name: impl Into<String>, data: Vec<f64>) {
431 self.datasets.push((name.into(), data));
432 }
433
434 pub fn write_attribute(&mut self, name: impl Into<String>, value: f64) {
436 self.attributes.push((name.into(), value));
437 }
438
439 pub fn n_datasets(&self) -> usize {
441 self.datasets.len()
442 }
443
444 pub fn n_attributes(&self) -> usize {
446 self.attributes.len()
447 }
448
449 pub fn get_dataset(&self, name: &str) -> Option<&Vec<f64>> {
451 self.datasets
452 .iter()
453 .find(|(n, _)| n == name)
454 .map(|(_, d)| d)
455 }
456
457 pub fn get_attribute(&self, name: &str) -> Option<f64> {
459 self.attributes
460 .iter()
461 .find(|(n, _)| n == name)
462 .map(|(_, v)| *v)
463 }
464}
465
466impl Default for PyHdf5Writer {
467 fn default() -> Self {
468 Self::new("output.h5")
469 }
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct PyTrajectoryWriter {
479 pub filename: String,
481 pub format: String,
483 pub closed: bool,
485 pub frame_count: usize,
487 pub frame_buffer: Vec<String>,
489}
490
491impl PyTrajectoryWriter {
492 pub fn new(filename: impl Into<String>, format: impl Into<String>) -> Self {
494 Self {
495 filename: filename.into(),
496 format: format.into(),
497 closed: false,
498 frame_count: 0,
499 frame_buffer: Vec::new(),
500 }
501 }
502
503 pub fn write_frame(&mut self, positions: &[f64], velocities: &[f64], step: u64) {
508 if self.closed {
509 return;
510 }
511 let frame = format!(
512 "FRAME step={} n_pos={} n_vel={} fmt={}\n",
513 step,
514 positions.len(),
515 velocities.len(),
516 self.format
517 );
518 self.frame_buffer.push(frame);
519 self.frame_count += 1;
520 }
521
522 pub fn close(&mut self) {
524 self.closed = true;
525 }
526
527 pub fn is_closed(&self) -> bool {
529 self.closed
530 }
531
532 pub fn n_frames(&self) -> usize {
534 self.frame_count
535 }
536
537 pub fn as_string(&self) -> String {
539 self.frame_buffer.concat()
540 }
541}
542
543impl Default for PyTrajectoryWriter {
544 fn default() -> Self {
545 Self::new("trajectory.xyz", "xyz")
546 }
547}
548
549pub fn register_io_module(_m: &str) {
559 }
561
562#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
573 fn test_vtk_new() {
574 let w = PyVtkWriter::new("out.vtk");
575 assert_eq!(w.filename, "out.vtk");
576 assert_eq!(w.n_point_arrays(), 0);
577 }
578
579 #[test]
580 fn test_vtk_add_point_data() {
581 let mut w = PyVtkWriter::default();
582 w.add_point_data("pressure", vec![1.0, 2.0, 3.0]);
583 assert_eq!(w.n_point_arrays(), 1);
584 }
585
586 #[test]
587 fn test_vtk_add_cell_data() {
588 let mut w = PyVtkWriter::default();
589 w.add_cell_data("stress", vec![10.0, 20.0]);
590 assert_eq!(w.n_cell_arrays(), 1);
591 }
592
593 #[test]
594 fn test_vtk_write_ascii_contains_header() {
595 let w = PyVtkWriter::new("test.vtk");
596 let s = w.write_ascii();
597 assert!(s.contains("vtk DataFile"));
598 }
599
600 #[test]
601 fn test_vtk_write_ascii_contains_point_data_name() {
602 let mut w = PyVtkWriter::new("test.vtk");
603 w.add_point_data("velocity", vec![1.0, 2.0]);
604 let s = w.write_ascii();
605 assert!(s.contains("velocity"));
606 }
607
608 #[test]
609 fn test_vtk_write_binary_size_grows() {
610 let mut w = PyVtkWriter::new("test.vtk");
611 let s0 = w.write_binary();
612 w.add_point_data("p", vec![1.0; 100]);
613 let s1 = w.write_binary();
614 assert!(s1 > s0);
615 }
616
617 #[test]
618 fn test_vtk_default() {
619 let w = PyVtkWriter::default();
620 assert!(w.filename.ends_with(".vtk"));
621 }
622
623 #[test]
626 fn test_csv_reader_new() {
627 let r = PyCsvReader::new("data.csv");
628 assert_eq!(r.filename, "data.csv");
629 assert_eq!(r.n_rows(), 0);
630 }
631
632 #[test]
633 fn test_csv_reader_load_and_read_column() {
634 let mut r = PyCsvReader::default();
635 r.load_data(
636 vec!["x".to_string(), "y".to_string()],
637 vec![vec![1.0, 2.0], vec![3.0, 4.0]],
638 );
639 let col0 = r.read_column(0);
640 assert_eq!(col0, vec![1.0, 3.0]);
641 }
642
643 #[test]
644 fn test_csv_reader_read_all_f64() {
645 let mut r = PyCsvReader::default();
646 r.load_data(vec![], vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
647 let all = r.read_all_f64();
648 assert_eq!(all.len(), 4);
649 }
650
651 #[test]
652 fn test_csv_reader_header_names() {
653 let mut r = PyCsvReader::default();
654 r.load_data(vec!["a".to_string(), "b".to_string()], vec![]);
655 assert_eq!(r.header_names().len(), 2);
656 }
657
658 #[test]
659 fn test_csv_reader_n_cols() {
660 let mut r = PyCsvReader::default();
661 r.load_data(vec![], vec![vec![1.0, 2.0, 3.0]]);
662 assert_eq!(r.n_cols(), 3);
663 }
664
665 #[test]
666 fn test_csv_reader_empty_n_cols_zero() {
667 let r = PyCsvReader::default();
668 assert_eq!(r.n_cols(), 0);
669 }
670
671 #[test]
674 fn test_csv_writer_new() {
675 let w = PyCsvWriter::new("out.csv");
676 assert_eq!(w.filename, "out.csv");
677 assert_eq!(w.buffered_rows(), 0);
678 }
679
680 #[test]
681 fn test_csv_writer_write_row() {
682 let mut w = PyCsvWriter::default();
683 w.write_row(vec![1.0, 2.0, 3.0]);
684 assert_eq!(w.buffered_rows(), 1);
685 }
686
687 #[test]
688 fn test_csv_writer_flush_clears_buffer() {
689 let mut w = PyCsvWriter::default();
690 w.write_row(vec![1.0]);
691 w.flush();
692 assert_eq!(w.buffered_rows(), 0);
693 }
694
695 #[test]
696 fn test_csv_writer_flush_returns_csv() {
697 let mut w = PyCsvWriter::default();
698 w.write_row(vec![1.0, 2.0]);
699 let s = w.flush();
700 assert!(s.contains("1") && s.contains("2"));
701 }
702
703 #[test]
704 fn test_csv_writer_default() {
705 let w = PyCsvWriter::default();
706 assert!(w.filename.ends_with(".csv"));
707 }
708
709 #[test]
712 fn test_xyz_reader_new() {
713 let r = PyXyzReader::new("mol.xyz");
714 assert_eq!(r.filename, "mol.xyz");
715 assert_eq!(r.n_atoms(), 0);
716 }
717
718 #[test]
719 fn test_xyz_reader_load_and_n_atoms() {
720 let mut r = PyXyzReader::default();
721 r.load_data(
722 vec!["C".to_string(), "H".to_string()],
723 vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
724 );
725 assert_eq!(r.n_atoms(), 2);
726 }
727
728 #[test]
729 fn test_xyz_reader_positions() {
730 let mut r = PyXyzReader::default();
731 r.load_data(vec!["O".to_string()], vec![1.0, 2.0, 3.0]);
732 assert_eq!(r.positions(), &[1.0, 2.0, 3.0]);
733 }
734
735 #[test]
736 fn test_xyz_reader_species() {
737 let mut r = PyXyzReader::default();
738 r.load_data(vec!["N".to_string()], vec![0.0, 0.0, 0.0]);
739 assert_eq!(r.species()[0], "N");
740 }
741
742 #[test]
743 fn test_xyz_reader_position_of() {
744 let mut r = PyXyzReader::default();
745 r.load_data(
746 vec!["C".to_string(), "H".to_string()],
747 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
748 );
749 let p = r.position_of(0).unwrap();
750 assert_eq!(p, [1.0, 2.0, 3.0]);
751 }
752
753 #[test]
754 fn test_xyz_reader_default() {
755 let r = PyXyzReader::default();
756 assert!(r.filename.ends_with(".xyz"));
757 }
758
759 #[test]
762 fn test_xyz_writer_new() {
763 let w = PyXyzWriter::new("out.xyz");
764 assert_eq!(w.filename, "out.xyz");
765 assert_eq!(w.n_frames(), 0);
766 }
767
768 #[test]
769 fn test_xyz_writer_write_frame_increments_count() {
770 let mut w = PyXyzWriter::default();
771 w.write_frame(&[0.0, 0.0, 0.0], &["C".to_string()], "frame 0");
772 assert_eq!(w.n_frames(), 1);
773 }
774
775 #[test]
776 fn test_xyz_writer_as_string_contains_n_atoms() {
777 let mut w = PyXyzWriter::default();
778 w.write_frame(&[0.0, 0.0, 0.0], &["C".to_string()], "test");
779 assert!(w.as_string().contains('1'));
780 }
781
782 #[test]
783 fn test_xyz_writer_multiple_frames() {
784 let mut w = PyXyzWriter::default();
785 for _ in 0..5 {
786 w.write_frame(&[0.0, 0.0, 0.0], &["H".to_string()], "");
787 }
788 assert_eq!(w.n_frames(), 5);
789 }
790
791 #[test]
792 fn test_xyz_writer_default() {
793 let w = PyXyzWriter::default();
794 assert!(w.filename.ends_with(".xyz"));
795 }
796
797 #[test]
800 fn test_lammps_reader_new() {
801 let r = PyLammpsReader::new("dump.lammps");
802 assert_eq!(r.n_atoms(), 0);
803 }
804
805 #[test]
806 fn test_lammps_reader_load_and_n_atoms() {
807 let mut r = PyLammpsReader::default();
808 r.load_data(vec![vec![1.0, 1.0, 0.0, 0.5, 0.5]], [[0.0, 1.0]; 3]);
809 assert_eq!(r.n_atoms(), 1);
810 }
811
812 #[test]
813 fn test_lammps_reader_box_bounds() {
814 let mut r = PyLammpsReader::default();
815 r.load_data(vec![], [[-5.0, 5.0], [-5.0, 5.0], [-5.0, 5.0]]);
816 let b = r.box_bounds();
817 assert_eq!(b[0], [-5.0, 5.0]);
818 }
819
820 #[test]
821 fn test_lammps_reader_box_lengths() {
822 let mut r = PyLammpsReader::default();
823 r.load_data(vec![], [[0.0, 10.0], [0.0, 20.0], [0.0, 30.0]]);
824 let l = r.box_lengths();
825 assert_eq!(l, [10.0, 20.0, 30.0]);
826 }
827
828 #[test]
829 fn test_lammps_reader_read_atoms() {
830 let mut r = PyLammpsReader::default();
831 let atom = vec![1.0, 1.0, 0.1, 0.2, 0.3];
832 r.load_data(vec![atom.clone()], [[0.0, 1.0]; 3]);
833 assert_eq!(r.read_atoms()[0], atom);
834 }
835
836 #[test]
837 fn test_lammps_reader_default() {
838 let r = PyLammpsReader::default();
839 assert!(!r.filename.is_empty());
840 }
841
842 #[test]
845 fn test_hdf5_writer_new() {
846 let w = PyHdf5Writer::new("out.h5");
847 assert_eq!(w.filename, "out.h5");
848 assert_eq!(w.n_datasets(), 0);
849 }
850
851 #[test]
852 fn test_hdf5_writer_write_dataset() {
853 let mut w = PyHdf5Writer::default();
854 w.write_dataset("pressure", vec![1.0, 2.0, 3.0]);
855 assert_eq!(w.n_datasets(), 1);
856 }
857
858 #[test]
859 fn test_hdf5_writer_write_attribute() {
860 let mut w = PyHdf5Writer::default();
861 w.write_attribute("timestep", 0.001);
862 assert_eq!(w.n_attributes(), 1);
863 }
864
865 #[test]
866 fn test_hdf5_writer_get_dataset() {
867 let mut w = PyHdf5Writer::default();
868 w.write_dataset("vel", vec![1.0, 2.0]);
869 let d = w.get_dataset("vel").unwrap();
870 assert_eq!(d.len(), 2);
871 }
872
873 #[test]
874 fn test_hdf5_writer_get_attribute() {
875 let mut w = PyHdf5Writer::default();
876 w.write_attribute("dt", 1e-4);
877 let v = w.get_attribute("dt").unwrap();
878 assert!((v - 1e-4).abs() < 1e-12);
879 }
880
881 #[test]
882 fn test_hdf5_writer_missing_dataset_none() {
883 let w = PyHdf5Writer::default();
884 assert!(w.get_dataset("missing").is_none());
885 }
886
887 #[test]
888 fn test_hdf5_writer_default() {
889 let w = PyHdf5Writer::default();
890 assert!(w.filename.ends_with(".h5"));
891 }
892
893 #[test]
896 fn test_trajectory_writer_new() {
897 let w = PyTrajectoryWriter::new("traj.xyz", "xyz");
898 assert_eq!(w.format, "xyz");
899 assert_eq!(w.n_frames(), 0);
900 }
901
902 #[test]
903 fn test_trajectory_writer_write_frame() {
904 let mut w = PyTrajectoryWriter::default();
905 w.write_frame(&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0], 0);
906 assert_eq!(w.n_frames(), 1);
907 }
908
909 #[test]
910 fn test_trajectory_writer_close() {
911 let mut w = PyTrajectoryWriter::default();
912 w.close();
913 assert!(w.is_closed());
914 }
915
916 #[test]
917 fn test_trajectory_writer_no_write_after_close() {
918 let mut w = PyTrajectoryWriter::default();
919 w.close();
920 w.write_frame(&[0.0], &[], 1);
921 assert_eq!(w.n_frames(), 0);
922 }
923
924 #[test]
925 fn test_trajectory_writer_as_string_contains_step() {
926 let mut w = PyTrajectoryWriter::default();
927 w.write_frame(&[1.0, 2.0, 3.0], &[0.1, 0.2, 0.3], 42);
928 assert!(w.as_string().contains("42"));
929 }
930
931 #[test]
932 fn test_trajectory_writer_multiple_frames() {
933 let mut w = PyTrajectoryWriter::new("t.lammps", "lammps");
934 for i in 0..10_u64 {
935 w.write_frame(&[0.0], &[0.0], i);
936 }
937 assert_eq!(w.n_frames(), 10);
938 }
939
940 #[test]
941 fn test_trajectory_writer_default() {
942 let w = PyTrajectoryWriter::default();
943 assert!(!w.format.is_empty());
944 }
945
946 #[test]
949 fn test_register_io_module_no_panic() {
950 register_io_module("io");
951 }
952}