use std::collections::{BTreeSet, HashMap};
use reed_solomon_erasure::galois_8::{self, add as gf_add, mul as gf_mul};
use crate::coords::get_plane_vector;
use crate::encode::EncodeParams;
use crate::error::ClayError;
use crate::transforms::{
compute_c_from_u_and_cstar, compute_u_from_c_and_ustar, pft_compute_both, prt_compute_both,
GAMMA,
};
pub type DecodeParams = EncodeParams;
pub fn decode(
params: &DecodeParams,
available: &HashMap<usize, Vec<u8>>,
erasures: &[usize],
) -> Result<Vec<u8>, ClayError> {
if available.is_empty() {
return Ok(Vec::new());
}
if erasures.len() > params.m {
return Err(ClayError::TooManyErasures {
max: params.m,
actual: erasures.len(),
});
}
let mut iter = available.iter();
let (_, first_chunk) = iter.next().unwrap();
let chunk_size = first_chunk.len();
if chunk_size == 0 || chunk_size % params.sub_chunk_no != 0 {
return Err(ClayError::InvalidChunkSize {
expected: params.sub_chunk_no,
actual: chunk_size,
});
}
for (&idx, chunk) in iter {
if chunk.len() != chunk_size {
return Err(ClayError::InconsistentChunkSizes {
first_size: chunk_size,
mismatched_idx: idx,
mismatched_size: chunk.len(),
});
}
}
for &idx in available.keys() {
if idx >= params.n {
return Err(ClayError::InvalidParameters(format!(
"Chunk index {} out of range [0, {})",
idx, params.n
)));
}
}
for &e in erasures {
if e >= params.n {
return Err(ClayError::InvalidParameters(format!(
"Erasure index {} out of range [0, {})",
e, params.n
)));
}
}
for &e in erasures {
if available.contains_key(&e) {
return Err(ClayError::InvalidParameters(format!(
"Node {} is both in available chunks and marked as erased",
e
)));
}
}
let expected_available = params.n - erasures.len();
if available.len() != expected_available {
return Err(ClayError::InvalidParameters(format!(
"Expected {} available chunks (n={} - erasures={}), but got {}",
expected_available,
params.n,
erasures.len(),
available.len()
)));
}
for node in 0..params.n {
if !erasures.contains(&node) && !available.contains_key(&node) {
return Err(ClayError::InvalidParameters(format!(
"Node {} is neither erased nor provided in available chunks",
node
)));
}
}
let sub_chunk_size = chunk_size / params.sub_chunk_no;
let total_nodes = params.q * params.t;
let mut chunks: Vec<Vec<u8>> = vec![vec![0u8; chunk_size]; total_nodes];
for (&idx, data) in available.iter() {
let internal_idx = if idx < params.k { idx } else { idx + params.nu };
chunks[internal_idx] = data.clone();
}
let mut erased_set: BTreeSet<usize> = BTreeSet::new();
for &e in erasures {
let internal_idx = if e < params.k { e } else { e + params.nu };
erased_set.insert(internal_idx);
}
decode_layered(params, &erased_set, &mut chunks, sub_chunk_size)?;
let mut result = Vec::with_capacity(params.k * chunk_size);
for i in 0..params.k {
result.extend_from_slice(&chunks[i]);
}
Ok(result)
}
pub fn decode_layered(
params: &DecodeParams,
erased_chunks: &BTreeSet<usize>,
chunks: &mut Vec<Vec<u8>>,
sub_chunk_size: usize,
) -> Result<(), ClayError> {
let total_nodes = params.q * params.t;
let chunk_size = chunks[0].len();
let mut u_buf: Vec<Vec<u8>> = vec![vec![0u8; chunk_size]; total_nodes];
let mut u_computed: Vec<Vec<bool>> = vec![vec![false; params.sub_chunk_no]; total_nodes];
let mut order: Vec<usize> = vec![0; params.sub_chunk_no];
set_planes_sequential_decoding_order(params, &mut order, erased_chunks);
let max_iscore = get_max_iscore(params, erased_chunks);
for iscore in 0..=max_iscore {
for z in 0..params.sub_chunk_no {
if order[z] == iscore {
decode_layered_with_tracking(
params,
erased_chunks,
z,
chunks,
&mut u_buf,
&mut u_computed,
sub_chunk_size,
)?;
}
}
for z in 0..params.sub_chunk_no {
if order[z] == iscore {
let z_vec = get_plane_vector(z, params.t, params.q);
for &node_xy in erased_chunks {
let x = node_xy % params.q;
let y = node_xy / params.q;
let z_y = z_vec[y];
let node_sw = y * params.q + z_y;
let z_sw = get_companion_layer(params, z, x, y, z_y);
if z_y != x {
if !erased_chunks.contains(&node_sw) {
recover_type1_erasure(
params,
chunks,
&u_buf,
x,
y,
z,
z_y,
z_sw,
sub_chunk_size,
);
} else if z_y < x {
get_coupled_from_uncoupled(
params, chunks, &u_buf, x, y, z, z_y, z_sw, sub_chunk_size,
);
}
} else {
let offset = z * sub_chunk_size;
chunks[node_xy][offset..offset + sub_chunk_size]
.copy_from_slice(&u_buf[node_xy][offset..offset + sub_chunk_size]);
}
}
}
}
}
Ok(())
}
fn decode_layered_with_tracking(
params: &DecodeParams,
erased_chunks: &BTreeSet<usize>,
z: usize,
chunks: &[Vec<u8>],
u_buf: &mut [Vec<u8>],
u_computed: &mut [Vec<bool>],
sub_chunk_size: usize,
) -> Result<(), ClayError> {
let z_vec = get_plane_vector(z, params.t, params.q);
let mut needs_mds: BTreeSet<usize> = erased_chunks.clone();
for x in 0..params.q {
for y in 0..params.t {
let node_xy = params.q * y + x;
let z_y = z_vec[y];
let node_sw = params.q * y + z_y;
let z_sw = get_companion_layer(params, z, x, y, z_y);
if !erased_chunks.contains(&node_xy) {
if z_y == x {
let offset = z * sub_chunk_size;
u_buf[node_xy][offset..offset + sub_chunk_size]
.copy_from_slice(&chunks[node_xy][offset..offset + sub_chunk_size]);
u_computed[node_xy][z] = true;
} else if !erased_chunks.contains(&node_sw) {
if z_y < x {
get_uncoupled_from_coupled(
params, chunks, u_buf, x, y, z, z_y, z_sw, sub_chunk_size,
);
u_computed[node_xy][z] = true;
u_computed[node_sw][z_sw] = true;
}
} else {
if u_computed[node_sw][z_sw] {
let offset_z = z * sub_chunk_size;
let offset_zsw = z_sw * sub_chunk_size;
let c_xy = &chunks[node_xy][offset_z..offset_z + sub_chunk_size];
let u_sw = &u_buf[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
let u_xy = compute_u_from_c_and_ustar(c_xy, u_sw);
u_buf[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&u_xy);
u_computed[node_xy][z] = true;
} else {
needs_mds.insert(node_xy);
}
}
}
}
}
decode_uncoupled_layer(params, &needs_mds, z, sub_chunk_size, u_buf)?;
for &node in &needs_mds {
u_computed[node][z] = true;
}
Ok(())
}
pub fn decode_uncoupled_layer(
params: &DecodeParams,
erased_chunks: &BTreeSet<usize>,
z: usize,
sub_chunk_size: usize,
u_buf: &mut [Vec<u8>],
) -> Result<(), ClayError> {
let total_nodes = params.q * params.t;
let offset = z * sub_chunk_size;
let parity_start = params.original_count;
if erased_chunks.len() > params.m {
return Err(ClayError::TooManyErasures {
max: params.m,
actual: erased_chunks.len(),
});
}
if erased_chunks.is_empty() {
return Ok(());
}
let has_erased_originals = erased_chunks.iter().any(|&i| i < parity_start);
let has_erased_parities = erased_chunks.iter().any(|&i| i >= parity_start);
let rs = reed_solomon_erasure::ReedSolomon::<galois_8::Field>::new(
params.original_count,
params.recovery_count,
)
.map_err(|e| ClayError::ReconstructionFailed(format!("Layer {} RS init failed: {:?}", z, e)))?;
if has_erased_originals {
let mut shards: Vec<Option<Vec<u8>>> = Vec::with_capacity(total_nodes);
for i in 0..total_nodes {
if erased_chunks.contains(&i) {
shards.push(None);
} else {
shards.push(Some(u_buf[i][offset..offset + sub_chunk_size].to_vec()));
}
}
rs.reconstruct(&mut shards).map_err(|e| {
ClayError::ReconstructionFailed(format!("Layer {} RS reconstruct failed: {:?}", z, e))
})?;
for i in 0..total_nodes {
if erased_chunks.contains(&i) {
if let Some(ref data) = shards[i] {
u_buf[i][offset..offset + sub_chunk_size].copy_from_slice(data);
}
}
}
} else if has_erased_parities {
let mut shards: Vec<Vec<u8>> = Vec::with_capacity(total_nodes);
for i in 0..total_nodes {
shards.push(u_buf[i][offset..offset + sub_chunk_size].to_vec());
}
rs.encode(&mut shards).map_err(|e| {
ClayError::ReconstructionFailed(format!("Layer {} RS encode failed: {:?}", z, e))
})?;
for i in parity_start..total_nodes {
if erased_chunks.contains(&i) {
u_buf[i][offset..offset + sub_chunk_size].copy_from_slice(&shards[i]);
}
}
}
Ok(())
}
pub fn get_companion_layer(params: &DecodeParams, z: usize, x: usize, y: usize, z_y: usize) -> usize {
debug_assert!(y < params.t, "y={} must be < t={}", y, params.t);
debug_assert!(x < params.q, "x={} must be < q={}", x, params.q);
debug_assert!(z_y < params.q, "z_y={} must be < q={}", z_y, params.q);
debug_assert!(
z < params.sub_chunk_no,
"z={} must be < α={}",
z,
params.sub_chunk_no
);
let alpha = params.sub_chunk_no as isize;
let multiplier = params.q.pow((params.t - 1 - y) as u32) as isize;
let diff = x as isize - z_y as isize;
let z_sw = ((z as isize) + diff * multiplier).rem_euclid(alpha) as usize;
debug_assert!(
z_sw < params.sub_chunk_no,
"z_sw out of bounds: {} >= {}",
z_sw,
params.sub_chunk_no
);
z_sw
}
fn get_uncoupled_from_coupled(
params: &DecodeParams,
chunks: &[Vec<u8>],
u_buf: &mut [Vec<u8>],
x: usize,
y: usize,
z: usize,
z_y: usize,
z_sw: usize,
sub_chunk_size: usize,
) {
let node_xy = y * params.q + x;
let node_sw = y * params.q + z_y;
let offset_z = z * sub_chunk_size;
let offset_zsw = z_sw * sub_chunk_size;
let c_xy = &chunks[node_xy][offset_z..offset_z + sub_chunk_size];
let c_sw = &chunks[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
let (u_xy, u_sw) = if x < z_y {
prt_compute_both(c_xy, c_sw)
} else {
let (u_sw, u_xy) = prt_compute_both(c_sw, c_xy);
(u_xy, u_sw)
};
u_buf[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&u_xy);
u_buf[node_sw][offset_zsw..offset_zsw + sub_chunk_size].copy_from_slice(&u_sw);
}
fn recover_type1_erasure(
params: &DecodeParams,
chunks: &mut [Vec<u8>],
u_buf: &[Vec<u8>],
x: usize,
y: usize,
z: usize,
z_y: usize,
z_sw: usize,
sub_chunk_size: usize,
) {
let node_xy = y * params.q + x;
let node_sw = y * params.q + z_y;
let offset_z = z * sub_chunk_size;
let offset_zsw = z_sw * sub_chunk_size;
let c_sw = &chunks[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
let u_xy = &u_buf[node_xy][offset_z..offset_z + sub_chunk_size];
let c_xy = compute_c_from_u_and_cstar(u_xy, c_sw);
chunks[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&c_xy);
}
fn get_coupled_from_uncoupled(
params: &DecodeParams,
chunks: &mut [Vec<u8>],
u_buf: &[Vec<u8>],
x: usize,
y: usize,
z: usize,
z_y: usize,
z_sw: usize,
sub_chunk_size: usize,
) {
let node_xy = y * params.q + x;
let node_sw = y * params.q + z_y;
let offset_z = z * sub_chunk_size;
let offset_zsw = z_sw * sub_chunk_size;
let u_xy = &u_buf[node_xy][offset_z..offset_z + sub_chunk_size];
let u_sw = &u_buf[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
let (c_xy, c_sw) = if x < z_y {
pft_compute_both(u_xy, u_sw)
} else {
let (c_sw, c_xy) = pft_compute_both(u_sw, u_xy);
(c_xy, c_sw)
};
chunks[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&c_xy);
chunks[node_sw][offset_zsw..offset_zsw + sub_chunk_size].copy_from_slice(&c_sw);
}
fn set_planes_sequential_decoding_order(
params: &DecodeParams,
order: &mut [usize],
erasures: &BTreeSet<usize>,
) {
for z in 0..params.sub_chunk_no {
let z_vec = get_plane_vector(z, params.t, params.q);
order[z] = 0;
for &i in erasures {
if i % params.q == z_vec[i / params.q] {
order[z] += 1;
}
}
}
}
fn get_max_iscore(params: &DecodeParams, erased_chunks: &BTreeSet<usize>) -> usize {
let mut weight_vec = vec![false; params.t];
let mut iscore = 0;
for &i in erased_chunks {
let y = i / params.q;
if !weight_vec[y] {
weight_vec[y] = true;
iscore += 1;
}
}
iscore
}
pub fn compute_cstar_from_c_and_u(c_helper: &[u8], u_helper: &[u8]) -> Vec<u8> {
let len = c_helper.len();
let mut companion_c = vec![0u8; len];
let gamma_inv = crate::transforms::gf_inv(GAMMA);
for i in 0..len {
companion_c[i] = gf_mul(gf_add(u_helper[i], c_helper[i]), gamma_inv);
}
companion_c
}
#[cfg(test)]
mod tests {
use super::*;
fn test_params() -> DecodeParams {
DecodeParams {
k: 4,
m: 2,
n: 6,
q: 2,
t: 3,
nu: 0,
sub_chunk_no: 8,
original_count: 4,
recovery_count: 2,
}
}
#[test]
fn test_companion_layer_valid_range() {
let params = test_params();
for z in 0..params.sub_chunk_no {
let z_vec = get_plane_vector(z, params.t, params.q);
for y in 0..params.t {
for x in 0..params.q {
let z_sw = get_companion_layer(¶ms, z, x, y, z_vec[y]);
assert!(
z_sw < params.sub_chunk_no,
"z_sw {} out of range for z={}, x={}, y={}",
z_sw,
z,
x,
y
);
}
}
}
}
#[test]
fn test_get_max_iscore() {
let params = test_params();
let empty: BTreeSet<usize> = BTreeSet::new();
assert_eq!(get_max_iscore(¶ms, &empty), 0);
let mut one: BTreeSet<usize> = BTreeSet::new();
one.insert(0);
assert_eq!(get_max_iscore(¶ms, &one), 1);
let mut two_same: BTreeSet<usize> = BTreeSet::new();
two_same.insert(0);
two_same.insert(1);
assert_eq!(get_max_iscore(¶ms, &two_same), 1);
let mut two_diff: BTreeSet<usize> = BTreeSet::new();
two_diff.insert(0);
two_diff.insert(2);
assert_eq!(get_max_iscore(¶ms, &two_diff), 2);
}
}