mfsk_core/core/equalize.rs
1//! Adaptive per-tone equaliser using the protocol's Costas pilot tones.
2//!
3//! Estimates the channel response `H(tone)` by averaging the pilot-tone
4//! observations gathered from every [`SyncBlock`](super::SyncBlock) across
5//! the frame, then applies a Wiener-regularised zero-forcing correction to
6//! every symbol's complex spectrum so the downstream LLR sees flat tones.
7//!
8//! Protocol differences handled automatically:
9//! - **FT8** (3 × Costas-7): tones 0..6 observed 3× each, tone 7 never →
10//! extrapolated as `2·H[6] − H[5]`.
11//! - **FT4** (4 × Costas-4): every tone observed 4× each → extrapolation
12//! branch is not exercised.
13//! - Future protocols with any subset of observed tones use the same
14//! machinery; missing tones are linearly extrapolated from their two
15//! lower neighbours.
16
17use super::Protocol;
18use num_complex::Complex;
19
20/// Equaliser operating mode.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum EqMode {
23 /// No equalisation (passthrough).
24 Off,
25 /// Per-signal equalisation using local Costas pilot tones.
26 Local,
27 /// Try without EQ first; fall back to EQ only if BP decode fails.
28 Adaptive,
29}
30
31/// Apply local (per-signal) Wiener equalisation to a flat symbol-spectra
32/// buffer in place. `cs` is laid out row-major by symbol (length
33/// `N_SYMBOLS × NTONES`) — the same layout produced by
34/// [`super::llr::symbol_spectra`].
35pub fn equalize_local<P: Protocol>(cs: &mut [Complex<f32>]) {
36 let ntones = P::NTONES as usize;
37 let _n_sym = P::N_SYMBOLS as usize;
38
39 // Gather per-tone observations across all sync blocks.
40 let mut obs: Vec<Vec<Complex<f32>>> = vec![Vec::new(); ntones];
41 for block in P::SYNC_MODE.blocks() {
42 let start = block.start_symbol as usize;
43 for (k, &tone) in block.pattern.iter().enumerate() {
44 let t = tone as usize;
45 if t < ntones {
46 obs[t].push(cs[(start + k) * ntones + t]);
47 }
48 }
49 }
50
51 // Per-tone pilot estimate: mean of observations. Missing tones are
52 // linearly extrapolated from the previous two in ascending order.
53 let mut pilots = vec![Complex::new(0.0f32, 0.0); ntones];
54 let mut observed = vec![false; ntones];
55 for t in 0..ntones {
56 if !obs[t].is_empty() {
57 let n = obs[t].len() as f32;
58 pilots[t] = obs[t].iter().copied().sum::<Complex<f32>>() / n;
59 observed[t] = true;
60 }
61 }
62 for t in 0..ntones {
63 if !observed[t] {
64 // Try `2·p[t-1] − p[t-2]` if both predecessors are observed.
65 if t >= 2 && observed[t - 1] && observed[t - 2] {
66 pilots[t] = pilots[t - 1] * 2.0 - pilots[t - 2];
67 } else if t >= 1 && observed[t - 1] {
68 // Fall back to flat extrapolation.
69 pilots[t] = pilots[t - 1];
70 }
71 // else: stays zero — callers must ensure pattern visits enough tones.
72 }
73 }
74
75 // Noise variance from the scatter of observations around the per-tone mean.
76 let (total_var, count) = obs.iter().enumerate().filter(|(_, o)| !o.is_empty()).fold(
77 (0.0f32, 0usize),
78 |(v, n), (t, obs_t)| {
79 let mean = pilots[t];
80 (
81 v + obs_t.iter().map(|o| (*o - mean).norm_sqr()).sum::<f32>(),
82 n + obs_t.len(),
83 )
84 },
85 );
86 let noise_var = if count > 0 {
87 total_var / count as f32
88 } else {
89 1.0
90 };
91
92 // Regularise by median pilot power × 0.3 (prevents over-correction at low SNR).
93 let mut powers: Vec<f32> = pilots.iter().map(|p| p.norm_sqr()).collect();
94 powers.sort_by(|a, b| a.partial_cmp(b).unwrap());
95 let median_power = powers[powers.len() / 2];
96 let noise_var = noise_var.max(median_power * 0.3);
97
98 // Wiener weights.
99 let mut weights = vec![Complex::new(0.0f32, 0.0); ntones];
100 for t in 0..ntones {
101 let p = pilots[t];
102 weights[t] = p.conj() / (p.norm_sqr() + noise_var);
103 }
104
105 // Normalise mean |w| → 1 so downstream SNR estimates remain meaningful.
106 let mean_mag = weights.iter().map(|w| w.norm()).sum::<f32>() / ntones as f32;
107 if mean_mag > f32::EPSILON {
108 for w in weights.iter_mut() {
109 *w /= mean_mag;
110 }
111 }
112
113 // Apply to every symbol.
114 let n_sym = cs.len() / ntones;
115 for sym in 0..n_sym {
116 for (t, w) in weights.iter().enumerate() {
117 cs[sym * ntones + t] *= *w;
118 }
119 }
120}