use std::collections::HashMap;
use std::fmt;
use crate::error::{NetError, NetResult};
#[derive(Debug, Clone)]
pub struct InterleavedFecConfig {
pub num_rows: usize,
pub num_cols: usize,
}
impl InterleavedFecConfig {
pub fn new(num_rows: usize, num_cols: usize) -> NetResult<Self> {
if num_rows == 0 || num_cols == 0 {
return Err(NetError::protocol("num_rows and num_cols must be >= 1"));
}
if num_rows.saturating_mul(num_cols) > 256 {
return Err(NetError::protocol(
"FEC group size (rows × cols) must not exceed 256",
));
}
Ok(Self { num_rows, num_cols })
}
#[must_use]
pub const fn group_size(&self) -> usize {
self.num_rows * self.num_cols
}
#[must_use]
pub const fn row_repair_count(&self) -> usize {
self.num_rows
}
#[must_use]
pub const fn col_repair_count(&self) -> usize {
self.num_cols
}
#[must_use]
pub fn to_matrix_pos(&self, idx: usize) -> (usize, usize) {
(idx / self.num_cols, idx % self.num_cols)
}
#[must_use]
pub const fn from_matrix_pos(&self, row: usize, col: usize) -> usize {
row * self.num_cols + col
}
}
fn xor_into(dst: &mut [u8], src: &[u8]) {
let common = dst.len().min(src.len());
let words = common / 8;
for i in 0..words {
let off = i * 8;
let d = u64::from_le_bytes(dst[off..off + 8].try_into().expect("slice len"));
let s = u64::from_le_bytes(src[off..off + 8].try_into().expect("slice len"));
dst[off..off + 8].copy_from_slice(&(d ^ s).to_le_bytes());
}
for i in (words * 8)..common {
dst[i] ^= src[i];
}
}
fn compute_xor_repair(packets: &[&[u8]], width: usize) -> Vec<u8> {
let mut repair = vec![0u8; width];
for pkt in packets {
xor_into(&mut repair, pkt);
}
repair
}
#[derive(Debug, Clone)]
pub struct RepairPacket {
pub index: usize,
pub payload: Vec<u8>,
pub source_seqs: Vec<u16>,
}
#[derive(Debug, Clone)]
pub struct FecGroup {
pub row_repair: Vec<RepairPacket>,
pub col_repair: Vec<RepairPacket>,
pub config: InterleavedFecConfig,
}
#[derive(Debug)]
pub struct InterleavedFecEncoder {
config: InterleavedFecConfig,
sources: HashMap<u16, Vec<u8>>,
anchor_seq: Option<u16>,
}
impl InterleavedFecEncoder {
#[must_use]
pub fn new(config: InterleavedFecConfig) -> Self {
Self {
config,
sources: HashMap::new(),
anchor_seq: None,
}
}
pub fn feed(&mut self, seq: u16, payload: &[u8]) {
if self.anchor_seq.is_none() {
self.anchor_seq = Some(seq);
}
self.sources.insert(seq, payload.to_vec());
}
#[must_use]
pub fn count(&self) -> usize {
self.sources.len()
}
pub fn finalize(&self) -> NetResult<FecGroup> {
let group_size = self.config.group_size();
if self.sources.len() < group_size {
return Err(NetError::encoding(format!(
"FEC group incomplete: {} / {} packets",
self.sources.len(),
group_size
)));
}
let anchor = self
.anchor_seq
.ok_or_else(|| NetError::encoding("no packets fed"))?;
let mut ordered: Vec<(u16, &[u8])> = (0..group_size)
.filter_map(|i| {
let seq = anchor.wrapping_add(i as u16);
self.sources.get(&seq).map(|p| (seq, p.as_slice()))
})
.collect();
if ordered.len() < group_size {
return Err(NetError::encoding(
"FEC group has gaps — all source packets required for encoding",
));
}
ordered.sort_by_key(|(s, _)| *s);
let max_width = ordered.iter().map(|(_, p)| p.len()).max().unwrap_or(0);
let mut row_repair = Vec::with_capacity(self.config.num_rows);
for row in 0..self.config.num_rows {
let slices: Vec<&[u8]> = (0..self.config.num_cols)
.map(|col| ordered[self.config.from_matrix_pos(row, col)].1)
.collect();
let source_seqs: Vec<u16> = (0..self.config.num_cols)
.map(|col| ordered[self.config.from_matrix_pos(row, col)].0)
.collect();
let payload = compute_xor_repair(&slices, max_width);
row_repair.push(RepairPacket {
index: row,
payload,
source_seqs,
});
}
let mut col_repair = Vec::with_capacity(self.config.num_cols);
for col in 0..self.config.num_cols {
let slices: Vec<&[u8]> = (0..self.config.num_rows)
.map(|row| ordered[self.config.from_matrix_pos(row, col)].1)
.collect();
let source_seqs: Vec<u16> = (0..self.config.num_rows)
.map(|row| ordered[self.config.from_matrix_pos(row, col)].0)
.collect();
let payload = compute_xor_repair(&slices, max_width);
col_repair.push(RepairPacket {
index: col,
payload,
source_seqs,
});
}
Ok(FecGroup {
row_repair,
col_repair,
config: self.config.clone(),
})
}
}
#[derive(Debug)]
pub struct InterleavedFecDecoder {
config: InterleavedFecConfig,
sources: HashMap<u16, Vec<u8>>,
row_repairs: HashMap<usize, RepairPacket>,
col_repairs: HashMap<usize, RepairPacket>,
}
impl InterleavedFecDecoder {
#[must_use]
pub fn new(config: InterleavedFecConfig) -> Self {
Self {
config,
sources: HashMap::new(),
row_repairs: HashMap::new(),
col_repairs: HashMap::new(),
}
}
pub fn feed_source(&mut self, seq: u16, payload: Vec<u8>) {
self.sources.insert(seq, payload);
}
pub fn feed_row_repair(&mut self, rp: RepairPacket) {
self.row_repairs.insert(rp.index, rp);
}
pub fn feed_col_repair(&mut self, cp: RepairPacket) {
self.col_repairs.insert(cp.index, cp);
}
pub fn recover(&mut self) -> NetResult<HashMap<u16, Vec<u8>>> {
let mut recovered: HashMap<u16, Vec<u8>> = HashMap::new();
for _ in 0..self.config.group_size() {
let mut progress = false;
for row in 0..self.config.num_rows {
let col_seqs: Vec<u16> = (0..self.config.num_cols)
.map(|col| {
self.row_repairs
.get(&row)
.and_then(|rp| rp.source_seqs.get(col).copied())
.unwrap_or(u16::MAX)
})
.collect();
let missing: Vec<usize> = col_seqs
.iter()
.enumerate()
.filter(|(_, &seq)| seq != u16::MAX && !self.sources.contains_key(&seq))
.map(|(i, _)| i)
.collect();
if missing.len() == 1 {
if let Some(rp) = self.row_repairs.get(&row) {
let missing_col = missing[0];
let missing_seq = col_seqs[missing_col];
let mut buf = rp.payload.clone();
for (col, &seq) in col_seqs.iter().enumerate() {
if col == missing_col {
continue;
}
if let Some(pkt) = self.sources.get(&seq) {
xor_into(&mut buf, pkt);
}
}
trim_zeros(&mut buf);
self.sources.insert(missing_seq, buf.clone());
recovered.insert(missing_seq, buf);
progress = true;
}
}
}
for col in 0..self.config.num_cols {
let row_seqs: Vec<u16> = (0..self.config.num_rows)
.map(|row| {
self.col_repairs
.get(&col)
.and_then(|cp| cp.source_seqs.get(row).copied())
.unwrap_or(u16::MAX)
})
.collect();
let missing: Vec<usize> = row_seqs
.iter()
.enumerate()
.filter(|(_, &seq)| seq != u16::MAX && !self.sources.contains_key(&seq))
.map(|(i, _)| i)
.collect();
if missing.len() == 1 {
if let Some(cp) = self.col_repairs.get(&col) {
let missing_row = missing[0];
let missing_seq = row_seqs[missing_row];
let mut buf = cp.payload.clone();
for (row, &seq) in row_seqs.iter().enumerate() {
if row == missing_row {
continue;
}
if let Some(pkt) = self.sources.get(&seq) {
xor_into(&mut buf, pkt);
}
}
trim_zeros(&mut buf);
self.sources.insert(missing_seq, buf.clone());
recovered.insert(missing_seq, buf);
progress = true;
}
}
}
if !progress {
break;
}
}
Ok(recovered)
}
}
fn trim_zeros(buf: &mut Vec<u8>) {
while buf.last() == Some(&0) {
buf.pop();
}
}
#[derive(Debug, Clone)]
pub struct RecoveryPlan {
pub recoverable: Vec<u16>,
pub unrecoverable: Vec<u16>,
pub needs_both_dimensions: bool,
}
#[derive(Debug)]
pub struct RecoveryPlanner {
config: InterleavedFecConfig,
}
impl RecoveryPlanner {
#[must_use]
pub fn new(config: InterleavedFecConfig) -> Self {
Self { config }
}
#[must_use]
pub fn plan(&self, anchor_seq: u16, received_seqs: &[u16]) -> RecoveryPlan {
let received_set: std::collections::HashSet<u16> = received_seqs.iter().copied().collect();
let all_seqs: Vec<u16> = (0..self.config.group_size())
.map(|i| anchor_seq.wrapping_add(i as u16))
.collect();
let lost: Vec<u16> = all_seqs
.iter()
.copied()
.filter(|s| !received_set.contains(s))
.collect();
if lost.is_empty() {
return RecoveryPlan {
recoverable: vec![],
unrecoverable: vec![],
needs_both_dimensions: false,
};
}
let mut row_loss_count = vec![0usize; self.config.num_rows];
let mut col_loss_count = vec![0usize; self.config.num_cols];
for &seq in &lost {
let flat = seq.wrapping_sub(anchor_seq) as usize;
if flat < self.config.group_size() {
let (row, col) = self.config.to_matrix_pos(flat);
row_loss_count[row] += 1;
col_loss_count[col] += 1;
}
}
let mut recoverable = Vec::new();
let mut unrecoverable = Vec::new();
for &seq in &lost {
let flat = seq.wrapping_sub(anchor_seq) as usize;
if flat < self.config.group_size() {
let (row, col) = self.config.to_matrix_pos(flat);
if row_loss_count[row] == 1 || col_loss_count[col] == 1 {
recoverable.push(seq);
} else {
unrecoverable.push(seq);
}
}
}
let needs_both_dimensions = recoverable.iter().any(|&seq| {
let flat = seq.wrapping_sub(anchor_seq) as usize;
let (row, col) = self.config.to_matrix_pos(flat);
row_loss_count[row] == 1 && col_loss_count[col] == 1
});
RecoveryPlan {
recoverable,
unrecoverable,
needs_both_dimensions,
}
}
}
impl fmt::Display for InterleavedFecConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FEC({}×{})", self.num_rows, self.num_cols)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_payloads(n: usize, len: usize) -> Vec<Vec<u8>> {
(0..n).map(|i| vec![i as u8; len]).collect()
}
fn encode_group(rows: usize, cols: usize, payloads: &[Vec<u8>]) -> FecGroup {
let cfg = InterleavedFecConfig::new(rows, cols).expect("cfg");
let mut enc = InterleavedFecEncoder::new(cfg);
for (i, p) in payloads.iter().enumerate() {
enc.feed(i as u16, p);
}
enc.finalize().expect("group")
}
#[test]
fn test_config_rejects_zero_rows() {
assert!(InterleavedFecConfig::new(0, 4).is_err());
}
#[test]
fn test_config_rejects_oversized_group() {
assert!(InterleavedFecConfig::new(17, 16).is_err()); }
#[test]
fn test_encoder_repair_packet_count() {
let payloads = make_payloads(8, 16); let group = encode_group(2, 4, &payloads);
assert_eq!(group.row_repair.len(), 2);
assert_eq!(group.col_repair.len(), 4);
}
#[test]
fn test_xor_into_inverse() {
let src = vec![0xABu8; 33]; let mut dst = vec![0u8; 33];
xor_into(&mut dst, &src);
xor_into(&mut dst, &src);
assert!(dst.iter().all(|&b| b == 0));
}
#[test]
fn test_single_row_loss_recovery() {
let payloads = make_payloads(8, 12);
let group = encode_group(2, 4, &payloads);
let cfg = group.config.clone();
let mut dec = InterleavedFecDecoder::new(cfg);
for (i, p) in payloads.iter().enumerate() {
if i != 2 {
dec.feed_source(i as u16, p.clone());
}
}
for rp in group.row_repair {
dec.feed_row_repair(rp);
}
for cp in group.col_repair {
dec.feed_col_repair(cp);
}
let recovered = dec.recover().expect("recover");
assert!(recovered.contains_key(&2), "seq 2 must be recovered");
assert_eq!(recovered[&2], payloads[2]);
}
#[test]
fn test_single_col_loss_recovery() {
let payloads = make_payloads(8, 12);
let group = encode_group(2, 4, &payloads);
let cfg = group.config.clone();
let mut dec = InterleavedFecDecoder::new(cfg);
for (i, p) in payloads.iter().enumerate() {
if i != 4 {
dec.feed_source(i as u16, p.clone());
}
}
for rp in group.row_repair {
dec.feed_row_repair(rp);
}
for cp in group.col_repair {
dec.feed_col_repair(cp);
}
let recovered = dec.recover().expect("recover");
assert_eq!(recovered[&4], payloads[4]);
}
#[test]
fn test_two_loss_recovery_different_rows_cols() {
let payloads = make_payloads(8, 8);
let group = encode_group(2, 4, &payloads);
let cfg = group.config.clone();
let mut dec = InterleavedFecDecoder::new(cfg);
for (i, p) in payloads.iter().enumerate() {
if i != 1 && i != 5 {
dec.feed_source(i as u16, p.clone());
}
}
for rp in group.row_repair {
dec.feed_row_repair(rp);
}
for cp in group.col_repair {
dec.feed_col_repair(cp);
}
let recovered = dec.recover().expect("recover");
assert_eq!(recovered[&1], payloads[1]);
assert_eq!(recovered[&5], payloads[5]);
}
#[test]
fn test_no_losses_empty_recovery() {
let payloads = make_payloads(6, 10);
let group = encode_group(2, 3, &payloads);
let cfg = group.config.clone();
let mut dec = InterleavedFecDecoder::new(cfg);
for (i, p) in payloads.iter().enumerate() {
dec.feed_source(i as u16, p.clone());
}
for rp in group.row_repair {
dec.feed_row_repair(rp);
}
for cp in group.col_repair {
dec.feed_col_repair(cp);
}
let recovered = dec.recover().expect("recover");
assert!(recovered.is_empty());
}
#[test]
fn test_encoder_rejects_incomplete_group() {
let cfg = InterleavedFecConfig::new(2, 4).expect("cfg");
let mut enc = InterleavedFecEncoder::new(cfg);
enc.feed(0, b"only one packet");
assert!(enc.finalize().is_err());
}
#[test]
fn test_matrix_pos_roundtrip() {
let cfg = InterleavedFecConfig::new(3, 5).expect("cfg");
for i in 0..15 {
let (r, c) = cfg.to_matrix_pos(i);
assert_eq!(cfg.from_matrix_pos(r, c), i);
}
}
#[test]
fn test_planner_unrecoverable_column_burst() {
let cfg = InterleavedFecConfig::new(2, 4).expect("cfg");
let planner = RecoveryPlanner::new(cfg);
let received: Vec<u16> = (0u16..8).filter(|&s| s != 1 && s != 5).collect();
let plan = planner.plan(0, &received);
assert!(plan.recoverable.contains(&1));
assert!(plan.recoverable.contains(&5));
assert!(plan.unrecoverable.is_empty());
}
#[test]
fn test_planner_no_losses() {
let cfg = InterleavedFecConfig::new(2, 3).expect("cfg");
let planner = RecoveryPlanner::new(cfg);
let received: Vec<u16> = (0..6).collect();
let plan = planner.plan(0, &received);
assert!(plan.recoverable.is_empty());
assert!(plan.unrecoverable.is_empty());
}
#[test]
fn test_xor_into_mismatched_lengths() {
let mut dst = vec![0xFFu8; 10];
let src = vec![0xFFu8; 5]; xor_into(&mut dst, &src);
assert!(dst[..5].iter().all(|&b| b == 0));
assert!(dst[5..].iter().all(|&b| b == 0xFF));
}
#[test]
fn test_config_display() {
let cfg = InterleavedFecConfig::new(4, 8).expect("cfg");
let s = format!("{cfg}");
assert!(s.contains("4×8") || s.contains("4\u{00D7}8"));
}
#[test]
fn test_large_payload_roundtrip() {
let payloads: Vec<Vec<u8>> = (0u8..6)
.map(|i| {
let mut v = vec![0u8; 1316];
v[0] = i;
v[1315] = i.wrapping_mul(7);
v
})
.collect();
let group = encode_group(2, 3, &payloads);
let cfg = group.config.clone();
let mut dec = InterleavedFecDecoder::new(cfg);
for (i, p) in payloads.iter().enumerate() {
if i != 3 {
dec.feed_source(i as u16, p.clone());
}
}
for rp in group.row_repair {
dec.feed_row_repair(rp);
}
for cp in group.col_repair {
dec.feed_col_repair(cp);
}
let recovered = dec.recover().expect("recover");
assert_eq!(recovered[&3], payloads[3]);
}
}