use std::fmt;
use alloy_primitives::{B256, keccak256};
use serde::{Deserialize, Serialize};
use crate::error::CowError;
use super::{
order_id,
types::{ConditionalOrderParams, ProofLocation},
};
#[derive(Debug, Clone)]
pub struct OrderProof {
pub order_id: B256,
pub proof: Vec<B256>,
pub params: ConditionalOrderParams,
}
impl OrderProof {
#[must_use]
pub const fn new(order_id: B256, proof: Vec<B256>, params: ConditionalOrderParams) -> Self {
Self { order_id, proof, params }
}
#[must_use]
pub const fn proof_len(&self) -> usize {
self.proof.len()
}
}
#[derive(Debug, Clone)]
pub struct ProofWithParams {
pub proof: Vec<B256>,
pub params: ConditionalOrderParams,
}
impl ProofWithParams {
#[must_use]
pub const fn new(proof: Vec<B256>, params: ConditionalOrderParams) -> Self {
Self { proof, params }
}
#[must_use]
pub const fn proof_len(&self) -> usize {
self.proof.len()
}
}
impl fmt::Display for OrderProof {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "order-proof({:#x}, {} siblings)", self.order_id, self.proof.len())
}
}
impl fmt::Display for ProofWithParams {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"proof-with-params({} siblings, handler={:#x})",
self.proof.len(),
self.params.handler
)
}
}
#[derive(Serialize, Deserialize)]
struct MultiplexerJson {
proof_location: u8,
orders: Vec<ParamsJson>,
}
#[derive(Serialize, Deserialize)]
struct ParamsJson {
handler: String,
salt: String,
static_input: String,
}
#[derive(Deserialize)]
struct WatchtowerEntry {
proof: Vec<String>,
params: WatchtowerParams,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct WatchtowerParams {
handler: String,
salt: String,
static_input: String,
}
impl From<&ConditionalOrderParams> for ParamsJson {
fn from(p: &ConditionalOrderParams) -> Self {
Self {
handler: format!("{:?}", p.handler),
salt: format!("0x{}", alloy_primitives::hex::encode(p.salt.as_slice())),
static_input: format!("0x{}", alloy_primitives::hex::encode(&p.static_input)),
}
}
}
impl TryFrom<ParamsJson> for ConditionalOrderParams {
type Error = CowError;
fn try_from(j: ParamsJson) -> Result<Self, CowError> {
let handler = j
.handler
.parse()
.map_err(|e: alloy_primitives::hex::FromHexError| CowError::AppData(e.to_string()))?;
let salt_hex = j.salt.strip_prefix("0x").map_or(j.salt.as_str(), |s| s);
let salt_bytes = alloy_primitives::hex::decode(salt_hex)
.map_err(|e| CowError::AppData(format!("salt: {e}")))?;
let mut salt = [0u8; 32];
salt.copy_from_slice(&salt_bytes);
let input_hex = j.static_input.strip_prefix("0x").map_or(j.static_input.as_str(), |s| s);
let static_input = alloy_primitives::hex::decode(input_hex)
.map_err(|e| CowError::AppData(format!("static_input: {e}")))?;
Ok(Self { handler, salt: B256::new(salt), static_input })
}
}
#[derive(Debug, Clone, Default)]
pub struct Multiplexer {
orders: Vec<ConditionalOrderParams>,
proof_location: ProofLocation,
}
impl Multiplexer {
#[must_use]
pub const fn new(proof_location: ProofLocation) -> Self {
Self { orders: Vec::new(), proof_location }
}
pub fn add(&mut self, params: ConditionalOrderParams) {
self.orders.push(params);
}
pub fn remove(&mut self, id: B256) {
self.orders.retain(|p| order_id(p) != id);
}
pub fn update(&mut self, index: usize, params: ConditionalOrderParams) -> Result<(), CowError> {
if index >= self.orders.len() {
return Err(CowError::AppData(format!(
"index {index} out of range (len {})",
self.orders.len()
)));
}
self.orders[index] = params;
Ok(())
}
#[must_use]
pub fn get_by_index(&self, index: usize) -> Option<&ConditionalOrderParams> {
self.orders.get(index)
}
#[must_use]
pub fn get_by_id(&self, id: B256) -> Option<&ConditionalOrderParams> {
self.orders.iter().find(|p| order_id(p) == id)
}
#[must_use]
pub const fn len(&self) -> usize {
self.orders.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.orders.is_empty()
}
pub fn root(&self) -> Result<Option<B256>, CowError> {
if self.orders.is_empty() {
return Ok(None);
}
let leaves: Vec<B256> = self.orders.iter().map(leaf_hash).collect();
Ok(Some(merkle_root(&leaves)))
}
pub fn proof(&self, index: usize) -> Result<OrderProof, CowError> {
if index >= self.orders.len() {
return Err(CowError::AppData(format!(
"index {index} out of range (len {})",
self.orders.len()
)));
}
let leaves: Vec<B256> = self.orders.iter().map(leaf_hash).collect();
Ok(OrderProof {
order_id: order_id(&self.orders[index]),
proof: generate_proof(&leaves, index),
params: self.orders[index].clone(),
})
}
pub fn dump_proofs_and_params(&self) -> Result<Vec<ProofWithParams>, CowError> {
(0..self.orders.len())
.map(|i| {
let op = self.proof(i)?;
Ok(ProofWithParams { proof: op.proof, params: op.params })
})
.collect()
}
pub fn order_ids(&self) -> impl Iterator<Item = alloy_primitives::B256> + '_ {
self.orders.iter().map(order_id)
}
pub fn iter(&self) -> impl Iterator<Item = &ConditionalOrderParams> {
self.orders.iter()
}
#[must_use]
pub fn as_slice(&self) -> &[ConditionalOrderParams] {
&self.orders
}
pub fn clear(&mut self) {
self.orders.clear();
}
#[must_use]
pub const fn proof_location(&self) -> ProofLocation {
self.proof_location
}
#[must_use]
pub const fn with_proof_location(mut self, location: ProofLocation) -> Self {
self.proof_location = location;
self
}
#[must_use]
pub fn into_vec(self) -> Vec<ConditionalOrderParams> {
self.orders
}
pub fn to_json(&self) -> Result<String, CowError> {
let j = MultiplexerJson {
proof_location: self.proof_location as u8,
orders: self.orders.iter().map(ParamsJson::from).collect(),
};
serde_json::to_string(&j).map_err(|e| CowError::AppData(e.to_string()))
}
pub fn decode_proofs_from_json(json: &str) -> Result<Vec<ProofWithParams>, CowError> {
let entries: Vec<WatchtowerEntry> =
serde_json::from_str(json).map_err(|e| CowError::AppData(e.to_string()))?;
entries
.into_iter()
.map(|entry| {
let proof = entry
.proof
.iter()
.map(|s| {
let hex = s.strip_prefix("0x").map_or(s.as_str(), |h| h);
let bytes = alloy_primitives::hex::decode(hex)
.map_err(|e| CowError::AppData(format!("proof hash: {e}")))?;
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(B256::new(arr))
})
.collect::<Result<Vec<_>, CowError>>()?;
let p = entry.params;
let handler =
p.handler.parse().map_err(|e: alloy_primitives::hex::FromHexError| {
CowError::AppData(e.to_string())
})?;
let salt_hex = p.salt.strip_prefix("0x").map_or(p.salt.as_str(), |s| s);
let salt_bytes = alloy_primitives::hex::decode(salt_hex)
.map_err(|e| CowError::AppData(format!("salt: {e}")))?;
let mut salt = [0u8; 32];
salt.copy_from_slice(&salt_bytes);
let input_hex =
p.static_input.strip_prefix("0x").map_or(p.static_input.as_str(), |s| s);
let static_input = alloy_primitives::hex::decode(input_hex)
.map_err(|e| CowError::AppData(format!("staticInput: {e}")))?;
let params =
ConditionalOrderParams { handler, salt: B256::new(salt), static_input };
Ok(ProofWithParams { proof, params })
})
.collect()
}
pub fn from_json(json: &str) -> Result<Self, CowError> {
let j: MultiplexerJson =
serde_json::from_str(json).map_err(|e| CowError::AppData(e.to_string()))?;
let proof_location = match j.proof_location {
0 => ProofLocation::Private,
1 => ProofLocation::Emitted,
2 => ProofLocation::Swarm,
3 => ProofLocation::Waku,
4 => ProofLocation::Reserved,
5 => ProofLocation::Ipfs,
n => {
return Err(CowError::AppData(format!("unknown ProofLocation: {n}")));
}
};
let orders = j
.orders
.into_iter()
.map(ConditionalOrderParams::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(Self { orders, proof_location })
}
}
impl fmt::Display for Multiplexer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "multiplexer({} orders, {})", self.orders.len(), self.proof_location)
}
}
fn leaf_hash(params: &ConditionalOrderParams) -> B256 {
keccak256(order_id(params))
}
fn merkle_root(leaves: &[B256]) -> B256 {
if leaves.len() == 1 {
return leaves[0];
}
let mut layer = leaves.to_vec();
while layer.len() > 1 {
let mut next = Vec::with_capacity(layer.len().div_ceil(2));
let mut i = 0;
while i < layer.len() {
if i + 1 < layer.len() {
next.push(hash_pair(layer[i], layer[i + 1]));
} else {
next.push(layer[i]);
}
i += 2;
}
layer = next;
}
layer[0]
}
fn hash_pair(a: B256, b: B256) -> B256 {
let (lo, hi) = if a <= b { (a, b) } else { (b, a) };
let mut buf = [0u8; 64];
buf[..32].copy_from_slice(lo.as_slice());
buf[32..].copy_from_slice(hi.as_slice());
keccak256(buf)
}
fn generate_proof(leaves: &[B256], mut index: usize) -> Vec<B256> {
let mut proof = Vec::new();
let mut layer = leaves.to_vec();
while layer.len() > 1 {
let sibling = if index.is_multiple_of(2) {
(index + 1 < layer.len()).then(|| layer[index + 1])
} else {
Some(layer[index - 1])
};
if let Some(s) = sibling {
proof.push(s);
}
let mut next = Vec::with_capacity(layer.len().div_ceil(2));
let mut i = 0;
while i < layer.len() {
if i + 1 < layer.len() {
next.push(hash_pair(layer[i], layer[i + 1]));
} else {
next.push(layer[i]);
}
i += 2;
}
layer = next;
index /= 2;
}
proof
}
#[cfg(test)]
mod tests {
use alloy_primitives::Address;
use super::*;
fn make_params(salt_byte: u8) -> ConditionalOrderParams {
ConditionalOrderParams {
handler: Address::ZERO,
salt: B256::new([salt_byte; 32]),
static_input: vec![salt_byte; 4],
}
}
#[test]
fn decode_proofs_from_json_roundtrip() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
let proofs = mux.dump_proofs_and_params().unwrap();
let json_entries: Vec<serde_json::Value> = proofs
.iter()
.map(|p| {
let proof_arr: Vec<String> = p
.proof
.iter()
.map(|h| format!("0x{}", alloy_primitives::hex::encode(h.as_slice())))
.collect();
serde_json::json!({
"proof": proof_arr,
"params": {
"handler": format!("{:#x}", p.params.handler),
"salt": format!("0x{}", alloy_primitives::hex::encode(p.params.salt.as_slice())),
"staticInput": format!("0x{}", alloy_primitives::hex::encode(&p.params.static_input)),
}
})
})
.collect();
let json = serde_json::to_string(&json_entries).unwrap();
let decoded = Multiplexer::decode_proofs_from_json(&json).unwrap();
assert_eq!(decoded.len(), 2);
assert_eq!(decoded[0].params.salt, proofs[0].params.salt);
assert_eq!(decoded[1].params.static_input, proofs[1].params.static_input);
}
#[test]
fn decode_proofs_from_json_invalid_returns_error() {
let result = Multiplexer::decode_proofs_from_json("not json");
assert!(result.is_err());
}
#[test]
fn multiplexer_root_single_order() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
let root = mux.root().unwrap();
assert!(root.is_some());
}
#[test]
fn multiplexer_root_empty() {
let mux = Multiplexer::new(ProofLocation::Private);
assert!(mux.root().unwrap().is_none());
}
#[test]
fn add_increases_len() {
let mut mux = Multiplexer::new(ProofLocation::Private);
assert!(mux.is_empty());
mux.add(make_params(1));
assert_eq!(mux.len(), 1);
mux.add(make_params(2));
assert_eq!(mux.len(), 2);
}
#[test]
fn remove_by_id() {
let mut mux = Multiplexer::new(ProofLocation::Private);
let p = make_params(0xaa);
let id = order_id(&p);
mux.add(p);
mux.add(make_params(0xbb));
assert_eq!(mux.len(), 2);
mux.remove(id);
assert_eq!(mux.len(), 1);
}
#[test]
fn remove_nonexistent_is_noop() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
mux.remove(B256::ZERO);
assert_eq!(mux.len(), 1);
}
#[test]
fn update_in_range() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
mux.add(make_params(2));
let new_params = make_params(99);
mux.update(1, new_params.clone()).unwrap();
assert_eq!(mux.get_by_index(1).unwrap().salt, new_params.salt);
}
#[test]
fn update_out_of_range() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
assert!(mux.update(5, make_params(2)).is_err());
}
#[test]
fn get_by_index_valid() {
let mut mux = Multiplexer::new(ProofLocation::Private);
let p = make_params(0xcc);
mux.add(p.clone());
let got = mux.get_by_index(0).unwrap();
assert_eq!(got.salt, p.salt);
}
#[test]
fn get_by_index_out_of_range() {
let mux = Multiplexer::new(ProofLocation::Private);
assert!(mux.get_by_index(0).is_none());
}
#[test]
fn get_by_id_found() {
let mut mux = Multiplexer::new(ProofLocation::Private);
let p = make_params(0xdd);
let id = order_id(&p);
mux.add(p.clone());
let got = mux.get_by_id(id).unwrap();
assert_eq!(got.salt, p.salt);
}
#[test]
fn get_by_id_not_found() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
assert!(mux.get_by_id(B256::ZERO).is_none());
}
#[test]
fn root_changes_when_order_added() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
let root1 = mux.root().unwrap().unwrap();
mux.add(make_params(2));
let root2 = mux.root().unwrap().unwrap();
assert_ne!(root1, root2);
}
#[test]
fn root_two_orders() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
let root = mux.root().unwrap();
assert!(root.is_some());
}
#[test]
fn proof_valid_index() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
let proof = mux.proof(0).unwrap();
assert!(!proof.proof.is_empty());
assert_eq!(proof.params.salt, make_params(0xaa).salt);
}
#[test]
fn proof_out_of_range() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
assert!(mux.proof(5).is_err());
}
#[test]
fn dump_proofs_and_params_returns_all() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
mux.add(make_params(0xcc));
let proofs = mux.dump_proofs_and_params().unwrap();
assert_eq!(proofs.len(), 3);
}
#[test]
fn to_json_from_json_roundtrip() {
let mut mux = Multiplexer::new(ProofLocation::Ipfs);
mux.add(make_params(0x11));
mux.add(make_params(0x22));
let json = mux.to_json().unwrap();
let restored = Multiplexer::from_json(&json).unwrap();
assert_eq!(restored.len(), 2);
assert_eq!(restored.proof_location(), ProofLocation::Ipfs);
assert_eq!(restored.get_by_index(0).unwrap().salt, make_params(0x11).salt);
assert_eq!(restored.get_by_index(1).unwrap().salt, make_params(0x22).salt);
}
#[test]
fn from_json_invalid() {
assert!(Multiplexer::from_json("not json").is_err());
}
#[test]
fn from_json_unknown_proof_location() {
let json = r#"{"proof_location": 99, "orders": []}"#;
assert!(Multiplexer::from_json(json).is_err());
}
#[test]
fn clear_empties_multiplexer() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
mux.add(make_params(2));
mux.clear();
assert!(mux.is_empty());
}
#[test]
fn order_ids_iterator() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
let ids: Vec<_> = mux.order_ids().collect();
assert_eq!(ids.len(), 2);
assert_eq!(ids[0], order_id(&make_params(0xaa)));
}
#[test]
fn iter_and_as_slice() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
mux.add(make_params(2));
assert_eq!(mux.iter().count(), 2);
assert_eq!(mux.as_slice().len(), 2);
}
#[test]
fn into_vec_returns_orders() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
let v = mux.into_vec();
assert_eq!(v.len(), 1);
}
#[test]
fn with_proof_location_builder() {
let mux =
Multiplexer::new(ProofLocation::Private).with_proof_location(ProofLocation::Swarm);
assert_eq!(mux.proof_location(), ProofLocation::Swarm);
}
#[test]
fn display_multiplexer() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(1));
let s = format!("{mux}");
assert!(s.contains("1 orders"));
}
#[test]
fn display_order_proof() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
let proof = mux.proof(0).unwrap();
let s = format!("{proof}");
assert!(s.contains("order-proof"));
}
#[test]
fn from_json_all_proof_locations() {
for (val, expected) in [
(0, ProofLocation::Private),
(1, ProofLocation::Emitted),
(2, ProofLocation::Swarm),
(3, ProofLocation::Waku),
(4, ProofLocation::Reserved),
(5, ProofLocation::Ipfs),
] {
let json = format!(r#"{{"proof_location": {val}, "orders": []}}"#);
let mux = Multiplexer::from_json(&json).unwrap();
assert_eq!(mux.proof_location(), expected);
}
}
#[test]
fn multiplexer_root_three_orders() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
mux.add(make_params(0xcc));
let root = mux.root().unwrap();
assert!(root.is_some());
for i in 0..3 {
let proof = mux.proof(i).unwrap();
assert!(!proof.proof.is_empty() || mux.len() == 1);
}
}
#[test]
fn order_proof_accessors() {
let params = make_params(0xaa);
let id = order_id(¶ms);
let proof = OrderProof::new(id, vec![B256::ZERO], params.clone());
assert_eq!(proof.order_id, id);
assert_eq!(proof.proof_len(), 1);
assert_eq!(proof.params.salt, params.salt);
}
#[test]
fn proof_with_params_accessors() {
let params = make_params(0xaa);
let pwp = ProofWithParams::new(vec![B256::ZERO, B256::ZERO], params);
assert_eq!(pwp.proof_len(), 2);
}
#[test]
fn display_proof_with_params() {
let mut mux = Multiplexer::new(ProofLocation::Private);
mux.add(make_params(0xaa));
mux.add(make_params(0xbb));
let proofs = mux.dump_proofs_and_params().unwrap();
let s = format!("{}", proofs[0]);
assert!(s.contains("proof-with-params"));
}
#[test]
fn order_proof_new_and_proof_len() {
let op = OrderProof::new(B256::ZERO, vec![B256::ZERO, B256::ZERO], make_params(1));
assert_eq!(op.proof_len(), 2);
}
#[test]
fn proof_with_params_new_and_proof_len() {
let pwp = ProofWithParams::new(vec![B256::ZERO], make_params(1));
assert_eq!(pwp.proof_len(), 1);
}
}