use crate::AggregateOrders;
use crate::MarketError;
use crate::SignedFill;
use alloy::primitives::{Address, U256};
use serde::{Deserialize, Serialize};
use signet_zenith::RollupOrders;
use std::collections::HashMap;
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AggregateFills {
fills: HashMap<(u64, Address), HashMap<Address, U256>>,
}
impl AggregateFills {
pub fn new() -> Self {
Self::default()
}
pub fn filled(&self, output_asset: &(u64, Address), recipient: Address) -> U256 {
self.fills.get(output_asset).and_then(|m| m.get(&recipient)).copied().unwrap_or_default()
}
pub fn check_filled(
&self,
output_asset: &(u64, Address),
recipient: Address,
amount: U256,
) -> Result<(), MarketError> {
if self.filled(output_asset, recipient) < amount {
return Err(MarketError::InsufficientBalance {
chain_id: output_asset.0,
asset: output_asset.1,
recipient,
amount,
});
}
Ok(())
}
pub fn add_raw_fill(
&mut self,
chain_id: u64,
asset: Address,
recipient: Address,
amount: U256,
) {
let entry = self.fills.entry((chain_id, asset)).or_default().entry(recipient).or_default();
*entry = entry.saturating_add(amount);
}
fn add_fill_output(&mut self, chain_id: u64, output: &RollupOrders::Output) {
self.add_raw_fill(chain_id, output.token, output.recipient, output.amount)
}
pub fn add_fill(&mut self, chain_id: u64, fill: &RollupOrders::Filled) {
fill.outputs.iter().for_each(|o| self.add_fill_output(chain_id, o));
}
pub fn add_signed_fill(&mut self, chain_id: u64, fill: &SignedFill) {
fill.outputs.iter().for_each(|o| self.add_fill_output(chain_id, o));
}
pub fn absorb(&mut self, other: &Self) {
for (output_asset, recipients) in other.fills.iter() {
let context_recipients = self.fills.entry(*output_asset).or_default();
for (recipient, value) in recipients {
let filled = context_recipients.entry(*recipient).or_default();
*filled = filled.saturating_add(*value);
}
}
}
pub fn unchecked_unabsorb(&mut self, other: &Self) -> Result<(), MarketError> {
for (output_asset, recipients) in other.fills.iter() {
if let Some(context_recipients) = self.fills.get_mut(output_asset) {
for (recipient, value) in recipients {
if let Some(filled) = context_recipients.get_mut(recipient) {
*filled =
filled.checked_sub(*value).ok_or(MarketError::InsufficientBalance {
chain_id: output_asset.0,
asset: output_asset.1,
recipient: *recipient,
amount: *value,
})?;
}
}
}
}
Ok(())
}
pub fn check_aggregate(&self, aggregate: &AggregateOrders) -> Result<(), MarketError> {
for (output_asset, recipients) in aggregate.outputs.iter() {
if !self.fills.contains_key(output_asset) {
return Err(MarketError::MissingAsset {
chain_id: output_asset.0,
asset: output_asset.1,
});
};
for (recipient, value) in recipients {
self.check_filled(output_asset, *recipient, *value)?;
}
}
Ok(())
}
pub fn unchecked_remove_aggregate(
&mut self,
aggregate: &AggregateOrders,
) -> Result<(), MarketError> {
for (output_asset, recipients) in aggregate.outputs.iter() {
let context_recipients =
self.fills.get_mut(output_asset).ok_or(MarketError::MissingAsset {
chain_id: output_asset.0,
asset: output_asset.1,
})?;
for (recipient, amount) in recipients {
let filled = context_recipients.get_mut(recipient).ok_or(
MarketError::InsufficientBalance {
chain_id: output_asset.0,
asset: output_asset.1,
recipient: *recipient,
amount: *amount,
},
)?;
*filled = filled.checked_sub(*amount).ok_or(MarketError::InsufficientBalance {
chain_id: output_asset.0,
asset: output_asset.1,
recipient: *recipient,
amount: *amount,
})?;
}
}
Ok(())
}
pub fn checked_remove_aggregate(
&mut self,
aggregate: &AggregateOrders,
) -> Result<(), MarketError> {
self.check_aggregate(aggregate)?;
for (output_asset, recipients) in aggregate.outputs.iter() {
let context_recipients =
self.fills.get_mut(output_asset).expect("checked in check_aggregate");
for (recipient, amount) in recipients {
let filled = context_recipients.get_mut(recipient).unwrap();
*filled = filled.checked_sub(*amount).unwrap();
}
}
Ok(())
}
pub fn check_order(&self, order: &RollupOrders::Order) -> Result<(), MarketError> {
self.check_aggregate(&std::iter::once(order).collect())
}
pub fn checked_remove_order(&mut self, order: &RollupOrders::Order) -> Result<(), MarketError> {
let aggregate = std::iter::once(order).collect();
self.check_aggregate(&aggregate)?;
self.unchecked_remove_aggregate(&aggregate)
}
pub fn unchecked_remove_order(
&mut self,
order: &RollupOrders::Order,
) -> Result<(), MarketError> {
let aggregate = std::iter::once(order).collect();
self.unchecked_remove_aggregate(&aggregate)
}
pub const fn fills(&self) -> &HashMap<(u64, Address), HashMap<Address, U256>> {
&self.fills
}
pub const fn fills_mut(&mut self) -> &mut HashMap<(u64, Address), HashMap<Address, U256>> {
&mut self.fills
}
pub fn check_ru_tx_events(
&self,
fills: &AggregateFills,
orders: &AggregateOrders,
) -> Result<(), MarketError> {
let combined = CombinedContext { context: self, extra: fills };
combined.check_aggregate(orders)?;
Ok(())
}
pub fn checked_remove_ru_tx_events(
&mut self,
fills: &AggregateFills,
orders: &AggregateOrders,
) -> Result<(), MarketError> {
self.check_ru_tx_events(fills, orders)?;
self.absorb(fills);
self.unchecked_remove_aggregate(orders)
}
pub fn unchecked_remove_ru_tx_events(
&mut self,
fills: &AggregateFills,
orders: &AggregateOrders,
) -> Result<(), MarketError> {
self.absorb(fills);
self.unchecked_remove_aggregate(orders)
}
}
struct CombinedContext<'a, 'b> {
context: &'a AggregateFills,
extra: &'b AggregateFills,
}
impl CombinedContext<'_, '_> {
fn balance(&self, output_asset: &(u64, Address), recipient: Address) -> U256 {
self.context
.filled(output_asset, recipient)
.saturating_add(self.extra.filled(output_asset, recipient))
}
fn check_filled(
&self,
output_asset: &(u64, Address),
recipient: Address,
amount: U256,
) -> Result<(), MarketError> {
if self.balance(output_asset, recipient) < amount {
return Err(MarketError::InsufficientBalance {
chain_id: output_asset.0,
asset: output_asset.1,
recipient,
amount,
});
}
Ok(())
}
fn check_aggregate(&self, aggregate: &AggregateOrders) -> Result<(), MarketError> {
for (output_asset, recipients) in aggregate.outputs.iter() {
for (recipient, amount) in recipients {
self.check_filled(output_asset, *recipient, *amount)?;
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use signet_zenith::RollupOrders::{Filled, Order, Output};
#[test]
fn basic_fills() {
let user_a = Address::with_last_byte(1);
let user_b = Address::with_last_byte(2);
let asset_a = Address::with_last_byte(3);
let asset_b = Address::with_last_byte(4);
let a_to_a =
Output { token: asset_a, amount: U256::from(100), recipient: user_a, chainId: 1 };
let b_to_b =
Output { token: asset_b, amount: U256::from(200), recipient: user_b, chainId: 1 };
let a_to_b =
Output { token: asset_a, amount: U256::from(300), recipient: user_b, chainId: 1 };
let fill = Filled { outputs: vec![a_to_a, b_to_b, a_to_b] };
let order =
Order { deadline: U256::ZERO, inputs: vec![], outputs: vec![a_to_a, b_to_b, a_to_b] };
let mut context = AggregateFills::default();
context.add_fill(1, &fill);
assert_eq!(context.fills().len(), 2);
assert_eq!(
context.fills().get(&(1, asset_a)).unwrap().get(&user_a).unwrap(),
&U256::from(100)
);
assert_eq!(
context.fills().get(&(1, asset_b)).unwrap().get(&user_b).unwrap(),
&U256::from(200)
);
assert_eq!(
context.fills().get(&(1, asset_a)).unwrap().get(&user_b).unwrap(),
&U256::from(300)
);
context.checked_remove_order(&order).unwrap();
assert_eq!(context.fills().len(), 2);
assert_eq!(
context.fills().get(&(1, asset_a)).unwrap().get(&user_a).unwrap(),
&U256::from(0)
);
assert_eq!(
context.fills().get(&(1, asset_b)).unwrap().get(&user_b).unwrap(),
&U256::from(0)
);
assert_eq!(
context.fills().get(&(1, asset_a)).unwrap().get(&user_b).unwrap(),
&U256::from(0)
);
}
#[test]
fn empty_everything() {
AggregateFills::default()
.checked_remove_ru_tx_events(&Default::default(), &Default::default())
.unwrap();
}
#[test]
fn absorb_unabsorb() {
let mut context_a = AggregateFills::default();
let mut context_b = AggregateFills::default();
let user = Address::with_last_byte(1);
let asset = Address::with_last_byte(2);
context_a.add_raw_fill(1, asset, user, U256::from(100));
context_b.add_raw_fill(1, asset, user, U256::from(200));
let pre_absorb = context_a.clone();
context_a.absorb(&context_b);
assert_eq!(context_a.filled(&(1, asset), user), U256::from(300));
context_a.unchecked_unabsorb(&context_b).unwrap();
assert_eq!(context_a, pre_absorb);
}
#[test]
fn combined_context_saturates_on_overflow() {
let mut context = AggregateFills::default();
let mut extra = AggregateFills::default();
let user = Address::with_last_byte(1);
let asset = Address::with_last_byte(2);
context.add_raw_fill(1, asset, user, U256::MAX - U256::from(100));
extra.add_raw_fill(1, asset, user, U256::from(200));
let mut orders = AggregateOrders::new();
orders.ingest_raw_output(1, asset, user, U256::from(150));
context.check_ru_tx_events(&extra, &orders).unwrap();
}
#[test]
fn combined_context_near_max_no_overflow() {
let mut context = AggregateFills::default();
let mut extra = AggregateFills::default();
let user = Address::with_last_byte(1);
let asset = Address::with_last_byte(2);
let half_max = U256::MAX / U256::from(2);
context.add_raw_fill(1, asset, user, half_max);
extra.add_raw_fill(1, asset, user, U256::from(100));
let mut orders = AggregateOrders::new();
orders.ingest_raw_output(1, asset, user, half_max + U256::from(50));
context.check_ru_tx_events(&extra, &orders).unwrap();
}
#[test]
fn fill_saturates_at_max() {
let mut context = AggregateFills::default();
context.add_raw_fill(1, Address::ZERO, Address::ZERO, U256::MAX);
context.add_raw_fill(1, Address::ZERO, Address::ZERO, U256::from(1));
assert_eq!(context.filled(&(1, Address::ZERO), Address::ZERO), U256::MAX);
}
#[test]
fn remove_max_from_max() {
let mut context = AggregateFills::default();
context.add_raw_fill(1, Address::ZERO, Address::ZERO, U256::MAX);
let mut aggregate = AggregateOrders::new();
aggregate.ingest_raw_output(1, Address::ZERO, Address::ZERO, U256::MAX);
context.checked_remove_aggregate(&aggregate).unwrap();
assert_eq!(context.filled(&(1, Address::ZERO), Address::ZERO), U256::ZERO);
}
#[test]
fn absorb_saturates_at_max() {
let mut a = AggregateFills::default();
let mut b = AggregateFills::default();
a.add_raw_fill(1, Address::ZERO, Address::ZERO, U256::MAX);
b.add_raw_fill(1, Address::ZERO, Address::ZERO, U256::from(1000));
a.absorb(&b);
assert_eq!(a.filled(&(1, Address::ZERO), Address::ZERO), U256::MAX);
}
#[test]
fn unchecked_remove_aggregate_errors_on_missing_recipient() {
let mut context = AggregateFills::default();
let asset = Address::with_last_byte(1);
context.add_raw_fill(1, asset, Address::with_last_byte(99), U256::from(100));
let mut aggregate = AggregateOrders::new();
aggregate.ingest_raw_output(1, asset, Address::with_last_byte(2), U256::from(50));
let result = context.unchecked_remove_aggregate(&aggregate);
assert!(matches!(result, Err(MarketError::InsufficientBalance { .. })));
}
#[test]
fn checked_remove_handles_missing_recipient() {
let mut context = AggregateFills::default();
let asset = Address::with_last_byte(1);
context.add_raw_fill(1, asset, Address::with_last_byte(99), U256::from(100));
let mut aggregate = AggregateOrders::new();
aggregate.ingest_raw_output(1, asset, Address::with_last_byte(2), U256::from(50));
let result = context.checked_remove_aggregate(&aggregate);
assert!(matches!(result, Err(MarketError::InsufficientBalance { .. })));
}
#[test]
fn insufficient_balance_error_fields() {
let context = AggregateFills::default();
let result = context.check_filled(
&(42, Address::with_last_byte(1)),
Address::with_last_byte(2),
U256::from(100),
);
let err = result.unwrap_err();
assert!(matches!(err, MarketError::InsufficientBalance { chain_id: 42, .. }));
}
#[test]
fn missing_asset_error_fields() {
let context = AggregateFills::default();
let mut aggregate = AggregateOrders::new();
aggregate.ingest_raw_output(42, Address::with_last_byte(1), Address::ZERO, U256::from(100));
let result = context.check_aggregate(&aggregate);
assert!(matches!(result, Err(MarketError::MissingAsset { chain_id: 42, .. })));
}
#[test]
fn fills_across_multiple_chains() {
let mut context = AggregateFills::default();
let user = Address::with_last_byte(1);
let asset = Address::with_last_byte(2);
context.add_raw_fill(1, asset, user, U256::from(100));
context.add_raw_fill(10, asset, user, U256::from(200));
context.add_raw_fill(42161, asset, user, U256::from(300));
assert_eq!(context.filled(&(1, asset), user), U256::from(100));
assert_eq!(context.filled(&(10, asset), user), U256::from(200));
assert_eq!(context.filled(&(42161, asset), user), U256::from(300));
}
#[test]
fn remove_from_wrong_chain_fails() {
let mut context = AggregateFills::default();
let user = Address::with_last_byte(1);
let asset = Address::with_last_byte(2);
context.add_raw_fill(1, asset, user, U256::from(100));
let mut aggregate = AggregateOrders::new();
aggregate.ingest_raw_output(10, asset, user, U256::from(50));
let result = context.checked_remove_aggregate(&aggregate);
assert!(matches!(result, Err(MarketError::MissingAsset { chain_id: 10, .. })));
}
#[test]
fn clone_equality() {
let mut context = AggregateFills::default();
context.add_raw_fill(
1,
Address::with_last_byte(1),
Address::with_last_byte(2),
U256::from(12345),
);
context.add_raw_fill(10, Address::with_last_byte(3), Address::with_last_byte(2), U256::MAX);
let cloned = context.clone();
assert_eq!(context, cloned);
}
}
#[cfg(all(test, feature = "proptest"))]
mod proptests {
use super::*;
use proptest::prelude::*;
fn nonzero_u256() -> impl Strategy<Value = U256> {
any::<U256>().prop_map(|x| if x == U256::ZERO { U256::from(1) } else { x })
}
proptest! {
#[test]
fn add_then_remove_returns_to_zero(
chain_id in 0u64..100,
asset_byte in 0u8..=255,
recipient_byte in 0u8..=255,
amount in nonzero_u256(),
) {
let mut context = AggregateFills::default();
let asset = Address::with_last_byte(asset_byte);
let recipient = Address::with_last_byte(recipient_byte);
context.add_raw_fill(chain_id, asset, recipient, amount);
let mut aggregate = AggregateOrders::new();
aggregate.ingest_raw_output(chain_id, asset, recipient, amount);
context.checked_remove_aggregate(&aggregate).unwrap();
prop_assert_eq!(context.filled(&(chain_id, asset), recipient), U256::ZERO);
}
#[test]
fn fill_addition_is_commutative(
chain_id in 0u64..100,
asset_byte in 0u8..=255,
recipient_byte in 0u8..=255,
amount_a in any::<U256>(),
amount_b in any::<U256>(),
) {
let asset = Address::with_last_byte(asset_byte);
let recipient = Address::with_last_byte(recipient_byte);
let mut ab = AggregateFills::default();
ab.add_raw_fill(chain_id, asset, recipient, amount_a);
ab.add_raw_fill(chain_id, asset, recipient, amount_b);
let mut ba = AggregateFills::default();
ba.add_raw_fill(chain_id, asset, recipient, amount_b);
ba.add_raw_fill(chain_id, asset, recipient, amount_a);
prop_assert_eq!(ab, ba);
}
}
}