use alloy_primitives::{B256, keccak256};
use alloy_sol_types::SolValue;
use crate::composable::ConditionalOrderParams;
pub fn conditional_order_leaf(params: &ConditionalOrderParams) -> B256 {
let id = keccak256(params.abi_encode());
keccak256(id.as_slice())
}
fn hash_pair(a: B256, b: B256) -> B256 {
let (lo, hi) = if a <= b { (a, b) } else { (b, a) };
let mut buf = [0_u8; 64];
buf[..32].copy_from_slice(lo.as_slice());
buf[32..].copy_from_slice(hi.as_slice());
keccak256(buf)
}
#[derive(Clone, Debug)]
pub struct Multiplexer {
levels: Vec<Vec<B256>>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MultiplexerError {
Empty,
IndexOutOfRange,
}
impl std::fmt::Display for MultiplexerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::Empty => "Multiplexer needs at least one leaf",
Self::IndexOutOfRange => "leaf index out of range",
})
}
}
impl std::error::Error for MultiplexerError {}
impl Multiplexer {
pub fn new(leaves: &[B256]) -> Result<Self, MultiplexerError> {
if leaves.is_empty() {
return Err(MultiplexerError::Empty);
}
let mut levels: Vec<Vec<B256>> = vec![leaves.to_vec()];
while levels.last().unwrap().len() > 1 {
let prev = levels.last().unwrap();
let mut next = Vec::with_capacity(prev.len().div_ceil(2));
let mut i = 0;
while i < prev.len() {
if i + 1 < prev.len() {
next.push(hash_pair(prev[i], prev[i + 1]));
} else {
next.push(prev[i]);
}
i += 2;
}
levels.push(next);
}
Ok(Self { levels })
}
pub fn from_params(orders: &[ConditionalOrderParams]) -> Result<Self, MultiplexerError> {
let leaves: Vec<B256> = orders.iter().map(conditional_order_leaf).collect();
Self::new(&leaves)
}
pub fn root(&self) -> B256 {
*self.levels.last().unwrap().first().unwrap()
}
pub fn leaves(&self) -> &[B256] {
&self.levels[0]
}
pub fn proof(&self, index: usize) -> Result<Vec<B256>, MultiplexerError> {
if index >= self.leaves().len() {
return Err(MultiplexerError::IndexOutOfRange);
}
let mut proof = Vec::new();
let mut idx = index;
for level in &self.levels[..self.levels.len() - 1] {
let sibling = idx ^ 1;
if sibling < level.len() {
proof.push(level[sibling]);
}
idx /= 2;
}
Ok(proof)
}
}
pub fn verify_proof(root: B256, proof: &[B256], leaf: B256) -> bool {
let mut current = leaf;
for sibling in proof {
current = hash_pair(current, *sibling);
}
current == root
}
#[cfg(test)]
mod tests {
use alloy_primitives::{Address, Bytes, hex};
use super::*;
fn leaf(i: u8) -> B256 {
B256::repeat_byte(i)
}
#[test]
fn single_leaf_tree() {
let leaves = vec![leaf(1)];
let tree = Multiplexer::new(&leaves).unwrap();
assert_eq!(tree.root(), leaf(1));
assert!(tree.proof(0).unwrap().is_empty());
assert!(verify_proof(tree.root(), &[], leaf(1)));
assert!(!verify_proof(tree.root(), &[], leaf(2)));
}
#[test]
fn two_leaf_tree_root_matches_sorted_pair_keccak() {
let leaves = vec![leaf(1), leaf(2)];
let tree = Multiplexer::new(&leaves).unwrap();
let expected_root = {
let mut buf = [0_u8; 64];
buf[..32].copy_from_slice(leaf(1).as_slice());
buf[32..].copy_from_slice(leaf(2).as_slice());
keccak256(buf)
};
assert_eq!(tree.root(), expected_root);
let proof_for_0 = tree.proof(0).unwrap();
assert_eq!(proof_for_0, vec![leaf(2)]);
assert!(verify_proof(tree.root(), &proof_for_0, leaf(1)));
let proof_for_1 = tree.proof(1).unwrap();
assert_eq!(proof_for_1, vec![leaf(1)]);
assert!(verify_proof(tree.root(), &proof_for_1, leaf(2)));
}
#[test]
fn three_leaf_tree_round_trips_every_proof() {
let leaves: Vec<B256> = (1_u8..=3).map(leaf).collect();
let tree = Multiplexer::new(&leaves).unwrap();
for (i, l) in leaves.iter().enumerate() {
let proof = tree.proof(i).unwrap();
assert!(
verify_proof(tree.root(), &proof, *l),
"leaf {i} did not round-trip",
);
}
let proof = tree.proof(0).unwrap();
assert!(!verify_proof(tree.root(), &proof, leaf(99)));
}
#[test]
fn arbitrary_size_tree_round_trips() {
let leaves: Vec<B256> = (1_u8..=10).map(leaf).collect();
let tree = Multiplexer::new(&leaves).unwrap();
for (i, l) in leaves.iter().enumerate() {
let proof = tree.proof(i).unwrap();
assert!(verify_proof(tree.root(), &proof, *l));
}
}
#[test]
fn conditional_order_leaf_double_hashes_params() {
let params = ConditionalOrderParams {
handler: Address::repeat_byte(0xab),
salt: B256::from(hex!(
"0101010101010101010101010101010101010101010101010101010101010101"
)),
staticInput: Bytes::from_static(&hex!("deadbeef")),
};
let id = keccak256(params.abi_encode());
let expected = keccak256(id.as_slice());
assert_eq!(conditional_order_leaf(¶ms), expected);
}
#[test]
fn empty_input_is_rejected() {
assert_eq!(Multiplexer::new(&[]).unwrap_err(), MultiplexerError::Empty);
}
#[test]
fn out_of_range_proof_is_rejected() {
let tree = Multiplexer::new(&[leaf(1), leaf(2)]).unwrap();
assert_eq!(
tree.proof(2).unwrap_err(),
MultiplexerError::IndexOutOfRange
);
}
}