use alloc::vec;
use alloc::vec::Vec;
use super::Protocol;
use num_complex::Complex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EqMode {
Off,
Local,
Adaptive,
}
pub fn equalize_local<P: Protocol>(cs: &mut [Complex<f32>]) {
let ntones = P::NTONES as usize;
let _n_sym = P::N_SYMBOLS as usize;
let mut obs: Vec<Vec<Complex<f32>>> = vec![Vec::new(); ntones];
for block in P::SYNC_MODE.blocks() {
let start = block.start_symbol as usize;
for (k, &tone) in block.pattern.iter().enumerate() {
let t = tone as usize;
if t < ntones {
obs[t].push(cs[(start + k) * ntones + t]);
}
}
}
let mut pilots = vec![Complex::new(0.0f32, 0.0); ntones];
let mut observed = vec![false; ntones];
for t in 0..ntones {
if !obs[t].is_empty() {
let n = obs[t].len() as f32;
pilots[t] = obs[t].iter().copied().sum::<Complex<f32>>() / n;
observed[t] = true;
}
}
for t in 0..ntones {
if !observed[t] {
if t >= 2 && observed[t - 1] && observed[t - 2] {
pilots[t] = pilots[t - 1] * 2.0 - pilots[t - 2];
} else if t >= 1 && observed[t - 1] {
pilots[t] = pilots[t - 1];
}
}
}
let (total_var, count) = obs.iter().enumerate().filter(|(_, o)| !o.is_empty()).fold(
(0.0f32, 0usize),
|(v, n), (t, obs_t)| {
let mean = pilots[t];
(
v + obs_t.iter().map(|o| (*o - mean).norm_sqr()).sum::<f32>(),
n + obs_t.len(),
)
},
);
let noise_var = if count > 0 {
total_var / count as f32
} else {
1.0
};
let mut powers: Vec<f32> = pilots.iter().map(|p| p.norm_sqr()).collect();
powers.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median_power = powers[powers.len() / 2];
let noise_var = noise_var.max(median_power * 0.3);
let mut weights = vec![Complex::new(0.0f32, 0.0); ntones];
for t in 0..ntones {
let p = pilots[t];
weights[t] = p.conj() / (p.norm_sqr() + noise_var);
}
let mean_mag = weights.iter().map(|w| w.norm()).sum::<f32>() / ntones as f32;
if mean_mag > f32::EPSILON {
for w in weights.iter_mut() {
*w /= mean_mag;
}
}
let n_sym = cs.len() / ntones;
for sym in 0..n_sym {
for (t, w) in weights.iter().enumerate() {
cs[sym * ntones + t] *= *w;
}
}
}