kontor-crypto-core 0.2.0

Shared cryptographic primitives for Kontor PoR (prepare_file, encode, Merkle)
Documentation
//! Reed-Solomon erasure coding for Kontor PoR (31-byte symbols, multi-codeword).

use crate::config;
use crate::error::{CoreError, Result};
use reed_solomon_erasure::galois_8::ReedSolomon;

/// Encodes a file into 31-byte symbols using multi-codeword Reed-Solomon.
pub fn encode_file_symbols(data: &[u8]) -> Result<Vec<Vec<u8>>> {
    if data.is_empty() {
        return Err(CoreError::EmptyData {
            operation: "encode_file_symbols".to_string(),
        });
    }

    let mut chunks = Vec::new();
    for chunk_start in (0..data.len()).step_by(config::CHUNK_SIZE_BYTES) {
        let chunk_end = std::cmp::min(chunk_start + config::CHUNK_SIZE_BYTES, data.len());
        let mut chunk = data[chunk_start..chunk_end].to_vec();
        chunk.resize(config::CHUNK_SIZE_BYTES, 0);
        chunks.push(chunk);
    }

    let rs = ReedSolomon::new(
        config::DATA_SYMBOLS_PER_CODEWORD,
        config::PARITY_SYMBOLS_PER_CODEWORD,
    )
    .map_err(|e| CoreError::ErasureCoding {
        details: format!("Reed-Solomon setup failed: {e}"),
    })?;

    let mut all_symbols = Vec::new();
    for codeword_chunks in chunks.chunks(config::DATA_SYMBOLS_PER_CODEWORD) {
        let mut codeword = codeword_chunks.to_vec();
        for _ in 0..config::PARITY_SYMBOLS_PER_CODEWORD {
            codeword.push(vec![0; config::CHUNK_SIZE_BYTES]);
        }
        while codeword.len() < config::TOTAL_SYMBOLS_PER_CODEWORD {
            codeword.push(vec![0; config::CHUNK_SIZE_BYTES]);
        }
        rs.encode(&mut codeword)
            .map_err(|e| CoreError::ErasureCoding {
                details: format!("RS encode failed: {e}"),
            })?;
        all_symbols.extend(codeword);
    }

    Ok(all_symbols)
}

/// Reconstructs original file from erasure-coded symbols.
pub fn decode_file_symbols(
    symbols: &mut [Option<Vec<u8>>],
    num_codewords: usize,
    original_size: usize,
) -> Result<Vec<u8>> {
    if num_codewords == 0 {
        return Err(CoreError::InvalidInput(
            "decode_file_symbols: num_codewords must be > 0".to_string(),
        ));
    }

    let expected_symbols = num_codewords
        .checked_mul(config::TOTAL_SYMBOLS_PER_CODEWORD)
        .ok_or_else(|| {
            CoreError::InvalidInput("decode_file_symbols: symbol count overflow".to_string())
        })?;
    if symbols.len() != expected_symbols {
        return Err(CoreError::InvalidInput(format!(
            "decode_file_symbols: symbol vector length {} does not match expected {}",
            symbols.len(),
            expected_symbols
        )));
    }

    for (i, sym) in symbols.iter().enumerate() {
        if let Some(bytes) = sym {
            if bytes.len() != config::CHUNK_SIZE_BYTES {
                return Err(CoreError::InvalidInput(format!(
                    "decode_file_symbols: symbol {} has invalid length {} (expected {})",
                    i,
                    bytes.len(),
                    config::CHUNK_SIZE_BYTES
                )));
            }
        }
    }

    let encoded_capacity = expected_symbols
        .checked_mul(config::CHUNK_SIZE_BYTES)
        .ok_or_else(|| {
            CoreError::InvalidInput(
                "decode_file_symbols: encoded byte capacity overflow".to_string(),
            )
        })?;
    if original_size > encoded_capacity {
        return Err(CoreError::InvalidInput(format!(
            "decode_file_symbols: original_size {} exceeds encoded capacity {}",
            original_size, encoded_capacity
        )));
    }

    let rs = ReedSolomon::new(
        config::DATA_SYMBOLS_PER_CODEWORD,
        config::PARITY_SYMBOLS_PER_CODEWORD,
    )
    .map_err(|e| CoreError::ErasureCoding {
        details: format!("Reed-Solomon setup failed: {e}"),
    })?;

    let mut reconstructed = Vec::with_capacity(original_size);
    for cw_idx in 0..num_codewords {
        let start = cw_idx * config::TOTAL_SYMBOLS_PER_CODEWORD;
        let end = start + config::TOTAL_SYMBOLS_PER_CODEWORD;
        let mut codeword_symbols = symbols[start..end].to_vec();
        rs.reconstruct(&mut codeword_symbols)
            .map_err(|e| CoreError::ErasureCoding {
                details: format!("RS decode failed for codeword {}: {e}", cw_idx),
            })?;
        for sym in codeword_symbols
            .iter()
            .take(config::DATA_SYMBOLS_PER_CODEWORD)
            .flatten()
        {
            reconstructed.extend_from_slice(sym);
        }
    }
    reconstructed.truncate(original_size);
    Ok(reconstructed)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_encode_decode_roundtrip_small() {
        let data = b"Hello, world! This is a test.";
        let symbols = encode_file_symbols(data).unwrap();
        assert_eq!(symbols.len(), 255);
        let mut full_symbols: Vec<Option<Vec<u8>>> = symbols.into_iter().map(Some).collect();
        let reconstructed = decode_file_symbols(&mut full_symbols, 1, data.len()).unwrap();
        assert_eq!(reconstructed, data);
    }

    #[test]
    fn test_encode_decode_with_missing_symbols() {
        let data = b"Test data for reconstruction.";
        let symbols = encode_file_symbols(data).unwrap();
        let mut damaged: Vec<Option<Vec<u8>>> = symbols.into_iter().map(Some).collect();
        for item in damaged.iter_mut().take(20) {
            *item = None;
        }
        let reconstructed = decode_file_symbols(&mut damaged, 1, data.len()).unwrap();
        assert_eq!(reconstructed, data);
    }

    #[test]
    fn test_empty_data() {
        assert!(encode_file_symbols(b"").is_err());
    }

    #[test]
    fn test_multi_codeword_file() {
        let data = vec![42u8; 15000];
        let symbols = encode_file_symbols(&data).unwrap();
        assert_eq!(symbols.len(), 765);
        let mut damaged: Vec<Option<Vec<u8>>> = symbols.into_iter().map(Some).collect();
        damaged[0] = None;
        damaged[255] = None;
        damaged[510] = None;
        let reconstructed = decode_file_symbols(&mut damaged, 3, data.len()).unwrap();
        assert_eq!(reconstructed, data);
    }

    #[test]
    fn test_too_many_missing_symbols() {
        let data = b"Test";
        let symbols = encode_file_symbols(data).unwrap();
        let mut damaged: Vec<Option<Vec<u8>>> = symbols.into_iter().map(Some).collect();
        for item in damaged.iter_mut().take(25) {
            *item = None;
        }
        assert!(decode_file_symbols(&mut damaged, 1, data.len()).is_err());
    }

    #[test]
    fn test_single_byte() {
        let data = b"A";
        let symbols = encode_file_symbols(data).unwrap();
        assert_eq!(symbols.len(), 255);
        let mut full: Vec<Option<Vec<u8>>> = symbols.into_iter().map(Some).collect();
        let reconstructed = decode_file_symbols(&mut full, 1, data.len()).unwrap();
        assert_eq!(reconstructed, data);
    }
}