use super::{
channel::AwgnChannel,
modulation::{BpskDemodulator, BpskModulator},
puncturing::Puncturer,
};
use crate::{
decoder::{factory::DecoderImplementation, LdpcDecoder},
encoder::{Encoder, Error},
gf2::GF2,
sparse::SparseMatrix,
};
use ndarray::Array1;
use num_traits::{One, Zero};
use rand::{distributions::Standard, Rng};
use std::{
sync::mpsc::{self, Receiver, Sender, SyncSender, TryRecvError},
time::{Duration, Instant},
};
#[derive(Debug)]
pub struct BerTest {
decoder_implementation: DecoderImplementation,
h: SparseMatrix,
num_workers: usize,
k: usize,
n: usize,
n_cw: usize,
rate: f64,
encoder: Encoder,
puncturer: Option<Puncturer>,
modulator: BpskModulator,
ebn0s_db: Vec<f32>,
statistics: Vec<Statistics>,
max_iterations: usize,
max_frame_errors: u64,
reporter: Option<Reporter>,
last_reported: Instant,
}
#[derive(Debug)]
struct Worker {
terminate_rx: Receiver<()>,
results_tx: Sender<WorkerResult>,
k: usize,
encoder: Encoder,
puncturer: Option<Puncturer>,
modulator: BpskModulator,
channel: AwgnChannel,
demodulator: BpskDemodulator,
decoder: Box<dyn LdpcDecoder>,
max_iterations: usize,
}
#[derive(Debug, Clone)]
struct WorkerResultOk {
bit_errors: u64,
frame_error: bool,
false_decode: bool,
iterations: u64,
}
type WorkerResult = Result<WorkerResultOk, ()>;
#[derive(Debug, Clone, PartialEq)]
struct CurrentStatistics {
num_frames: u64,
bit_errors: u64,
frame_errors: u64,
false_decodes: u64,
total_iterations: u64,
correct_iterations: u64,
start: Instant,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Statistics {
pub ebn0_db: f32,
pub num_frames: u64,
pub bit_errors: u64,
pub frame_errors: u64,
pub total_iterations: u64,
pub correct_iterations: u64,
pub false_decodes: u64,
pub ber: f64,
pub fer: f64,
pub average_iterations: f64,
pub average_iterations_correct: f64,
pub elapsed: Duration,
pub throughput_mbps: f64,
}
#[derive(Debug, Clone)]
pub struct Reporter {
pub tx: Sender<Report>,
pub interval: Duration,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Report {
Statistics(Statistics),
Finished,
}
macro_rules! report {
($self:expr, $current_statistics:expr, $ebn0_db:expr, $final:expr) => {
if let Some(reporter) = $self.reporter.as_ref() {
let now = Instant::now();
if $final || $self.last_reported + reporter.interval < now {
reporter
.tx
.send(Report::Statistics(Statistics::from_current(
&$current_statistics,
$ebn0_db,
$self.k,
)))
.unwrap();
$self.last_reported = now;
}
}
};
}
impl BerTest {
pub fn new(
h: SparseMatrix,
decoder_implementation: DecoderImplementation,
puncturing_pattern: Option<&[bool]>,
max_frame_errors: u64,
max_iterations: usize,
ebn0s_db: &[f32],
reporter: Option<Reporter>,
) -> Result<BerTest, Error> {
let k = h.num_cols() - h.num_rows();
let n_cw = h.num_cols();
let puncturer = puncturing_pattern.map(Puncturer::new);
let puncturer_rate = if let Some(p) = puncturer.as_ref() {
p.rate()
} else {
1.0
};
let n = (n_cw as f64 / puncturer_rate).round() as usize;
let rate = k as f64 / n as f64;
Ok(BerTest {
decoder_implementation,
num_workers: num_cpus::get(),
k,
n,
n_cw,
rate,
encoder: Encoder::from_h(&h)?,
h,
puncturer,
modulator: BpskModulator::new(),
ebn0s_db: ebn0s_db.to_owned(),
statistics: Vec::with_capacity(ebn0s_db.len()),
max_iterations,
max_frame_errors,
reporter,
last_reported: Instant::now(),
})
}
pub fn run(mut self) -> Result<Vec<Statistics>, Box<dyn std::error::Error>> {
let ret = self.do_run();
if let Some(reporter) = self.reporter.as_ref() {
reporter.tx.send(Report::Finished).unwrap();
}
ret?;
Ok(self.statistics)
}
pub fn n(&self) -> usize {
self.n
}
pub fn n_cw(&self) -> usize {
self.n_cw
}
pub fn k(&self) -> usize {
self.k
}
pub fn rate(&self) -> f64 {
self.rate
}
fn do_run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
self.last_reported = Instant::now();
for &ebn0_db in &self.ebn0s_db {
let ebn0 = 10.0_f64.powf(0.1 * f64::from(ebn0_db));
let esn0 = self.rate * ebn0;
let noise_sigma = (0.5 / esn0).sqrt();
let (results_tx, results_rx) = mpsc::channel();
let workers = std::iter::repeat_with(|| {
let (mut worker, terminate_tx) = self.make_worker(noise_sigma, results_tx.clone());
let handle = std::thread::spawn(move || worker.work());
(handle, terminate_tx)
})
.take(self.num_workers)
.collect::<Vec<_>>();
let mut current_statistics = CurrentStatistics::new();
while current_statistics.frame_errors < self.max_frame_errors {
match results_rx.recv().unwrap() {
Ok(result) => {
current_statistics.bit_errors += result.bit_errors;
current_statistics.frame_errors += u64::from(result.frame_error);
current_statistics.false_decodes += u64::from(result.false_decode);
current_statistics.total_iterations += result.iterations;
if !result.frame_error {
current_statistics.correct_iterations += result.iterations;
}
current_statistics.num_frames += 1;
}
Err(()) => break,
}
report!(self, current_statistics, ebn0_db, false);
}
report!(self, current_statistics, ebn0_db, true);
for (_, terminate_tx) in workers.iter() {
let _ = terminate_tx.send(());
}
let mut join_error = None;
for (handle, _) in workers.into_iter() {
if let Err(e) = handle.join().unwrap() {
join_error = Some(e);
}
}
if let Some(e) = join_error {
return Err(e);
}
self.statistics.push(Statistics::from_current(
¤t_statistics,
ebn0_db,
self.k,
));
}
Ok(())
}
fn make_worker(
&self,
noise_sigma: f64,
results_tx: Sender<WorkerResult>,
) -> (Worker, SyncSender<()>) {
let (terminate_tx, terminate_rx) = mpsc::sync_channel(1);
(
Worker {
terminate_rx,
results_tx,
k: self.k,
encoder: self.encoder.clone(),
puncturer: self.puncturer.clone(),
modulator: self.modulator.clone(),
channel: AwgnChannel::new(noise_sigma),
demodulator: BpskDemodulator::new(noise_sigma),
decoder: self.decoder_implementation.build_decoder(self.h.clone()),
max_iterations: self.max_iterations,
},
terminate_tx,
)
}
}
impl Worker {
fn work(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let mut rng = rand::thread_rng();
loop {
match self.terminate_rx.try_recv() {
Ok(()) => return Ok(()),
Err(TryRecvError::Disconnected) => panic!(),
Err(TryRecvError::Empty) => (),
};
let result = self.simulate(&mut rng);
let to_send = match result.as_ref() {
Ok(r) => Ok(r.clone()),
Err(_) => Err(()),
};
self.results_tx.send(to_send).unwrap();
result?;
}
}
fn simulate<R: Rng>(
&mut self,
rng: &mut R,
) -> Result<WorkerResultOk, Box<dyn std::error::Error + Send + Sync + 'static>> {
let message = Self::random_message(rng, self.k);
let codeword = self.encoder.encode(&Self::gf2_array(&message));
let transmitted = match self.puncturer.as_ref() {
Some(p) => p.puncture(&codeword)?,
None => codeword,
};
let mut symbols = self.modulator.modulate(&transmitted);
self.channel.add_noise(rng, &mut symbols);
let llrs_demod = self.demodulator.demodulate(&symbols);
let llrs_decoder = match self.puncturer.as_ref() {
Some(p) => p.depuncture(&llrs_demod)?,
None => llrs_demod,
};
let (decoded, iterations, success) =
match self.decoder.decode(&llrs_decoder, self.max_iterations) {
Ok(output) => (output.codeword, output.iterations, true),
Err(output) => (output.codeword, output.iterations, false),
};
let bit_errors = message
.iter()
.zip(decoded.iter())
.filter(|(&a, &b)| a != b)
.count() as u64;
let frame_error = bit_errors > 0;
let false_decode = frame_error && success;
Ok(WorkerResultOk {
bit_errors,
frame_error,
false_decode,
iterations: iterations as u64,
})
}
fn random_message<R: Rng>(rng: &mut R, size: usize) -> Vec<u8> {
rng.sample_iter(Standard)
.map(<u8 as From<bool>>::from)
.take(size)
.collect()
}
fn gf2_array(bits: &[u8]) -> Array1<GF2> {
Array1::from_iter(
bits.iter()
.map(|&b| if b == 1 { GF2::one() } else { GF2::zero() }),
)
}
}
impl CurrentStatistics {
fn new() -> CurrentStatistics {
CurrentStatistics {
num_frames: 0,
bit_errors: 0,
frame_errors: 0,
false_decodes: 0,
total_iterations: 0,
correct_iterations: 0,
start: Instant::now(),
}
}
}
impl Default for CurrentStatistics {
fn default() -> CurrentStatistics {
CurrentStatistics::new()
}
}
impl Statistics {
fn from_current(stats: &CurrentStatistics, ebn0_db: f32, k: usize) -> Statistics {
let elapsed = Instant::now() - stats.start;
Statistics {
ebn0_db,
num_frames: stats.num_frames,
bit_errors: stats.bit_errors,
frame_errors: stats.frame_errors,
false_decodes: stats.false_decodes,
total_iterations: stats.total_iterations,
correct_iterations: stats.correct_iterations,
ber: stats.bit_errors as f64 / (k as f64 * stats.num_frames as f64),
fer: stats.frame_errors as f64 / stats.num_frames as f64,
average_iterations: stats.total_iterations as f64 / stats.num_frames as f64,
average_iterations_correct: stats.correct_iterations as f64
/ (stats.num_frames - stats.frame_errors) as f64,
elapsed,
throughput_mbps: 1e-6 * (k as f64 * stats.num_frames as f64) / elapsed.as_secs_f64(),
}
}
}