use super::filter::{FilterCoeff, Window};
#[cfg(not(test))]
use log::debug;
#[cfg(test)]
use std::println as debug;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct NoTrainingSequenceErr;
#[derive(Clone, Debug)]
pub struct Equalizer {
relaxation: f32,
regularization: f32,
train_to: Option<u32>,
feedforward_coeff: FilterCoeff<f32>,
feedback_coeff: FilterCoeff<f32>,
feedforward_wind: Window<f32>,
feedback_wind: Window<f32>,
mode: EqualizerState,
}
#[allow(dead_code)]
impl Equalizer {
pub const INPUT_LENGTH: usize = 16;
pub fn new(
nfeedforward: usize,
nfeedback: usize,
relaxation: f32,
regularization: f32,
train_to: Option<u32>,
) -> Self {
let feedforward_coeff = FilterCoeff::from_identity(nfeedforward);
let feedback_coeff = FilterCoeff::from_identity(nfeedback);
let feedforward_wind = Window::new(nfeedforward);
let feedback_wind = Window::new(nfeedback);
assert_eq!(feedforward_coeff.inner().len(), feedforward_wind.len());
assert_eq!(feedback_coeff.inner().len(), feedback_wind.len());
Self {
relaxation,
regularization,
train_to,
feedforward_coeff,
feedback_coeff,
feedforward_wind,
feedback_wind,
mode: EqualizerState::EnabledFeedback,
}
}
pub fn input(&mut self, byte_samples: &[f32]) -> (u8, f32) {
assert_eq!(byte_samples.len(), Self::INPUT_LENGTH);
let mut byte = 0;
let mut last_err = 0.0f32;
for (bitind, twosamp) in byte_samples.chunks(2).enumerate() {
let (bit, err) = self.estimate_symbol(twosamp);
last_err = err;
byte |= (bit as u8) << bitind;
}
(byte, last_err)
}
pub fn reset(&mut self) {
self.feedforward_coeff.identity();
self.feedback_coeff.identity();
self.feedforward_wind.reset();
self.feedback_wind.reset();
}
pub fn enable(&mut self, enable: bool) {
self.mode = match enable {
true => EqualizerState::EnabledFeedback,
false => EqualizerState::Disabled,
}
}
pub fn train(&mut self) -> Result<(), NoTrainingSequenceErr> {
let train_to = self.train_to.ok_or(NoTrainingSequenceErr)?;
self.mode = EqualizerState::EnabledTraining(train_to, 0);
Ok(())
}
pub fn is_enabled(&self) -> bool {
self.mode != EqualizerState::Disabled
}
pub fn is_training(&self) -> bool {
if let EqualizerState::EnabledTraining(_, _) = self.mode {
true
} else {
false
}
}
fn estimate_symbol(&mut self, input: &[f32]) -> (bool, f32) {
assert_eq!(2, input.len());
self.feedforward_wind.push(input);
let ff = self.feedforward_coeff.filter(&self.feedforward_wind);
let fb = self.feedback_coeff.filter(&self.feedback_wind);
let sym_val = ff - fb;
let out = match self.mode {
EqualizerState::Disabled => {
(sym_val.signum(), 0.0f32)
}
EqualizerState::EnabledFeedback => {
let sym_est = sym_val.signum();
let err = sym_est - sym_val;
self.evolve(err);
(sym_est, err)
}
EqualizerState::EnabledTraining(mut sa, mut count) => {
let sym_est = (2.0f32 * (sa & 0x1u32) as f32) - 1.0f32;
sa = sa >> 1;
let err = sym_est - sym_val;
self.evolve(err);
count += 1;
if count >= 32 {
debug!("equalizer: end training with err: {:.4}", err);
self.mode = EqualizerState::EnabledFeedback;
} else {
self.mode = EqualizerState::EnabledTraining(sa, count)
}
(sym_est, err)
}
};
self.feedback_wind.push(&[out.0, 0.0f32]);
(out.0 >= 0.0f32, out.1)
}
#[inline]
fn evolve(&mut self, error: f32) {
nlms_update(
self.relaxation,
self.regularization,
error,
&self.feedforward_wind,
self.feedforward_coeff.as_mut(),
);
assert_eq!(self.feedback_wind.inner().len(), self.feedback_coeff.len());
nlms_update(
self.relaxation,
self.regularization,
-error,
&self.feedback_wind,
self.feedback_coeff.as_mut(),
);
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum EqualizerState {
Disabled,
EnabledFeedback,
EnabledTraining(u32, u32),
}
fn nlms_update<W>(relaxation: f32, regularization: f32, error: f32, window: W, filter: &mut [f32])
where
W: IntoIterator<Item = f32>,
W::IntoIter: DoubleEndedIterator + Clone,
{
let window = window.into_iter();
let gain = nlms_gain(relaxation, regularization, window.clone());
for (coeff, data) in filter.iter_mut().zip(window.rev()) {
*coeff += gain * error * data;
}
}
#[inline]
fn nlms_gain<'a, W>(relaxation: f32, regularization: f32, window: W) -> f32
where
W: IntoIterator<Item = f32>,
{
let mut sumsq = 0.0f32;
for w in window {
sumsq += w * w;
}
relaxation / (regularization + sumsq)
}
impl std::fmt::Display for NoTrainingSequenceErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "no training sequence defined")
}
}
impl std::error::Error for NoTrainingSequenceErr {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
#[cfg(test)]
mod tests {
use super::super::waveform;
use super::*;
use assert_approx_eq::assert_approx_eq;
const PROAKIS_B: &[f32] = &[0.407, 0.815, 0.407];
#[test]
fn test_estimate_symbol_simple() {
const INPUT: &[f32] = &[0.0f32, 0.5f32, 0.0f32, -0.5f32];
let mut eqlz = Equalizer::new(8, 4, 0.2, 1.0e-5, None);
eqlz.enable(false);
let out: Vec<(bool, f32)> = INPUT
.chunks(2)
.map(|samps| eqlz.estimate_symbol(samps))
.collect();
assert_eq!(2, out.len());
assert_eq!(true, out[0].0);
assert_approx_eq!(out[0].1, 0.0f32);
assert_eq!(false, out[1].0);
assert_approx_eq!(out[1].1, 0.0f32);
eqlz.enable(true);
let out: Vec<(bool, f32)> = INPUT
.chunks(2)
.cycle()
.take(32)
.map(|sample| eqlz.estimate_symbol(sample))
.collect();
assert!(out[out.len() - 1].1.abs() < 1.0e-5f32);
}
#[test]
fn test_nlms_evolve() {
const INPUT: &[f32] = &[0.0f32, 1.0f32, 0.0f32, -1.0f32];
const RELAXATION: f32 = 0.10f32;
const REGULARIZATION: f32 = 1.0e-6f32;
let channel_coeff = FilterCoeff::from_slice(PROAKIS_B);
let mut channel_wind = Window::<f32>::new(PROAKIS_B.len());
let mut inverse_coeff = FilterCoeff::<f32>::from_identity(3);
let mut inverse_wind = Window::<f32>::new(3);
let mut err = 0.0f32;
for sample in INPUT.iter().cycle().take(128) {
channel_wind.push(&[*sample]);
let ch_sample = channel_coeff.filter(&channel_wind);
inverse_wind.push(&[ch_sample]);
let est_sample = inverse_coeff.filter(&inverse_wind);
err = sample - est_sample;
nlms_update(
RELAXATION,
REGULARIZATION,
err,
&inverse_wind,
inverse_coeff.as_mut(),
);
}
assert!(err.abs() < 1e-2);
}
#[test]
fn test_estimate_symbol_proakis() {
const CHANNEL_COEFF: &[f32] = &[0.8f32, -0.2f32]; const INPUT: &[f32] = &[0.0f32, 1.0f32, 0.0f32, -1.0f32];
let channel_coeff = FilterCoeff::from_slice(CHANNEL_COEFF);
let mut channel_wind = Window::<f32>::new(CHANNEL_COEFF.len());
let mut uut = Equalizer::new(8, 4, 0.2, 1.0e-5, None);
let mut last = (false, 0.0f32);
for samples in INPUT.chunks(2).cycle().take(32) {
let mut ch_samples = [0.0f32, 0.0f32];
for (inp, outp) in samples.iter().zip(ch_samples.iter_mut()) {
channel_wind.push(&[*inp]);
*outp = channel_coeff.filter(&channel_wind);
}
last = uut.estimate_symbol(&ch_samples);
}
assert!(last.1.abs() < 1e-4);
for samples in INPUT.chunks(2) {
let mut ch_samples = [0.0f32, 0.0f32];
for (inp, outp) in samples.iter().zip(ch_samples.iter_mut()) {
channel_wind.push(&[*inp]);
*outp = channel_coeff.filter(&channel_wind);
}
last = uut.estimate_symbol(&ch_samples);
assert_eq!(samples[1] >= 0.0f32, last.0);
assert!(last.1.abs() < 1e-4);
}
}
#[test]
fn test_estimate_symbol_training() {
let mut uut = Equalizer::new(8, 4, 0.2, 1.0e-5, Some(waveform::PREAMBLE_SYNC_WORD));
uut.train().expect("training mode");
assert_eq!(
uut.mode,
EqualizerState::EnabledTraining(waveform::PREAMBLE_SYNC_WORD, 0)
);
let chansig = waveform::bytes_to_samples(&[0x54, 0x54], 2);
let _rxsig0: Vec<(bool, f32)> = chansig
.chunks(2)
.map(|sa| uut.estimate_symbol(sa))
.collect();
match uut.mode {
EqualizerState::EnabledTraining(_, 16) => assert!(true),
_ => unreachable!(),
}
let _rxsig1: Vec<(bool, f32)> = chansig
.chunks(2)
.map(|sa| uut.estimate_symbol(sa))
.collect();
assert_eq!(uut.mode, EqualizerState::EnabledFeedback);
let rx = uut.estimate_symbol(&[0.0f32, -1.0f32]);
assert!(rx.0);
uut.reset();
uut.train().expect("training mode err");
let chansig = waveform::bytes_to_samples(&[0xAB, 0xAB, 0xAB, 0xAB], 2);
let _rxsig3: Vec<(bool, f32)> = chansig
.chunks(2)
.map(|sa| uut.estimate_symbol(sa))
.collect();
assert_eq!(uut.mode, EqualizerState::EnabledFeedback);
let rx = uut.estimate_symbol(&[0.0f32, -1.0f32]);
assert!(!rx.0);
}
#[test]
fn test_input() {
let chansig = waveform::bytes_to_samples(&[0xAB, 0xBA], 2);
let mut uut = Equalizer::new(8, 4, 0.2, 1.0e-5, None);
let out: Vec<(u8, f32)> = chansig.chunks(16).map(|sa| uut.input(sa)).collect();
assert_eq!(out.len(), 2);
assert_eq!(out[0].0, 0xAB);
assert_eq!(out[1].0, 0xBA);
}
}