use ark_bn254::Fr as BN254Scalar;
use ark_ff::{BigInteger, PrimeField};
pub const WIDTH: usize = 16;
pub const FULL_ROUNDS_BEGIN: usize = 4;
pub const PARTIAL_ROUNDS: usize = 14;
pub const FULL_ROUNDS_END: usize = 4;
pub const TOTAL_ROUNDS: usize = FULL_ROUNDS_BEGIN + PARTIAL_ROUNDS + FULL_ROUNDS_END;
fn sbox(x: BN254Scalar) -> BN254Scalar {
let x2 = x * x;
let x4 = x2 * x2;
x4 * x
}
fn round_constants() -> Vec<[BN254Scalar; WIDTH]> {
let mut constants = Vec::with_capacity(TOTAL_ROUNDS);
for round in 0..TOTAL_ROUNDS {
let mut rc = [BN254Scalar::from(0u64); WIDTH];
for i in 0..WIDTH {
let seed = ((round * WIDTH + i) as u64).wrapping_mul(0x9e3779b97f4a7c15);
rc[i] = BN254Scalar::from(seed);
}
constants.push(rc);
}
constants
}
fn mds_matrix() -> [[BN254Scalar; WIDTH]; WIDTH] {
let two = BN254Scalar::from(2u64);
let three = BN254Scalar::from(3u64);
let one = BN254Scalar::from(1u64);
let zero = BN254Scalar::from(0u64);
let m4 = [
[two, three, one, one],
[one, two, three, one],
[one, one, two, three],
[three, one, one, two],
];
let mut matrix = [[zero; WIDTH]; WIDTH];
for block in 0..4 {
for i in 0..4 {
for j in 0..4 {
matrix[block * 4 + i][block * 4 + j] = m4[i][j];
}
}
}
matrix
}
fn external_matrix_mult(state: &mut [BN254Scalar; WIDTH]) {
let matrix = mds_matrix();
let mut result = [BN254Scalar::from(0u64); WIDTH];
for i in 0..WIDTH {
for j in 0..WIDTH {
result[i] += matrix[i][j] * state[j];
}
}
*state = result;
}
fn internal_matrix_mult(state: &mut [BN254Scalar; WIDTH]) {
let sum: BN254Scalar = state.iter().fold(BN254Scalar::from(0u64), |acc, &x| acc + x);
for i in 0..WIDTH {
let diag = BN254Scalar::from((i + 1) as u64);
state[i] = state[i] * diag + sum;
}
}
fn full_round(state: &mut [BN254Scalar; WIDTH], rc: &[BN254Scalar; WIDTH]) {
for i in 0..WIDTH {
state[i] += rc[i];
}
for i in 0..WIDTH {
state[i] = sbox(state[i]);
}
external_matrix_mult(state);
}
fn partial_round(state: &mut [BN254Scalar; WIDTH], rc: &[BN254Scalar; WIDTH]) {
for i in 0..WIDTH {
state[i] += rc[i];
}
state[0] = sbox(state[0]);
internal_matrix_mult(state);
}
pub fn poseidon2_permutation(state: &mut [BN254Scalar; WIDTH]) {
let constants = round_constants();
let mut round_idx = 0;
for _ in 0..FULL_ROUNDS_BEGIN {
full_round(state, &constants[round_idx]);
round_idx += 1;
}
for _ in 0..PARTIAL_ROUNDS {
partial_round(state, &constants[round_idx]);
round_idx += 1;
}
for _ in 0..FULL_ROUNDS_END {
full_round(state, &constants[round_idx]);
round_idx += 1;
}
}
pub fn poseidon2_hash(inputs: &[BN254Scalar]) -> BN254Scalar {
let rate = WIDTH - 1;
let mut state = [BN254Scalar::from(0u64); WIDTH];
for chunk in inputs.chunks(rate) {
for (i, &input) in chunk.iter().enumerate() {
state[i] += input;
}
poseidon2_permutation(&mut state);
}
state[0]
}
pub fn poseidon2_hash_5(inputs: &[BN254Scalar; 5]) -> BN254Scalar {
poseidon2_hash(inputs.as_slice())
}
pub fn poseidon2_hash_2(left: BN254Scalar, right: BN254Scalar) -> BN254Scalar {
poseidon2_hash(&[left, right])
}
pub fn poseidon2_hash_3(a: BN254Scalar, b: BN254Scalar, c: BN254Scalar) -> BN254Scalar {
poseidon2_hash(&[a, b, c])
}
pub fn derive_nsec(leaf_secret: &[BN254Scalar; 5]) -> BN254Scalar {
poseidon2_hash_3(leaf_secret[0], leaf_secret[1], leaf_secret[2])
}
pub fn derive_npub_commitment(nsec: BN254Scalar) -> BN254Scalar {
let domain_sep = BN254Scalar::from(0x6e707562u64);
poseidon2_hash(&[nsec, domain_sep])
}
pub fn generate_test_vectors() -> Vec<TestVector> {
let mut vectors = Vec::new();
let leaf_secret_1 = [
BN254Scalar::from(1u64),
BN254Scalar::from(2u64),
BN254Scalar::from(3u64),
BN254Scalar::from(4u64),
BN254Scalar::from(5u64),
];
let leaf_1 = poseidon2_hash_5(&leaf_secret_1);
let nsec_1 = derive_nsec(&leaf_secret_1);
let npub_1 = derive_npub_commitment(nsec_1);
vectors.push(TestVector {
name: "simple_sequential".to_string(),
leaf_secret: leaf_secret_1,
expected_leaf: leaf_1,
expected_nsec: nsec_1,
expected_npub: npub_1,
});
let leaf_secret_2 = [BN254Scalar::from(0u64); 5];
let leaf_2 = poseidon2_hash_5(&leaf_secret_2);
let nsec_2 = derive_nsec(&leaf_secret_2);
let npub_2 = derive_npub_commitment(nsec_2);
vectors.push(TestVector {
name: "all_zeros".to_string(),
leaf_secret: leaf_secret_2,
expected_leaf: leaf_2,
expected_nsec: nsec_2,
expected_npub: npub_2,
});
let leaf_secret_3 = [
BN254Scalar::from(0xdeadbeefu64),
BN254Scalar::from(0xcafebabeu64),
BN254Scalar::from(0x12345678u64),
BN254Scalar::from(0x87654321u64),
BN254Scalar::from(0xfeedface64u64),
];
let leaf_3 = poseidon2_hash_5(&leaf_secret_3);
let nsec_3 = derive_nsec(&leaf_secret_3);
let npub_3 = derive_npub_commitment(nsec_3);
vectors.push(TestVector {
name: "hex_values".to_string(),
leaf_secret: leaf_secret_3,
expected_leaf: leaf_3,
expected_nsec: nsec_3,
expected_npub: npub_3,
});
vectors
}
#[derive(Debug)]
pub struct TestVector {
pub name: String,
pub leaf_secret: [BN254Scalar; 5],
pub expected_leaf: BN254Scalar,
pub expected_nsec: BN254Scalar,
pub expected_npub: BN254Scalar,
}
pub fn scalar_to_hex(s: BN254Scalar) -> String {
let bytes = s.into_bigint().to_bytes_be();
hex::encode(bytes)
}
pub fn hex_to_scalar(s: &str) -> Option<BN254Scalar> {
let bytes = hex::decode(s.trim_start_matches("0x")).ok()?;
let mut arr = [0u8; 32];
let start = 32 - bytes.len().min(32);
arr[start..].copy_from_slice(&bytes[..bytes.len().min(32)]);
BN254Scalar::from_be_bytes_mod_order(&arr).into()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sbox() {
let x = BN254Scalar::from(2u64);
let result = sbox(x);
assert_eq!(result, BN254Scalar::from(32u64)); }
#[test]
fn test_poseidon2_deterministic() {
let inputs = [BN254Scalar::from(1u64), BN254Scalar::from(2u64)];
let hash1 = poseidon2_hash(&inputs);
let hash2 = poseidon2_hash(&inputs);
assert_eq!(hash1, hash2);
}
#[test]
fn test_poseidon2_different_inputs() {
let inputs1 = [BN254Scalar::from(1u64), BN254Scalar::from(2u64)];
let inputs2 = [BN254Scalar::from(2u64), BN254Scalar::from(1u64)];
let hash1 = poseidon2_hash(&inputs1);
let hash2 = poseidon2_hash(&inputs2);
assert_ne!(hash1, hash2);
}
#[test]
fn test_leaf_commitment() {
let leaf_secret = [
BN254Scalar::from(1u64),
BN254Scalar::from(2u64),
BN254Scalar::from(3u64),
BN254Scalar::from(4u64),
BN254Scalar::from(5u64),
];
let leaf = poseidon2_hash_5(&leaf_secret);
assert_ne!(leaf, BN254Scalar::from(0u64));
}
#[test]
fn test_nsec_derivation() {
let leaf_secret = [
BN254Scalar::from(10u64),
BN254Scalar::from(20u64),
BN254Scalar::from(30u64),
BN254Scalar::from(40u64),
BN254Scalar::from(50u64),
];
let nsec1 = derive_nsec(&leaf_secret);
let nsec2 = derive_nsec(&leaf_secret);
assert_eq!(nsec1, nsec2);
}
#[test]
fn test_nsec_only_uses_first_three() {
let leaf_secret_a = [
BN254Scalar::from(1u64),
BN254Scalar::from(2u64),
BN254Scalar::from(3u64),
BN254Scalar::from(100u64), BN254Scalar::from(200u64), ];
let leaf_secret_b = [
BN254Scalar::from(1u64),
BN254Scalar::from(2u64),
BN254Scalar::from(3u64),
BN254Scalar::from(999u64), BN254Scalar::from(888u64), ];
let nsec_a = derive_nsec(&leaf_secret_a);
let nsec_b = derive_nsec(&leaf_secret_b);
assert_eq!(nsec_a, nsec_b);
}
#[test]
fn test_npub_derivation() {
let nsec = BN254Scalar::from(42u64);
let npub = derive_npub_commitment(nsec);
assert_eq!(npub, derive_npub_commitment(nsec));
assert_ne!(npub, nsec);
}
#[test]
fn test_generate_test_vectors() {
let vectors = generate_test_vectors();
assert!(vectors.len() >= 3);
for v in &vectors {
let computed_leaf = poseidon2_hash_5(&v.leaf_secret);
let computed_nsec = derive_nsec(&v.leaf_secret);
let computed_npub = derive_npub_commitment(computed_nsec);
assert_eq!(computed_leaf, v.expected_leaf, "Leaf mismatch for {}", v.name);
assert_eq!(computed_nsec, v.expected_nsec, "Nsec mismatch for {}", v.name);
assert_eq!(computed_npub, v.expected_npub, "Npub mismatch for {}", v.name);
}
}
#[test]
fn test_merkle_hash_2() {
let left = BN254Scalar::from(1u64);
let right = BN254Scalar::from(2u64);
let parent = poseidon2_hash_2(left, right);
let reversed = poseidon2_hash_2(right, left);
assert_ne!(parent, reversed);
}
}