Skip to main content

dsfb_hret/
lib.rs

1//! Hierarchical Residual-Envelope Trust (HRET) for grouped multi-sensor fusion.
2//!
3//! `HretObserver` maintains channel and group envelope state, converts those envelopes
4//! into trust weights, and produces a convexly weighted correction vector.
5//!
6//! # Example
7//!
8//! ```rust
9//! use dsfb_hret::HretObserver;
10//!
11//! let mut obs = HretObserver::new(
12//!     3,
13//!     2,
14//!     vec![0, 0, 1],
15//!     0.95,
16//!     vec![0.9, 0.85],
17//!     vec![1.0, 1.0, 1.0],
18//!     vec![1.0, 1.0],
19//!     vec![
20//!         vec![1.0, 0.5, 0.5],
21//!         vec![0.0, 1.0, 0.0],
22//!     ],
23//! )
24//! .unwrap();
25//!
26//! let (delta_x, weights, s_k, s_g) = obs.update(vec![0.05, 0.12, 0.30]).unwrap();
27//! assert_eq!(delta_x.len(), 2);
28//! assert_eq!(weights.len(), 3);
29//! assert_eq!(s_k.len(), 3);
30//! assert_eq!(s_g.len(), 2);
31//! ```
32//!
33#![allow(clippy::useless_conversion)] // False positive from PyO3-generated PyResult signature.
34
35use ndarray::{Array1, Array2};
36use pyo3::exceptions::PyValueError;
37use pyo3::prelude::*;
38
39const WEIGHT_SUM_EPS: f64 = 1e-12;
40
41/// Result of a single HRET update.
42///
43/// The tuple components are, in order:
44/// 1. fused correction `delta_x`
45/// 2. normalized channel weights
46/// 3. channel envelopes `s_k`
47/// 4. group envelopes `s_g`
48pub type HretUpdate = (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>);
49
50/// Error returned when HRET inputs fail validation.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct HretError {
53    message: String,
54}
55
56impl HretError {
57    fn new(message: impl Into<String>) -> Self {
58        Self {
59            message: message.into(),
60        }
61    }
62}
63
64impl std::fmt::Display for HretError {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.write_str(&self.message)
67    }
68}
69
70impl std::error::Error for HretError {}
71
72#[derive(Clone, Debug)]
73#[pyclass]
74/// Stateful HRET observer for grouped residual fusion.
75///
76/// The observer keeps exponentially weighted absolute residual envelopes for each
77/// channel and group, then combines channel and group trust into convex fusion weights.
78pub struct HretObserver {
79    m: usize,
80    g: usize,
81    group_mapping: Array1<usize>,
82    group_indices: Vec<Vec<usize>>,
83    rho: f64,
84    rho_g: Array1<f64>,
85    beta_k: Array1<f64>,
86    beta_g: Array1<f64>,
87    s_k: Array1<f64>,
88    s_g: Array1<f64>,
89    k_k: Array2<f64>,
90}
91
92impl HretObserver {
93    /// Constructs a new observer and validates all dimensions and scalar parameters.
94    ///
95    /// `k_k` is the fusion gain matrix with shape `(p, m)`, where `m` is the number
96    /// of channels and `p` is the correction dimension.
97    #[allow(clippy::too_many_arguments)]
98    pub fn new(
99        m: usize,
100        g: usize,
101        group_mapping: Vec<usize>,
102        rho: f64,
103        rho_g: Vec<f64>,
104        beta_k: Vec<f64>,
105        beta_g: Vec<f64>,
106        k_k: Vec<Vec<f64>>,
107    ) -> Result<Self, HretError> {
108        validate_positive("m", m)?;
109        validate_positive("g", g)?;
110        validate_len("group_mapping", m, group_mapping.len())?;
111        validate_len("rho_g", g, rho_g.len())?;
112        validate_len("beta_k", m, beta_k.len())?;
113        validate_len("beta_g", g, beta_g.len())?;
114        validate_forgetting_factor("rho", rho)?;
115        validate_forgetting_factors("rho_g", &rho_g)?;
116        validate_non_negative_finite("beta_k", &beta_k)?;
117        validate_non_negative_finite("beta_g", &beta_g)?;
118
119        let mut group_indices = vec![Vec::new(); g];
120        for (channel_idx, &group_idx) in group_mapping.iter().enumerate() {
121            if group_idx >= g {
122                return Err(HretError::new(format!(
123                    "group_mapping[{channel_idx}] = {group_idx} is out of range 0..{g}",
124                )));
125            }
126            group_indices[group_idx].push(channel_idx);
127        }
128
129        if k_k.is_empty() {
130            return Err(HretError::new("k_k must contain at least one gain row"));
131        }
132
133        let p = k_k.len();
134        let mut k_k_flat = Vec::with_capacity(p * m);
135        for (row_idx, row) in k_k.into_iter().enumerate() {
136            validate_len(&format!("k_k[{row_idx}]"), m, row.len())?;
137            for (col_idx, value) in row.into_iter().enumerate() {
138                if !value.is_finite() {
139                    return Err(HretError::new(format!(
140                        "k_k[{row_idx}][{col_idx}] must be finite (got {value})",
141                    )));
142                }
143                k_k_flat.push(value);
144            }
145        }
146
147        let k_k = Array2::from_shape_vec((p, m), k_k_flat).map_err(|e| {
148            HretError::new(format!(
149                "failed to build gain matrix with shape ({p}, {m}): {e}",
150            ))
151        })?;
152
153        Ok(Self {
154            m,
155            g,
156            group_mapping: Array1::from(group_mapping),
157            group_indices,
158            rho,
159            rho_g: Array1::from(rho_g),
160            beta_k: Array1::from(beta_k),
161            beta_g: Array1::from(beta_g),
162            s_k: Array1::zeros(m),
163            s_g: Array1::zeros(g),
164            k_k,
165        })
166    }
167
168    /// Applies one HRET update for the provided channel residuals.
169    ///
170    /// Returns the fused correction, normalized channel weights, updated channel
171    /// envelopes, and updated group envelopes.
172    pub fn update(&mut self, residuals: Vec<f64>) -> Result<HretUpdate, HretError> {
173        validate_len("residuals", self.m, residuals.len())?;
174        validate_finite("residuals", &residuals)?;
175
176        let r_arr = Array1::from(residuals);
177
178        // Channel envelopes (eq. 8)
179        self.s_k = self.rho * &self.s_k + (1.0 - self.rho) * r_arr.mapv(f64::abs);
180
181        // Group envelopes (eq. 11)
182        for (group_idx, channels) in self.group_indices.iter().enumerate() {
183            if channels.is_empty() {
184                continue;
185            }
186
187            let avg_abs_r =
188                channels.iter().map(|&i| r_arr[i].abs()).sum::<f64>() / channels.len() as f64;
189            self.s_g[group_idx] = self.rho_g[group_idx] * self.s_g[group_idx]
190                + (1.0 - self.rho_g[group_idx]) * avg_abs_r;
191        }
192
193        // Trusts (eq. 9, 12)
194        let w_k =
195            Array1::from_iter((0..self.m).map(|i| 1.0 / (1.0 + self.beta_k[i] * self.s_k[i])));
196        let w_g =
197            Array1::from_iter((0..self.g).map(|i| 1.0 / (1.0 + self.beta_g[i] * self.s_g[i])));
198
199        // Hierarchical composition (eq. 14-15)
200        let w_g_mapped =
201            Array1::from_iter(self.group_mapping.iter().map(|&group_idx| w_g[group_idx]));
202        let hat_w_k = &w_k * &w_g_mapped;
203        let sum_hat = hat_w_k.sum();
204        let tilde_w_k = if sum_hat > WEIGHT_SUM_EPS {
205            hat_w_k / sum_hat
206        } else {
207            Array1::from_elem(self.m, 1.0 / self.m as f64)
208        };
209
210        // Fusion correction (eq. 19): Delta_x = K * (tilde_w ⊙ r)
211        let weighted_r = &tilde_w_k * &r_arr;
212        let delta_x = self.k_k.dot(&weighted_r);
213
214        debug_assert!(tilde_w_k.iter().all(|&w| w >= -1e-12));
215        debug_assert!((tilde_w_k.sum() - 1.0).abs() < 1e-8);
216
217        Ok((
218            delta_x.to_vec(),
219            tilde_w_k.to_vec(),
220            self.s_k.to_vec(),
221            self.s_g.to_vec(),
222        ))
223    }
224
225    /// Resets the stored channel and group envelope state to zero.
226    pub fn reset_envelopes(&mut self) {
227        self.s_k.fill(0.0);
228        self.s_g.fill(0.0);
229    }
230
231    /// Returns the configured number of residual channels.
232    pub fn channel_count(&self) -> usize {
233        self.m
234    }
235
236    /// Returns the configured number of groups.
237    pub fn group_count(&self) -> usize {
238        self.g
239    }
240
241    /// Returns the channel-to-group mapping as a plain vector.
242    pub fn group_mapping_vec(&self) -> Vec<usize> {
243        self.group_mapping.to_vec()
244    }
245}
246
247#[pymethods]
248impl HretObserver {
249    #[new]
250    #[pyo3(signature = (m, g, group_mapping, rho, rho_g, beta_k, beta_g, k_k))]
251    #[allow(clippy::too_many_arguments)]
252    fn py_new(
253        m: usize,
254        g: usize,
255        group_mapping: Vec<usize>,
256        rho: f64,
257        rho_g: Vec<f64>,
258        beta_k: Vec<f64>,
259        beta_g: Vec<f64>,
260        k_k: Vec<Vec<f64>>,
261    ) -> PyResult<Self> {
262        Self::new(m, g, group_mapping, rho, rho_g, beta_k, beta_g, k_k)
263            .map_err(|error| PyValueError::new_err(error.to_string()))
264    }
265
266    #[pyo3(name = "update")]
267    #[allow(clippy::useless_conversion)]
268    fn py_update(&mut self, residuals: Vec<f64>) -> PyResult<HretUpdate> {
269        self.update(residuals)
270            .map_err(|error| PyValueError::new_err(error.to_string()))
271    }
272
273    #[pyo3(name = "reset_envelopes")]
274    fn py_reset_envelopes(&mut self) {
275        self.reset_envelopes();
276    }
277
278    #[getter]
279    fn m(&self) -> usize {
280        self.channel_count()
281    }
282
283    #[getter]
284    fn g(&self) -> usize {
285        self.group_count()
286    }
287
288    #[getter]
289    fn group_mapping(&self) -> Vec<usize> {
290        self.group_mapping_vec()
291    }
292
293    fn __repr__(&self) -> String {
294        format!(
295            "HretObserver(m={}, g={}, p={})",
296            self.m,
297            self.g,
298            self.k_k.nrows()
299        )
300    }
301}
302
303fn validate_positive(field: &str, value: usize) -> Result<(), HretError> {
304    if value == 0 {
305        return Err(HretError::new(format!("{field} must be > 0 (got 0)")));
306    }
307    Ok(())
308}
309
310fn validate_len(field: &str, expected: usize, got: usize) -> Result<(), HretError> {
311    if expected != got {
312        return Err(HretError::new(format!(
313            "{field} length mismatch: expected {expected}, got {got}",
314        )));
315    }
316    Ok(())
317}
318
319fn validate_forgetting_factor(field: &str, value: f64) -> Result<(), HretError> {
320    if !value.is_finite() || value <= 0.0 || value >= 1.0 {
321        return Err(HretError::new(format!(
322            "{field} must be finite and in (0, 1); got {value}",
323        )));
324    }
325    Ok(())
326}
327
328fn validate_forgetting_factors(field: &str, values: &[f64]) -> Result<(), HretError> {
329    for (idx, value) in values.iter().copied().enumerate() {
330        if !value.is_finite() || value <= 0.0 || value >= 1.0 {
331            return Err(HretError::new(format!(
332                "{field}[{idx}] must be finite and in (0, 1); got {value}",
333            )));
334        }
335    }
336    Ok(())
337}
338
339fn validate_non_negative_finite(field: &str, values: &[f64]) -> Result<(), HretError> {
340    for (idx, value) in values.iter().copied().enumerate() {
341        if !value.is_finite() || value < 0.0 {
342            return Err(HretError::new(format!(
343                "{field}[{idx}] must be finite and >= 0; got {value}",
344            )));
345        }
346    }
347    Ok(())
348}
349
350fn validate_finite(field: &str, values: &[f64]) -> Result<(), HretError> {
351    for (idx, value) in values.iter().copied().enumerate() {
352        if !value.is_finite() {
353            return Err(HretError::new(format!(
354                "{field}[{idx}] must be finite; got {value}",
355            )));
356        }
357    }
358    Ok(())
359}
360
361#[pymodule]
362fn dsfb_hret(m: &Bound<'_, PyModule>) -> PyResult<()> {
363    m.add_class::<HretObserver>()?;
364    Ok(())
365}
366
367#[cfg(test)]
368mod tests;