Skip to main content

mfsk_core/core/
equalize.rs

1//! Adaptive per-tone equaliser using the protocol's Costas pilot tones.
2//!
3//! Estimates the channel response `H(tone)` by averaging the pilot-tone
4//! observations gathered from every [`SyncBlock`](super::SyncBlock) across
5//! the frame, then applies a Wiener-regularised zero-forcing correction to
6//! every symbol's complex spectrum so the downstream LLR sees flat tones.
7//!
8//! Protocol differences handled automatically:
9//! - **FT8** (3 × Costas-7): tones 0..6 observed 3× each, tone 7 never →
10//!   extrapolated as `2·H[6] − H[5]`.
11//! - **FT4** (4 × Costas-4): every tone observed 4× each → extrapolation
12//!   branch is not exercised.
13//! - Future protocols with any subset of observed tones use the same
14//!   machinery; missing tones are linearly extrapolated from their two
15//!   lower neighbours.
16
17use super::Protocol;
18use num_complex::Complex;
19
20/// Equaliser operating mode.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum EqMode {
23    /// No equalisation (passthrough).
24    Off,
25    /// Per-signal equalisation using local Costas pilot tones.
26    Local,
27    /// Try without EQ first; fall back to EQ only if BP decode fails.
28    Adaptive,
29}
30
31/// Apply local (per-signal) Wiener equalisation to a flat symbol-spectra
32/// buffer in place. `cs` is laid out row-major by symbol (length
33/// `N_SYMBOLS × NTONES`) — the same layout produced by
34/// [`super::llr::symbol_spectra`].
35pub fn equalize_local<P: Protocol>(cs: &mut [Complex<f32>]) {
36    let ntones = P::NTONES as usize;
37    let _n_sym = P::N_SYMBOLS as usize;
38
39    // Gather per-tone observations across all sync blocks.
40    let mut obs: Vec<Vec<Complex<f32>>> = vec![Vec::new(); ntones];
41    for block in P::SYNC_MODE.blocks() {
42        let start = block.start_symbol as usize;
43        for (k, &tone) in block.pattern.iter().enumerate() {
44            let t = tone as usize;
45            if t < ntones {
46                obs[t].push(cs[(start + k) * ntones + t]);
47            }
48        }
49    }
50
51    // Per-tone pilot estimate: mean of observations. Missing tones are
52    // linearly extrapolated from the previous two in ascending order.
53    let mut pilots = vec![Complex::new(0.0f32, 0.0); ntones];
54    let mut observed = vec![false; ntones];
55    for t in 0..ntones {
56        if !obs[t].is_empty() {
57            let n = obs[t].len() as f32;
58            pilots[t] = obs[t].iter().copied().sum::<Complex<f32>>() / n;
59            observed[t] = true;
60        }
61    }
62    for t in 0..ntones {
63        if !observed[t] {
64            // Try `2·p[t-1] − p[t-2]` if both predecessors are observed.
65            if t >= 2 && observed[t - 1] && observed[t - 2] {
66                pilots[t] = pilots[t - 1] * 2.0 - pilots[t - 2];
67            } else if t >= 1 && observed[t - 1] {
68                // Fall back to flat extrapolation.
69                pilots[t] = pilots[t - 1];
70            }
71            // else: stays zero — callers must ensure pattern visits enough tones.
72        }
73    }
74
75    // Noise variance from the scatter of observations around the per-tone mean.
76    let (total_var, count) = obs.iter().enumerate().filter(|(_, o)| !o.is_empty()).fold(
77        (0.0f32, 0usize),
78        |(v, n), (t, obs_t)| {
79            let mean = pilots[t];
80            (
81                v + obs_t.iter().map(|o| (*o - mean).norm_sqr()).sum::<f32>(),
82                n + obs_t.len(),
83            )
84        },
85    );
86    let noise_var = if count > 0 {
87        total_var / count as f32
88    } else {
89        1.0
90    };
91
92    // Regularise by median pilot power × 0.3 (prevents over-correction at low SNR).
93    let mut powers: Vec<f32> = pilots.iter().map(|p| p.norm_sqr()).collect();
94    powers.sort_by(|a, b| a.partial_cmp(b).unwrap());
95    let median_power = powers[powers.len() / 2];
96    let noise_var = noise_var.max(median_power * 0.3);
97
98    // Wiener weights.
99    let mut weights = vec![Complex::new(0.0f32, 0.0); ntones];
100    for t in 0..ntones {
101        let p = pilots[t];
102        weights[t] = p.conj() / (p.norm_sqr() + noise_var);
103    }
104
105    // Normalise mean |w| → 1 so downstream SNR estimates remain meaningful.
106    let mean_mag = weights.iter().map(|w| w.norm()).sum::<f32>() / ntones as f32;
107    if mean_mag > f32::EPSILON {
108        for w in weights.iter_mut() {
109            *w /= mean_mag;
110        }
111    }
112
113    // Apply to every symbol.
114    let n_sym = cs.len() / ntones;
115    for sym in 0..n_sym {
116        for (t, w) in weights.iter().enumerate() {
117            cs[sym * ntones + t] *= *w;
118        }
119    }
120}