use alloy::primitives::{keccak256, FixedBytes, U256};
use serde::{Deserialize, Serialize};
const OPERATOR_INFO_LEAF_SALT: u8 = 0x75;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OperatorInfo {
pub pubkey_x: U256,
pub pubkey_y: U256,
pub weights: Vec<U256>,
}
impl OperatorInfo {
pub fn new(pubkey_x: U256, pubkey_y: U256, weights: Vec<U256>) -> Self {
Self {
pubkey_x,
pubkey_y,
weights,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MerkleProof {
pub proof: Vec<FixedBytes<32>>,
pub index: usize,
pub leaf: FixedBytes<32>,
}
pub fn compute_operator_info_leaf(operator: &OperatorInfo) -> FixedBytes<32> {
let capacity = 1 + 32 + 64 + 32 + 32 + operator.weights.len() * 32;
let mut encoded = Vec::with_capacity(capacity);
encoded.push(OPERATOR_INFO_LEAF_SALT);
encoded.extend_from_slice(&U256::from(32).to_be_bytes::<32>());
encoded.extend_from_slice(&operator.pubkey_x.to_be_bytes::<32>());
encoded.extend_from_slice(&operator.pubkey_y.to_be_bytes::<32>());
encoded.extend_from_slice(&U256::from(96).to_be_bytes::<32>());
encoded.extend_from_slice(&U256::from(operator.weights.len()).to_be_bytes::<32>());
for weight in &operator.weights {
encoded.extend_from_slice(&weight.to_be_bytes::<32>());
}
keccak256(&encoded)
}
pub fn compute_operator_info_tree_root(operators: &[OperatorInfo]) -> FixedBytes<32> {
if operators.is_empty() {
return FixedBytes::ZERO;
}
let leaves: Vec<FixedBytes<32>> = operators.iter().map(compute_operator_info_leaf).collect();
compute_merkle_root_from_leaves(&leaves)
}
fn compute_merkle_root_from_leaves(leaves: &[FixedBytes<32>]) -> FixedBytes<32> {
if leaves.is_empty() {
return FixedBytes::ZERO;
}
if leaves.len() == 1 {
return leaves[0];
}
let mut current_level = leaves.to_vec();
let next_power_of_two = current_level.len().next_power_of_two();
current_level.resize(next_power_of_two, FixedBytes::ZERO);
while current_level.len() > 1 {
let mut next_level = Vec::with_capacity(current_level.len() / 2);
for chunk in current_level.chunks(2) {
let left = chunk[0];
let right = chunk.get(1).copied().unwrap_or(FixedBytes::ZERO);
next_level.push(hash_pair(left, right));
}
current_level = next_level;
}
current_level[0]
}
fn hash_pair(left: FixedBytes<32>, right: FixedBytes<32>) -> FixedBytes<32> {
let mut combined = [0u8; 64];
combined[..32].copy_from_slice(left.as_slice());
combined[32..].copy_from_slice(right.as_slice());
keccak256(combined)
}
pub fn generate_merkle_proof(operators: &[OperatorInfo], operator_index: usize) -> Option<MerkleProof> {
if operator_index >= operators.len() {
return None;
}
let leaves: Vec<FixedBytes<32>> = operators.iter().map(compute_operator_info_leaf).collect();
let target_leaf = *leaves.get(operator_index)?;
let proof = generate_proof_from_leaves(&leaves, operator_index);
Some(MerkleProof {
proof,
index: operator_index,
leaf: target_leaf,
})
}
fn generate_proof_from_leaves(leaves: &[FixedBytes<32>], index: usize) -> Vec<FixedBytes<32>> {
if leaves.len() <= 1 {
return vec![];
}
let mut current_level = leaves.to_vec();
let next_power_of_two = current_level.len().next_power_of_two();
current_level.resize(next_power_of_two, FixedBytes::ZERO);
let mut proof = Vec::new();
let mut current_index = index;
while current_level.len() > 1 {
let sibling_index = if current_index.is_multiple_of(2) {
current_index + 1
} else {
current_index - 1
};
if sibling_index < current_level.len() {
proof.push(current_level[sibling_index]);
} else {
proof.push(FixedBytes::ZERO);
}
let mut next_level = Vec::with_capacity(current_level.len() / 2);
for chunk in current_level.chunks(2) {
let left = chunk[0];
let right = chunk.get(1).copied().unwrap_or(FixedBytes::ZERO);
next_level.push(hash_pair(left, right));
}
current_level = next_level;
current_index /= 2;
}
proof
}
pub fn verify_merkle_proof(root: FixedBytes<32>, proof: &MerkleProof) -> bool {
let mut computed_hash = proof.leaf;
let mut index = proof.index;
for sibling in &proof.proof {
computed_hash = if index.is_multiple_of(2) {
hash_pair(computed_hash, *sibling)
} else {
hash_pair(*sibling, computed_hash)
};
index /= 2;
}
computed_hash == root
}
#[cfg(test)]
mod tests {
use std::slice;
use super::*;
#[test]
fn test_encoding_matches_alloy_abi_encode() {
use crate::bn254_table_calculator::{IOperatorTableCalculatorTypes::BN254OperatorInfo, BN254::G1Point};
use alloy::sol_types::SolValue;
let op = OperatorInfo::new(U256::from(0x1234), U256::from(0x5678), vec![U256::from(1000)]);
let mut manual_encoded = Vec::new();
manual_encoded.extend_from_slice(&U256::from(32).to_be_bytes::<32>());
manual_encoded.extend_from_slice(&op.pubkey_x.to_be_bytes::<32>());
manual_encoded.extend_from_slice(&op.pubkey_y.to_be_bytes::<32>());
manual_encoded.extend_from_slice(&U256::from(96).to_be_bytes::<32>());
manual_encoded.extend_from_slice(&U256::from(op.weights.len()).to_be_bytes::<32>());
for w in &op.weights {
manual_encoded.extend_from_slice(&w.to_be_bytes::<32>());
}
let contract_op = BN254OperatorInfo {
pubkey: G1Point {
X: U256::from(0x1234),
Y: U256::from(0x5678),
},
weights: vec![U256::from(1000)],
};
let alloy_encoded = contract_op.abi_encode();
assert_eq!(
alloy_encoded,
manual_encoded,
"Manual encoding doesn't match alloy's abi_encode (alloy={}, manual={})",
alloy_encoded.len(),
manual_encoded.len()
);
}
#[test]
fn test_empty_operators_returns_zero_root() {
let root = compute_operator_info_tree_root(&[]);
assert_eq!(root, FixedBytes::ZERO);
}
#[test]
fn test_single_operator_leaf_is_root() {
let operator = OperatorInfo::new(
U256::from(123),
U256::from(456),
vec![U256::from(1_000_000_000_000_000_000u128)],
);
let root = compute_operator_info_tree_root(slice::from_ref(&operator));
let leaf = compute_operator_info_leaf(&operator);
assert_eq!(root, leaf);
}
#[test]
fn test_deterministic_root_computation() {
let operators = vec![
OperatorInfo::new(U256::from(1), U256::from(2), vec![U256::from(100)]),
OperatorInfo::new(U256::from(3), U256::from(4), vec![U256::from(200)]),
];
let root1 = compute_operator_info_tree_root(&operators);
let root2 = compute_operator_info_tree_root(&operators);
assert_eq!(root1, root2);
}
#[test]
fn test_different_order_different_root() {
let op1 = OperatorInfo::new(U256::from(1), U256::from(2), vec![U256::from(100)]);
let op2 = OperatorInfo::new(U256::from(3), U256::from(4), vec![U256::from(200)]);
let root1 = compute_operator_info_tree_root(&[op1.clone(), op2.clone()]);
let root2 = compute_operator_info_tree_root(&[op2, op1]);
assert_ne!(root1, root2);
}
#[test]
fn test_merkle_proof_generation_and_verification() {
let operators = vec![
OperatorInfo::new(U256::from(1), U256::from(2), vec![U256::from(100)]),
OperatorInfo::new(U256::from(3), U256::from(4), vec![U256::from(200)]),
OperatorInfo::new(U256::from(5), U256::from(6), vec![U256::from(300)]),
];
let root = compute_operator_info_tree_root(&operators);
for i in 0..operators.len() {
let proof = generate_merkle_proof(&operators, i).unwrap();
assert!(
verify_merkle_proof(root, &proof),
"Proof verification failed for operator {}",
i
);
}
}
#[test]
fn test_invalid_proof_fails_verification() {
let operators = vec![
OperatorInfo::new(U256::from(1), U256::from(2), vec![U256::from(100)]),
OperatorInfo::new(U256::from(3), U256::from(4), vec![U256::from(200)]),
];
let root = compute_operator_info_tree_root(&operators);
let mut proof = generate_merkle_proof(&operators, 0).unwrap();
proof.leaf = FixedBytes::ZERO;
assert!(!verify_merkle_proof(root, &proof));
}
#[test]
fn test_out_of_bounds_index_returns_none() {
let operators = vec![OperatorInfo::new(U256::from(1), U256::from(2), vec![U256::from(100)])];
let proof = generate_merkle_proof(&operators, 5);
assert!(proof.is_none());
}
#[test]
fn test_leaf_computation_matches_solidity_encoding() {
let operator = OperatorInfo::new(U256::from(0x1234), U256::from(0x5678), vec![U256::from(1000)]);
let leaf = compute_operator_info_leaf(&operator);
let mut expected_encoded = Vec::new();
expected_encoded.push(OPERATOR_INFO_LEAF_SALT);
expected_encoded.extend_from_slice(&U256::from(32).to_be_bytes::<32>());
expected_encoded.extend_from_slice(&U256::from(0x1234).to_be_bytes::<32>());
expected_encoded.extend_from_slice(&U256::from(0x5678).to_be_bytes::<32>());
expected_encoded.extend_from_slice(&U256::from(96).to_be_bytes::<32>());
expected_encoded.extend_from_slice(&U256::from(1).to_be_bytes::<32>());
expected_encoded.extend_from_slice(&U256::from(1000).to_be_bytes::<32>());
let expected_leaf = keccak256(&expected_encoded);
assert_eq!(leaf, expected_leaf);
}
#[test]
fn test_multiple_weights() {
let operator = OperatorInfo::new(
U256::from(1),
U256::from(2),
vec![U256::from(100), U256::from(200), U256::from(300)],
);
let leaf = compute_operator_info_leaf(&operator);
assert_ne!(leaf, FixedBytes::ZERO);
let operator2 = OperatorInfo::new(
U256::from(1),
U256::from(2),
vec![U256::from(100), U256::from(200), U256::from(400)],
);
let leaf2 = compute_operator_info_leaf(&operator2);
assert_ne!(leaf, leaf2);
}
#[test]
fn test_power_of_two_padding() {
let operators = vec![
OperatorInfo::new(U256::from(1), U256::from(2), vec![U256::from(100)]),
OperatorInfo::new(U256::from(3), U256::from(4), vec![U256::from(200)]),
OperatorInfo::new(U256::from(5), U256::from(6), vec![U256::from(300)]),
];
let root = compute_operator_info_tree_root(&operators);
assert_ne!(root, FixedBytes::ZERO);
for i in 0..operators.len() {
let proof = generate_merkle_proof(&operators, i).unwrap();
assert!(verify_merkle_proof(root, &proof));
}
}
#[test]
fn test_operator_index_uses_iteration_order() {
let op0 = OperatorInfo::new(
U256::from_be_slice(&[0xff; 32]),
U256::from_be_slice(&[0xff; 32]),
vec![U256::from(20)],
);
let op1 = OperatorInfo::new(
U256::from_be_slice(&[0x00; 32]),
U256::from_be_slice(&[0x01; 32]),
vec![U256::from(40)],
);
let operators = vec![op0, op1];
let root = compute_operator_info_tree_root(&operators);
for i in 0..operators.len() {
let proof = generate_merkle_proof(&operators, i).unwrap();
assert!(verify_merkle_proof(root, &proof));
}
}
}