use std::collections::{HashMap, HashSet, VecDeque};
use num_bigint::BigUint;
use num_traits::{ToPrimitive, Zero};
use tycho_simulation::tycho_common::{
dto::ProtocolStateDelta,
models::token::Token,
simulation::{
errors::{SimulationError, TransitionError},
protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
},
Bytes,
};
use crate::{
algorithm::AlgorithmError,
feed::market_data::MarketState,
types::{ComponentId, Order, Route, Swap},
};
pub(crate) struct HopDescriptor {
pub(crate) component_id: ComponentId,
pub(crate) token_in: Token,
pub(crate) token_out: Token,
pub(crate) amount_out: Option<BigUint>,
pub(crate) gas: Option<BigUint>,
}
impl HopDescriptor {
pub(crate) fn new(component_id: ComponentId, token_in: Token, token_out: Token) -> Self {
Self { component_id, token_in, token_out, amount_out: None, gas: None }
}
pub(crate) fn with_amounts(mut self, amount_out: BigUint, gas: BigUint) -> Self {
self.amount_out = Some(amount_out);
self.gas = Some(gas);
self
}
}
pub(crate) struct PathAllocation {
pub(crate) hops: Vec<HopDescriptor>,
pub(crate) flow_fraction: f64,
pub(crate) amount_in: BigUint,
pub(crate) amount_out: BigUint,
pub(crate) marginal_price_product: f64,
}
pub(crate) struct SimResult {
pub(crate) amount_out: BigUint,
pub(crate) gas: u64,
pub(crate) marginal_price_product: f64,
}
#[derive(Default)]
pub(crate) struct MarketOverrides(HashMap<ComponentId, Box<dyn ProtocolSim>>);
impl MarketOverrides {
pub(crate) fn empty() -> Self {
Self::default()
}
pub(crate) fn with_override(mut self, id: ComponentId, sim: Box<dyn ProtocolSim>) -> Self {
self.0.insert(id, sim);
self
}
pub(crate) fn with_zero_gas(mut self, id: ComponentId, sim: Box<dyn ProtocolSim>) -> Self {
self.0
.insert(id, Box::new(ZeroGasSim(sim)));
self
}
pub(crate) fn get(&self, id: &ComponentId) -> Option<&dyn ProtocolSim> {
self.0.get(id).map(|b| b.as_ref())
}
}
pub(crate) fn wrap_zero_gas(sim: Box<dyn ProtocolSim>) -> Box<dyn ProtocolSim> {
Box::new(ZeroGasSim(sim))
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ZeroGasSim(Box<dyn ProtocolSim>);
#[typetag::serde]
impl ProtocolSim for ZeroGasSim {
fn fee(&self) -> f64 {
self.0.fee()
}
fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
self.0.spot_price(base, quote)
}
fn get_amount_out(
&self,
amount_in: BigUint,
token_in: &Token,
token_out: &Token,
) -> Result<GetAmountOutResult, SimulationError> {
let mut result = self
.0
.get_amount_out(amount_in, token_in, token_out)?;
result.gas = BigUint::ZERO;
result.new_state = Box::new(ZeroGasSim(result.new_state));
Ok(result)
}
fn get_limits(
&self,
sell_token: Bytes,
buy_token: Bytes,
) -> Result<(BigUint, BigUint), SimulationError> {
self.0.get_limits(sell_token, buy_token)
}
fn delta_transition(
&mut self,
delta: ProtocolStateDelta,
tokens: &HashMap<Bytes, Token>,
balances: &Balances,
) -> Result<(), TransitionError> {
self.0
.delta_transition(delta, tokens, balances)
}
fn clone_box(&self) -> Box<dyn ProtocolSim> {
Box::new(ZeroGasSim(self.0.clone_box()))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn eq(&self, other: &dyn ProtocolSim) -> bool {
other
.as_any()
.downcast_ref::<Self>()
.map(|o| self.0.eq(&*o.0))
.unwrap_or(false)
}
}
pub(crate) fn golden_section_search(
f: impl Fn(f64) -> f64,
mut lo: f64,
mut hi: f64,
max_evals: usize,
) -> f64 {
let inv_phi = (5_f64.sqrt() - 1.0) / 2.0;
let mut x1 = hi - inv_phi * (hi - lo);
let mut x2 = lo + inv_phi * (hi - lo);
let mut f1 = f(x1);
let mut f2 = f(x2);
let remaining = max_evals.saturating_sub(2);
for _ in 0..remaining {
if f1 < f2 {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + inv_phi * (hi - lo);
f2 = f(x2);
} else {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - inv_phi * (hi - lo);
f1 = f(x1);
}
}
if f1 >= f2 {
x1
} else {
x2
}
}
pub(crate) fn split_amount(total: &BigUint, fraction: f64) -> (BigUint, BigUint) {
let clamped = fraction.clamp(0.0, 1.0);
let scale: u64 = 1_000_000_000_000_000_000;
let numerator = (clamped * scale as f64) as u64;
let part = (total * BigUint::from(numerator)) / BigUint::from(scale);
let remainder = total - ∂
(part, remainder)
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub(crate) enum SplitMathError {
#[error("fractions slice must not be empty")]
EmptyFractions,
#[error("all fractions are zero, cannot normalize")]
AllZeroFractions,
#[error("fractions must not be negative")]
NegativeFraction,
}
pub(crate) fn normalize_fractions(fractions: &mut [f64]) -> Result<(), SplitMathError> {
if fractions.is_empty() {
return Err(SplitMathError::EmptyFractions);
}
if fractions.iter().any(|&f| f < 0.0) {
return Err(SplitMathError::NegativeFraction);
}
let sum: f64 = fractions.iter().sum();
if sum == 0.0 {
return Err(SplitMathError::AllZeroFractions);
}
for f in fractions.iter_mut() {
*f /= sum;
}
Ok(())
}
pub(crate) fn fractions_to_amounts(
total: &BigUint,
fractions: &[f64],
) -> Result<Vec<BigUint>, SplitMathError> {
if fractions.is_empty() {
return Err(SplitMathError::EmptyFractions);
}
let n = fractions.len();
let mut amounts = Vec::with_capacity(n);
let mut running_sum = BigUint::zero();
for &frac in &fractions[..n - 1] {
let (part, _) = split_amount(total, frac);
running_sum += ∂
amounts.push(part);
}
amounts.push(total - &running_sum);
Ok(amounts)
}
pub(crate) fn compute_marginal_price_product(
hops: &[HopDescriptor],
market: &MarketState,
overrides: &MarketOverrides,
) -> Result<f64, AlgorithmError> {
let mut product = 1.0;
for hop in hops {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id))
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(hop.component_id.clone()),
})?;
let price = sim
.spot_price(&hop.token_in, &hop.token_out)
.map_err(|e| AlgorithmError::SimulationFailed {
component_id: hop.component_id.clone(),
error: e.to_string(),
})?;
product *= price;
}
Ok(product)
}
pub(crate) fn simulate_path(
hops: &[HopDescriptor],
amount_in: &BigUint,
market: &MarketState,
overrides: &MarketOverrides,
) -> Result<SimResult, AlgorithmError> {
let mut current_amount = amount_in.clone();
let mut total_gas: u64 = 0;
for hop in hops {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id))
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(hop.component_id.clone()),
})?;
let result = sim
.get_amount_out(current_amount, &hop.token_in, &hop.token_out)
.map_err(|e| AlgorithmError::SimulationFailed {
component_id: hop.component_id.clone(),
error: e.to_string(),
})?;
total_gas = total_gas.saturating_add(result.gas.to_u64().unwrap_or(u64::MAX));
current_amount = result.amount;
}
let marginal_price_product = compute_marginal_price_product(hops, market, overrides)?;
Ok(SimResult { amount_out: current_amount, gas: total_gas, marginal_price_product })
}
pub(crate) fn evaluate_total_output(
paths: &[&[HopDescriptor]],
fractions: &[f64],
total_amount: &BigUint,
market: &MarketState,
overrides: &MarketOverrides,
) -> Result<(BigUint, u64), AlgorithmError> {
let amounts = fractions_to_amounts(total_amount, fractions)
.map_err(|e| AlgorithmError::Other(e.to_string()))?;
let mut total_out = BigUint::zero();
let mut total_gas: u64 = 0;
let mut seen_hops: HashSet<(ComponentId, Bytes, Bytes)> = HashSet::new();
for (path, amount) in paths.iter().zip(amounts.iter()) {
if amount.is_zero() {
continue;
}
let mut current_amount = amount.clone();
for hop in path.iter() {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id))
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(hop.component_id.clone()),
})?;
let result = sim
.get_amount_out(current_amount, &hop.token_in, &hop.token_out)
.map_err(|e| AlgorithmError::SimulationFailed {
component_id: hop.component_id.clone(),
error: e.to_string(),
})?;
let hop_key = (
hop.component_id.clone(),
hop.token_in.address.clone(),
hop.token_out.address.clone(),
);
if seen_hops.insert(hop_key) {
total_gas = total_gas.saturating_add(result.gas.to_u64().unwrap_or(u64::MAX));
}
current_amount = result.amount;
}
total_out += current_amount;
}
Ok((total_out, total_gas))
}
pub(crate) fn build_post_swap_overrides(
paths: &[&PathAllocation],
market: &MarketState,
) -> MarketOverrides {
let mut overrides = MarketOverrides::empty();
for path in paths {
let mut current_amount = path.amount_in.clone();
for hop in &path.hops {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id));
let Some(sim) = sim else { break };
let Ok(result) = sim.get_amount_out(current_amount, &hop.token_in, &hop.token_out)
else {
break;
};
current_amount = result.amount;
overrides = overrides.with_override(hop.component_id.clone(), result.new_state);
}
}
overrides
}
struct SplitSwap {
hop: HopDescriptor,
split: f64,
amount_in: BigUint,
amount_out: BigUint,
gas: BigUint,
}
fn merge_shared_hops(
paths: &[PathAllocation],
) -> Result<HashMap<Bytes, Vec<SplitSwap>>, AlgorithmError> {
type HopKey = (ComponentId, Bytes, Bytes);
let mut hops: HashMap<HopKey, SplitSwap> = HashMap::new();
for path in paths {
for hop in &path.hops {
let key: HopKey = (
hop.component_id.clone(),
hop.token_in.address.clone(),
hop.token_out.address.clone(),
);
let hop_amount_out =
hop.amount_out
.clone()
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "hop amount_out",
id: Some(hop.component_id.clone()),
})?;
let hop_gas = hop
.gas
.clone()
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "hop gas",
id: Some(hop.component_id.clone()),
})?;
hops.entry(key)
.and_modify(|h| {
h.split += path.flow_fraction;
h.amount_out += &hop_amount_out;
})
.or_insert(SplitSwap {
hop: HopDescriptor::new(
hop.component_id.clone(),
hop.token_in.clone(),
hop.token_out.clone(),
),
split: path.flow_fraction,
amount_in: BigUint::ZERO,
amount_out: hop_amount_out,
gas: hop_gas,
});
}
}
let mut branch_collections: HashMap<Bytes, Vec<SplitSwap>> = HashMap::new();
for (_, swap) in hops {
branch_collections
.entry(swap.hop.token_in.address.clone())
.or_default()
.push(swap);
}
for branch_collection in branch_collections.values_mut() {
branch_collection.sort_by(|a, b| {
b.split
.partial_cmp(&a.split)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
Ok(branch_collections)
}
fn assign_splits_and_amounts(
mut hops: Vec<SplitSwap>,
total_available: &BigUint,
) -> Vec<SplitSwap> {
let len = hops.len();
let fraction_total: f64 = hops.iter().map(|h| h.split).sum();
let normalized: Vec<f64> = hops
.iter()
.map(|h| h.split / fraction_total)
.collect();
let amounts = fractions_to_amounts(total_available, &normalized)
.unwrap_or_else(|_| vec![total_available.clone()]);
for (i, (swap, amount)) in hops.iter_mut().zip(amounts).enumerate() {
swap.amount_in = amount;
swap.split = if i == len - 1 { 0.0 } else { normalized[i] };
}
hops
}
pub(crate) fn build_split_route(
paths: &[PathAllocation],
market: &MarketState,
order: &Order,
) -> Result<Route, AlgorithmError> {
let mut hops_by_token = merge_shared_hops(paths)?;
let mut pending_tokens = VecDeque::new();
pending_tokens.push_back(order.token_in().clone());
let mut available: HashMap<Bytes, BigUint> = HashMap::new();
available.insert(order.token_in().clone(), order.amount().clone());
let mut swaps = Vec::new();
let mut route_tokens: HashMap<Bytes, Token> = HashMap::new();
let mut visited: HashSet<Bytes> = HashSet::new();
while let Some(token_addr) = pending_tokens.pop_front() {
if !visited.insert(token_addr.clone()) {
continue;
}
let Some(branch_collection) = hops_by_token.remove(&token_addr) else {
continue;
};
let total = available
.get(&token_addr)
.cloned()
.unwrap_or_default();
for split_swap in assign_splits_and_amounts(branch_collection, &total) {
let sim = market
.get_simulation_state(&split_swap.hop.component_id)
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(split_swap.hop.component_id.clone()),
})?;
let component = market
.get_component(&split_swap.hop.component_id)
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "protocol component",
id: Some(split_swap.hop.component_id.clone()),
})?;
let in_addr = split_swap.hop.token_in.address.clone();
let out_addr = split_swap.hop.token_out.address.clone();
*available
.entry(out_addr.clone())
.or_default() += &split_swap.amount_out;
swaps.push(
Swap::new(
split_swap.hop.component_id,
component.protocol_system.clone(),
in_addr.clone(),
out_addr.clone(),
split_swap.amount_in,
split_swap.amount_out,
split_swap.gas,
component.clone(),
sim.clone_box(),
)
.with_split(split_swap.split),
);
route_tokens
.entry(in_addr)
.or_insert(split_swap.hop.token_in);
route_tokens
.entry(out_addr.clone())
.or_insert(split_swap.hop.token_out);
if !visited.contains(&out_addr) {
pending_tokens.push_back(out_addr);
}
}
}
Ok(Route::new(swaps, route_tokens))
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::{
algorithm::test_utils::{component, order, token, ConstantProductSim, MockProtocolSim},
types::OrderSide,
};
fn make_market(pools: Vec<(&str, Vec<Token>, Box<dyn ProtocolSim>)>) -> MarketState {
let mut market = MarketState::new();
for (pool_id, tokens, sim) in pools {
market.upsert_components(std::iter::once(component(pool_id, &tokens)));
market.update_states([(pool_id.to_string(), sim)]);
market.upsert_tokens(tokens);
}
market
}
#[test]
fn test_split_amount_exact_sum() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
for fraction in [0.1, 0.5, 0.9, 0.999] {
let (part, remainder) = split_amount(&total, fraction);
assert_eq!(
&part + &remainder,
total,
"part + remainder must equal total for fraction={fraction}"
);
}
}
#[test]
fn test_split_amount_edge_fraction_zero() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, 0.0);
assert!(part.is_zero());
assert_eq!(remainder, total);
}
#[test]
fn test_split_amount_clamps_above_one() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, 1.5);
assert_eq!(part, total);
assert!(remainder.is_zero());
}
#[test]
fn test_split_amount_clamps_negative() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, -0.5);
assert!(part.is_zero());
assert_eq!(remainder, total);
}
#[test]
fn test_fractions_to_amounts_exact_sum() {
let total = BigUint::from(999_999_999_999_999_999_u64);
let fractions = [0.3, 0.5, 0.2];
let amounts = fractions_to_amounts(&total, &fractions).unwrap();
assert_eq!(amounts.len(), 3);
let sum: BigUint = amounts.iter().sum();
assert_eq!(sum, total, "amounts must sum exactly to total");
}
#[test]
fn test_fractions_to_amounts_empty() {
let total = BigUint::from(1_000_u64);
let err = fractions_to_amounts(&total, &[]).unwrap_err();
assert_eq!(err, SplitMathError::EmptyFractions);
}
#[rstest]
#[case::already_normalized(&[0.3, 0.5, 0.2])]
#[case::drift(&[0.33, 0.33, 0.33])]
fn test_normalize_fractions(#[case] input: &[f64]) {
let mut fractions = input.to_vec();
normalize_fractions(&mut fractions).unwrap();
let sum: f64 = fractions.iter().sum();
assert!((sum - 1.0).abs() < f64::EPSILON);
}
#[rstest]
#[case::empty(&[], SplitMathError::EmptyFractions)]
#[case::all_zeros(&[0.0, 0.0, 0.0], SplitMathError::AllZeroFractions)]
#[case::negative(&[-0.5, 0.5], SplitMathError::NegativeFraction)]
fn test_normalize_fractions_invalid(#[case] input: &[f64], #[case] expected: SplitMathError) {
let mut fractions = input.to_vec();
let err = normalize_fractions(&mut fractions).unwrap_err();
assert_eq!(err, expected);
}
#[test]
fn test_golden_section_finds_maximum() {
let f = |x: f64| -(x - 0.3) * (x - 0.3);
let result = golden_section_search(f, 0.0, 1.0, 100);
assert!((result - 0.3).abs() < 1e-4, "expected ~0.3, got {result}");
}
#[test]
fn test_compute_marginal_price_product_single_hop() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(3.0)),
)]);
let hops = [HopDescriptor::new("pool_ab".to_string(), token_a, token_b)];
let product =
compute_marginal_price_product(&hops, &market, &MarketOverrides::empty()).unwrap();
assert!((product - 3.0).abs() < f64::EPSILON, "expected 3.0, got {product}");
}
#[test]
fn test_compute_marginal_price_product_multi_hop() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_bc",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(4.0)),
),
]);
let hops = [
HopDescriptor::new("pool_ab".to_string(), token_a, token_b.clone()),
HopDescriptor::new("pool_bc".to_string(), token_b, token_c),
];
let product =
compute_marginal_price_product(&hops, &market, &MarketOverrides::empty()).unwrap();
assert!((product - 8.0).abs() < f64::EPSILON, "expected 8.0, got {product}");
}
#[test]
fn test_compute_marginal_price_product_uses_overrides() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(3.0)),
)]);
let hops = [HopDescriptor::new("pool_ab".to_string(), token_a, token_b)];
let overrides = MarketOverrides::empty()
.with_override("pool_ab".to_string(), Box::new(MockProtocolSim::new(7.0)));
let product = compute_marginal_price_product(&hops, &market, &overrides).unwrap();
assert!((product - 7.0).abs() < f64::EPSILON, "expected 7.0, got {product}");
}
#[test]
fn test_simulate_path_correct_output() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_bc",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
]);
let hops = [
HopDescriptor::new("pool_ab".to_string(), token_a, token_b.clone()),
HopDescriptor::new("pool_bc".to_string(), token_b, token_c),
];
let amount_in = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let result = simulate_path(&hops, &amount_in, &market, &overrides).unwrap();
assert_eq!(result.amount_out, BigUint::from(6000u64));
assert!(
(result.marginal_price_product - 6.0).abs() < f64::EPSILON,
"expected marginal_price_product 6.0, got {}",
result.marginal_price_product
);
}
#[test]
fn test_market_overrides_with_zero_gas() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let sim_ab = MockProtocolSim::new(2.0).with_gas(100_000);
let sim_bc = MockProtocolSim::new(3.0).with_gas(70_000);
let market = make_market(vec![
("pool_ab", vec![token_a.clone(), token_b.clone()], Box::new(sim_ab.clone())),
("pool_bc", vec![token_b.clone(), token_c.clone()], Box::new(sim_bc.clone())),
]);
let overrides = MarketOverrides::empty()
.with_zero_gas("pool_ab".to_string(), Box::new(sim_ab))
.with_override("pool_bc".to_string(), Box::new(sim_bc));
let hops_ab = [HopDescriptor::new("pool_ab".to_string(), token_a.clone(), token_b.clone())];
let hops_bc = [HopDescriptor::new("pool_bc".to_string(), token_b, token_c)];
let amount_in = BigUint::from(1000u64);
let normal_ab =
simulate_path(&hops_ab, &amount_in, &market, &MarketOverrides::empty()).unwrap();
let zero_gas_ab = simulate_path(&hops_ab, &amount_in, &market, &overrides).unwrap();
assert_eq!(normal_ab.amount_out, zero_gas_ab.amount_out);
assert!(normal_ab.gas > 0, "normal gas should be non-zero");
assert_eq!(zero_gas_ab.gas, 0, "zero-gas override should report gas=0");
let result_bc = simulate_path(&hops_bc, &amount_in, &market, &overrides).unwrap();
assert_eq!(result_bc.gas, 70_000, "non-zero-gas override should keep its gas");
}
#[test]
fn test_evaluate_total_output_two_paths() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![
(
"pool_1",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0).with_gas(50_000)),
),
(
"pool_2",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(3.0).with_gas(60_000)),
),
]);
let hops_1 = [HopDescriptor::new("pool_1".to_string(), token_a.clone(), token_b.clone())];
let hops_2 = [HopDescriptor::new("pool_2".to_string(), token_a, token_b)];
let paths: Vec<&[HopDescriptor]> = vec![&hops_1, &hops_2];
let fractions = [0.5, 0.5];
let total_amount = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let (total_out, total_gas) =
evaluate_total_output(&paths, &fractions, &total_amount, &market, &overrides).unwrap();
assert_eq!(total_out, BigUint::from(2500u64));
assert_eq!(total_gas, 110_000);
}
#[test]
fn test_evaluate_total_output_gas_deduplication() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let token_d = token(0x0D, "D");
let market = make_market(vec![
(
"P1",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0).with_gas(100_000)),
),
(
"P2",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(1.5).with_gas(50_000)),
),
(
"P3",
vec![token_b.clone(), token_d.clone()],
Box::new(MockProtocolSim::new(3.0).with_gas(70_000)),
),
]);
let hops_1 = [
HopDescriptor::new("P1".to_string(), token_a.clone(), token_b.clone()),
HopDescriptor::new("P2".to_string(), token_b.clone(), token_c),
];
let hops_2 = [
HopDescriptor::new("P1".to_string(), token_a, token_b.clone()),
HopDescriptor::new("P3".to_string(), token_b, token_d),
];
let paths: Vec<&[HopDescriptor]> = vec![&hops_1, &hops_2];
let fractions = [0.5, 0.5];
let total_amount = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let (_, total_gas) =
evaluate_total_output(&paths, &fractions, &total_amount, &market, &overrides).unwrap();
assert_eq!(total_gas, 220_000);
}
#[test]
fn test_gas_dedup_different_tokens() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![(
"tripool",
vec![token_a.clone(), token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(1.0).with_gas(80_000)),
)]);
let hops_1 = [HopDescriptor::new("tripool".to_string(), token_a, token_b.clone())];
let hops_2 = [HopDescriptor::new("tripool".to_string(), token_b, token_c)];
let paths: Vec<&[HopDescriptor]> = vec![&hops_1, &hops_2];
let fractions = [0.5, 0.5];
let total_amount = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let (_, total_gas) =
evaluate_total_output(&paths, &fractions, &total_amount, &market, &overrides).unwrap();
assert_eq!(total_gas, 160_000);
}
#[test]
fn test_build_post_swap_overrides_degrades_used_pools() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(ConstantProductSim {
reserve_0: BigUint::from(10_000u64),
reserve_1: BigUint::from(20_000u64),
gas: 50_000,
}),
)]);
let allocation = PathAllocation {
hops: vec![HopDescriptor::new("pool_ab".to_string(), token_a.clone(), token_b.clone())],
flow_fraction: 1.0,
amount_in: BigUint::from(1000u64),
amount_out: BigUint::from(1818u64),
marginal_price_product: 2.0,
};
let degraded = build_post_swap_overrides(&[&allocation], &market);
let probe = BigUint::from(100u64);
let fresh_out = market
.get_simulation_state("pool_ab")
.unwrap()
.get_amount_out(probe.clone(), &token_a, &token_b)
.unwrap()
.amount;
assert_eq!(fresh_out, BigUint::from(198u64));
let degraded_out = degraded
.get(&"pool_ab".to_string())
.unwrap()
.get_amount_out(probe, &token_a, &token_b)
.unwrap()
.amount;
assert_eq!(degraded_out, BigUint::from(163u64));
}
#[test]
fn test_merge_shared_hops_combines_fractions() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![
HopDescriptor::new("P1".to_string(), token_a.clone(), token_b.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("P2".to_string(), token_b.clone(), token_c.clone())
.with_amounts(BigUint::from(3600u64), gas.clone()),
],
flow_fraction: 0.6,
amount_in: BigUint::from(600u64),
amount_out: BigUint::from(3600u64),
marginal_price_product: 6.0,
},
PathAllocation {
hops: vec![
HopDescriptor::new("P1".to_string(), token_a.clone(), token_b.clone())
.with_amounts(BigUint::from(800u64), gas.clone()),
HopDescriptor::new("P3".to_string(), token_b.clone(), token_c.clone())
.with_amounts(BigUint::from(1600u64), gas),
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(1600u64),
marginal_price_product: 4.0,
},
];
let hops_by_token = merge_shared_hops(&paths).unwrap();
let branch_collection_a = &hops_by_token[&token_a.address];
assert_eq!(branch_collection_a.len(), 1);
assert_eq!(branch_collection_a[0].hop.component_id, "P1");
assert!((branch_collection_a[0].split - 1.0).abs() < f64::EPSILON);
let branch_collection_b = &hops_by_token[&token_b.address];
assert_eq!(branch_collection_b.len(), 2);
assert_eq!(branch_collection_b[0].hop.component_id, "P2");
assert!((branch_collection_b[0].split - 0.6).abs() < f64::EPSILON);
assert_eq!(branch_collection_b[1].hop.component_id, "P3");
assert!((branch_collection_b[1].split - 0.4).abs() < f64::EPSILON);
}
#[test]
fn test_assign_splits_and_amounts_splits_and_amounts() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let branch_collection = vec![
SplitSwap {
hop: HopDescriptor::new("pool1".to_string(), token_a.clone(), token_b.clone()),
split: 0.7,
amount_in: BigUint::ZERO,
amount_out: BigUint::ZERO,
gas: BigUint::ZERO,
},
SplitSwap {
hop: HopDescriptor::new("pool2".to_string(), token_a.clone(), token_b.clone()),
split: 0.3,
amount_in: BigUint::ZERO,
amount_out: BigUint::ZERO,
gas: BigUint::ZERO,
},
];
let result = assign_splits_and_amounts(branch_collection, &BigUint::from(1000u64));
assert_eq!(result.len(), 2);
assert_eq!(result[0].split, 0.7);
assert_eq!(result[0].amount_in, BigUint::from(700u64));
assert_eq!(result[1].split, 0.0);
assert_eq!(result[1].amount_in, BigUint::from(300u64));
}
#[test]
fn test_assign_splits_and_amounts_single_hop() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let branch_collection = vec![SplitSwap {
hop: HopDescriptor::new("pool1".to_string(), token_a, token_b),
split: 1.0,
amount_in: BigUint::ZERO,
amount_out: BigUint::ZERO,
gas: BigUint::ZERO,
}];
let total = BigUint::from(1000u64);
let result = assign_splits_and_amounts(branch_collection, &total);
assert_eq!(result.len(), 1);
assert_eq!(result[0].split, 0.0);
assert_eq!(result[0].amount_in, total);
}
#[test]
fn test_build_split_route_remainder_convention() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![
("pool1", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(2.0))),
("pool2", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(3.0))),
("pool3", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(4.0))),
]);
let ord = order(&token_a, &token_b, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![HopDescriptor::new(
"pool1".to_string(),
token_a.clone(),
token_b.clone(),
)
.with_amounts(BigUint::from(1000u64), gas.clone())],
flow_fraction: 0.5,
amount_in: BigUint::from(500u64),
amount_out: BigUint::from(1000u64),
marginal_price_product: 2.0,
},
PathAllocation {
hops: vec![HopDescriptor::new(
"pool2".to_string(),
token_a.clone(),
token_b.clone(),
)
.with_amounts(BigUint::from(900u64), gas.clone())],
flow_fraction: 0.3,
amount_in: BigUint::from(300u64),
amount_out: BigUint::from(900u64),
marginal_price_product: 3.0,
},
PathAllocation {
hops: vec![HopDescriptor::new(
"pool3".to_string(),
token_a.clone(),
token_b.clone(),
)
.with_amounts(BigUint::from(800u64), gas)],
flow_fraction: 0.2,
amount_in: BigUint::from(200u64),
amount_out: BigUint::from(800u64),
marginal_price_product: 4.0,
},
];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 3);
assert_eq!(swaps[0].component_id(), "pool1");
assert_eq!(*swaps[0].split(), 0.5);
assert_eq!(swaps[1].component_id(), "pool2");
assert_eq!(*swaps[1].split(), 0.3);
assert_eq!(swaps[2].component_id(), "pool3");
assert_eq!(*swaps[2].split(), 0.0);
}
#[test]
fn test_build_split_route_single_path() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_bc",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
]);
let ord = order(&token_a, &token_c, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![PathAllocation {
hops: vec![
HopDescriptor::new("pool_ab".to_string(), token_a.clone(), token_b.clone())
.with_amounts(BigUint::from(2000u64), gas.clone()),
HopDescriptor::new("pool_bc".to_string(), token_b, token_c)
.with_amounts(BigUint::from(6000u64), gas),
],
flow_fraction: 1.0,
amount_in: BigUint::from(1000u64),
amount_out: BigUint::from(6000u64),
marginal_price_product: 6.0,
}];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 2);
for swap in swaps {
assert_eq!(*swap.split(), 0.0, "single path should produce all-zero splits");
}
}
#[test]
fn test_build_split_route_shared_first_pool() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
("P1", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(2.0))),
("P2", vec![token_b.clone(), token_c.clone()], Box::new(MockProtocolSim::new(3.0))),
("P3", vec![token_b.clone(), token_c.clone()], Box::new(MockProtocolSim::new(4.0))),
]);
let ord = order(&token_a, &token_c, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![
HopDescriptor::new("P1".to_string(), token_a.clone(), token_b.clone())
.with_amounts(BigUint::from(1400u64), gas.clone()),
HopDescriptor::new("P2".to_string(), token_b.clone(), token_c.clone())
.with_amounts(BigUint::from(4200u64), gas.clone()),
],
flow_fraction: 0.7,
amount_in: BigUint::from(700u64),
amount_out: BigUint::from(4200u64),
marginal_price_product: 6.0,
},
PathAllocation {
hops: vec![
HopDescriptor::new("P1".to_string(), token_a.clone(), token_b.clone())
.with_amounts(BigUint::from(600u64), gas.clone()),
HopDescriptor::new("P3".to_string(), token_b.clone(), token_c.clone())
.with_amounts(BigUint::from(2400u64), gas),
],
flow_fraction: 0.3,
amount_in: BigUint::from(300u64),
amount_out: BigUint::from(1200u64),
marginal_price_product: 8.0,
},
];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 3, "expected 3 swaps, got {}", swaps.len());
let ab_swap = &swaps[0];
assert_eq!(ab_swap.component_id(), "P1");
assert_eq!(
*ab_swap.amount_in(),
BigUint::from(1000u64),
"A→B swap amount_in should equal sum of both paths"
);
assert_eq!(
*ab_swap.amount_out(),
BigUint::from(2000u64),
"A→B amount_out should be sum of per-path outputs (1400+600)"
);
assert_eq!(
*ab_swap.split(),
0.0,
"A→B is the sole swap in its branch collection, so it gets the remainder convention (split = 0.0)"
);
assert_eq!(swaps[1].component_id(), "P2");
assert_eq!(*swaps[1].split(), 0.7);
assert_eq!(swaps[2].component_id(), "P3");
assert_eq!(*swaps[2].split(), 0.0);
}
#[test]
fn test_build_split_route_source_level_split_different_intermediates() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let token_z = token(0x1A, "Z");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_ac",
vec![token_a.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
(
"pool_bz",
vec![token_b.clone(), token_z.clone()],
Box::new(MockProtocolSim::new(4.0)),
),
(
"pool_cz",
vec![token_c.clone(), token_z.clone()],
Box::new(MockProtocolSim::new(5.0)),
),
]);
let ord = order(&token_a, &token_z, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![
HopDescriptor::new("pool_ab".to_string(), token_a.clone(), token_b.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_bz".to_string(), token_b, token_z.clone())
.with_amounts(BigUint::from(4800u64), gas.clone()),
],
flow_fraction: 0.6,
amount_in: BigUint::from(600u64),
amount_out: BigUint::from(4800u64),
marginal_price_product: 8.0,
},
PathAllocation {
hops: vec![
HopDescriptor::new("pool_ac".to_string(), token_a.clone(), token_c.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_cz".to_string(), token_c, token_z)
.with_amounts(BigUint::from(6000u64), gas),
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(6000u64),
marginal_price_product: 15.0,
},
];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 4, "expected 4 swaps (2 source + 2 intermediate)");
assert_eq!(swaps[0].component_id(), "pool_ab");
assert_eq!(*swaps[0].split(), 0.6);
assert_eq!(*swaps[0].amount_in(), BigUint::from(600u64));
assert_eq!(*swaps[0].amount_out(), BigUint::from(1200u64));
assert_eq!(swaps[1].component_id(), "pool_ac");
assert_eq!(*swaps[1].split(), 0.0);
assert_eq!(*swaps[1].amount_in(), BigUint::from(400u64));
assert_eq!(*swaps[1].amount_out(), BigUint::from(1200u64));
assert_eq!(swaps[2].component_id(), "pool_bz");
assert_eq!(*swaps[2].split(), 0.0);
assert_eq!(*swaps[2].amount_in(), BigUint::from(1200u64));
assert_eq!(*swaps[2].amount_out(), BigUint::from(4800u64));
assert_eq!(swaps[3].component_id(), "pool_cz");
assert_eq!(*swaps[3].split(), 0.0);
assert_eq!(*swaps[3].amount_in(), BigUint::from(1200u64));
assert_eq!(*swaps[3].amount_out(), BigUint::from(6000u64));
}
}