#[derive(Debug, Clone, Copy, PartialEq)]
pub struct HretParams {
pub channel_rho: f32,
pub group_rho: f32,
pub beta_channel: f32,
pub beta_group: f32,
}
impl HretParams {
pub const fn default_sdr() -> Self {
Self {
channel_rho: 0.95,
group_rho: 0.97,
beta_channel: 10.0,
beta_group: 10.0,
}
}
pub fn from_sigma(sigma0: f32, channel_rho: f32, group_rho: f32) -> Self {
let beta = if sigma0 > 1e-12 { 1.0 / sigma0 } else { 10.0 };
Self {
channel_rho,
group_rho,
beta_channel: beta,
beta_group: beta,
}
}
}
impl Default for HretParams {
fn default() -> Self { Self::default_sdr() }
}
#[derive(Debug, Clone, Copy)]
pub struct ChannelState {
pub envelope: f32,
pub trust_weight: f32,
}
impl Default for ChannelState {
fn default() -> Self {
Self { envelope: 0.0, trust_weight: 1.0 }
}
}
#[derive(Debug, Clone, Copy)]
pub struct GroupState {
pub envelope: f32,
pub trust_weight: f32,
pub count: u8,
}
impl Default for GroupState {
fn default() -> Self {
Self { envelope: 0.0, trust_weight: 1.0, count: 0 }
}
}
#[derive(Debug, Clone, Copy)]
pub struct HretResult<const C: usize> {
pub weights: [f32; C],
pub weighted_residual: f32,
pub max_weight: f32,
pub min_weight: f32,
pub trust_diversity: f32,
}
pub struct HretEstimator<const C: usize, const G: usize> {
channel_states: [ChannelState; C],
group_states: [GroupState; G],
group_map: [usize; C],
params: HretParams,
gain: f32,
}
impl<const C: usize, const G: usize> HretEstimator<C, G> {
pub fn new(group_map: [usize; C], params: HretParams) -> Self {
debug_assert!(
group_map.iter().all(|&g| g < G),
"group_map contains index >= G"
);
Self {
channel_states: [ChannelState::default(); C],
group_states: [GroupState::default(); G],
group_map,
params,
gain: 1.0,
}
}
pub fn single_group(params: HretParams) -> Self {
Self::new([0usize; C], params)
}
pub fn with_gain(mut self, gain: f32) -> Self {
self.gain = gain;
self
}
pub fn observe(&mut self, residuals: &[f32; C]) -> HretResult<C> {
self.update_group_envelopes(residuals);
self.update_channel_envelopes(residuals);
let weights = self.compose_normalised_weights();
let weighted_residual = self.gain * dot_product_c(&weights, residuals);
let (max_w, min_w) = weight_extrema(&weights);
HretResult {
weights,
weighted_residual,
max_weight: max_w,
min_weight: min_w,
trust_diversity: 1.0 - (max_w - min_w),
}
}
fn update_group_envelopes(&mut self, residuals: &[f32; C]) {
let mut group_sum = [0.0_f32; G];
let mut group_cnt = [0_u32; G];
for (k, &r) in residuals.iter().enumerate() {
let g = self.group_map[k].min(G - 1);
group_sum[g] += r.abs();
group_cnt[g] += 1;
}
let rho_gr = self.params.group_rho;
let beta_gr = self.params.beta_group;
for g in 0..G {
let mean_abs = if group_cnt[g] > 0 { group_sum[g] / group_cnt[g] as f32 } else { 0.0 };
let s = &mut self.group_states[g].envelope;
*s = rho_gr * (*s) + (1.0 - rho_gr) * mean_abs;
self.group_states[g].trust_weight = 1.0 / (1.0 + beta_gr * self.group_states[g].envelope);
}
}
fn update_channel_envelopes(&mut self, residuals: &[f32; C]) {
let rho_ch = self.params.channel_rho;
let beta_ch = self.params.beta_channel;
for (k, &r) in residuals.iter().enumerate() {
let s = &mut self.channel_states[k].envelope;
*s = rho_ch * (*s) + (1.0 - rho_ch) * r.abs();
self.channel_states[k].trust_weight = 1.0 / (1.0 + beta_ch * self.channel_states[k].envelope);
}
}
fn compose_normalised_weights(&self) -> [f32; C] {
let mut hat_w = [0.0_f32; C];
for k in 0..C {
let g = self.group_map[k].min(G - 1);
hat_w[k] = self.channel_states[k].trust_weight * self.group_states[g].trust_weight;
}
let sum_hat: f32 = hat_w.iter().sum();
let mut weights = [0.0_f32; C];
if sum_hat > 1e-30 {
for k in 0..C { weights[k] = hat_w[k] / sum_hat; }
} else {
let unif = 1.0 / C as f32;
for k in 0..C { weights[k] = unif; }
}
weights
}
#[inline]
pub fn channel_states(&self) -> &[ChannelState; C] { &self.channel_states }
#[inline]
pub fn group_states(&self) -> &[GroupState; G] { &self.group_states }
pub fn channel_trust(&self, k: usize) -> f32 {
self.channel_states.get(k).map(|s| s.trust_weight).unwrap_or(0.0)
}
pub fn reset(&mut self) {
for s in &mut self.channel_states { *s = ChannelState::default(); }
for s in &mut self.group_states { *s = GroupState::default(); }
}
}
fn dot_product_c<const C: usize>(a: &[f32; C], b: &[f32; C]) -> f32 {
let mut d = 0.0_f32;
for k in 0..C { d += a[k] * b[k]; }
d
}
fn weight_extrema<const C: usize>(weights: &[f32; C]) -> (f32, f32) {
let mut max_w = weights[0];
let mut min_w = weights[0];
for k in 1..C {
if weights[k] > max_w { max_w = weights[k]; }
if weights[k] < min_w { min_w = weights[k]; }
}
(max_w, min_w)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_group_uniform_channels() {
let mut h = HretEstimator::<4, 1>::single_group(HretParams::default_sdr());
for _ in 0..50 {
let r = h.observe(&[0.1, 0.1, 0.1, 0.1]);
let _ = r;
}
let r = h.observe(&[0.1, 0.1, 0.1, 0.1]);
for k in 0..4 {
let diff = (r.weights[k] - 0.25).abs();
assert!(diff < 0.01, "weight[{}]={} (expected ~0.25)", k, r.weights[k]);
}
assert!((r.weights.iter().sum::<f32>() - 1.0).abs() < 1e-5);
}
#[test]
fn faulty_channel_down_weighted() {
let mut h = HretEstimator::<4, 1>::single_group(HretParams::default_sdr());
for _ in 0..200 {
h.observe(&[0.02, 0.02, 0.02, 0.20]);
}
let r = h.observe(&[0.02, 0.02, 0.02, 0.20]);
let good_sum = r.weights[0] + r.weights[1] + r.weights[2];
assert!(
good_sum > r.weights[3],
"good_sum={}, bad={}: faulty channel should be down-weighted",
good_sum, r.weights[3]
);
}
#[test]
fn hierarchical_group_fault_down_weights_entire_group() {
let map = [0usize, 0, 1, 1];
let mut h = HretEstimator::<4, 2>::new(map, HretParams::default_sdr());
for _ in 0..200 {
h.observe(&[0.02, 0.02, 0.20, 0.20]);
}
let r = h.observe(&[0.02, 0.02, 0.20, 0.20]);
let group0_sum = r.weights[0] + r.weights[1];
let group1_sum = r.weights[2] + r.weights[3];
assert!(
group0_sum > group1_sum,
"clean group0={} should outweigh noisy group1={}",
group0_sum, group1_sum
);
}
#[test]
fn weights_always_sum_to_one() {
let map = [0usize, 0, 1, 1];
let mut h = HretEstimator::<4, 2>::new(map, HretParams::default_sdr());
for i in 0..100 {
let r = h.observe(&[i as f32 * 0.01, 0.05, 0.03, i as f32 * 0.02]);
let sum: f32 = r.weights.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"weights sum={} at step {}", sum, i
);
}
}
#[test]
fn trust_diversity_bounded() {
let mut h = HretEstimator::<4, 1>::single_group(HretParams::default_sdr());
for _ in 0..100 {
let r = h.observe(&[0.1, 0.2, 0.3, 0.4]);
assert!(r.trust_diversity >= 0.0, "diversity must be non-negative");
assert!(r.trust_diversity <= 1.0, "diversity must be <= 1.0");
}
}
#[test]
fn reset_clears_state() {
let mut h = HretEstimator::<2, 1>::single_group(HretParams::default_sdr());
for _ in 0..100 { h.observe(&[0.5, 0.5]); }
h.reset();
assert_eq!(h.channel_states[0].envelope, 0.0);
assert_eq!(h.group_states[0].envelope, 0.0);
}
}