use crate::{
bls12_381_const::{
FP_LENGTH, FP_PAD_BY, G1_LENGTH, PADDED_FP_LENGTH, PADDED_G1_LENGTH, PADDED_G2_LENGTH,
},
PrecompileHalt,
};
pub(super) fn remove_fp_padding(input: &[u8]) -> Result<&[u8; FP_LENGTH], PrecompileHalt> {
if input.len() != PADDED_FP_LENGTH {
return Err(PrecompileHalt::Bls12381FpPaddingLength);
}
let (padding, unpadded) = input.split_at(FP_PAD_BY);
if !padding.iter().all(|&x| x == 0) {
return Err(PrecompileHalt::Bls12381FpPaddingInvalid);
}
Ok(unpadded.try_into().unwrap())
}
pub(super) fn remove_g1_padding(input: &[u8]) -> Result<[&[u8; FP_LENGTH]; 2], PrecompileHalt> {
if input.len() != PADDED_G1_LENGTH {
return Err(PrecompileHalt::Bls12381G1PaddingLength);
}
let x = remove_fp_padding(&input[..PADDED_FP_LENGTH])?;
let y = remove_fp_padding(&input[PADDED_FP_LENGTH..PADDED_G1_LENGTH])?;
Ok([x, y])
}
pub(super) fn remove_g2_padding(input: &[u8]) -> Result<[&[u8; FP_LENGTH]; 4], PrecompileHalt> {
if input.len() != PADDED_G2_LENGTH {
return Err(PrecompileHalt::Bls12381G2PaddingLength);
}
let mut input_fps = [&[0; FP_LENGTH]; 4];
for i in 0..4 {
input_fps[i] = remove_fp_padding(&input[i * PADDED_FP_LENGTH..(i + 1) * PADDED_FP_LENGTH])?;
}
Ok(input_fps)
}
pub(super) fn pad_g1_point(unpadded: &[u8]) -> [u8; PADDED_G1_LENGTH] {
assert_eq!(
unpadded.len(),
G1_LENGTH,
"Invalid unpadded G1 point length"
);
let mut padded = [0u8; PADDED_G1_LENGTH];
for i in 0..2 {
padded[i * PADDED_FP_LENGTH + FP_PAD_BY..(i + 1) * PADDED_FP_LENGTH]
.copy_from_slice(&unpadded[i * FP_LENGTH..(i + 1) * FP_LENGTH]);
}
padded
}
pub(super) fn pad_g2_point(unpadded: &[u8]) -> [u8; PADDED_G2_LENGTH] {
assert_eq!(
unpadded.len(),
4 * FP_LENGTH,
"Invalid unpadded G2 point length"
);
let mut padded = [0u8; PADDED_G2_LENGTH];
for i in 0..4 {
padded[i * PADDED_FP_LENGTH + FP_PAD_BY..(i + 1) * PADDED_FP_LENGTH]
.copy_from_slice(&unpadded[i * FP_LENGTH..(i + 1) * FP_LENGTH]);
}
padded
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pad_g1_point_roundtrip() {
let mut unpadded = [0u8; G1_LENGTH];
for (i, byte) in unpadded.iter_mut().enumerate() {
*byte = (i * 2 + 1) as u8;
}
let padded = pad_g1_point(&unpadded);
let result = remove_g1_padding(&padded).unwrap();
assert_eq!(result[0], &unpadded[0..FP_LENGTH]);
assert_eq!(result[1], &unpadded[FP_LENGTH..G1_LENGTH]);
}
#[test]
fn test_pad_g2_point_roundtrip() {
let mut unpadded = [0u8; 4 * FP_LENGTH];
for (i, byte) in unpadded.iter_mut().enumerate() {
*byte = (i * 2 + 1) as u8;
}
let padded = pad_g2_point(&unpadded);
let result = remove_g2_padding(&padded).unwrap();
assert_eq!(result[0], &unpadded[0..FP_LENGTH]);
assert_eq!(result[1], &unpadded[FP_LENGTH..2 * FP_LENGTH]);
assert_eq!(result[2], &unpadded[2 * FP_LENGTH..3 * FP_LENGTH]);
assert_eq!(result[3], &unpadded[3 * FP_LENGTH..4 * FP_LENGTH]);
}
}