1use std::collections::HashMap;
13
14use ndarray::{Array3, ArrayD};
15
16use crate::error::MolRsError;
17use crate::types::F;
18
19#[derive(Debug, Clone, PartialEq)]
27pub struct Grid {
28 pub dim: [usize; 3],
30 pub origin: [F; 3],
32 pub cell: [[F; 3]; 3],
34 pub pbc: [bool; 3],
36 arrays: HashMap<String, Vec<F>>,
38}
39
40impl Grid {
41 pub fn new(dim: [usize; 3], origin: [F; 3], cell: [[F; 3]; 3], pbc: [bool; 3]) -> Self {
43 Self {
44 dim,
45 origin,
46 cell,
47 pbc,
48 arrays: HashMap::new(),
49 }
50 }
51
52 pub fn total(&self) -> usize {
54 self.dim[0] * self.dim[1] * self.dim[2]
55 }
56
57 pub fn insert(&mut self, name: impl Into<String>, data: Vec<F>) -> Result<(), MolRsError> {
61 let expected = self.total();
62 let name = name.into();
63 if data.len() != expected {
64 return Err(MolRsError::validation(format!(
65 "grid array '{}' length mismatch: expected {}, got {}",
66 name,
67 expected,
68 data.len()
69 )));
70 }
71 self.arrays.insert(name, data);
72 Ok(())
73 }
74
75 pub fn get(&self, name: &str) -> Option<ArrayD<F>> {
77 self.arrays.get(name).map(|data| {
78 Array3::from_shape_vec([self.dim[0], self.dim[1], self.dim[2]], data.clone())
79 .expect("grid shape matches stored data")
80 .into_dyn()
81 })
82 }
83
84 pub fn get_raw(&self, name: &str) -> Option<&[F]> {
86 self.arrays.get(name).map(|v| v.as_slice())
87 }
88
89 pub fn contains(&self, name: &str) -> bool {
91 self.arrays.contains_key(name)
92 }
93
94 pub fn len(&self) -> usize {
96 self.arrays.len()
97 }
98
99 pub fn is_empty(&self) -> bool {
101 self.arrays.is_empty()
102 }
103
104 pub fn raw_arrays(&self) -> impl Iterator<Item = (&str, &[F])> {
106 self.arrays.iter().map(|(k, v)| (k.as_str(), v.as_slice()))
107 }
108
109 pub fn keys(&self) -> impl Iterator<Item = &str> {
111 self.arrays.keys().map(|s| s.as_str())
112 }
113
114 pub fn voxel_position(&self, ix: usize, iy: usize, iz: usize) -> [F; 3] {
116 let fx = ix as F / self.dim[0] as F;
117 let fy = iy as F / self.dim[1] as F;
118 let fz = iz as F / self.dim[2] as F;
119 [
120 self.origin[0] + fx * self.cell[0][0] + fy * self.cell[1][0] + fz * self.cell[2][0],
121 self.origin[1] + fx * self.cell[0][1] + fy * self.cell[1][1] + fz * self.cell[2][1],
122 self.origin[2] + fx * self.cell[0][2] + fy * self.cell[1][2] + fz * self.cell[2][2],
123 ]
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn insert_validates_length() {
133 let mut g = Grid::new(
134 [2, 2, 2],
135 [0.0; 3],
136 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
137 [false; 3],
138 );
139 assert!(g.insert("rho", vec![0.0; 7]).is_err());
140 assert!(g.insert("rho", vec![0.0; 8]).is_ok());
141 }
142
143 #[test]
144 fn get_returns_shaped_array() {
145 let mut g = Grid::new(
146 [2, 3, 4],
147 [0.0; 3],
148 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
149 [false; 3],
150 );
151 g.insert("rho", (0..24).map(|x| x as F).collect()).unwrap();
152 let arr = g.get("rho").unwrap();
153 assert_eq!(arr.shape(), &[2, 3, 4]);
154 }
155}