Skip to main content

dsfb_robotics/
balancing.rs

1//! Shared residual helper for balancing datasets (MIT Cheetah 3 /
2//! Mini-Cheetah, IIT iCub push-recovery).
3//!
4//! Balancing platforms expose a **dual-channel** residual:
5//!
6//! - `r_F(k) = F_contact,measured(k) − F_contact,planned(k)` — the
7//!   whole-body controller's planned ground-reaction / contact-wrench
8//!   minus what was actually realised. Rolled forward into the next
9//!   MPC horizon and discarded as "tracking error".
10//!
11//! - `r_ξ(k) = ξ_measured(k) − ξ_model(k)` — the centroidal-momentum
12//!   (or full-body centre-of-mass) observer discrepancy between the
13//!   IMU-fused estimate and the rigid-body model prediction. Fused
14//!   into the state estimate and discarded once consumed.
15//!
16//! DSFB ingests both channels through the same `observe()` core by
17//! combining them into a single scalar residual norm. This module
18//! provides the combiner; the per-dataset adapters (Phase 3) supply
19//! the raw channels from their respective controllers.
20
21use crate::math;
22
23/// Channel-combination strategy for the dual balancing residual.
24///
25/// `SumOfSquares` is the default and most conservative: treats both
26/// channels as equally important and produces the Euclidean norm.
27/// `WeightedSum` allows the caller to bias one channel more — for
28/// example, weighting `r_F` higher during stance and `r_ξ` higher
29/// during swing.
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum BalancingCombine {
32    /// `sqrt(r_F² + r_ξ²)` — Euclidean combination, equal weights.
33    SumOfSquares,
34    /// `sqrt(w_F · r_F² + w_xi · r_ξ²)` — weighted Euclidean.
35    ///
36    /// Both weights must be non-negative. A zero weight suppresses
37    /// the corresponding channel; typical values are `w_F = 1.0`,
38    /// `w_xi = 1.0` for parity with `SumOfSquares`.
39    WeightedSum {
40        /// Weight on the contact-force residual channel `r_F`.
41        w_force: f64,
42        /// Weight on the centroidal-momentum residual channel `r_ξ`.
43        w_xi: f64,
44    },
45}
46
47/// Combine a force-residual channel and a centroidal-momentum channel
48/// into a single scalar residual norm.
49///
50/// Returns `None` if both inputs are non-finite or the strategy is
51/// mis-configured. Missingness-aware: a finite channel combined with
52/// a non-finite one degrades to the finite channel's magnitude.
53#[must_use]
54pub fn combine_channels(
55    r_force: f64,
56    r_xi: f64,
57    strategy: BalancingCombine,
58) -> Option<f64> {
59    let f_finite = r_force.is_finite();
60    let x_finite = r_xi.is_finite();
61    if !f_finite && !x_finite {
62        return None;
63    }
64    debug_assert!(f_finite || x_finite, "guarded above: at least one channel is finite");
65    let rf = if f_finite { r_force } else { 0.0 };
66    let rx = if x_finite { r_xi } else { 0.0 };
67    debug_assert!(rf.is_finite() && rx.is_finite(), "post-degrade rf/rx must be finite");
68
69    let ssq = match strategy {
70        BalancingCombine::SumOfSquares => rf * rf + rx * rx,
71        BalancingCombine::WeightedSum { w_force, w_xi } => {
72            if w_force < 0.0 || w_xi < 0.0 || !w_force.is_finite() || !w_xi.is_finite() {
73                return None;
74            }
75            debug_assert!(w_force >= 0.0 && w_xi >= 0.0, "weights validated above");
76            w_force * rf * rf + w_xi * rx * rx
77        }
78    };
79    debug_assert!(ssq >= 0.0, "sum-of-squares is non-negative by construction");
80
81    math::sqrt_f64(ssq)
82}
83
84/// Vectorised variant: produce a streaming residual-norm sequence from
85/// two aligned channel slices. Returns the number of finite samples
86/// written into `out` (never exceeds `out.len()`). Non-finite entries
87/// in either channel become zero in the combined residual (per
88/// [`combine_channels`] missingness rule).
89pub fn combine_stream(
90    r_force: &[f64],
91    r_xi: &[f64],
92    out: &mut [f64],
93    strategy: BalancingCombine,
94) -> usize {
95    debug_assert!(r_force.len() == r_xi.len(), "channels must have equal length");
96    debug_assert!(!out.is_empty() || r_force.is_empty(), "non-empty output requires non-empty input");
97    let n = r_force.len().min(r_xi.len()).min(out.len());
98    debug_assert!(n <= out.len(), "n must respect destination capacity");
99    let mut i = 0_usize;
100    while i < n {
101        let combined = combine_channels(r_force[i], r_xi[i], strategy).unwrap_or(0.0);
102        debug_assert!(combined.is_finite(), "combined residual must be finite (non-finite inputs degrade to 0)");
103        out[i] = combined;
104        i += 1;
105    }
106    n
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn sum_of_squares_zero_inputs_zero_out() {
115        let r = combine_channels(0.0, 0.0, BalancingCombine::SumOfSquares).expect("finite");
116        assert!(r.abs() < 1e-12);
117    }
118
119    #[test]
120    fn sum_of_squares_3_4_5_triangle() {
121        let r = combine_channels(3.0, 4.0, BalancingCombine::SumOfSquares).expect("finite");
122        assert!((r - 5.0).abs() < 1e-12);
123    }
124
125    #[test]
126    fn weighted_sum_zero_weight_suppresses_channel() {
127        let r = combine_channels(
128            10.0,
129            0.1,
130            BalancingCombine::WeightedSum { w_force: 0.0, w_xi: 1.0 },
131        )
132        .expect("finite");
133        assert!((r - 0.1).abs() < 1e-6, "force channel zero-weighted → only xi shows through");
134    }
135
136    #[test]
137    fn weighted_sum_rejects_negative_weights() {
138        let r = combine_channels(
139            1.0,
140            1.0,
141            BalancingCombine::WeightedSum { w_force: -1.0, w_xi: 1.0 },
142        );
143        assert!(r.is_none());
144    }
145
146    #[test]
147    fn both_non_finite_is_none() {
148        assert!(combine_channels(f64::NAN, f64::NAN, BalancingCombine::SumOfSquares).is_none());
149    }
150
151    #[test]
152    fn one_non_finite_degrades_to_other() {
153        let r = combine_channels(3.0, f64::NAN, BalancingCombine::SumOfSquares).expect("finite");
154        assert!((r - 3.0).abs() < 1e-12);
155    }
156
157    #[test]
158    fn stream_aligns_and_respects_capacity() {
159        let rf = [3.0, 0.0, 1.0, 2.0];
160        let rx = [4.0, 0.0, 0.0, 2.0];
161        let mut out = [0.0_f64; 3];
162        let n = combine_stream(&rf, &rx, &mut out, BalancingCombine::SumOfSquares);
163        assert_eq!(n, 3);
164        assert!((out[0] - 5.0).abs() < 1e-12);
165        assert!(out[1].abs() < 1e-12);
166        assert!((out[2] - 1.0).abs() < 1e-12);
167    }
168}