1use crate::SensorOutput;
4use std::collections::HashMap;
5use phyz_model::State;
6
7pub struct TrajectoryRecorder {
9 pub q_history: Vec<Vec<f64>>,
11 pub v_history: Vec<Vec<f64>>,
13 pub ctrl_history: Vec<Vec<f64>>,
15 pub time_history: Vec<f64>,
17 pub sensor_history: Vec<Vec<SensorOutput>>,
19}
20
21impl TrajectoryRecorder {
22 pub fn new() -> Self {
24 Self {
25 q_history: Vec::new(),
26 v_history: Vec::new(),
27 ctrl_history: Vec::new(),
28 time_history: Vec::new(),
29 sensor_history: Vec::new(),
30 }
31 }
32
33 pub fn record(&mut self, state: &State) {
35 self.q_history.push(state.q.as_slice().to_vec());
36 self.v_history.push(state.v.as_slice().to_vec());
37 self.ctrl_history.push(state.ctrl.as_slice().to_vec());
38 self.time_history.push(state.time);
39 }
40
41 pub fn record_with_sensors(&mut self, state: &State, sensors: Vec<SensorOutput>) {
43 self.record(state);
44 self.sensor_history.push(sensors);
45 }
46
47 pub fn len(&self) -> usize {
49 self.time_history.len()
50 }
51
52 pub fn is_empty(&self) -> bool {
54 self.time_history.is_empty()
55 }
56
57 pub fn clear(&mut self) {
59 self.q_history.clear();
60 self.v_history.clear();
61 self.ctrl_history.clear();
62 self.time_history.clear();
63 self.sensor_history.clear();
64 }
65
66 pub fn to_json(&self) -> Result<String, serde_json::Error> {
68 let mut data = HashMap::new();
69 data.insert("q", &self.q_history);
70 data.insert("v", &self.v_history);
71 data.insert("ctrl", &self.ctrl_history);
72
73 let time_nested: Vec<Vec<f64>> = self.time_history.iter().map(|&t| vec![t]).collect();
75 data.insert("time", &time_nested);
76
77 serde_json::to_string_pretty(&data)
78 }
79
80 pub fn to_json_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
82 let json = self.to_json()?;
83 std::fs::write(path, json)?;
84 Ok(())
85 }
86
87 pub fn to_flat_dict(&self) -> HashMap<String, Vec<f64>> {
95 let mut dict = HashMap::new();
96
97 let q_flat: Vec<f64> = self.q_history.iter().flatten().copied().collect();
99 let v_flat: Vec<f64> = self.v_history.iter().flatten().copied().collect();
100 let ctrl_flat: Vec<f64> = self.ctrl_history.iter().flatten().copied().collect();
101
102 dict.insert("q".to_string(), q_flat);
103 dict.insert("v".to_string(), v_flat);
104 dict.insert("ctrl".to_string(), ctrl_flat);
105 dict.insert("time".to_string(), self.time_history.clone());
106
107 dict
108 }
109
110 pub fn stats(&self) -> TrajectoryStats {
112 if self.is_empty() {
113 return TrajectoryStats::default();
114 }
115
116 let nsteps = self.len();
117 let nq = self.q_history[0].len();
118 let nv = self.v_history[0].len();
119 let duration = self.time_history.last().unwrap() - self.time_history[0];
120
121 TrajectoryStats {
122 nsteps,
123 nq,
124 nv,
125 duration,
126 }
127 }
128}
129
130impl Default for TrajectoryRecorder {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[derive(Debug, Clone, Default)]
138pub struct TrajectoryStats {
139 pub nsteps: usize,
141 pub nq: usize,
143 pub nv: usize,
145 pub duration: f64,
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use phyz_math::DVec;
153
154 fn make_test_state(t: f64) -> State {
155 let mut state = State::new(2, 2, 1);
156 state.q = DVec::from_vec(vec![t, t * 2.0]);
157 state.v = DVec::from_vec(vec![t * 3.0, t * 4.0]);
158 state.ctrl = DVec::from_vec(vec![t * 5.0, t * 6.0]);
159 state.time = t;
160 state
161 }
162
163 #[test]
164 fn test_trajectory_recording() {
165 let mut recorder = TrajectoryRecorder::new();
166
167 for i in 0..10 {
168 let t = i as f64 * 0.1;
169 let state = make_test_state(t);
170 recorder.record(&state);
171 }
172
173 assert_eq!(recorder.len(), 10);
174 assert_eq!(recorder.q_history.len(), 10);
175 assert_eq!(recorder.v_history.len(), 10);
176 assert_eq!(recorder.time_history.len(), 10);
177 }
178
179 #[test]
180 fn test_trajectory_stats() {
181 let mut recorder = TrajectoryRecorder::new();
182
183 for i in 0..5 {
184 let state = make_test_state(i as f64);
185 recorder.record(&state);
186 }
187
188 let stats = recorder.stats();
189 assert_eq!(stats.nsteps, 5);
190 assert_eq!(stats.nq, 2);
191 assert_eq!(stats.nv, 2);
192 assert_eq!(stats.duration, 4.0);
193 }
194
195 #[test]
196 fn test_trajectory_to_json() {
197 let mut recorder = TrajectoryRecorder::new();
198
199 for i in 0..3 {
200 let state = make_test_state(i as f64);
201 recorder.record(&state);
202 }
203
204 let json = recorder.to_json();
205 assert!(json.is_ok());
206 let json_str = json.unwrap();
207 assert!(json_str.contains("\"q\""));
208 assert!(json_str.contains("\"v\""));
209 assert!(json_str.contains("\"ctrl\""));
210 }
211
212 #[test]
213 fn test_trajectory_to_flat_dict() {
214 let mut recorder = TrajectoryRecorder::new();
215
216 for i in 0..3 {
217 let state = make_test_state(i as f64);
218 recorder.record(&state);
219 }
220
221 let dict = recorder.to_flat_dict();
222 assert_eq!(dict.get("q").unwrap().len(), 6); assert_eq!(dict.get("v").unwrap().len(), 6);
224 assert_eq!(dict.get("time").unwrap().len(), 3);
225 }
226
227 #[test]
228 fn test_trajectory_clear() {
229 let mut recorder = TrajectoryRecorder::new();
230 recorder.record(&make_test_state(0.0));
231 recorder.record(&make_test_state(1.0));
232
233 assert_eq!(recorder.len(), 2);
234
235 recorder.clear();
236 assert_eq!(recorder.len(), 0);
237 assert!(recorder.is_empty());
238 }
239}