Skip to main content

dsfb_rf/
trust.rs

1//! Hierarchical Residual-Envelope Trust (HRET) for RF multi-channel receivers.
2//!
3//! ## Theoretical Basis
4//!
5//! Derived from the DSFB-HRET framework (de Beer 2026, §III–IV).  The core
6//! insight is that in a multi-channel RF receiver (multi-antenna, multi-band,
7//! dual-polarisation) different observation channels vary in reliability.
8//! Naively averaging all channel residuals degrades the composite estimate when
9//! one antenna or band is in deep fade, experiencing local RFI, or has a faulty
10//! LNA.  HRET builds **two levels** of EMA-based envelope trust and combines
11//! them before computing the weighted residual.
12//!
13//! ### Level 1 — Channel envelope (eq. 8)
14//!
15//! For each channel k, a per-channel EMA envelope tracks the running
16//! absolute residual:
17//!
18//! ```text
19//! s_k ← ρ · s_k + (1 − ρ) · |r_k|
20//! ```
21//!
22//! Channel trust weight (eq. 9):
23//!
24//! ```text
25//! w_k = 1 / (1 + β · s_k)
26//! ```
27//!
28//! ### Level 2 — Group envelope (eq. 11)
29//!
30//! Channels are partitioned into groups (e.g., by polarisation, frequency band,
31//! or spatial cluster).  A per-group EMA envelope tracks the mean absolute
32//! residual across the group:
33//!
34//! ```text
35//! s_g ← ρ_g · s_g + (1 − ρ_g) · (1/|G| · Σ_{k∈G} |r_k|)
36//! ```
37//!
38//! Group trust weight (eq. 12):
39//!
40//! ```text
41//! w_g = 1 / (1 + β_g · s_g)
42//! ```
43//!
44//! ### Hierarchical composition (eqs. 14–15) and correction (eq. 19)
45//!
46//! Composite weights are the product of the channel weight and the weight of
47//! that channel's group, then L1-normalised:
48//!
49//! ```text
50//! ŵ_k = w_k · w_{g[k]}
51//! w̃_k = ŵ_k / Σ_j ŵ_j          (normalisation)
52//! ```
53//!
54//! The correction signal fed to downstream stages is (eq. 19):
55//!
56//! ```text
57//! Δx = K · (w̃ ⊙ r)
58//! ```
59//!
60//! ### RF interpretation
61//!
62//! | HRET concept | RF analogue |
63//! |---|---|
64//! | Channel k | Receive antenna element / ADC lane |
65//! | Group g | Polarisation pair / sub-array / frequency band |
66//! | Channel envelope s_k | Per-antenna noise / interference run-in |
67//! | Group envelope s_g | Sub-array health / band cleanliness |
68//! | ŵ_k | Phased-array weighting analogous to optimal combining |
69//! | Δx | Weighted residual anomaly injected into grammar layer |
70//!
71//! The hierarchical scheme is empirically superior to flat average combining
72//! in the presence of partial-array failures and spectrally local RFI.
73//!
74//! ## Design
75//!
76//! - `no_std`, `no_alloc`, zero `unsafe`
77//! - Const-generic over `C` (channel count) and `G` (group count)
78//! - O(C+G) per call — no heap scan
79//! - Channel-to-group mapping supplied as a `[usize; C]` index array
80
81/// Parameters for the HRET trust estimator.
82#[derive(Debug, Clone, Copy, PartialEq)]
83pub struct HretParams {
84    /// Channel-level EMA smoothing factor ρ ∈ (0, 1).
85    ///
86    /// Larger → slower adaptation (more memory); smaller → faster adaptation.
87    /// Typical: 0.95 for slowly varying RF channels, 0.80 for fast fades.
88    pub channel_rho: f32,
89
90    /// Group-level EMA smoothing factor ρ_g ∈ (0, 1).
91    ///
92    /// Usually slightly smoother than channel level (e.g., 0.97).
93    pub group_rho: f32,
94
95    /// Channel trust shaping coefficient β > 0.
96    ///
97    /// Controls how steeply small envelope increases reduce trust.
98    /// β = 1/σ₀ where σ₀ is the nominal healthy-window sigma.
99    pub beta_channel: f32,
100
101    /// Group trust shaping coefficient β_g > 0.
102    pub beta_group: f32,
103}
104
105impl HretParams {
106    /// Construct conservative defaults suitable for most SDR receivers.
107    ///
108    /// ρ = 0.95, ρ_g = 0.97, β = β_g = 10.0 (nominal σ₀ = 0.1).
109    pub const fn default_sdr() -> Self {
110        Self {
111            channel_rho: 0.95,
112            group_rho: 0.97,
113            beta_channel: 10.0,
114            beta_group: 10.0,
115        }
116    }
117
118    /// Construct from explicit nominal healthy-window sigma (sets β = 1/σ₀).
119    pub fn from_sigma(sigma0: f32, channel_rho: f32, group_rho: f32) -> Self {
120        let beta = if sigma0 > 1e-12 { 1.0 / sigma0 } else { 10.0 };
121        Self {
122            channel_rho,
123            group_rho,
124            beta_channel: beta,
125            beta_group: beta,
126        }
127    }
128}
129
130impl Default for HretParams {
131    fn default() -> Self { Self::default_sdr() }
132}
133
134/// Per-channel HRET trust state.
135#[derive(Debug, Clone, Copy)]
136pub struct ChannelState {
137    /// EMA envelope s_k tracking |r_k|.  Initialised to 0.
138    pub envelope: f32,
139    /// Last computed channel trust weight w_k.
140    pub trust_weight: f32,
141}
142
143impl Default for ChannelState {
144    fn default() -> Self {
145        Self { envelope: 0.0, trust_weight: 1.0 }
146    }
147}
148
149/// Per-group HRET trust state.
150#[derive(Debug, Clone, Copy)]
151pub struct GroupState {
152    /// EMA envelope s_g tracking mean |r| within the group.
153    pub envelope: f32,
154    /// Last computed group trust weight w_g.
155    pub trust_weight: f32,
156    /// Running channel count accumulator (used in mean computation).
157    pub count: u8,
158}
159
160impl Default for GroupState {
161    fn default() -> Self {
162        Self { envelope: 0.0, trust_weight: 1.0, count: 0 }
163    }
164}
165
166/// Complete HRET result returned by a single `observe()` call.
167#[derive(Debug, Clone, Copy)]
168pub struct HretResult<const C: usize> {
169    /// Normalised hierarchical channel weights w̃_k (sum = 1).
170    pub weights: [f32; C],
171    /// Weighted composite residual Δx = K · (w̃ ⊙ r).
172    ///
173    /// This is the single scalar anomaly signal fed to downstream grammar/DSA.
174    pub weighted_residual: f32,
175    /// Maximum normalised weight (identifies the most-trusted channel).
176    pub max_weight: f32,
177    /// Minimum normalised weight (identifies least-trusted channel).
178    pub min_weight: f32,
179    /// Trust diversity index = 1 − (max − min).  Close to 1 → uniform trust.
180    /// Close to 0 → power law: one channel dominates.
181    pub trust_diversity: f32,
182}
183
184/// Hierarchical Residual-Envelope Trust estimator.
185///
186/// ## Type Parameters
187///
188/// - `C`: number of observation channels (antenna elements, ADC lanes)
189/// - `G`: number of channel groups (polarisation pairs, sub-arrays, bands)
190///
191/// ## Memory footprint (no_std / no_alloc)
192///
193/// For C=4, G=2: 4×ChannelState + 2×GroupState + 4×usize = ~128 bytes.
194pub struct HretEstimator<const C: usize, const G: usize> {
195    /// Per-channel trust state.
196    channel_states: [ChannelState; C],
197    /// Per-group trust state.
198    group_states: [GroupState; G],
199    /// Channel-to-group mapping: group_map[k] = group index for channel k.
200    group_map: [usize; C],
201    /// HRET parameters.
202    params: HretParams,
203    /// Observation gain K applied to the weighted residual (default 1.0).
204    gain: f32,
205}
206
207impl<const C: usize, const G: usize> HretEstimator<C, G> {
208    /// Construct with a channel-to-group mapping and given parameters.
209    ///
210    /// # Panics (debug only)
211    ///
212    /// Panics in debug mode if any `group_map[k] >= G`.
213    /// In release mode, out-of-range indices are silently saturated (no UB).
214    pub fn new(group_map: [usize; C], params: HretParams) -> Self {
215        // Validate mapping in debug builds
216        debug_assert!(
217            group_map.iter().all(|&g| g < G),
218            "group_map contains index >= G"
219        );
220        Self {
221            channel_states: [ChannelState::default(); C],
222            group_states: [GroupState::default(); G],
223            group_map,
224            params,
225            gain: 1.0,
226        }
227    }
228
229    /// Construct with default SDR parameters and a uniform group mapping
230    /// (all channels in group 0) — useful for single-band single-array receivers.
231    pub fn single_group(params: HretParams) -> Self {
232        Self::new([0usize; C], params)
233    }
234
235    /// Set the output gain K (default 1.0).
236    pub fn with_gain(mut self, gain: f32) -> Self {
237        self.gain = gain;
238        self
239    }
240
241    /// Process one observation of per-channel residuals.
242    ///
243    /// `residuals[k]` = signed residual r_k for channel k.  We use |r_k|
244    /// for envelope update but the signed value for the weighted composite.
245    ///
246    /// Returns an `HretResult<C>` with normalised weights and the weighted
247    /// composite residual Δx.
248    pub fn observe(&mut self, residuals: &[f32; C]) -> HretResult<C> {
249        self.update_group_envelopes(residuals);
250        self.update_channel_envelopes(residuals);
251        let weights = self.compose_normalised_weights();
252        let weighted_residual = self.gain * dot_product_c(&weights, residuals);
253        let (max_w, min_w) = weight_extrema(&weights);
254        HretResult {
255            weights,
256            weighted_residual,
257            max_weight: max_w,
258            min_weight: min_w,
259            trust_diversity: 1.0 - (max_w - min_w),
260        }
261    }
262
263    fn update_group_envelopes(&mut self, residuals: &[f32; C]) {
264        let mut group_sum = [0.0_f32; G];
265        let mut group_cnt = [0_u32; G];
266        for (k, &r) in residuals.iter().enumerate() {
267            let g = self.group_map[k].min(G - 1);
268            group_sum[g] += r.abs();
269            group_cnt[g] += 1;
270        }
271        let rho_gr = self.params.group_rho;
272        let beta_gr = self.params.beta_group;
273        for g in 0..G {
274            let mean_abs = if group_cnt[g] > 0 { group_sum[g] / group_cnt[g] as f32 } else { 0.0 };
275            let s = &mut self.group_states[g].envelope;
276            *s = rho_gr * (*s) + (1.0 - rho_gr) * mean_abs;
277            self.group_states[g].trust_weight = 1.0 / (1.0 + beta_gr * self.group_states[g].envelope);
278        }
279    }
280
281    fn update_channel_envelopes(&mut self, residuals: &[f32; C]) {
282        let rho_ch = self.params.channel_rho;
283        let beta_ch = self.params.beta_channel;
284        for (k, &r) in residuals.iter().enumerate() {
285            let s = &mut self.channel_states[k].envelope;
286            *s = rho_ch * (*s) + (1.0 - rho_ch) * r.abs();
287            self.channel_states[k].trust_weight = 1.0 / (1.0 + beta_ch * self.channel_states[k].envelope);
288        }
289    }
290
291    fn compose_normalised_weights(&self) -> [f32; C] {
292        let mut hat_w = [0.0_f32; C];
293        for k in 0..C {
294            let g = self.group_map[k].min(G - 1);
295            hat_w[k] = self.channel_states[k].trust_weight * self.group_states[g].trust_weight;
296        }
297        let sum_hat: f32 = hat_w.iter().sum();
298        let mut weights = [0.0_f32; C];
299        if sum_hat > 1e-30 {
300            for k in 0..C { weights[k] = hat_w[k] / sum_hat; }
301        } else {
302            let unif = 1.0 / C as f32;
303            for k in 0..C { weights[k] = unif; }
304        }
305        weights
306    }
307
308    /// Return a snapshot of all channel states (trust weights + envelopes).
309    #[inline]
310    pub fn channel_states(&self) -> &[ChannelState; C] { &self.channel_states }
311
312    /// Return a snapshot of all group states.
313    #[inline]
314    pub fn group_states(&self) -> &[GroupState; G] { &self.group_states }
315
316    /// Normalised channel trust weight for channel k.
317    ///
318    /// Returns the last computed normalised weight w̃_k.
319    /// This is safe to call after at least one `observe()` call.
320    pub fn channel_trust(&self, k: usize) -> f32 {
321        self.channel_states.get(k).map(|s| s.trust_weight).unwrap_or(0.0)
322    }
323
324    /// Reset all state to initial values.
325    pub fn reset(&mut self) {
326        for s in &mut self.channel_states { *s = ChannelState::default(); }
327        for s in &mut self.group_states { *s = GroupState::default(); }
328    }
329}
330
331fn dot_product_c<const C: usize>(a: &[f32; C], b: &[f32; C]) -> f32 {
332    let mut d = 0.0_f32;
333    for k in 0..C { d += a[k] * b[k]; }
334    d
335}
336
337fn weight_extrema<const C: usize>(weights: &[f32; C]) -> (f32, f32) {
338    let mut max_w = weights[0];
339    let mut min_w = weights[0];
340    for k in 1..C {
341        if weights[k] > max_w { max_w = weights[k]; }
342        if weights[k] < min_w { min_w = weights[k]; }
343    }
344    (max_w, min_w)
345}
346
347// ---------------------------------------------------------------
348// Tests
349// ---------------------------------------------------------------
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn single_group_uniform_channels() {
356        // 4 channels all same residual → nearly uniform weights
357        let mut h = HretEstimator::<4, 1>::single_group(HretParams::default_sdr());
358        for _ in 0..50 {
359            let r = h.observe(&[0.1, 0.1, 0.1, 0.1]);
360            let _ = r;
361        }
362        let r = h.observe(&[0.1, 0.1, 0.1, 0.1]);
363        for k in 0..4 {
364            let diff = (r.weights[k] - 0.25).abs();
365            assert!(diff < 0.01, "weight[{}]={} (expected ~0.25)", k, r.weights[k]);
366        }
367        assert!((r.weights.iter().sum::<f32>() - 1.0).abs() < 1e-5);
368    }
369
370    #[test]
371    fn faulty_channel_down_weighted() {
372        // Channel 3 has 10× the noise of others → should get lower trust
373        let mut h = HretEstimator::<4, 1>::single_group(HretParams::default_sdr());
374        for _ in 0..200 {
375            // channel 3 always large
376            h.observe(&[0.02, 0.02, 0.02, 0.20]);
377        }
378        let r = h.observe(&[0.02, 0.02, 0.02, 0.20]);
379        // Good channels should cumulatively dominate
380        let good_sum = r.weights[0] + r.weights[1] + r.weights[2];
381        assert!(
382            good_sum > r.weights[3],
383            "good_sum={}, bad={}: faulty channel should be down-weighted",
384            good_sum, r.weights[3]
385        );
386    }
387
388    #[test]
389    fn hierarchical_group_fault_down_weights_entire_group() {
390        // 4 channels: channels 0,1 in group 0; channels 2,3 in group 1.
391        // Group 1 has persistent large residuals → both channels 2,3 should lose trust.
392        let map = [0usize, 0, 1, 1];
393        let mut h = HretEstimator::<4, 2>::new(map, HretParams::default_sdr());
394        for _ in 0..200 {
395            h.observe(&[0.02, 0.02, 0.20, 0.20]);
396        }
397        let r = h.observe(&[0.02, 0.02, 0.20, 0.20]);
398        let group0_sum = r.weights[0] + r.weights[1];
399        let group1_sum = r.weights[2] + r.weights[3];
400        assert!(
401            group0_sum > group1_sum,
402            "clean group0={} should outweigh noisy group1={}",
403            group0_sum, group1_sum
404        );
405    }
406
407    #[test]
408    fn weights_always_sum_to_one() {
409        let map = [0usize, 0, 1, 1];
410        let mut h = HretEstimator::<4, 2>::new(map, HretParams::default_sdr());
411        for i in 0..100 {
412            let r = h.observe(&[i as f32 * 0.01, 0.05, 0.03, i as f32 * 0.02]);
413            let sum: f32 = r.weights.iter().sum();
414            assert!(
415                (sum - 1.0).abs() < 1e-5,
416                "weights sum={} at step {}", sum, i
417            );
418        }
419    }
420
421    #[test]
422    fn trust_diversity_bounded() {
423        let mut h = HretEstimator::<4, 1>::single_group(HretParams::default_sdr());
424        for _ in 0..100 {
425            let r = h.observe(&[0.1, 0.2, 0.3, 0.4]);
426            assert!(r.trust_diversity >= 0.0, "diversity must be non-negative");
427            assert!(r.trust_diversity <= 1.0, "diversity must be <= 1.0");
428        }
429    }
430
431    #[test]
432    fn reset_clears_state() {
433        let mut h = HretEstimator::<2, 1>::single_group(HretParams::default_sdr());
434        for _ in 0..100 { h.observe(&[0.5, 0.5]); }
435        h.reset();
436        assert_eq!(h.channel_states[0].envelope, 0.0);
437        assert_eq!(h.group_states[0].envelope, 0.0);
438    }
439}