1use std::collections::BTreeMap;
4
5use serde::{Deserialize, Serialize};
6use serde_json::{Map as JsonMap, Value as JsonValue};
7
8use crate::MolRsError;
9use crate::block::Column;
10use crate::frame::Frame;
11use crate::grid::Grid;
12use crate::types::F;
13
14pub type SchemaValue = JsonValue;
16
17#[derive(Debug, Clone, Default)]
19pub struct Trajectory {
20 pub frames: Vec<Frame>,
22 pub step: Option<Vec<i64>>,
24 pub time: Option<Vec<F>>,
26}
27
28impl Trajectory {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn from_frames(frames: Vec<Frame>) -> Self {
36 Self {
37 frames,
38 step: None,
39 time: None,
40 }
41 }
42
43 pub fn len(&self) -> usize {
45 self.frames.len()
46 }
47
48 pub fn is_empty(&self) -> bool {
50 self.frames.is_empty()
51 }
52
53 pub fn validate(&self) -> Result<(), MolRsError> {
55 let n = self.frames.len();
56 if let Some(step) = &self.step
57 && step.len() != n
58 {
59 return Err(MolRsError::validation(format!(
60 "trajectory.step length mismatch: expected {}, got {}",
61 n,
62 step.len()
63 )));
64 }
65 if let Some(time) = &self.time
66 && time.len() != n
67 {
68 return Err(MolRsError::validation(format!(
69 "trajectory.time length mismatch: expected {}, got {}",
70 n,
71 time.len()
72 )));
73 }
74 Ok(())
75 }
76}
77
78#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
80#[serde(rename_all = "snake_case")]
81pub enum ObservableKind {
82 Scalar,
83 Vector,
84 Grid,
85}
86
87#[derive(Debug, Clone)]
89pub enum ObservableData {
90 Column(Column),
92 Grid(Grid),
94}
95
96#[derive(Debug, Clone)]
98pub struct ObservableRecord {
99 pub name: String,
100 pub kind: ObservableKind,
101 pub description: String,
102 pub time_dependent: bool,
103 pub unit: Option<String>,
104 pub axes: Vec<String>,
105 pub sampling: Option<String>,
106 pub domain: Option<String>,
107 pub target: Option<String>,
108 pub extra: JsonMap<String, JsonValue>,
109 pub data: ObservableData,
110}
111
112impl ObservableRecord {
113 pub fn scalar(name: impl Into<String>, data: Column) -> Self {
115 Self {
116 name: name.into(),
117 kind: ObservableKind::Scalar,
118 description: String::new(),
119 time_dependent: false,
120 unit: None,
121 axes: Vec::new(),
122 sampling: None,
123 domain: None,
124 target: None,
125 extra: JsonMap::new(),
126 data: ObservableData::Column(data),
127 }
128 }
129
130 pub fn vector(name: impl Into<String>, data: Column) -> Self {
132 Self {
133 name: name.into(),
134 kind: ObservableKind::Vector,
135 description: String::new(),
136 time_dependent: false,
137 unit: None,
138 axes: Vec::new(),
139 sampling: None,
140 domain: None,
141 target: None,
142 extra: JsonMap::new(),
143 data: ObservableData::Column(data),
144 }
145 }
146
147 pub fn grid(name: impl Into<String>, grid: Grid) -> Self {
149 Self {
150 name: name.into(),
151 kind: ObservableKind::Grid,
152 description: String::new(),
153 time_dependent: false,
154 unit: None,
155 axes: Vec::new(),
156 sampling: None,
157 domain: None,
158 target: None,
159 extra: JsonMap::new(),
160 data: ObservableData::Grid(grid),
161 }
162 }
163
164 pub fn validate(&self) -> Result<(), MolRsError> {
166 match (&self.kind, &self.data) {
167 (ObservableKind::Scalar | ObservableKind::Vector, ObservableData::Column(_)) => Ok(()),
168 (ObservableKind::Grid, ObservableData::Grid(_)) => Ok(()),
169 _ => Err(MolRsError::validation("observable kind/data mismatch")),
170 }
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct MolRec {
177 pub meta: SchemaValue,
179 pub frame: Frame,
181 pub trajectory: Option<Trajectory>,
183 pub observables: BTreeMap<String, ObservableRecord>,
185 pub method: SchemaValue,
187 pub parameters: SchemaValue,
189}
190
191impl Default for MolRec {
192 fn default() -> Self {
193 Self::new(Frame::new())
194 }
195}
196
197impl MolRec {
198 pub fn new(frame: Frame) -> Self {
200 Self {
201 meta: empty_object(),
202 frame,
203 trajectory: None,
204 observables: BTreeMap::new(),
205 method: empty_object(),
206 parameters: empty_object(),
207 }
208 }
209
210 pub fn from_frames(frame: Frame, frames: Vec<Frame>) -> Self {
212 let mut rec = Self::new(frame);
213 let trajectory = Trajectory::from_frames(frames);
214 rec.trajectory = Some(trajectory);
215 rec
216 }
217
218 pub fn from_trajectory(trajectory: Trajectory) -> Result<Self, MolRsError> {
222 trajectory.validate()?;
223 let Some(frame) = trajectory.frames.first().cloned() else {
224 return Err(MolRsError::validation(
225 "cannot build MolRec from an empty trajectory",
226 ));
227 };
228 let mut rec = Self::new(frame);
229 rec.trajectory = Some(trajectory);
230 Ok(rec)
231 }
232
233 pub fn count_frames(&self) -> usize {
235 match &self.trajectory {
236 Some(traj) if !traj.frames.is_empty() => traj.frames.len(),
237 _ => 1,
238 }
239 }
240
241 pub fn frame_at(&self, index: usize) -> Option<Frame> {
243 match &self.trajectory {
244 Some(traj) if !traj.frames.is_empty() => traj.frames.get(index).cloned(),
245 _ if index == 0 => Some(self.frame.clone()),
246 _ => None,
247 }
248 }
249
250 pub fn set_frame(&mut self, frame: Frame) {
252 self.frame = frame;
253 }
254
255 pub fn add_frame(&mut self, frame: Frame) {
257 match &mut self.trajectory {
258 Some(traj) => traj.frames.push(frame),
259 None => {
260 self.trajectory = Some(Trajectory::from_frames(vec![frame]));
261 }
262 }
263 }
264
265 pub fn set_trajectory(&mut self, trajectory: Option<Trajectory>) {
267 self.trajectory = trajectory;
268 }
269
270 pub fn add_observable(&mut self, observable: ObservableRecord) -> Option<ObservableRecord> {
272 self.observables.insert(observable.name.clone(), observable)
273 }
274
275 pub fn get_observable(&self, name: &str) -> Option<&ObservableRecord> {
277 self.observables.get(name)
278 }
279
280 pub fn remove_observable(&mut self, name: &str) -> Option<ObservableRecord> {
282 self.observables.remove(name)
283 }
284}
285
286fn empty_object() -> JsonValue {
287 JsonValue::Object(JsonMap::new())
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn static_molrec_counts_one_frame() {
296 let rec = MolRec::new(Frame::new());
297 assert_eq!(rec.count_frames(), 1);
298 assert!(rec.frame_at(0).is_some());
299 assert!(rec.frame_at(1).is_none());
300 }
301
302 #[test]
303 fn from_trajectory_uses_first_frame_as_canonical() {
304 let mut traj = Trajectory::new();
305 traj.frames.push(Frame::new());
306 traj.frames.push(Frame::new());
307 let rec = MolRec::from_trajectory(traj).unwrap();
308 assert_eq!(rec.count_frames(), 2);
309 }
310}