dsfb-hret 0.1.1

Hierarchical Residual-Envelope Trust (HRET): A deterministic extension of DSFB for grouped multi-sensor fusion with correlated disturbance handling.
Documentation
//! Hierarchical Residual-Envelope Trust (HRET) for grouped multi-sensor fusion.
//!
//! `HretObserver` maintains channel and group envelope state, converts those envelopes
//! into trust weights, and produces a convexly weighted correction vector.
//!
//! # Example
//!
//! ```rust
//! use dsfb_hret::HretObserver;
//!
//! let mut obs = HretObserver::new(
//!     3,
//!     2,
//!     vec![0, 0, 1],
//!     0.95,
//!     vec![0.9, 0.85],
//!     vec![1.0, 1.0, 1.0],
//!     vec![1.0, 1.0],
//!     vec![
//!         vec![1.0, 0.5, 0.5],
//!         vec![0.0, 1.0, 0.0],
//!     ],
//! )
//! .unwrap();
//!
//! let (delta_x, weights, s_k, s_g) = obs.update(vec![0.05, 0.12, 0.30]).unwrap();
//! assert_eq!(delta_x.len(), 2);
//! assert_eq!(weights.len(), 3);
//! assert_eq!(s_k.len(), 3);
//! assert_eq!(s_g.len(), 2);
//! ```
//!
#![allow(clippy::useless_conversion)] // False positive from PyO3-generated PyResult signature.

use ndarray::{Array1, Array2};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

const WEIGHT_SUM_EPS: f64 = 1e-12;

/// Result of a single HRET update.
///
/// The tuple components are, in order:
/// 1. fused correction `delta_x`
/// 2. normalized channel weights
/// 3. channel envelopes `s_k`
/// 4. group envelopes `s_g`
pub type HretUpdate = (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>);

/// Error returned when HRET inputs fail validation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HretError {
    message: String,
}

impl HretError {
    fn new(message: impl Into<String>) -> Self {
        Self {
            message: message.into(),
        }
    }
}

impl std::fmt::Display for HretError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(&self.message)
    }
}

impl std::error::Error for HretError {}

#[derive(Clone, Debug)]
#[pyclass]
/// Stateful HRET observer for grouped residual fusion.
///
/// The observer keeps exponentially weighted absolute residual envelopes for each
/// channel and group, then combines channel and group trust into convex fusion weights.
pub struct HretObserver {
    m: usize,
    g: usize,
    group_mapping: Array1<usize>,
    group_indices: Vec<Vec<usize>>,
    rho: f64,
    rho_g: Array1<f64>,
    beta_k: Array1<f64>,
    beta_g: Array1<f64>,
    s_k: Array1<f64>,
    s_g: Array1<f64>,
    k_k: Array2<f64>,
}

impl HretObserver {
    /// Constructs a new observer and validates all dimensions and scalar parameters.
    ///
    /// `k_k` is the fusion gain matrix with shape `(p, m)`, where `m` is the number
    /// of channels and `p` is the correction dimension.
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        m: usize,
        g: usize,
        group_mapping: Vec<usize>,
        rho: f64,
        rho_g: Vec<f64>,
        beta_k: Vec<f64>,
        beta_g: Vec<f64>,
        k_k: Vec<Vec<f64>>,
    ) -> Result<Self, HretError> {
        validate_positive("m", m)?;
        validate_positive("g", g)?;
        validate_len("group_mapping", m, group_mapping.len())?;
        validate_len("rho_g", g, rho_g.len())?;
        validate_len("beta_k", m, beta_k.len())?;
        validate_len("beta_g", g, beta_g.len())?;
        validate_forgetting_factor("rho", rho)?;
        validate_forgetting_factors("rho_g", &rho_g)?;
        validate_non_negative_finite("beta_k", &beta_k)?;
        validate_non_negative_finite("beta_g", &beta_g)?;

        let mut group_indices = vec![Vec::new(); g];
        for (channel_idx, &group_idx) in group_mapping.iter().enumerate() {
            if group_idx >= g {
                return Err(HretError::new(format!(
                    "group_mapping[{channel_idx}] = {group_idx} is out of range 0..{g}",
                )));
            }
            group_indices[group_idx].push(channel_idx);
        }

        if k_k.is_empty() {
            return Err(HretError::new("k_k must contain at least one gain row"));
        }

        let p = k_k.len();
        let mut k_k_flat = Vec::with_capacity(p * m);
        for (row_idx, row) in k_k.into_iter().enumerate() {
            validate_len(&format!("k_k[{row_idx}]"), m, row.len())?;
            for (col_idx, value) in row.into_iter().enumerate() {
                if !value.is_finite() {
                    return Err(HretError::new(format!(
                        "k_k[{row_idx}][{col_idx}] must be finite (got {value})",
                    )));
                }
                k_k_flat.push(value);
            }
        }

        let k_k = Array2::from_shape_vec((p, m), k_k_flat).map_err(|e| {
            HretError::new(format!(
                "failed to build gain matrix with shape ({p}, {m}): {e}",
            ))
        })?;

        Ok(Self {
            m,
            g,
            group_mapping: Array1::from(group_mapping),
            group_indices,
            rho,
            rho_g: Array1::from(rho_g),
            beta_k: Array1::from(beta_k),
            beta_g: Array1::from(beta_g),
            s_k: Array1::zeros(m),
            s_g: Array1::zeros(g),
            k_k,
        })
    }

    /// Applies one HRET update for the provided channel residuals.
    ///
    /// Returns the fused correction, normalized channel weights, updated channel
    /// envelopes, and updated group envelopes.
    pub fn update(&mut self, residuals: Vec<f64>) -> Result<HretUpdate, HretError> {
        validate_len("residuals", self.m, residuals.len())?;
        validate_finite("residuals", &residuals)?;

        let r_arr = Array1::from(residuals);

        // Channel envelopes (eq. 8)
        self.s_k = self.rho * &self.s_k + (1.0 - self.rho) * r_arr.mapv(f64::abs);

        // Group envelopes (eq. 11)
        for (group_idx, channels) in self.group_indices.iter().enumerate() {
            if channels.is_empty() {
                continue;
            }

            let avg_abs_r =
                channels.iter().map(|&i| r_arr[i].abs()).sum::<f64>() / channels.len() as f64;
            self.s_g[group_idx] = self.rho_g[group_idx] * self.s_g[group_idx]
                + (1.0 - self.rho_g[group_idx]) * avg_abs_r;
        }

        // Trusts (eq. 9, 12)
        let w_k =
            Array1::from_iter((0..self.m).map(|i| 1.0 / (1.0 + self.beta_k[i] * self.s_k[i])));
        let w_g =
            Array1::from_iter((0..self.g).map(|i| 1.0 / (1.0 + self.beta_g[i] * self.s_g[i])));

        // Hierarchical composition (eq. 14-15)
        let w_g_mapped =
            Array1::from_iter(self.group_mapping.iter().map(|&group_idx| w_g[group_idx]));
        let hat_w_k = &w_k * &w_g_mapped;
        let sum_hat = hat_w_k.sum();
        let tilde_w_k = if sum_hat > WEIGHT_SUM_EPS {
            hat_w_k / sum_hat
        } else {
            Array1::from_elem(self.m, 1.0 / self.m as f64)
        };

        // Fusion correction (eq. 19): Delta_x = K * (tilde_w ⊙ r)
        let weighted_r = &tilde_w_k * &r_arr;
        let delta_x = self.k_k.dot(&weighted_r);

        debug_assert!(tilde_w_k.iter().all(|&w| w >= -1e-12));
        debug_assert!((tilde_w_k.sum() - 1.0).abs() < 1e-8);

        Ok((
            delta_x.to_vec(),
            tilde_w_k.to_vec(),
            self.s_k.to_vec(),
            self.s_g.to_vec(),
        ))
    }

    /// Resets the stored channel and group envelope state to zero.
    pub fn reset_envelopes(&mut self) {
        self.s_k.fill(0.0);
        self.s_g.fill(0.0);
    }

    /// Returns the configured number of residual channels.
    pub fn channel_count(&self) -> usize {
        self.m
    }

    /// Returns the configured number of groups.
    pub fn group_count(&self) -> usize {
        self.g
    }

    /// Returns the channel-to-group mapping as a plain vector.
    pub fn group_mapping_vec(&self) -> Vec<usize> {
        self.group_mapping.to_vec()
    }
}

#[pymethods]
impl HretObserver {
    #[new]
    #[pyo3(signature = (m, g, group_mapping, rho, rho_g, beta_k, beta_g, k_k))]
    #[allow(clippy::too_many_arguments)]
    fn py_new(
        m: usize,
        g: usize,
        group_mapping: Vec<usize>,
        rho: f64,
        rho_g: Vec<f64>,
        beta_k: Vec<f64>,
        beta_g: Vec<f64>,
        k_k: Vec<Vec<f64>>,
    ) -> PyResult<Self> {
        Self::new(m, g, group_mapping, rho, rho_g, beta_k, beta_g, k_k)
            .map_err(|error| PyValueError::new_err(error.to_string()))
    }

    #[pyo3(name = "update")]
    #[allow(clippy::useless_conversion)]
    fn py_update(&mut self, residuals: Vec<f64>) -> PyResult<HretUpdate> {
        self.update(residuals)
            .map_err(|error| PyValueError::new_err(error.to_string()))
    }

    #[pyo3(name = "reset_envelopes")]
    fn py_reset_envelopes(&mut self) {
        self.reset_envelopes();
    }

    #[getter]
    fn m(&self) -> usize {
        self.channel_count()
    }

    #[getter]
    fn g(&self) -> usize {
        self.group_count()
    }

    #[getter]
    fn group_mapping(&self) -> Vec<usize> {
        self.group_mapping_vec()
    }

    fn __repr__(&self) -> String {
        format!(
            "HretObserver(m={}, g={}, p={})",
            self.m,
            self.g,
            self.k_k.nrows()
        )
    }
}

fn validate_positive(field: &str, value: usize) -> Result<(), HretError> {
    if value == 0 {
        return Err(HretError::new(format!("{field} must be > 0 (got 0)")));
    }
    Ok(())
}

fn validate_len(field: &str, expected: usize, got: usize) -> Result<(), HretError> {
    if expected != got {
        return Err(HretError::new(format!(
            "{field} length mismatch: expected {expected}, got {got}",
        )));
    }
    Ok(())
}

fn validate_forgetting_factor(field: &str, value: f64) -> Result<(), HretError> {
    if !value.is_finite() || value <= 0.0 || value >= 1.0 {
        return Err(HretError::new(format!(
            "{field} must be finite and in (0, 1); got {value}",
        )));
    }
    Ok(())
}

fn validate_forgetting_factors(field: &str, values: &[f64]) -> Result<(), HretError> {
    for (idx, value) in values.iter().copied().enumerate() {
        if !value.is_finite() || value <= 0.0 || value >= 1.0 {
            return Err(HretError::new(format!(
                "{field}[{idx}] must be finite and in (0, 1); got {value}",
            )));
        }
    }
    Ok(())
}

fn validate_non_negative_finite(field: &str, values: &[f64]) -> Result<(), HretError> {
    for (idx, value) in values.iter().copied().enumerate() {
        if !value.is_finite() || value < 0.0 {
            return Err(HretError::new(format!(
                "{field}[{idx}] must be finite and >= 0; got {value}",
            )));
        }
    }
    Ok(())
}

fn validate_finite(field: &str, values: &[f64]) -> Result<(), HretError> {
    for (idx, value) in values.iter().copied().enumerate() {
        if !value.is_finite() {
            return Err(HretError::new(format!(
                "{field}[{idx}] must be finite; got {value}",
            )));
        }
    }
    Ok(())
}

#[pymodule]
fn dsfb_hret(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_class::<HretObserver>()?;
    Ok(())
}

#[cfg(test)]
mod tests;