1use ndarray::Array1;
9use serde::{Deserialize, Serialize};
10
11use crate::block::Block;
12use crate::error::MolRsError;
13use crate::frame::Frame;
14use crate::types::F;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct FieldObservable {
19 pub name: String,
21 pub quantity: String,
23 pub scope: String,
25 pub domain: String,
27 pub unit: String,
29 pub encoding: FieldEncoding,
31}
32
33impl FieldObservable {
34 pub fn uniform_grid(
36 name: impl Into<String>,
37 quantity: impl Into<String>,
38 unit: impl Into<String>,
39 grid: UniformGridField,
40 ) -> Self {
41 Self {
42 name: name.into(),
43 quantity: quantity.into(),
44 scope: "field".to_string(),
45 domain: "real_space".to_string(),
46 unit: unit.into(),
47 encoding: FieldEncoding::UniformGrid(grid),
48 }
49 }
50
51 pub fn validate(&self) -> Result<(), MolRsError> {
53 if self.scope != "field" {
54 return Err(MolRsError::validation(format!(
55 "field observable scope must be 'field', got '{}'",
56 self.scope
57 )));
58 }
59 match &self.encoding {
60 FieldEncoding::UniformGrid(grid) => grid.validate(),
61 }
62 }
63
64 pub fn to_point_cloud_frame(&self, threshold: F, stride: usize) -> Result<Frame, MolRsError> {
70 match &self.encoding {
71 FieldEncoding::UniformGrid(grid) => grid.to_point_cloud_frame(threshold, stride),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78#[serde(tag = "kind", rename_all = "snake_case")]
79pub enum FieldEncoding {
80 UniformGrid(UniformGridField),
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct UniformGridField {
86 pub shape: [usize; 3],
88 pub origin: [F; 3],
90 pub cell: [[F; 3]; 3],
93 pub pbc: [bool; 3],
95 pub values: Vec<F>,
97}
98
99impl UniformGridField {
100 pub fn len(&self) -> usize {
102 self.shape.iter().product()
103 }
104
105 pub fn is_empty(&self) -> bool {
107 self.values.is_empty()
108 }
109
110 pub fn validate(&self) -> Result<(), MolRsError> {
112 if self.shape.contains(&0) {
113 return Err(MolRsError::validation(format!(
114 "uniform_grid shape must be strictly positive, got {:?}",
115 self.shape
116 )));
117 }
118 let expected = self.len();
119 if self.values.len() != expected {
120 return Err(MolRsError::validation(format!(
121 "uniform_grid values length mismatch: expected {}, got {}",
122 expected,
123 self.values.len()
124 )));
125 }
126 Ok(())
127 }
128
129 pub fn sample_position(&self, ix: usize, iy: usize, iz: usize) -> [F; 3] {
131 let fx = ix as F / self.shape[0] as F;
132 let fy = iy as F / self.shape[1] as F;
133 let fz = iz as F / self.shape[2] as F;
134 [
135 self.origin[0] + fx * self.cell[0][0] + fy * self.cell[1][0] + fz * self.cell[2][0],
136 self.origin[1] + fx * self.cell[0][1] + fy * self.cell[1][1] + fz * self.cell[2][1],
137 self.origin[2] + fx * self.cell[0][2] + fy * self.cell[1][2] + fz * self.cell[2][2],
138 ]
139 }
140
141 pub fn index(&self, ix: usize, iy: usize, iz: usize) -> usize {
143 (ix * self.shape[1] + iy) * self.shape[2] + iz
144 }
145
146 pub fn to_point_cloud_frame(&self, threshold: F, stride: usize) -> Result<Frame, MolRsError> {
148 self.validate()?;
149 let stride = stride.max(1);
150
151 let mut xs = Vec::new();
152 let mut ys = Vec::new();
153 let mut zs = Vec::new();
154 let mut density = Vec::new();
155 let mut element = Vec::new();
156
157 for ix in (0..self.shape[0]).step_by(stride) {
158 for iy in (0..self.shape[1]).step_by(stride) {
159 for iz in (0..self.shape[2]).step_by(stride) {
160 let idx = self.index(ix, iy, iz);
161 let value = self.values[idx];
162 if value.abs() < threshold {
163 continue;
164 }
165 let pos = self.sample_position(ix, iy, iz);
166 xs.push(pos[0]);
167 ys.push(pos[1]);
168 zs.push(pos[2]);
169 density.push(value);
170 element.push(String::from("He"));
174 }
175 }
176 }
177
178 let mut atoms = Block::new();
179 atoms.insert("x", Array1::from_vec(xs).into_dyn())?;
180 atoms.insert("y", Array1::from_vec(ys).into_dyn())?;
181 atoms.insert("z", Array1::from_vec(zs).into_dyn())?;
182 atoms.insert("density", Array1::from_vec(density).into_dyn())?;
183 atoms.insert("element", Array1::from_vec(element).into_dyn())?;
184
185 let mut frame = Frame::new();
186 frame.insert("atoms", atoms);
187 frame
188 .meta
189 .insert("molrec_view".into(), "field_point_cloud".into());
190 Ok(frame)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn uniform_grid_validates_shape() {
200 let grid = UniformGridField {
201 shape: [2, 2, 2],
202 origin: [0.0, 0.0, 0.0],
203 cell: [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]],
204 pbc: [false, false, false],
205 values: vec![0.0; 7],
206 };
207 assert!(grid.validate().is_err());
208 }
209
210 #[test]
211 fn uniform_grid_to_point_cloud_filters_values() {
212 let grid = UniformGridField {
213 shape: [2, 2, 2],
214 origin: [0.0, 0.0, 0.0],
215 cell: [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]],
216 pbc: [false, false, false],
217 values: vec![0.0, 0.2, 0.0, 0.3, 0.0, 0.0, 0.4, 0.0],
218 };
219
220 let frame = grid.to_point_cloud_frame(0.1, 1).unwrap();
221 let atoms = frame.get("atoms").unwrap();
222 assert_eq!(atoms.nrows(), Some(3));
223 assert!(atoms.get_float("density").is_some());
224 }
225}