use reed_solomon_simd::ReedSolomonEncoder;
use crate::CryptoError;
pub fn rs_encoded_size(original_size: usize) -> usize {
let padded_size = if original_size % 2 != 0 {
original_size + 1
} else {
original_size
};
1 + (padded_size * 3)
}
pub fn rs_encode(data: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut data_vec = data.to_vec();
let padding_byte = if data.len() % 2 != 0 {
data_vec.push(0);
1u8
} else {
0u8
};
let shard_bytes = data_vec.len();
let mut encoder = ReedSolomonEncoder::new(1, 2, shard_bytes)?;
encoder.add_original_shard(&data_vec)?;
let result = encoder.encode()?;
let mut output = vec![padding_byte];
output.extend_from_slice(&data_vec);
let recovery_0 = result
.recovery(0)
.ok_or_else(|| CryptoError::Message("Missing recovery shard 0".to_string()))?;
let recovery_1 = result
.recovery(1)
.ok_or_else(|| CryptoError::Message("Missing recovery shard 1".to_string()))?;
output.extend_from_slice(recovery_0);
output.extend_from_slice(recovery_1);
Ok(output)
}
pub fn rs_decode(data: &[u8]) -> Result<Vec<u8>, CryptoError> {
if data.is_empty() {
return Err(CryptoError::Message("Empty data for decoding".to_string()));
}
let padding_byte = data[0];
let remaining = &data[1..];
if remaining.len() % 3 != 0 {
return Err(CryptoError::Message(
"Incorrect encoded bytes length".to_string(),
));
}
let shard_bytes = remaining.len() / 3;
let original_shard = &remaining[0..shard_bytes];
let recovery_shard_0 = &remaining[shard_bytes..2 * shard_bytes];
let recovery_shard_1 = &remaining[2 * shard_bytes..3 * shard_bytes];
let shards = vec![original_shard, recovery_shard_0, recovery_shard_1];
let mut result = vec![];
for i in 0..shard_bytes {
let mut freq = std::collections::HashMap::new();
for shard in &shards {
let byte = shard[i];
*freq.entry(byte).or_insert(0) += 1;
}
let most_frequent = freq
.iter()
.filter(|(_, &count)| count >= 2)
.max_by_key(|(_, &count)| count)
.map(|(&byte, _)| byte)
.unwrap_or(original_shard[i]);
result.push(most_frequent);
}
if padding_byte == 1 && !result.is_empty() {
result.pop();
}
Ok(result)
}
#[allow(dead_code)]
fn pad_pkcs7(data: &[u8], block_size: usize) -> Vec<u8> {
let mut byte_vec = data.to_vec();
let padding_size = block_size - byte_vec.len() % block_size;
let padding_char = padding_size as u8;
let padding: Vec<u8> = vec![padding_char; padding_size];
byte_vec.extend_from_slice(&padding);
byte_vec
}
#[allow(dead_code)]
fn unpad_pkcs7(data: &[u8]) -> Vec<u8> {
let mut byte_vec = data.to_vec();
let padding_size = byte_vec.last().copied().unwrap() as usize;
let final_length = byte_vec.len().saturating_sub(padding_size);
byte_vec.truncate(final_length);
byte_vec
}
#[cfg(test)]
mod tests {
use super::{pad_pkcs7, rs_decode, rs_encode, unpad_pkcs7};
#[test]
fn encode_reconstruct_test() {
let arr_32_orig = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
0, 1, 2,
];
let mut arr_32_enc = rs_encode(&arr_32_orig).unwrap();
println!("encoded_salt_32.len(): {}", &arr_32_enc.len());
println!("encoded_salt_32: {:?}", &arr_32_enc);
arr_32_enc[0] = 0;
arr_32_enc[35] = 0;
arr_32_enc[40] = 0;
arr_32_enc[65] = 0;
arr_32_enc[90] = 0;
let arr_32_dec = rs_decode(&arr_32_enc).unwrap();
println!("{:?}", &arr_32_orig);
println!("{:?}", &arr_32_dec);
assert_eq!(&arr_32_orig.to_vec(), &arr_32_dec);
}
#[test]
fn pkcs_padding_unpadding() {
let arr_12_orig = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2];
let arr_16_orig = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6];
let arr_12_padded = pad_pkcs7(&arr_12_orig, 16);
let arr_16_padded = pad_pkcs7(&arr_16_orig, 16);
let arr_12_unpadded = unpad_pkcs7(&arr_12_padded);
let arr_16_unpadded = unpad_pkcs7(&arr_16_padded);
println!("{:?}", &arr_12_padded);
println!("{:?}", &arr_12_orig);
println!("{:?}", &arr_12_unpadded);
println!();
println!("{:?}", &arr_16_padded);
println!("{:?}", &arr_16_orig);
println!("{:?}", &arr_16_unpadded);
assert_eq!(&arr_12_orig, &arr_12_unpadded.as_slice());
assert_eq!(&arr_16_orig, &arr_16_unpadded.as_slice());
}
}