use std::collections::{HashMap, HashSet, VecDeque};
use num_bigint::BigUint;
use num_traits::{ToPrimitive, Zero};
use tycho_simulation::tycho_common::{
dto::ProtocolStateDelta,
models::token::Token,
simulation::{
errors::{SimulationError, TransitionError},
protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
},
Bytes,
};
use crate::{
algorithm::AlgorithmError,
feed::market_data::MarketState,
types::{ComponentId, Order, Route, Swap},
};
#[derive(Clone)]
pub(crate) struct HopDescriptor {
pub(crate) component_id: ComponentId,
pub(crate) token_in: Token,
pub(crate) token_out: Token,
}
impl HopDescriptor {
pub(crate) fn new(component_id: ComponentId, token_in: Token, token_out: Token) -> Self {
Self { component_id, token_in, token_out }
}
#[cfg(test)]
pub(crate) fn with_amounts(self, amount_out: BigUint, gas: BigUint) -> SimulatedHop {
SimulatedHop { descriptor: self, amount_out, gas }
}
}
#[derive(Clone)]
pub(crate) struct SimulatedHop {
pub(crate) descriptor: HopDescriptor,
pub(crate) amount_out: BigUint,
pub(crate) gas: BigUint,
}
#[derive(Clone)]
pub(crate) struct PathAllocation {
pub(crate) hops: Vec<SimulatedHop>,
pub(crate) flow_fraction: f64,
pub(crate) amount_in: BigUint,
pub(crate) amount_out: BigUint,
pub(crate) marginal_price_product: f64,
}
impl PathAllocation {
pub(crate) fn validate_token_cycles(&self) -> Result<(), AlgorithmError> {
if self.hops.is_empty() {
return Err(AlgorithmError::Other("path has no hops".to_string()));
}
let first_token = &self.hops[0].descriptor.token_in.address;
let mut seen = HashSet::new();
seen.insert(first_token.clone());
let last_idx = self.hops.len() - 1;
for (i, hop) in self.hops.iter().enumerate() {
let out_addr = &hop.descriptor.token_out.address;
if !seen.insert(out_addr.clone()) {
let is_valid_round_trip = i == last_idx && out_addr == first_token;
if !is_valid_round_trip {
return Err(AlgorithmError::Other(format!(
"path revisits token {out_addr} at hop {i} \
(would corrupt merge_shared_hops)",
)));
}
}
}
Ok(())
}
}
pub(crate) struct SimResult {
pub(crate) amount_out: BigUint,
pub(crate) gas: u64,
pub(crate) marginal_price_product: f64,
}
#[derive(Default)]
pub(crate) struct MarketOverrides(HashMap<ComponentId, Box<dyn ProtocolSim>>);
impl MarketOverrides {
pub(crate) fn empty() -> Self {
Self::default()
}
pub(crate) fn with_override(mut self, id: ComponentId, sim: Box<dyn ProtocolSim>) -> Self {
self.0.insert(id, sim);
self
}
pub(crate) fn with_zero_gas(
mut self,
id: ComponentId,
token_in: Bytes,
token_out: Bytes,
) -> Self {
if let Some(sim) = self.0.remove(&id) {
let wrapped = if let Some(selective) = sim
.as_any()
.downcast_ref::<SelectiveZeroGasSim>()
{
let mut pairs = selective.zero_gas_pairs.clone();
pairs.insert((token_in, token_out));
Box::new(SelectiveZeroGasSim {
inner: selective.inner.clone_box(),
zero_gas_pairs: pairs,
}) as Box<dyn ProtocolSim>
} else {
let mut pairs = HashSet::new();
pairs.insert((token_in, token_out));
Box::new(SelectiveZeroGasSim { inner: sim, zero_gas_pairs: pairs })
};
self.0.insert(id, wrapped);
}
self
}
pub(crate) fn get(&self, id: &ComponentId) -> Option<&dyn ProtocolSim> {
self.0.get(id).map(|b| b.as_ref())
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct SelectiveZeroGasSim {
inner: Box<dyn ProtocolSim>,
zero_gas_pairs: HashSet<(Bytes, Bytes)>,
}
#[typetag::serde]
impl ProtocolSim for SelectiveZeroGasSim {
fn fee(&self) -> f64 {
self.inner.fee()
}
fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
self.inner.spot_price(base, quote)
}
fn get_amount_out(
&self,
amount_in: BigUint,
token_in: &Token,
token_out: &Token,
) -> Result<GetAmountOutResult, SimulationError> {
let mut result = self
.inner
.get_amount_out(amount_in, token_in, token_out)?;
if self
.zero_gas_pairs
.contains(&(token_in.address.clone(), token_out.address.clone()))
{
result.gas = BigUint::ZERO;
}
result.new_state = Box::new(SelectiveZeroGasSim {
inner: result.new_state,
zero_gas_pairs: self.zero_gas_pairs.clone(),
});
Ok(result)
}
fn get_limits(
&self,
sell_token: Bytes,
buy_token: Bytes,
) -> Result<(BigUint, BigUint), SimulationError> {
self.inner
.get_limits(sell_token, buy_token)
}
fn delta_transition(
&mut self,
delta: ProtocolStateDelta,
tokens: &HashMap<Bytes, Token>,
balances: &Balances,
) -> Result<(), TransitionError> {
self.inner
.delta_transition(delta, tokens, balances)
}
fn clone_box(&self) -> Box<dyn ProtocolSim> {
Box::new(SelectiveZeroGasSim {
inner: self.inner.clone_box(),
zero_gas_pairs: self.zero_gas_pairs.clone(),
})
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn eq(&self, other: &dyn ProtocolSim) -> bool {
other
.as_any()
.downcast_ref::<Self>()
.map(|o| self.inner.eq(&*o.inner) && self.zero_gas_pairs == o.zero_gas_pairs)
.unwrap_or(false)
}
}
pub(crate) fn golden_section_search(
f: impl Fn(f64) -> f64,
mut lo: f64,
mut hi: f64,
max_evals: usize,
) -> f64 {
let inv_phi = (5_f64.sqrt() - 1.0) / 2.0;
let mut x1 = hi - inv_phi * (hi - lo);
let mut x2 = lo + inv_phi * (hi - lo);
let mut f1 = f(x1);
let mut f2 = f(x2);
let remaining = max_evals.saturating_sub(2);
for _ in 0..remaining {
if f1 < f2 {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + inv_phi * (hi - lo);
f2 = f(x2);
} else {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - inv_phi * (hi - lo);
f1 = f(x1);
}
}
if f1 >= f2 {
x1
} else {
x2
}
}
pub(crate) fn split_amount(total: &BigUint, fraction: f64) -> (BigUint, BigUint) {
let clamped = fraction.clamp(0.0, 1.0);
let scale: u64 = 1_000_000_000_000_000_000;
let numerator = (clamped * scale as f64) as u64;
let part = (total * BigUint::from(numerator)) / BigUint::from(scale);
let remainder = total - ∂
(part, remainder)
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub(crate) enum SplitMathError {
#[error("fractions slice must not be empty")]
EmptyFractions,
#[error("all fractions are zero, cannot normalize")]
AllZeroFractions,
#[error("fractions must not be negative")]
NegativeFraction,
}
pub(crate) fn normalize_fractions(fractions: &mut [f64]) -> Result<(), SplitMathError> {
if fractions.is_empty() {
return Err(SplitMathError::EmptyFractions);
}
if fractions.iter().any(|&f| f < 0.0) {
return Err(SplitMathError::NegativeFraction);
}
let sum: f64 = fractions.iter().sum();
if sum == 0.0 {
return Err(SplitMathError::AllZeroFractions);
}
for f in fractions.iter_mut() {
*f /= sum;
}
Ok(())
}
pub(crate) fn fractions_to_amounts(
total: &BigUint,
fractions: &[f64],
) -> Result<Vec<BigUint>, SplitMathError> {
if fractions.is_empty() {
return Err(SplitMathError::EmptyFractions);
}
let n = fractions.len();
let mut amounts = Vec::with_capacity(n);
let mut running_sum = BigUint::zero();
for &frac in &fractions[..n - 1] {
let (part, _) = split_amount(total, frac);
running_sum += ∂
amounts.push(part);
}
amounts.push(total - &running_sum);
Ok(amounts)
}
pub(crate) fn compute_marginal_price_product(
hops: &[HopDescriptor],
market: &MarketState,
overrides: &MarketOverrides,
) -> Result<f64, AlgorithmError> {
let mut product = 1.0;
for hop in hops {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id))
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(hop.component_id.clone()),
})?;
let price = sim
.spot_price(&hop.token_in, &hop.token_out)
.map_err(|e| AlgorithmError::SimulationFailed {
component_id: hop.component_id.clone(),
error: e.to_string(),
})?;
product *= price;
}
Ok(product)
}
pub(crate) fn simulate_path(
hops: &[HopDescriptor],
amount_in: &BigUint,
market: &MarketState,
overrides: &MarketOverrides,
) -> Result<SimResult, AlgorithmError> {
let mut current_amount = amount_in.clone();
let mut total_gas: u64 = 0;
for hop in hops {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id))
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(hop.component_id.clone()),
})?;
let result = sim
.get_amount_out(current_amount, &hop.token_in, &hop.token_out)
.map_err(|e| AlgorithmError::SimulationFailed {
component_id: hop.component_id.clone(),
error: e.to_string(),
})?;
total_gas = total_gas.saturating_add(result.gas.to_u64().unwrap_or(u64::MAX));
current_amount = result.amount;
}
let marginal_price_product = compute_marginal_price_product(hops, market, overrides)?;
Ok(SimResult { amount_out: current_amount, gas: total_gas, marginal_price_product })
}
pub(crate) fn evaluate_total_output(
paths: &[&[HopDescriptor]],
fractions: &[f64],
total_amount: &BigUint,
market: &MarketState,
overrides: &MarketOverrides,
) -> Result<(BigUint, u64), AlgorithmError> {
let amounts = fractions_to_amounts(total_amount, fractions)
.map_err(|e| AlgorithmError::Other(e.to_string()))?;
let mut total_out = BigUint::zero();
let mut total_gas: u64 = 0;
let mut seen_hops: HashSet<(ComponentId, Bytes, Bytes)> = HashSet::new();
for (path, amount) in paths.iter().zip(amounts.iter()) {
if amount.is_zero() {
continue;
}
let mut current_amount = amount.clone();
for hop in path.iter() {
let sim = overrides
.get(&hop.component_id)
.or_else(|| market.get_simulation_state(&hop.component_id))
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(hop.component_id.clone()),
})?;
let result = sim
.get_amount_out(current_amount, &hop.token_in, &hop.token_out)
.map_err(|e| AlgorithmError::SimulationFailed {
component_id: hop.component_id.clone(),
error: e.to_string(),
})?;
let hop_key = (
hop.component_id.clone(),
hop.token_in.address.clone(),
hop.token_out.address.clone(),
);
if seen_hops.insert(hop_key) {
total_gas = total_gas.saturating_add(result.gas.to_u64().unwrap_or(u64::MAX));
}
current_amount = result.amount;
}
total_out += current_amount;
}
Ok((total_out, total_gas))
}
pub(crate) fn build_post_swap_overrides(
paths: &[PathAllocation],
market: &MarketState,
) -> MarketOverrides {
let mut overrides = MarketOverrides::empty();
for path in paths {
let mut current_amount = path.amount_in.clone();
for hop in &path.hops {
let desc = &hop.descriptor;
let sim = overrides
.get(&desc.component_id)
.or_else(|| market.get_simulation_state(&desc.component_id));
let Some(sim) = sim else { break };
let Ok(result) = sim.get_amount_out(current_amount, &desc.token_in, &desc.token_out)
else {
break;
};
current_amount = result.amount;
overrides = overrides.with_override(desc.component_id.clone(), result.new_state);
}
}
overrides
}
struct SplitSwap {
hop: HopDescriptor,
split: f64,
amount_in: BigUint,
amount_out: BigUint,
gas: BigUint,
}
fn merge_shared_hops(
paths: &[PathAllocation],
) -> Result<HashMap<Bytes, Vec<SplitSwap>>, AlgorithmError> {
type HopKey = (ComponentId, Bytes, Bytes);
let mut hops: HashMap<HopKey, SplitSwap> = HashMap::new();
for path in paths {
for hop in &path.hops {
let desc = &hop.descriptor;
let key: HopKey = (
desc.component_id.clone(),
desc.token_in.address.clone(),
desc.token_out.address.clone(),
);
hops.entry(key)
.and_modify(|h| {
h.split += path.flow_fraction;
h.amount_out += &hop.amount_out;
})
.or_insert(SplitSwap {
hop: HopDescriptor::new(
desc.component_id.clone(),
desc.token_in.clone(),
desc.token_out.clone(),
),
split: path.flow_fraction,
amount_in: BigUint::ZERO,
amount_out: hop.amount_out.clone(),
gas: hop.gas.clone(),
});
}
}
let mut branch_collections: HashMap<Bytes, Vec<SplitSwap>> = HashMap::new();
for (_, swap) in hops {
branch_collections
.entry(swap.hop.token_in.address.clone())
.or_default()
.push(swap);
}
for branch_collection in branch_collections.values_mut() {
branch_collection.sort_by(|a, b| {
b.split
.partial_cmp(&a.split)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
Ok(branch_collections)
}
fn assign_splits_and_amounts(
mut hops: Vec<SplitSwap>,
total_available: &BigUint,
) -> Vec<SplitSwap> {
let len = hops.len();
let fraction_total: f64 = hops.iter().map(|h| h.split).sum();
let normalized: Vec<f64> = hops
.iter()
.map(|h| h.split / fraction_total)
.collect();
let amounts = fractions_to_amounts(total_available, &normalized)
.unwrap_or_else(|_| vec![total_available.clone()]);
for (i, (swap, amount)) in hops.iter_mut().zip(amounts).enumerate() {
swap.amount_in = amount;
swap.split = if i == len - 1 { 0.0 } else { normalized[i] };
}
hops
}
pub(crate) fn build_split_route(
paths: &[PathAllocation],
market: &MarketState,
order: &Order,
) -> Result<Route, AlgorithmError> {
for path in paths {
path.validate_token_cycles()?;
}
let mut hops_by_token = merge_shared_hops(paths)?;
let mut in_degree: HashMap<Bytes, usize> = HashMap::new();
for (token_in_addr, branch_collection) in &hops_by_token {
in_degree
.entry(token_in_addr.clone())
.or_insert(0);
for swap in branch_collection {
*in_degree
.entry(swap.hop.token_out.address.clone())
.or_insert(0) += 1;
}
}
let mut ready = VecDeque::new();
ready.push_back(order.token_in().clone());
let mut available: HashMap<Bytes, BigUint> = HashMap::new();
available.insert(order.token_in().clone(), order.amount().clone());
let mut swaps = Vec::new();
let mut route_tokens: HashMap<Bytes, Token> = HashMap::new();
while let Some(token_addr) = ready.pop_front() {
let Some(branch_collection) = hops_by_token.remove(&token_addr) else {
continue;
};
let total = available
.get(&token_addr)
.cloned()
.unwrap_or_default();
for split_swap in assign_splits_and_amounts(branch_collection, &total) {
let sim = market
.get_simulation_state(&split_swap.hop.component_id)
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "simulation state",
id: Some(split_swap.hop.component_id.clone()),
})?;
let component = market
.get_component(&split_swap.hop.component_id)
.ok_or_else(|| AlgorithmError::DataNotFound {
kind: "protocol component",
id: Some(split_swap.hop.component_id.clone()),
})?;
let in_addr = split_swap.hop.token_in.address.clone();
let out_addr = split_swap.hop.token_out.address.clone();
*available
.entry(out_addr.clone())
.or_default() += &split_swap.amount_out;
swaps.push(
Swap::new(
split_swap.hop.component_id,
component.protocol_system.clone(),
in_addr.clone(),
out_addr.clone(),
split_swap.amount_in,
split_swap.amount_out,
split_swap.gas,
component.clone(),
sim.clone_box(),
)
.with_split(split_swap.split),
);
route_tokens
.entry(in_addr)
.or_insert(split_swap.hop.token_in);
route_tokens
.entry(out_addr.clone())
.or_insert(split_swap.hop.token_out);
if let Some(deg) = in_degree.get_mut(&out_addr) {
*deg = deg.saturating_sub(1);
if *deg == 0 {
ready.push_back(out_addr);
}
}
}
}
if !hops_by_token.is_empty() {
let stuck: Vec<_> = hops_by_token
.keys()
.map(|k| format!("{k}"))
.collect();
return Err(AlgorithmError::Other(format!(
"dependency cycle — unprocessed tokens: [{}]",
stuck.join(", "),
)));
}
Ok(Route::new(swaps, route_tokens))
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::{
algorithm::test_utils::{component, order, token, ConstantProductSim, MockProtocolSim},
types::OrderSide,
};
fn make_market(pools: Vec<(&str, Vec<Token>, Box<dyn ProtocolSim>)>) -> MarketState {
let mut market = MarketState::new();
for (pool_id, tokens, sim) in pools {
market.upsert_components(std::iter::once(component(pool_id, &tokens)));
market.update_states([(pool_id.to_string(), sim)]);
market.upsert_tokens(tokens);
}
market
}
#[test]
fn test_split_amount_exact_sum() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
for fraction in [0.1, 0.5, 0.9, 0.999] {
let (part, remainder) = split_amount(&total, fraction);
assert_eq!(
&part + &remainder,
total,
"part + remainder must equal total for fraction={fraction}"
);
}
}
#[test]
fn test_split_amount_edge_fraction_zero() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, 0.0);
assert!(part.is_zero());
assert_eq!(remainder, total);
}
#[test]
fn test_split_amount_clamps_above_one() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, 1.5);
assert_eq!(part, total);
assert!(remainder.is_zero());
}
#[test]
fn test_split_amount_clamps_negative() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, -0.5);
assert!(part.is_zero());
assert_eq!(remainder, total);
}
#[test]
fn test_fractions_to_amounts_exact_sum() {
let total = BigUint::from(999_999_999_999_999_999_u64);
let fractions = [0.3, 0.5, 0.2];
let amounts = fractions_to_amounts(&total, &fractions).unwrap();
assert_eq!(amounts.len(), 3);
let sum: BigUint = amounts.iter().sum();
assert_eq!(sum, total, "amounts must sum exactly to total");
}
#[test]
fn test_fractions_to_amounts_empty() {
let total = BigUint::from(1_000_u64);
let err = fractions_to_amounts(&total, &[]).unwrap_err();
assert_eq!(err, SplitMathError::EmptyFractions);
}
#[rstest]
#[case::already_normalized(&[0.3, 0.5, 0.2])]
#[case::drift(&[0.33, 0.33, 0.33])]
fn test_normalize_fractions(#[case] input: &[f64]) {
let mut fractions = input.to_vec();
normalize_fractions(&mut fractions).unwrap();
let sum: f64 = fractions.iter().sum();
assert!((sum - 1.0).abs() < f64::EPSILON);
}
#[rstest]
#[case::empty(&[], SplitMathError::EmptyFractions)]
#[case::all_zeros(&[0.0, 0.0, 0.0], SplitMathError::AllZeroFractions)]
#[case::negative(&[-0.5, 0.5], SplitMathError::NegativeFraction)]
fn test_normalize_fractions_invalid(#[case] input: &[f64], #[case] expected: SplitMathError) {
let mut fractions = input.to_vec();
let err = normalize_fractions(&mut fractions).unwrap_err();
assert_eq!(err, expected);
}
#[test]
fn test_golden_section_finds_maximum() {
let f = |x: f64| -(x - 0.3) * (x - 0.3);
let result = golden_section_search(f, 0.0, 1.0, 100);
assert!((result - 0.3).abs() < 1e-4, "expected ~0.3, got {result}");
}
#[test]
fn test_validate_token_cycles_valid_path() {
let gas = BigUint::from(50_000u64);
let path = PathAllocation {
hops: vec![
HopDescriptor::new("p1".to_string(), token(0x01, "A"), token(0x02, "B"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p2".to_string(), token(0x02, "B"), token(0x03, "C"))
.with_amounts(BigUint::from(100u64), gas),
],
flow_fraction: 1.0,
amount_in: BigUint::from(100u64),
amount_out: BigUint::from(100u64),
marginal_price_product: 1.0,
};
assert!(path.validate_token_cycles().is_ok());
}
#[test]
fn test_validate_token_cycles_empty_hops() {
let path = PathAllocation {
hops: vec![],
flow_fraction: 1.0,
amount_in: BigUint::from(100u64),
amount_out: BigUint::from(100u64),
marginal_price_product: 1.0,
};
assert!(path.validate_token_cycles().is_err());
}
#[test]
fn test_validate_token_cycles_valid_round_trip() {
let gas = BigUint::from(50_000u64);
let path = PathAllocation {
hops: vec![
HopDescriptor::new("p1".to_string(), token(0x01, "A"), token(0x02, "B"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p2".to_string(), token(0x02, "B"), token(0x01, "A"))
.with_amounts(BigUint::from(100u64), gas),
],
flow_fraction: 1.0,
amount_in: BigUint::from(100u64),
amount_out: BigUint::from(100u64),
marginal_price_product: 1.0,
};
assert!(path.validate_token_cycles().is_ok());
}
#[test]
fn test_validate_token_cycles_rejects_mid_path_cycle() {
let gas = BigUint::from(50_000u64);
let path = PathAllocation {
hops: vec![
HopDescriptor::new("p1".to_string(), token(0x01, "A"), token(0x02, "B"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p2".to_string(), token(0x02, "B"), token(0x03, "C"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p3".to_string(), token(0x03, "C"), token(0x01, "A"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p4".to_string(), token(0x01, "A"), token(0x04, "D"))
.with_amounts(BigUint::from(100u64), gas),
],
flow_fraction: 1.0,
amount_in: BigUint::from(100u64),
amount_out: BigUint::from(100u64),
marginal_price_product: 1.0,
};
assert!(path.validate_token_cycles().is_err());
}
#[test]
fn test_validate_token_cycles_rejects_intermediate_revisit() {
let gas = BigUint::from(50_000u64);
let path = PathAllocation {
hops: vec![
HopDescriptor::new("p1".to_string(), token(0x01, "A"), token(0x02, "B"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p2".to_string(), token(0x02, "B"), token(0x03, "C"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p3".to_string(), token(0x03, "C"), token(0x02, "B"))
.with_amounts(BigUint::from(100u64), gas.clone()),
HopDescriptor::new("p4".to_string(), token(0x02, "B"), token(0x04, "D"))
.with_amounts(BigUint::from(100u64), gas),
],
flow_fraction: 1.0,
amount_in: BigUint::from(100u64),
amount_out: BigUint::from(100u64),
marginal_price_product: 1.0,
};
assert!(path.validate_token_cycles().is_err());
}
#[test]
fn test_compute_marginal_price_product_single_hop() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(3.0)),
)]);
let hops = [HopDescriptor::new("pool_ab".to_string(), token_a, token_b)];
let product =
compute_marginal_price_product(&hops, &market, &MarketOverrides::empty()).unwrap();
assert!((product - 3.0).abs() < f64::EPSILON, "expected 3.0, got {product}");
}
#[test]
fn test_compute_marginal_price_product_multi_hop() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_bc",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(4.0)),
),
]);
let hops = [
HopDescriptor::new("pool_ab".to_string(), token_a, token_b.clone()),
HopDescriptor::new("pool_bc".to_string(), token_b, token_c),
];
let product =
compute_marginal_price_product(&hops, &market, &MarketOverrides::empty()).unwrap();
assert!((product - 8.0).abs() < f64::EPSILON, "expected 8.0, got {product}");
}
#[test]
fn test_compute_marginal_price_product_uses_overrides() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(3.0)),
)]);
let hops = [HopDescriptor::new("pool_ab".to_string(), token_a, token_b)];
let overrides = MarketOverrides::empty()
.with_override("pool_ab".to_string(), Box::new(MockProtocolSim::new(7.0)));
let product = compute_marginal_price_product(&hops, &market, &overrides).unwrap();
assert!((product - 7.0).abs() < f64::EPSILON, "expected 7.0, got {product}");
}
#[test]
fn test_simulate_path_correct_output() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_bc",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
]);
let hops = [
HopDescriptor::new("pool_ab".to_string(), token_a, token_b.clone()),
HopDescriptor::new("pool_bc".to_string(), token_b, token_c),
];
let amount_in = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let result = simulate_path(&hops, &amount_in, &market, &overrides).unwrap();
assert_eq!(result.amount_out, BigUint::from(6000u64));
assert!(
(result.marginal_price_product - 6.0).abs() < f64::EPSILON,
"expected marginal_price_product 6.0, got {}",
result.marginal_price_product
);
}
#[test]
fn test_market_overrides_with_zero_gas() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let sim_ab = MockProtocolSim::new(2.0).with_gas(100_000);
let sim_bc = MockProtocolSim::new(3.0).with_gas(70_000);
let market = make_market(vec![
("pool_ab", vec![token_a.clone(), token_b.clone()], Box::new(sim_ab.clone())),
("pool_bc", vec![token_b.clone(), token_c.clone()], Box::new(sim_bc.clone())),
]);
let overrides = MarketOverrides::empty()
.with_override("pool_ab".to_string(), Box::new(sim_ab))
.with_zero_gas("pool_ab".to_string(), token_a.address.clone(), token_b.address.clone())
.with_override("pool_bc".to_string(), Box::new(sim_bc));
let hops_ab = [HopDescriptor::new("pool_ab".to_string(), token_a.clone(), token_b.clone())];
let hops_bc = [HopDescriptor::new("pool_bc".to_string(), token_b, token_c)];
let amount_in = BigUint::from(1000u64);
let normal_ab =
simulate_path(&hops_ab, &amount_in, &market, &MarketOverrides::empty()).unwrap();
let zero_gas_ab = simulate_path(&hops_ab, &amount_in, &market, &overrides).unwrap();
assert_eq!(normal_ab.amount_out, zero_gas_ab.amount_out);
assert!(normal_ab.gas > 0, "normal gas should be non-zero");
assert_eq!(zero_gas_ab.gas, 0, "zero-gas override should report gas=0");
let result_bc = simulate_path(&hops_bc, &amount_in, &market, &overrides).unwrap();
assert_eq!(result_bc.gas, 70_000, "non-zero-gas override should keep its gas");
}
#[test]
fn test_evaluate_total_output_two_paths() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![
(
"pool_1",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0).with_gas(50_000)),
),
(
"pool_2",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(3.0).with_gas(60_000)),
),
]);
let hops_1 = [HopDescriptor::new("pool_1".to_string(), token_a.clone(), token_b.clone())];
let hops_2 = [HopDescriptor::new("pool_2".to_string(), token_a, token_b)];
let paths: Vec<&[HopDescriptor]> = vec![&hops_1, &hops_2];
let fractions = [0.5, 0.5];
let total_amount = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let (total_out, total_gas) =
evaluate_total_output(&paths, &fractions, &total_amount, &market, &overrides).unwrap();
assert_eq!(total_out, BigUint::from(2500u64));
assert_eq!(total_gas, 110_000);
}
#[test]
fn test_evaluate_total_output_gas_deduplication() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let token_d = token(0x0D, "D");
let market = make_market(vec![
(
"P1",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0).with_gas(100_000)),
),
(
"P2",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(1.5).with_gas(50_000)),
),
(
"P3",
vec![token_b.clone(), token_d.clone()],
Box::new(MockProtocolSim::new(3.0).with_gas(70_000)),
),
]);
let hops_1 = [
HopDescriptor::new("P1".to_string(), token_a.clone(), token_b.clone()),
HopDescriptor::new("P2".to_string(), token_b.clone(), token_c),
];
let hops_2 = [
HopDescriptor::new("P1".to_string(), token_a, token_b.clone()),
HopDescriptor::new("P3".to_string(), token_b, token_d),
];
let paths: Vec<&[HopDescriptor]> = vec![&hops_1, &hops_2];
let fractions = [0.5, 0.5];
let total_amount = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let (_, total_gas) =
evaluate_total_output(&paths, &fractions, &total_amount, &market, &overrides).unwrap();
assert_eq!(total_gas, 220_000);
}
#[test]
fn test_gas_dedup_different_tokens() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![(
"tripool",
vec![token_a.clone(), token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(1.0).with_gas(80_000)),
)]);
let hops_1 = [HopDescriptor::new("tripool".to_string(), token_a, token_b.clone())];
let hops_2 = [HopDescriptor::new("tripool".to_string(), token_b, token_c)];
let paths: Vec<&[HopDescriptor]> = vec![&hops_1, &hops_2];
let fractions = [0.5, 0.5];
let total_amount = BigUint::from(1000u64);
let overrides = MarketOverrides::empty();
let (_, total_gas) =
evaluate_total_output(&paths, &fractions, &total_amount, &market, &overrides).unwrap();
assert_eq!(total_gas, 160_000);
}
#[test]
fn test_build_post_swap_overrides_degrades_used_pools() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(ConstantProductSim {
reserve_0: BigUint::from(10_000u64),
reserve_1: BigUint::from(20_000u64),
gas: 50_000,
}),
)]);
let allocation = PathAllocation {
hops: vec![SimulatedHop {
descriptor: HopDescriptor::new(
"pool_ab".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(1818u64),
gas: BigUint::from(50_000u64),
}],
flow_fraction: 1.0,
amount_in: BigUint::from(1000u64),
amount_out: BigUint::from(1818u64),
marginal_price_product: 2.0,
};
let degraded = build_post_swap_overrides(&[allocation], &market);
let probe = BigUint::from(100u64);
let fresh_out = market
.get_simulation_state("pool_ab")
.unwrap()
.get_amount_out(probe.clone(), &token_a, &token_b)
.unwrap()
.amount;
assert_eq!(fresh_out, BigUint::from(198u64));
let degraded_out = degraded
.get(&"pool_ab".to_string())
.unwrap()
.get_amount_out(probe, &token_a, &token_b)
.unwrap()
.amount;
assert_eq!(degraded_out, BigUint::from(163u64));
}
#[test]
fn test_merge_shared_hops_combines_fractions() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"P1".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(1200u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new(
"P2".to_string(),
token_b.clone(),
token_c.clone(),
),
amount_out: BigUint::from(3600u64),
gas: gas.clone(),
},
],
flow_fraction: 0.6,
amount_in: BigUint::from(600u64),
amount_out: BigUint::from(3600u64),
marginal_price_product: 6.0,
},
PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"P1".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(800u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new(
"P3".to_string(),
token_b.clone(),
token_c.clone(),
),
amount_out: BigUint::from(1600u64),
gas,
},
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(1600u64),
marginal_price_product: 4.0,
},
];
let hops_by_token = merge_shared_hops(&paths).unwrap();
let branch_collection_a = &hops_by_token[&token_a.address];
assert_eq!(branch_collection_a.len(), 1);
assert_eq!(branch_collection_a[0].hop.component_id, "P1");
assert!((branch_collection_a[0].split - 1.0).abs() < f64::EPSILON);
let branch_collection_b = &hops_by_token[&token_b.address];
assert_eq!(branch_collection_b.len(), 2);
assert_eq!(branch_collection_b[0].hop.component_id, "P2");
assert!((branch_collection_b[0].split - 0.6).abs() < f64::EPSILON);
assert_eq!(branch_collection_b[1].hop.component_id, "P3");
assert!((branch_collection_b[1].split - 0.4).abs() < f64::EPSILON);
}
#[test]
fn test_assign_splits_and_amounts_splits_and_amounts() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let branch_collection = vec![
SplitSwap {
hop: HopDescriptor::new("pool1".to_string(), token_a.clone(), token_b.clone()),
split: 0.7,
amount_in: BigUint::ZERO,
amount_out: BigUint::ZERO,
gas: BigUint::ZERO,
},
SplitSwap {
hop: HopDescriptor::new("pool2".to_string(), token_a.clone(), token_b.clone()),
split: 0.3,
amount_in: BigUint::ZERO,
amount_out: BigUint::ZERO,
gas: BigUint::ZERO,
},
];
let result = assign_splits_and_amounts(branch_collection, &BigUint::from(1000u64));
assert_eq!(result.len(), 2);
assert_eq!(result[0].split, 0.7);
assert_eq!(result[0].amount_in, BigUint::from(700u64));
assert_eq!(result[1].split, 0.0);
assert_eq!(result[1].amount_in, BigUint::from(300u64));
}
#[test]
fn test_assign_splits_and_amounts_single_hop() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let branch_collection = vec![SplitSwap {
hop: HopDescriptor::new("pool1".to_string(), token_a, token_b),
split: 1.0,
amount_in: BigUint::ZERO,
amount_out: BigUint::ZERO,
gas: BigUint::ZERO,
}];
let total = BigUint::from(1000u64);
let result = assign_splits_and_amounts(branch_collection, &total);
assert_eq!(result.len(), 1);
assert_eq!(result[0].split, 0.0);
assert_eq!(result[0].amount_in, total);
}
#[test]
fn test_build_split_route_remainder_convention() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let market = make_market(vec![
("pool1", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(2.0))),
("pool2", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(3.0))),
("pool3", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(4.0))),
]);
let ord = order(&token_a, &token_b, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![SimulatedHop {
descriptor: HopDescriptor::new(
"pool1".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(1000u64),
gas: gas.clone(),
}],
flow_fraction: 0.5,
amount_in: BigUint::from(500u64),
amount_out: BigUint::from(1000u64),
marginal_price_product: 2.0,
},
PathAllocation {
hops: vec![SimulatedHop {
descriptor: HopDescriptor::new(
"pool2".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(900u64),
gas: gas.clone(),
}],
flow_fraction: 0.3,
amount_in: BigUint::from(300u64),
amount_out: BigUint::from(900u64),
marginal_price_product: 3.0,
},
PathAllocation {
hops: vec![SimulatedHop {
descriptor: HopDescriptor::new(
"pool3".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(800u64),
gas,
}],
flow_fraction: 0.2,
amount_in: BigUint::from(200u64),
amount_out: BigUint::from(800u64),
marginal_price_product: 4.0,
},
];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 3);
assert_eq!(swaps[0].component_id(), "pool1");
assert_eq!(*swaps[0].split(), 0.5);
assert_eq!(swaps[1].component_id(), "pool2");
assert_eq!(*swaps[1].split(), 0.3);
assert_eq!(swaps[2].component_id(), "pool3");
assert_eq!(*swaps[2].split(), 0.0);
}
#[test]
fn test_build_split_route_single_path() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_bc",
vec![token_b.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
]);
let ord = order(&token_a, &token_c, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"pool_ab".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(2000u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new("pool_bc".to_string(), token_b, token_c),
amount_out: BigUint::from(6000u64),
gas,
},
],
flow_fraction: 1.0,
amount_in: BigUint::from(1000u64),
amount_out: BigUint::from(6000u64),
marginal_price_product: 6.0,
}];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 2);
for swap in swaps {
assert_eq!(*swap.split(), 0.0, "single path should produce all-zero splits");
}
}
#[test]
fn test_build_split_route_shared_first_pool() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let market = make_market(vec![
("P1", vec![token_a.clone(), token_b.clone()], Box::new(MockProtocolSim::new(2.0))),
("P2", vec![token_b.clone(), token_c.clone()], Box::new(MockProtocolSim::new(3.0))),
("P3", vec![token_b.clone(), token_c.clone()], Box::new(MockProtocolSim::new(4.0))),
]);
let ord = order(&token_a, &token_c, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"P1".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(1400u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new(
"P2".to_string(),
token_b.clone(),
token_c.clone(),
),
amount_out: BigUint::from(4200u64),
gas: gas.clone(),
},
],
flow_fraction: 0.7,
amount_in: BigUint::from(700u64),
amount_out: BigUint::from(4200u64),
marginal_price_product: 6.0,
},
PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"P1".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(600u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new(
"P3".to_string(),
token_b.clone(),
token_c.clone(),
),
amount_out: BigUint::from(2400u64),
gas,
},
],
flow_fraction: 0.3,
amount_in: BigUint::from(300u64),
amount_out: BigUint::from(1200u64),
marginal_price_product: 8.0,
},
];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 3, "expected 3 swaps, got {}", swaps.len());
let ab_swap = &swaps[0];
assert_eq!(ab_swap.component_id(), "P1");
assert_eq!(
*ab_swap.amount_in(),
BigUint::from(1000u64),
"A→B swap amount_in should equal sum of both paths"
);
assert_eq!(
*ab_swap.amount_out(),
BigUint::from(2000u64),
"A→B amount_out should be sum of per-path outputs (1400+600)"
);
assert_eq!(
*ab_swap.split(),
0.0,
"A→B is the sole swap in its branch collection, so it gets the remainder convention (split = 0.0)"
);
assert_eq!(swaps[1].component_id(), "P2");
assert_eq!(*swaps[1].split(), 0.7);
assert_eq!(swaps[2].component_id(), "P3");
assert_eq!(*swaps[2].split(), 0.0);
}
#[test]
fn test_build_split_route_source_level_split_different_intermediates() {
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
let token_z = token(0x1A, "Z");
let market = make_market(vec![
(
"pool_ab",
vec![token_a.clone(), token_b.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_ac",
vec![token_a.clone(), token_c.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
(
"pool_bz",
vec![token_b.clone(), token_z.clone()],
Box::new(MockProtocolSim::new(4.0)),
),
(
"pool_cz",
vec![token_c.clone(), token_z.clone()],
Box::new(MockProtocolSim::new(5.0)),
),
]);
let ord = order(&token_a, &token_z, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let paths = vec![
PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"pool_ab".to_string(),
token_a.clone(),
token_b.clone(),
),
amount_out: BigUint::from(1200u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new(
"pool_bz".to_string(),
token_b,
token_z.clone(),
),
amount_out: BigUint::from(4800u64),
gas: gas.clone(),
},
],
flow_fraction: 0.6,
amount_in: BigUint::from(600u64),
amount_out: BigUint::from(4800u64),
marginal_price_product: 8.0,
},
PathAllocation {
hops: vec![
SimulatedHop {
descriptor: HopDescriptor::new(
"pool_ac".to_string(),
token_a.clone(),
token_c.clone(),
),
amount_out: BigUint::from(1200u64),
gas: gas.clone(),
},
SimulatedHop {
descriptor: HopDescriptor::new("pool_cz".to_string(), token_c, token_z),
amount_out: BigUint::from(6000u64),
gas,
},
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(6000u64),
marginal_price_product: 15.0,
},
];
let route = build_split_route(&paths, &market, &ord).unwrap();
let swaps = route.swaps();
assert_eq!(swaps.len(), 4, "expected 4 swaps (2 source + 2 intermediate)");
assert_eq!(swaps[0].component_id(), "pool_ab");
assert_eq!(*swaps[0].split(), 0.6);
assert_eq!(*swaps[0].amount_in(), BigUint::from(600u64));
assert_eq!(*swaps[0].amount_out(), BigUint::from(1200u64));
assert_eq!(swaps[1].component_id(), "pool_ac");
assert_eq!(*swaps[1].split(), 0.0);
assert_eq!(*swaps[1].amount_in(), BigUint::from(400u64));
assert_eq!(*swaps[1].amount_out(), BigUint::from(1200u64));
assert_eq!(swaps[2].component_id(), "pool_bz");
assert_eq!(*swaps[2].split(), 0.0);
assert_eq!(*swaps[2].amount_in(), BigUint::from(1200u64));
assert_eq!(*swaps[2].amount_out(), BigUint::from(4800u64));
assert_eq!(swaps[3].component_id(), "pool_cz");
assert_eq!(*swaps[3].split(), 0.0);
assert_eq!(*swaps[3].amount_in(), BigUint::from(1200u64));
assert_eq!(*swaps[3].amount_out(), BigUint::from(6000u64));
}
#[test]
fn test_build_split_route_cross_depth_shared_pool() {
let weth = token(0x01, "WETH");
let usdc = token(0x02, "USDC");
let usdt = token(0x03, "USDT");
let dai = token(0x04, "DAI");
let market = make_market(vec![
(
"pool_weth_usdc",
vec![weth.clone(), usdc.clone()],
Box::new(MockProtocolSim::new(2.0)),
),
(
"pool_weth_usdt",
vec![weth.clone(), usdt.clone()],
Box::new(MockProtocolSim::new(3.0)),
),
(
"pool_usdt_usdc",
vec![usdt.clone(), usdc.clone()],
Box::new(MockProtocolSim::new(1.0)),
),
("pool_a", vec![usdc.clone(), dai.clone()], Box::new(MockProtocolSim::new(1.0))),
]);
let ord = order(&weth, &dai, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let path1 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_weth_usdc".to_string(), weth.clone(), usdc.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
],
flow_fraction: 0.6,
amount_in: BigUint::from(600u64),
amount_out: BigUint::from(1200u64),
marginal_price_product: 2.0,
};
let path2 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_weth_usdt".to_string(), weth.clone(), usdt.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_usdt_usdc".to_string(), usdt.clone(), usdc.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(1200u64), gas),
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(1200u64),
marginal_price_product: 3.0,
};
let route = build_split_route(&[path1, path2], &market, &ord).unwrap();
let swaps = route.swaps();
let pool_a_swap = swaps
.iter()
.find(|s| s.component_id() == "pool_a")
.expect("pool_a swap must exist");
assert_eq!(
*pool_a_swap.amount_in(),
BigUint::from(2400u64),
"pool_a must receive USDC from both paths (1200 + 1200)"
);
assert_eq!(
*pool_a_swap.amount_out(),
BigUint::from(2400u64),
"pool_a amount_out should be the merged total"
);
assert_eq!(swaps.len(), 4, "pool_a must appear once, not once per path");
assert_eq!(
route.total_gas(),
BigUint::from(200_000u64),
"gas must be counted once per pool, not once per path"
);
}
#[test]
fn test_build_split_route_cross_depth_convergence_with_downstream_split() {
let weth = token(0x01, "WETH");
let usdc = token(0x02, "USDC");
let usdt = token(0x03, "USDT");
let dai = token(0x04, "DAI");
let pepe = token(0x05, "PEPE");
let market = make_market(vec![
("pool_wu", vec![weth.clone(), usdc.clone()], Box::new(MockProtocolSim::new(2.0))),
("pool_wt", vec![weth.clone(), usdt.clone()], Box::new(MockProtocolSim::new(3.0))),
("pool_tu", vec![usdt.clone(), usdc.clone()], Box::new(MockProtocolSim::new(1.0))),
("pool_a", vec![usdc.clone(), dai.clone()], Box::new(MockProtocolSim::new(1.0))),
("pool_b", vec![dai.clone(), pepe.clone()], Box::new(MockProtocolSim::new(5.0))),
("pool_c", vec![dai.clone(), pepe.clone()], Box::new(MockProtocolSim::new(4.0))),
]);
let ord = order(&weth, &pepe, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let path1 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_wu".to_string(), weth.clone(), usdc.clone())
.with_amounts(BigUint::from(600u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(600u64), gas.clone()),
HopDescriptor::new("pool_b".to_string(), dai.clone(), pepe.clone())
.with_amounts(BigUint::from(3000u64), gas.clone()),
],
flow_fraction: 0.3,
amount_in: BigUint::from(300u64),
amount_out: BigUint::from(3000u64),
marginal_price_product: 10.0,
};
let path2 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_wu".to_string(), weth.clone(), usdc.clone())
.with_amounts(BigUint::from(600u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(600u64), gas.clone()),
HopDescriptor::new("pool_c".to_string(), dai.clone(), pepe.clone())
.with_amounts(BigUint::from(2400u64), gas.clone()),
],
flow_fraction: 0.3,
amount_in: BigUint::from(300u64),
amount_out: BigUint::from(2400u64),
marginal_price_product: 8.0,
};
let path3 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_wt".to_string(), weth.clone(), usdt.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_tu".to_string(), usdt.clone(), usdc.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_b".to_string(), dai.clone(), pepe.clone())
.with_amounts(BigUint::from(6000u64), gas),
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(6000u64),
marginal_price_product: 15.0,
};
let route = build_split_route(&[path1, path2, path3], &market, &ord).unwrap();
let swaps = route.swaps();
let pool_a_swap = swaps
.iter()
.find(|s| s.component_id() == "pool_a")
.expect("pool_a swap must exist");
assert_eq!(
*pool_a_swap.amount_in(),
BigUint::from(2400u64),
"pool_a must receive all USDC from both direct and USDT-detour paths"
);
let pool_b_swap = swaps
.iter()
.find(|s| s.component_id() == "pool_b")
.expect("pool_b swap must exist");
let pool_c_swap = swaps
.iter()
.find(|s| s.component_id() == "pool_c")
.expect("pool_c swap must exist");
assert_eq!(
*pool_b_swap.amount_out(),
BigUint::from(9000u64),
"pool_b amount_out should be merged total from paths 1+3"
);
assert_eq!(
*pool_c_swap.amount_out(),
BigUint::from(2400u64),
"pool_c amount_out should be path 2 only"
);
let pool_a_idx = swaps
.iter()
.position(|s| s.component_id() == "pool_a")
.unwrap();
let pool_b_idx = swaps
.iter()
.position(|s| s.component_id() == "pool_b")
.unwrap();
let pool_c_idx = swaps
.iter()
.position(|s| s.component_id() == "pool_c")
.unwrap();
assert!(
pool_a_idx < pool_b_idx && pool_a_idx < pool_c_idx,
"pool_a (idx {pool_a_idx}) must appear before pool_b (idx {pool_b_idx}) \
and pool_c (idx {pool_c_idx})"
);
let pool_tu_idx = swaps
.iter()
.position(|s| s.component_id() == "pool_tu")
.unwrap();
assert!(
pool_tu_idx < pool_a_idx,
"pool_tu (idx {pool_tu_idx}) must appear before pool_a (idx {pool_a_idx})"
);
assert_eq!(swaps.len(), 6, "pool_a must appear once, not once per path");
assert_eq!(
route.total_gas(),
BigUint::from(300_000u64),
"gas must be counted once per pool, not once per path"
);
}
#[test]
fn test_build_split_route_rejects_reverse_order_shared_pools() {
let weth = token(0x01, "WETH");
let usdc = token(0x02, "USDC");
let dai = token(0x03, "DAI");
let pepe = token(0x04, "PEPE");
let uni = token(0x05, "UNI");
let wbtc = token(0x06, "WBTC");
let market = make_market(vec![
("pool_wu", vec![weth.clone(), usdc.clone()], Box::new(MockProtocolSim::new(2.0))),
("pool_a", vec![usdc.clone(), dai.clone()], Box::new(MockProtocolSim::new(1.0))),
("pool_dp", vec![dai.clone(), pepe.clone()], Box::new(MockProtocolSim::new(5.0))),
("pool_b", vec![pepe.clone(), uni.clone()], Box::new(MockProtocolSim::new(1.0))),
("pool_uw", vec![uni.clone(), wbtc.clone()], Box::new(MockProtocolSim::new(3.0))),
("pool_wp", vec![weth.clone(), pepe.clone()], Box::new(MockProtocolSim::new(4.0))),
("pool_us", vec![uni.clone(), usdc.clone()], Box::new(MockProtocolSim::new(1.0))),
("pool_dw", vec![dai.clone(), wbtc.clone()], Box::new(MockProtocolSim::new(2.0))),
]);
let ord = order(&weth, &wbtc, 1000, OrderSide::Sell);
let gas = BigUint::from(50_000u64);
let path1 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_wu".to_string(), weth.clone(), usdc.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(1200u64), gas.clone()),
HopDescriptor::new("pool_dp".to_string(), dai.clone(), pepe.clone())
.with_amounts(BigUint::from(6000u64), gas.clone()),
HopDescriptor::new("pool_b".to_string(), pepe.clone(), uni.clone())
.with_amounts(BigUint::from(6000u64), gas.clone()),
HopDescriptor::new("pool_uw".to_string(), uni.clone(), wbtc.clone())
.with_amounts(BigUint::from(18000u64), gas.clone()),
],
flow_fraction: 0.6,
amount_in: BigUint::from(600u64),
amount_out: BigUint::from(18000u64),
marginal_price_product: 30.0,
};
let path2 = PathAllocation {
hops: vec![
HopDescriptor::new("pool_wp".to_string(), weth.clone(), pepe.clone())
.with_amounts(BigUint::from(1600u64), gas.clone()),
HopDescriptor::new("pool_b".to_string(), pepe.clone(), uni.clone())
.with_amounts(BigUint::from(1600u64), gas.clone()),
HopDescriptor::new("pool_us".to_string(), uni.clone(), usdc.clone())
.with_amounts(BigUint::from(1600u64), gas.clone()),
HopDescriptor::new("pool_a".to_string(), usdc.clone(), dai.clone())
.with_amounts(BigUint::from(1600u64), gas.clone()),
HopDescriptor::new("pool_dw".to_string(), dai.clone(), wbtc.clone())
.with_amounts(BigUint::from(3200u64), gas),
],
flow_fraction: 0.4,
amount_in: BigUint::from(400u64),
amount_out: BigUint::from(3200u64),
marginal_price_product: 8.0,
};
let merged = merge_shared_hops(&[path1.clone(), path2.clone()]).unwrap();
assert_eq!(
merged[&usdc.address]
.iter()
.filter(|s| s.hop.component_id == "pool_a")
.count(),
1,
"merge_shared_hops merges pool_a into one"
);
assert_eq!(
merged[&pepe.address]
.iter()
.filter(|s| s.hop.component_id == "pool_b")
.count(),
1,
"merge_shared_hops merges pool_b into one"
);
let err = build_split_route(&[path1, path2], &market, &ord)
.expect_err("must reject cyclic path combination");
assert!(
matches!(&err, AlgorithmError::Other(msg) if msg.contains("dependency cycle")),
"expected AlgorithmError::Other with dependency cycle, got: {err}"
);
}
}