Skip to main content

nabled_sim/
pipeline.rs

1//! Opt-in multi-stage Physical AI pipelines.
2
3use nabled_core::scalar::NabledReal;
4use nabled_linalg::lu::LuProviderScalar;
5use nabled_ml::stats::rolling::rolling_covariance;
6use ndarray::{Array1, Array2};
7
8use crate::SimError;
9use crate::context::RobotContext;
10use crate::sim::{SimConfig, SimState, semi_implicit_step};
11
12/// Logged torque rows from a simulation run.
13#[derive(Debug, Clone, PartialEq)]
14pub struct TorqueLog<T> {
15    pub samples: Array2<T>,
16}
17
18impl<T: NabledReal> TorqueLog<T> {
19    #[must_use]
20    pub fn len(&self) -> usize { self.samples.nrows() }
21
22    #[must_use]
23    pub fn is_empty(&self) -> bool { self.samples.nrows() == 0 }
24}
25
26/// Builder for cross-crate workflows (sim → stats, etc.).
27#[derive(Debug, Clone, PartialEq)]
28pub struct PhysicalAiPipeline<T> {
29    pub ctx:               RobotContext<T>,
30    pub sim_config:        SimConfig<T>,
31    pub torque_log_window: usize,
32}
33
34impl<T: NabledReal + Default + LuProviderScalar> PhysicalAiPipeline<T> {
35    #[must_use]
36    pub fn new(ctx: RobotContext<T>, sim_config: SimConfig<T>, torque_log_window: usize) -> Self {
37        Self { ctx, sim_config, torque_log_window }
38    }
39
40    /// Run `steps` semi-implicit steps, logging applied joint torques each step.
41    pub fn run_sim_with_torque_log(
42        &self,
43        initial: &SimState<T>,
44        tau_fn: impl Fn(usize) -> Array1<T>,
45        steps: usize,
46    ) -> Result<(SimState<T>, TorqueLog<T>), SimError> {
47        self.ctx.validate()?;
48        let dof = self.ctx.chain.num_joints();
49        let mut state = initial.clone();
50        let mut log = Array2::zeros((steps, dof));
51        for step in 0..steps {
52            let tau = tau_fn(step);
53            if tau.len() != dof {
54                return Err(SimError::DimensionMismatch);
55            }
56            log.row_mut(step).assign(&tau);
57            let result = semi_implicit_step(&self.ctx, &state, &tau.view(), &self.sim_config)?;
58            state = result.state;
59        }
60        Ok((state, TorqueLog { samples: log }))
61    }
62
63    /// Rolling covariance over logged torques (compose-down to `nabled-ml::stats`).
64    pub fn rolling_torque_covariance(log: &TorqueLog<T>, window: usize) -> Array2<T> {
65        rolling_covariance(&log.samples.view(), window)
66    }
67
68    /// Convenience: simulate, log torques, and return rolling covariance.
69    pub fn sim_torque_rolling_covariance(
70        &self,
71        initial: &SimState<T>,
72        tau_fn: impl Fn(usize) -> Array1<T>,
73        steps: usize,
74    ) -> Result<(SimState<T>, Array2<T>), SimError> {
75        let (state, log) = self.run_sim_with_torque_log(initial, tau_fn, steps)?;
76        let cov = Self::rolling_torque_covariance(&log, self.torque_log_window);
77        Ok((state, cov))
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use nabled_dynamics::DynamicsConfig;
84    use nabled_model::fixture::load_planar2r_json;
85    use ndarray::{Array1, arr1};
86
87    use super::*;
88    use crate::context::RobotContext;
89
90    #[test]
91    #[expect(clippy::cast_possible_truncation)]
92    fn sim_torque_pipeline_produces_bounded_covariance() {
93        let fixture = load_planar2r_json().expect("fixture");
94        let ctx = RobotContext::new(
95            fixture.to_robot_model::<f64>().expect("model"),
96            fixture.to_chain_spec::<f64>().expect("chain"),
97            DynamicsConfig {
98                gravity: fixture.gravity.unwrap_or([0.0, -9.81, 0.0]),
99                ..DynamicsConfig::default()
100            },
101        );
102        let pipeline = PhysicalAiPipeline::new(ctx, SimConfig::new(0.01), 5);
103        let initial = SimState::new(arr1(&[0.2, 0.4]), Array1::zeros(2));
104        let (_, cov) = pipeline
105            .sim_torque_rolling_covariance(
106                &initial,
107                |step| {
108                    let t = f64::from(step as u32) * 0.01;
109                    arr1(&[0.5 * t.sin(), 0.2 * t.cos()])
110                },
111                20,
112            )
113            .expect("pipeline");
114        assert_eq!(cov.nrows(), 20);
115        let last = cov.row(19);
116        assert!(last.iter().all(|v| v.is_finite()));
117        assert!(last.iter().any(|v| v.abs() > 0.0) || cov[[19, 3]].abs() > 0.0);
118    }
119}