use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::fmt;
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
use rayon::prelude::*;
use tracing::{debug, info};
use crate::gf;
use crate::gf_simd;
use crate::matrix::{GfMatrix, par2_input_constants};
use crate::recovery::{RecoveryBlock, load_recovery_blocks};
use crate::types::{Par2FileSet, VerifyResult};
use crate::verify;
#[derive(Debug)]
pub struct RepairResult {
pub success: bool,
pub blocks_repaired: u32,
pub files_repaired: usize,
pub message: String,
}
impl fmt::Display for RepairResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.success {
write!(
f,
"Repair complete: {} blocks repaired across {} files",
self.blocks_repaired, self.files_repaired
)
} else {
write!(f, "Repair failed: {}", self.message)
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RepairError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Insufficient recovery data: need {needed} blocks, have {available}")]
InsufficientRecovery { needed: u32, available: u32 },
#[error("Decode matrix is singular — cannot repair with these recovery blocks")]
SingularMatrix,
#[error("No damage detected — nothing to repair")]
NoDamage,
#[error("Verification after repair failed: {0}")]
VerifyFailed(String),
}
pub fn repair(file_set: &Par2FileSet, dir: &Path) -> Result<RepairResult, RepairError> {
let verify_result = verify::verify(file_set, dir);
repair_from_verify_inner(file_set, dir, &verify_result, true)
}
pub fn repair_from_verify(
file_set: &Par2FileSet,
dir: &Path,
verify_result: &VerifyResult,
) -> Result<RepairResult, RepairError> {
repair_from_verify_inner(file_set, dir, verify_result, true)
}
pub fn repair_from_verify_no_reverify(
file_set: &Par2FileSet,
dir: &Path,
verify_result: &VerifyResult,
) -> Result<RepairResult, RepairError> {
repair_from_verify_inner(file_set, dir, verify_result, false)
}
fn repair_from_verify_inner(
file_set: &Par2FileSet,
dir: &Path,
verify_result: &VerifyResult,
re_verify: bool,
) -> Result<RepairResult, RepairError> {
if verify_result.all_correct() {
return Err(RepairError::NoDamage);
}
let blocks_needed = verify_result.blocks_needed();
info!(
blocks_needed,
damaged = verify_result.damaged.len(),
missing = verify_result.missing.len(),
"Repair: damage detected"
);
let recovery_blocks = load_recovery_blocks(dir, &file_set.recovery_set_id, file_set.slice_size);
if (recovery_blocks.len() as u32) < blocks_needed {
return Err(RepairError::InsufficientRecovery {
needed: blocks_needed,
available: recovery_blocks.len() as u32,
});
}
let block_map = build_block_map(file_set);
let total_input_blocks = block_map.total_blocks as usize;
let damaged_indices = find_damaged_block_indices(verify_result, &block_map);
let num_damaged = damaged_indices.len();
info!(
damaged_block_count = num_damaged,
total_input_blocks, "Mapped damaged blocks to global indices"
);
let recovery_to_use: Vec<&RecoveryBlock> = recovery_blocks.iter().take(num_damaged).collect();
let recovery_exponents: Vec<u32> = recovery_to_use.iter().map(|b| b.exponent).collect();
let constants = par2_input_constants(total_input_blocks);
let mut vandermonde = GfMatrix::zeros(num_damaged, num_damaged);
for (e, &exp) in recovery_exponents.iter().enumerate() {
for (j, &dmg_idx) in damaged_indices.iter().enumerate() {
vandermonde.set(e, j, gf::pow(constants[dmg_idx], exp));
}
}
let inverse = vandermonde.invert().ok_or(RepairError::SingularMatrix)?;
info!(
"D×D decode matrix inverted ({}×{})",
num_damaged, num_damaged
);
let slice_size = file_set.slice_size as usize;
let damaged_set: std::collections::HashSet<usize> = damaged_indices.iter().copied().collect();
let mut adjusted: Vec<Vec<u8>> = recovery_to_use.iter().map(|rb| rb.data.clone()).collect();
let intact_indices: Vec<usize> = (0..total_input_blocks)
.filter(|i| !damaged_set.contains(i))
.collect();
const BATCH_SIZE: usize = 24;
let mut file_handles: HashMap<String, std::fs::File> = HashMap::new();
for batch in intact_indices.chunks(BATCH_SIZE) {
let batch_data: Vec<Vec<u8>> = batch
.iter()
.map(|&idx| read_source_block(dir, &block_map, idx, slice_size, &mut file_handles))
.collect::<std::io::Result<Vec<_>>>()?;
let batch_refs: Vec<&[u8]> = batch_data.iter().map(|v| v.as_slice()).collect();
adjusted
.par_iter_mut()
.enumerate()
.for_each(|(e, adj_buf)| {
let coeffs: Vec<u16> = batch
.iter()
.map(|&src_idx| gf::pow(constants[src_idx], recovery_exponents[e]))
.collect();
gf_simd::mul_add_multi(adj_buf, &batch_refs, &coeffs);
});
}
info!("Intact-block contributions subtracted from recovery data");
let adj_refs: Vec<&[u8]> = adjusted.iter().map(|v| v.as_slice()).collect();
let mut outputs: Vec<Vec<u8>> = (0..num_damaged).map(|_| vec![0u8; slice_size]).collect();
outputs.par_iter_mut().enumerate().for_each(|(j, dst)| {
let coeffs: Vec<u16> = (0..num_damaged).map(|e| inverse.get(j, e)).collect();
gf_simd::mul_add_multi(dst, &adj_refs, &coeffs);
});
info!("Repaired blocks reconstructed via D×D inverse");
let repaired_blocks: Vec<(usize, Vec<u8>)> =
damaged_indices.iter().copied().zip(outputs).collect();
let mut files_touched = std::collections::HashSet::new();
for (global_idx, data) in &repaired_blocks {
let (filename, file_offset, write_len) = block_map.global_to_file(*global_idx, slice_size);
let file_path = dir.join(&filename);
debug!(
filename,
global_block = global_idx,
offset = file_offset,
len = write_len,
"Writing repaired block"
);
let mut f = std::fs::OpenOptions::new()
.create(true)
.truncate(false)
.write(true)
.open(&file_path)?;
let expected_size = block_map
.files
.iter()
.find(|bf| bf.filename == filename)
.map(|bf| bf.file_size)
.unwrap_or(0);
let current_size = f.metadata()?.len();
if current_size < expected_size {
f.set_len(expected_size)?;
}
f.seek(SeekFrom::Start(file_offset as u64))?;
f.write_all(&data[..write_len])?;
files_touched.insert(filename.clone());
}
if re_verify {
let verification = verify::verify(file_set, dir);
if verification.all_correct() {
info!(
blocks = repaired_blocks.len(),
files = files_touched.len(),
"Repair successful — all files verified"
);
Ok(RepairResult {
success: true,
blocks_repaired: repaired_blocks.len() as u32,
files_repaired: files_touched.len(),
message: "All files repaired and verified".to_string(),
})
} else {
Err(RepairError::VerifyFailed(format!("{verification}")))
}
} else {
info!(
blocks = repaired_blocks.len(),
files = files_touched.len(),
"Repair complete (re-verify skipped)"
);
Ok(RepairResult {
success: true,
blocks_repaired: repaired_blocks.len() as u32,
files_repaired: files_touched.len(),
message: "All files repaired (re-verify skipped)".to_string(),
})
}
}
struct BlockMap {
files: Vec<BlockFile>,
total_blocks: u32,
}
struct BlockFile {
filename: String,
file_size: u64,
block_count: u32,
start_block: u32,
}
fn build_block_map(file_set: &Par2FileSet) -> BlockMap {
let slice_size = file_set.slice_size;
let mut files = Vec::new();
let mut block_offset = 0u32;
let mut sorted_files: Vec<_> = file_set.files.values().collect();
sorted_files.sort_by_key(|f| f.file_id);
for f in sorted_files {
let block_count = if slice_size == 0 {
0
} else {
f.size.div_ceil(slice_size) as u32
};
files.push(BlockFile {
filename: f.filename.clone(),
file_size: f.size,
block_count,
start_block: block_offset,
});
block_offset += block_count;
}
BlockMap {
files,
total_blocks: block_offset,
}
}
impl BlockMap {
fn global_to_file(&self, global_idx: usize, slice_size: usize) -> (String, usize, usize) {
let global = global_idx as u32;
for f in &self.files {
if global >= f.start_block && global < f.start_block + f.block_count {
let local_block = (global - f.start_block) as usize;
let file_offset = local_block * slice_size;
let remaining = f.file_size as usize - file_offset;
let write_len = remaining.min(slice_size);
return (f.filename.clone(), file_offset, write_len);
}
}
panic!("Global block index {global_idx} out of range");
}
}
fn find_damaged_block_indices(verify_result: &VerifyResult, block_map: &BlockMap) -> Vec<usize> {
let mut indices = Vec::new();
for damaged in &verify_result.damaged {
if let Some(bf) = block_map
.files
.iter()
.find(|f| f.filename == damaged.filename)
{
if damaged.damaged_block_indices.is_empty() {
for i in 0..bf.block_count {
indices.push((bf.start_block + i) as usize);
}
} else {
for &local_idx in &damaged.damaged_block_indices {
indices.push((bf.start_block + local_idx) as usize);
}
}
}
}
for missing in &verify_result.missing {
if let Some(bf) = block_map
.files
.iter()
.find(|f| f.filename == missing.filename)
{
for i in 0..bf.block_count {
indices.push((bf.start_block + i) as usize);
}
}
}
indices.sort();
indices.dedup();
indices
}
fn read_source_block(
dir: &Path,
block_map: &BlockMap,
global_idx: usize,
slice_size: usize,
file_handles: &mut HashMap<String, std::fs::File>,
) -> std::io::Result<Vec<u8>> {
let (filename, file_offset, _) = block_map.global_to_file(global_idx, slice_size);
let handle = match file_handles.entry(filename.clone()) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => {
let path = dir.join(&filename);
e.insert(std::fs::File::open(&path)?)
}
};
handle.seek(SeekFrom::Start(file_offset as u64))?;
let mut buf = vec![0u8; slice_size]; let mut total = 0;
while total < slice_size {
match handle.read(&mut buf[total..]) {
Ok(0) => break,
Ok(n) => total += n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rs_roundtrip_simple() {
let input0: Vec<u8> = vec![0x01, 0x00, 0x02, 0x00]; let input1: Vec<u8> = vec![0x03, 0x00, 0x04, 0x00];
let input_count = 2;
let recovery_exponents = vec![0u32, 1u32];
let enc = GfMatrix::par2_encoding_matrix(input_count, &recovery_exponents);
let slice_size = 4;
let u16_per_slice = slice_size / 2;
let inputs = [&input0, &input1];
let mut recovery0 = vec![0u8; slice_size];
let mut recovery1 = vec![0u8; slice_size];
for pos in 0..u16_per_slice {
let off = pos * 2;
let mut r0: u16 = 0;
let mut r1: u16 = 0;
for (i, inp) in inputs.iter().enumerate() {
let val = u16::from_le_bytes([inp[off], inp[off + 1]]);
r0 = gf::add(r0, gf::mul(enc.get(2, i), val));
r1 = gf::add(r1, gf::mul(enc.get(3, i), val));
}
recovery0[off] = r0 as u8;
recovery0[off + 1] = (r0 >> 8) as u8;
recovery1[off] = r1 as u8;
recovery1[off + 1] = (r1 >> 8) as u8;
}
let constants = par2_input_constants(input_count);
let damaged_indices = [0usize, 1usize];
let num_damaged = damaged_indices.len();
let mut vandermonde = GfMatrix::zeros(num_damaged, num_damaged);
for (e, &exp) in recovery_exponents.iter().enumerate() {
for (j, &dmg_idx) in damaged_indices.iter().enumerate() {
vandermonde.set(e, j, gf::pow(constants[dmg_idx], exp));
}
}
let inv = vandermonde.invert().expect("Should be invertible");
let adjusted = [&recovery0[..], &recovery1[..]];
let mut result0 = vec![0u8; slice_size];
let mut result1 = vec![0u8; slice_size];
for pos in 0..u16_per_slice {
let off = pos * 2;
let mut out0: u16 = 0;
let mut out1: u16 = 0;
for (e, adj) in adjusted.iter().enumerate() {
let val = u16::from_le_bytes([adj[off], adj[off + 1]]);
out0 = gf::add(out0, gf::mul(inv.get(0, e), val));
out1 = gf::add(out1, gf::mul(inv.get(1, e), val));
}
result0[off] = out0 as u8;
result0[off + 1] = (out0 >> 8) as u8;
result1[off] = out1 as u8;
result1[off + 1] = (out1 >> 8) as u8;
}
assert_eq!(result0, input0, "Recovered block 0 should match original");
assert_eq!(result1, input1, "Recovered block 1 should match original");
}
#[test]
fn test_rs_roundtrip_partial_damage() {
let slice_size = 4;
let input_count = 4;
let recovery_exponents = vec![0u32, 1u32];
let inputs: Vec<Vec<u8>> = vec![
vec![0x01, 0x00, 0x02, 0x00],
vec![0x03, 0x00, 0x04, 0x00],
vec![0x05, 0x00, 0x06, 0x00],
vec![0x07, 0x00, 0x08, 0x00],
];
let enc = GfMatrix::par2_encoding_matrix(input_count, &recovery_exponents);
let mut recovery = vec![vec![0u8; slice_size]; 2];
for pos in 0..(slice_size / 2) {
let off = pos * 2;
for (e, rec) in recovery.iter_mut().enumerate() {
let mut val: u16 = 0;
for (i, inp) in inputs.iter().enumerate() {
let d = u16::from_le_bytes([inp[off], inp[off + 1]]);
val = gf::add(val, gf::mul(enc.get(input_count + e, i), d));
}
rec[off] = val as u8;
rec[off + 1] = (val >> 8) as u8;
}
}
let damaged_indices = [1usize, 3usize];
let intact_indices: Vec<usize> = (0..input_count)
.filter(|i| !damaged_indices.contains(i))
.collect();
let num_damaged = damaged_indices.len();
let constants = par2_input_constants(input_count);
let mut vandermonde = GfMatrix::zeros(num_damaged, num_damaged);
for (e, &exp) in recovery_exponents.iter().enumerate() {
for (j, &dmg_idx) in damaged_indices.iter().enumerate() {
vandermonde.set(e, j, gf::pow(constants[dmg_idx], exp));
}
}
let inv = vandermonde.invert().expect("Should be invertible");
let mut adjusted = recovery.clone();
for &intact_idx in &intact_indices {
let c_i = constants[intact_idx];
for (e, adj) in adjusted.iter_mut().enumerate() {
let coeff = gf::pow(c_i, recovery_exponents[e]);
gf_simd::mul_add_buffer(adj, &inputs[intact_idx], coeff);
}
}
let adj_refs: Vec<&[u8]> = adjusted.iter().map(|v| v.as_slice()).collect();
let mut outputs: Vec<Vec<u8>> = (0..num_damaged).map(|_| vec![0u8; slice_size]).collect();
for (j, dst) in outputs.iter_mut().enumerate() {
let coeffs: Vec<u16> = (0..num_damaged).map(|e| inv.get(j, e)).collect();
gf_simd::mul_add_multi(dst, &adj_refs, &coeffs);
}
assert_eq!(outputs[0], inputs[1], "Recovered block 1 should match");
assert_eq!(outputs[1], inputs[3], "Recovered block 3 should match");
}
}