Skip to main content

phyz_world/
trajectory.rs

1//! Trajectory recording and export for machine learning training data.
2
3use crate::SensorOutput;
4use std::collections::HashMap;
5use phyz_model::State;
6
7/// Records trajectories of (q, v, ctrl, time) for ML training.
8pub struct TrajectoryRecorder {
9    /// Recorded positions at each timestep.
10    pub q_history: Vec<Vec<f64>>,
11    /// Recorded velocities at each timestep.
12    pub v_history: Vec<Vec<f64>>,
13    /// Recorded control inputs at each timestep.
14    pub ctrl_history: Vec<Vec<f64>>,
15    /// Timestamps for each step.
16    pub time_history: Vec<f64>,
17    /// Sensor outputs at each timestep.
18    pub sensor_history: Vec<Vec<SensorOutput>>,
19}
20
21impl TrajectoryRecorder {
22    /// Create a new empty trajectory recorder.
23    pub fn new() -> Self {
24        Self {
25            q_history: Vec::new(),
26            v_history: Vec::new(),
27            ctrl_history: Vec::new(),
28            time_history: Vec::new(),
29            sensor_history: Vec::new(),
30        }
31    }
32
33    /// Record the current state.
34    pub fn record(&mut self, state: &State) {
35        self.q_history.push(state.q.as_slice().to_vec());
36        self.v_history.push(state.v.as_slice().to_vec());
37        self.ctrl_history.push(state.ctrl.as_slice().to_vec());
38        self.time_history.push(state.time);
39    }
40
41    /// Record state with sensor outputs.
42    pub fn record_with_sensors(&mut self, state: &State, sensors: Vec<SensorOutput>) {
43        self.record(state);
44        self.sensor_history.push(sensors);
45    }
46
47    /// Number of timesteps recorded.
48    pub fn len(&self) -> usize {
49        self.time_history.len()
50    }
51
52    /// Check if recorder is empty.
53    pub fn is_empty(&self) -> bool {
54        self.time_history.is_empty()
55    }
56
57    /// Clear all recorded data.
58    pub fn clear(&mut self) {
59        self.q_history.clear();
60        self.v_history.clear();
61        self.ctrl_history.clear();
62        self.time_history.clear();
63        self.sensor_history.clear();
64    }
65
66    /// Export to JSON string.
67    pub fn to_json(&self) -> Result<String, serde_json::Error> {
68        let mut data = HashMap::new();
69        data.insert("q", &self.q_history);
70        data.insert("v", &self.v_history);
71        data.insert("ctrl", &self.ctrl_history);
72
73        // Convert time history to nested vec for consistency
74        let time_nested: Vec<Vec<f64>> = self.time_history.iter().map(|&t| vec![t]).collect();
75        data.insert("time", &time_nested);
76
77        serde_json::to_string_pretty(&data)
78    }
79
80    /// Export to JSON file.
81    pub fn to_json_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
82        let json = self.to_json()?;
83        std::fs::write(path, json)?;
84        Ok(())
85    }
86
87    /// Convert to a dictionary-like structure suitable for numpy/ML frameworks.
88    ///
89    /// Returns:
90    /// - "q": (nsteps, nq) flattened data
91    /// - "v": (nsteps, nv) flattened data
92    /// - "ctrl": (nsteps, nv) flattened data
93    /// - "time": (nsteps,) flattened data
94    pub fn to_flat_dict(&self) -> HashMap<String, Vec<f64>> {
95        let mut dict = HashMap::new();
96
97        // Flatten q, v, ctrl
98        let q_flat: Vec<f64> = self.q_history.iter().flatten().copied().collect();
99        let v_flat: Vec<f64> = self.v_history.iter().flatten().copied().collect();
100        let ctrl_flat: Vec<f64> = self.ctrl_history.iter().flatten().copied().collect();
101
102        dict.insert("q".to_string(), q_flat);
103        dict.insert("v".to_string(), v_flat);
104        dict.insert("ctrl".to_string(), ctrl_flat);
105        dict.insert("time".to_string(), self.time_history.clone());
106
107        dict
108    }
109
110    /// Get trajectory statistics.
111    pub fn stats(&self) -> TrajectoryStats {
112        if self.is_empty() {
113            return TrajectoryStats::default();
114        }
115
116        let nsteps = self.len();
117        let nq = self.q_history[0].len();
118        let nv = self.v_history[0].len();
119        let duration = self.time_history.last().unwrap() - self.time_history[0];
120
121        TrajectoryStats {
122            nsteps,
123            nq,
124            nv,
125            duration,
126        }
127    }
128}
129
130impl Default for TrajectoryRecorder {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136/// Statistics about a recorded trajectory.
137#[derive(Debug, Clone, Default)]
138pub struct TrajectoryStats {
139    /// Number of timesteps.
140    pub nsteps: usize,
141    /// Number of position DOFs.
142    pub nq: usize,
143    /// Number of velocity DOFs.
144    pub nv: usize,
145    /// Total duration (seconds).
146    pub duration: f64,
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use phyz_math::DVec;
153
154    fn make_test_state(t: f64) -> State {
155        let mut state = State::new(2, 2, 1);
156        state.q = DVec::from_vec(vec![t, t * 2.0]);
157        state.v = DVec::from_vec(vec![t * 3.0, t * 4.0]);
158        state.ctrl = DVec::from_vec(vec![t * 5.0, t * 6.0]);
159        state.time = t;
160        state
161    }
162
163    #[test]
164    fn test_trajectory_recording() {
165        let mut recorder = TrajectoryRecorder::new();
166
167        for i in 0..10 {
168            let t = i as f64 * 0.1;
169            let state = make_test_state(t);
170            recorder.record(&state);
171        }
172
173        assert_eq!(recorder.len(), 10);
174        assert_eq!(recorder.q_history.len(), 10);
175        assert_eq!(recorder.v_history.len(), 10);
176        assert_eq!(recorder.time_history.len(), 10);
177    }
178
179    #[test]
180    fn test_trajectory_stats() {
181        let mut recorder = TrajectoryRecorder::new();
182
183        for i in 0..5 {
184            let state = make_test_state(i as f64);
185            recorder.record(&state);
186        }
187
188        let stats = recorder.stats();
189        assert_eq!(stats.nsteps, 5);
190        assert_eq!(stats.nq, 2);
191        assert_eq!(stats.nv, 2);
192        assert_eq!(stats.duration, 4.0);
193    }
194
195    #[test]
196    fn test_trajectory_to_json() {
197        let mut recorder = TrajectoryRecorder::new();
198
199        for i in 0..3 {
200            let state = make_test_state(i as f64);
201            recorder.record(&state);
202        }
203
204        let json = recorder.to_json();
205        assert!(json.is_ok());
206        let json_str = json.unwrap();
207        assert!(json_str.contains("\"q\""));
208        assert!(json_str.contains("\"v\""));
209        assert!(json_str.contains("\"ctrl\""));
210    }
211
212    #[test]
213    fn test_trajectory_to_flat_dict() {
214        let mut recorder = TrajectoryRecorder::new();
215
216        for i in 0..3 {
217            let state = make_test_state(i as f64);
218            recorder.record(&state);
219        }
220
221        let dict = recorder.to_flat_dict();
222        assert_eq!(dict.get("q").unwrap().len(), 6); // 3 steps * 2 DOF
223        assert_eq!(dict.get("v").unwrap().len(), 6);
224        assert_eq!(dict.get("time").unwrap().len(), 3);
225    }
226
227    #[test]
228    fn test_trajectory_clear() {
229        let mut recorder = TrajectoryRecorder::new();
230        recorder.record(&make_test_state(0.0));
231        recorder.record(&make_test_state(1.0));
232
233        assert_eq!(recorder.len(), 2);
234
235        recorder.clear();
236        assert_eq!(recorder.len(), 0);
237        assert!(recorder.is_empty());
238    }
239}