use num_bigint::BigUint;
use tycho_common::Bytes;
use crate::encoding::{evm::constants::GROUPABLE_PROTOCOLS, models::Swap};
#[derive(Clone, Debug)]
pub struct SwapGroup {
pub token_in: Bytes,
pub token_out: Bytes,
pub protocol_system: String,
pub swaps: Vec<Swap>,
pub split: f64,
pub estimated_gas: BigUint,
}
impl PartialEq for SwapGroup {
fn eq(&self, other: &Self) -> bool {
self.token_in == other.token_in &&
self.token_out == other.token_out &&
self.protocol_system == other.protocol_system &&
self.swaps == other.swaps &&
self.split == other.split &&
self.estimated_gas == other.estimated_gas
}
}
pub fn group_swaps(swaps: &[Swap]) -> Vec<SwapGroup> {
let mut grouped_swaps: Vec<SwapGroup> = Vec::new();
let mut current_group: Option<SwapGroup> = None;
let mut last_swap_protocol = "".to_string();
let mut groupable_protocol;
let mut last_swap_out_token = Bytes::default();
for swap in swaps {
let mut current_swap_protocol = swap.component().protocol_system.clone();
if current_swap_protocol == "uniswap_v4_hooks" {
current_swap_protocol = "uniswap_v4".to_string();
};
groupable_protocol = GROUPABLE_PROTOCOLS.contains(¤t_swap_protocol.as_str());
let no_split = swap.split() == 0.0 && swap.token_in().address == last_swap_out_token;
let no_cycle = current_group
.as_ref()
.is_none_or(|g| swap.token_out().address != g.token_in);
if current_swap_protocol == last_swap_protocol && groupable_protocol && no_split && no_cycle
{
if let Some(group) = current_group.as_mut() {
group.swaps.push(swap.clone());
group.token_out = swap.token_out().address.clone();
}
} else {
if let Some(mut group) = current_group.take() {
group.estimated_gas = compute_group_gas(&group.swaps);
grouped_swaps.push(group);
}
current_group = Some(SwapGroup {
token_in: swap.token_in().address.clone(),
token_out: swap.token_out().address.clone(),
protocol_system: current_swap_protocol.clone(),
swaps: vec![swap.clone()],
split: swap.split(),
estimated_gas: BigUint::ZERO,
});
}
last_swap_protocol = current_swap_protocol;
last_swap_out_token = swap.token_out().address.clone();
}
if let Some(mut group) = current_group.take() {
group.estimated_gas = compute_group_gas(&group.swaps);
grouped_swaps.push(group);
}
grouped_swaps
}
fn compute_group_gas(swaps: &[Swap]) -> BigUint {
let mut total_gas: BigUint = swaps
.iter()
.map(|s| s.estimated_gas().clone())
.sum();
let n = swaps.len();
if n <= 1 {
return total_gas;
}
let safe_sub = |t: BigUint, v: BigUint| if t >= v { t - v } else { BigUint::ZERO };
total_gas = safe_sub(total_gas, swaps[0].token_out().gas_usage());
for swap in &swaps[1..n - 1] {
total_gas = safe_sub(total_gas, swap.token_in().gas_usage());
total_gas = safe_sub(total_gas, swap.token_out().gas_usage());
}
total_gas = safe_sub(total_gas, swaps[n - 1].token_in().gas_usage());
total_gas
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use alloy::primitives::hex;
use tycho_common::{
models::{protocol::ProtocolComponent, token::Token},
Bytes,
};
use super::*;
use crate::encoding::models::{default_token, Swap};
fn weth() -> Bytes {
Bytes::from(hex!("c02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").to_vec())
}
fn token_with_gas(address: Bytes, gas: u64) -> Token {
Token::new(&address, "", 0, 0, &[Some(gas)], Default::default(), 100)
}
#[test]
fn test_group_swaps_simple() {
let weth = weth();
let wbtc = Bytes::from_str("0x2260fac5e5542a773aa44fbcfedf7c193bc2c599").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swap_weth_wbtc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(wbtc.clone()),
BigUint::from(200_000u64),
);
let swap_wbtc_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(wbtc.clone()),
default_token(usdc.clone()),
BigUint::from(220_000u64),
);
let swap_usdc_dai = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v2".to_string(), ..Default::default() },
default_token(usdc.clone()),
default_token(dai.clone()),
BigUint::ZERO,
);
let swaps = vec![swap_weth_wbtc.clone(), swap_wbtc_usdc.clone(), swap_usdc_dai.clone()];
let grouped_swaps = group_swaps(&swaps);
assert_eq!(
grouped_swaps,
vec![
SwapGroup {
swaps: vec![swap_weth_wbtc, swap_wbtc_usdc],
token_in: weth,
token_out: usdc.clone(),
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::from(300_000u64),
},
SwapGroup {
swaps: vec![swap_usdc_dai],
token_in: usdc,
token_out: dai,
protocol_system: "uniswap_v2".to_string(),
split: 0f64,
estimated_gas: BigUint::ZERO,
}
]
);
}
#[test]
fn test_group_swaps_complex_split() {
let weth = weth();
let wbtc = Bytes::from_str("0x2260fac5e5542a773aa44fbcfedf7c193bc2c599").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swap_wbtc_weth = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(wbtc.clone()),
default_token(weth.clone()),
BigUint::ZERO,
);
let swap_weth_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(usdc.clone()),
BigUint::ZERO,
)
.with_split(0.5f64);
let swap_weth_dai = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(dai.clone()),
BigUint::from(220_000u64),
);
let swap_dai_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(dai.clone()),
default_token(usdc.clone()),
BigUint::from(250_000u64),
);
let swaps = vec![
swap_wbtc_weth.clone(),
swap_weth_usdc.clone(),
swap_weth_dai.clone(),
swap_dai_usdc.clone(),
];
let grouped_swaps = group_swaps(&swaps);
assert_eq!(
grouped_swaps,
vec![
SwapGroup {
swaps: vec![swap_wbtc_weth],
token_in: wbtc.clone(),
token_out: weth.clone(),
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::ZERO,
},
SwapGroup {
swaps: vec![swap_weth_usdc],
token_in: weth.clone(),
token_out: usdc.clone(),
protocol_system: "uniswap_v4".to_string(),
split: 0.5f64,
estimated_gas: BigUint::ZERO,
},
SwapGroup {
swaps: vec![swap_weth_dai, swap_dai_usdc],
token_in: weth,
token_out: usdc,
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::from(350_000u64),
}
]
);
}
#[test]
fn test_group_swaps_complex_split_multi_protocol() {
let weth = weth();
let wbtc = Bytes::from_str("0x2260fac5e5542a773aa44fbcfedf7c193bc2c599").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swap_weth_wbtc = Swap::new(
ProtocolComponent {
protocol_system: "vm:balancer_v3".to_string(),
..Default::default()
},
default_token(weth.clone()),
default_token(wbtc.clone()),
BigUint::from(220_000u64),
)
.with_split(0.5f64);
let swap_wbtc_usdc = Swap::new(
ProtocolComponent {
protocol_system: "vm:balancer_v3".to_string(),
..Default::default()
},
default_token(wbtc.clone()),
default_token(usdc.clone()),
BigUint::from(220_000u64),
);
let swap_weth_dai = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(dai.clone()),
BigUint::from(250_000u64),
);
let swap_dai_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(dai.clone()),
default_token(usdc.clone()),
BigUint::from(250_000u64),
);
let swaps = vec![
swap_weth_wbtc.clone(),
swap_wbtc_usdc.clone(),
swap_weth_dai.clone(),
swap_dai_usdc.clone(),
];
let grouped_swaps = group_swaps(&swaps);
assert_eq!(
grouped_swaps,
vec![
SwapGroup {
swaps: vec![swap_weth_wbtc, swap_wbtc_usdc],
token_in: weth.clone(),
token_out: usdc.clone(),
protocol_system: "vm:balancer_v3".to_string(),
split: 0.5f64,
estimated_gas: BigUint::from(320_000u64),
},
SwapGroup {
swaps: vec![swap_weth_dai, swap_dai_usdc],
token_in: weth,
token_out: usdc,
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::from(380_000u64),
}
]
);
}
#[test]
fn test_group_swaps_cyclic_two_hops() {
let weth = weth();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let swap_usdc_weth = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(usdc.clone()),
default_token(weth.clone()),
BigUint::from(220_000u64),
);
let swap_weth_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(usdc.clone()),
BigUint::from(220_000u64),
);
let grouped_swaps = group_swaps(&[swap_usdc_weth.clone(), swap_weth_usdc.clone()]);
assert_eq!(
grouped_swaps,
vec![
SwapGroup {
swaps: vec![swap_usdc_weth],
token_in: usdc.clone(),
token_out: weth.clone(),
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::from(220_000u64),
},
SwapGroup {
swaps: vec![swap_weth_usdc],
token_in: weth,
token_out: usdc,
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::from(220_000u64),
},
]
);
}
#[test]
fn test_group_swaps_cyclic_three_hops() {
let weth = weth();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swap_usdc_weth = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(usdc.clone()),
default_token(weth.clone()),
BigUint::from(220_000u64),
);
let swap_weth_dai = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(dai.clone()),
BigUint::from(250_000u64),
);
let swap_dai_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(dai.clone()),
default_token(usdc.clone()),
BigUint::ZERO,
);
let grouped_swaps =
group_swaps(&[swap_usdc_weth.clone(), swap_weth_dai.clone(), swap_dai_usdc.clone()]);
assert_eq!(
grouped_swaps,
vec![
SwapGroup {
swaps: vec![swap_usdc_weth, swap_weth_dai],
token_in: usdc.clone(),
token_out: dai.clone(),
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::from(350_000u64),
},
SwapGroup {
swaps: vec![swap_dai_usdc],
token_in: dai,
token_out: usdc,
protocol_system: "uniswap_v4".to_string(),
split: 0f64,
estimated_gas: BigUint::ZERO,
},
]
);
}
#[test]
fn test_group_swaps_uniswap_v4_with_hooks() {
let weth = weth();
let wbtc = Bytes::from_str("0x2260fac5e5542a773aa44fbcfedf7c193bc2c599").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swap_weth_wbtc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
default_token(weth.clone()),
default_token(wbtc.clone()),
BigUint::ZERO,
);
let swap_wbtc_usdc = Swap::new(
ProtocolComponent {
protocol_system: "uniswap_v4_hooks".to_string(),
..Default::default()
},
default_token(wbtc.clone()),
default_token(usdc.clone()),
BigUint::ZERO,
);
let swap_usdc_dai = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v2".to_string(), ..Default::default() },
default_token(usdc.clone()),
default_token(dai.clone()),
BigUint::ZERO,
);
let swaps = vec![swap_weth_wbtc.clone(), swap_wbtc_usdc.clone(), swap_usdc_dai.clone()];
let grouped_swaps = group_swaps(&swaps);
assert_eq!(grouped_swaps.len(), 2);
assert_eq!(grouped_swaps[0].swaps.len(), 2);
assert_eq!(grouped_swaps[0].token_in, weth);
assert_eq!(grouped_swaps[0].token_out, usdc.clone());
assert_eq!(grouped_swaps[0].protocol_system, "uniswap_v4");
assert_eq!(grouped_swaps[1].swaps.len(), 1);
assert_eq!(grouped_swaps[1].token_in, usdc);
assert_eq!(grouped_swaps[1].token_out, dai);
assert_eq!(grouped_swaps[1].protocol_system, "uniswap_v2");
}
#[test]
fn test_group_swaps_estimated_gas() {
let weth = token_with_gas(weth(), 10);
let wbtc = token_with_gas(
Bytes::from_str("0x2260fac5e5542a773aa44fbcfedf7c193bc2c599").unwrap(),
20,
);
let usdc = token_with_gas(
Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
30,
);
let dai = token_with_gas(
Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(),
40,
);
let swap_weth_wbtc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
weth,
wbtc.clone(),
BigUint::from(1000u64),
);
let swap_wbtc_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
wbtc,
usdc.clone(),
BigUint::from(1500u64),
);
let swap_usdc_dai = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
usdc,
dai,
BigUint::from(2000u64),
);
let grouped_swaps = group_swaps(&[swap_weth_wbtc, swap_wbtc_usdc, swap_usdc_dai]);
assert_eq!(grouped_swaps.len(), 1);
assert_eq!(grouped_swaps[0].estimated_gas, BigUint::from(4400u64));
}
#[test]
fn test_group_swaps_estimated_gas_single_swap_group() {
let weth = token_with_gas(weth(), 10);
let usdc = token_with_gas(
Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
30,
);
let swap = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v2".to_string(), ..Default::default() },
weth,
usdc,
BigUint::from(1234u64),
);
let grouped_swaps = group_swaps(&[swap]);
assert_eq!(grouped_swaps.len(), 1);
assert_eq!(grouped_swaps[0].estimated_gas, BigUint::from(1234u64));
}
#[test]
fn test_group_swaps_estimated_gas_saturates_to_zero() {
let weth = token_with_gas(weth(), 10_000);
let wbtc = token_with_gas(
Bytes::from_str("0x2260fac5e5542a773aa44fbcfedf7c193bc2c599").unwrap(),
10_000,
);
let usdc = token_with_gas(
Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
10_000,
);
let swap_weth_wbtc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
weth,
wbtc.clone(),
BigUint::from(50u64),
);
let swap_wbtc_usdc = Swap::new(
ProtocolComponent { protocol_system: "uniswap_v4".to_string(), ..Default::default() },
wbtc,
usdc,
BigUint::from(50u64),
);
let grouped_swaps = group_swaps(&[swap_weth_wbtc, swap_wbtc_usdc]);
assert_eq!(grouped_swaps.len(), 1);
assert_eq!(grouped_swaps[0].estimated_gas, BigUint::ZERO);
}
}