use std::collections::HashMap;
use num_bigint::BigUint;
use num_traits::Zero;
use tycho_simulation::tycho_common::{
dto::ProtocolStateDelta,
models::token::Token,
simulation::{
errors::{SimulationError, TransitionError},
protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
},
Bytes,
};
use crate::types::ComponentId;
pub(crate) struct HopDescriptor {
pub(crate) component_id: ComponentId,
pub(crate) token_in: Token,
pub(crate) token_out: Token,
}
pub(crate) struct PathAllocation {
pub(crate) hops: Vec<HopDescriptor>,
pub(crate) flow_fraction: f64,
pub(crate) amount_in: BigUint,
pub(crate) amount_out: BigUint,
pub(crate) marginal_price_product: f64,
}
pub(crate) struct SimResult {
pub(crate) amount_out: BigUint,
pub(crate) gas: u64,
pub(crate) marginal_price_product: f64,
}
#[derive(Default)]
pub(crate) struct MarketOverrides(HashMap<ComponentId, Box<dyn ProtocolSim>>);
impl MarketOverrides {
pub(crate) fn empty() -> Self {
Self::default()
}
pub(crate) fn with_override(mut self, id: ComponentId, sim: Box<dyn ProtocolSim>) -> Self {
self.0.insert(id, sim);
self
}
pub(crate) fn with_zero_gas(mut self, id: ComponentId, sim: Box<dyn ProtocolSim>) -> Self {
self.0
.insert(id, Box::new(ZeroGasSim(sim)));
self
}
pub(crate) fn get(&self, id: &ComponentId) -> Option<&dyn ProtocolSim> {
self.0.get(id).map(|b| b.as_ref())
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ZeroGasSim(Box<dyn ProtocolSim>);
#[typetag::serde]
impl ProtocolSim for ZeroGasSim {
fn fee(&self) -> f64 {
self.0.fee()
}
fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
self.0.spot_price(base, quote)
}
fn get_amount_out(
&self,
amount_in: BigUint,
token_in: &Token,
token_out: &Token,
) -> Result<GetAmountOutResult, SimulationError> {
let mut result = self
.0
.get_amount_out(amount_in, token_in, token_out)?;
result.gas = BigUint::ZERO;
result.new_state = Box::new(ZeroGasSim(result.new_state));
Ok(result)
}
fn get_limits(
&self,
sell_token: Bytes,
buy_token: Bytes,
) -> Result<(BigUint, BigUint), SimulationError> {
self.0.get_limits(sell_token, buy_token)
}
fn delta_transition(
&mut self,
delta: ProtocolStateDelta,
tokens: &HashMap<Bytes, Token>,
balances: &Balances,
) -> Result<(), TransitionError> {
self.0
.delta_transition(delta, tokens, balances)
}
fn clone_box(&self) -> Box<dyn ProtocolSim> {
Box::new(ZeroGasSim(self.0.clone_box()))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn eq(&self, other: &dyn ProtocolSim) -> bool {
other
.as_any()
.downcast_ref::<Self>()
.map(|o| self.0.eq(&*o.0))
.unwrap_or(false)
}
}
pub(crate) fn golden_section_search(
f: impl Fn(f64) -> f64,
mut lo: f64,
mut hi: f64,
max_evals: usize,
) -> f64 {
let inv_phi = (5_f64.sqrt() - 1.0) / 2.0;
let mut x1 = hi - inv_phi * (hi - lo);
let mut x2 = lo + inv_phi * (hi - lo);
let mut f1 = f(x1);
let mut f2 = f(x2);
let remaining = max_evals.saturating_sub(2);
for _ in 0..remaining {
if f1 < f2 {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + inv_phi * (hi - lo);
f2 = f(x2);
} else {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - inv_phi * (hi - lo);
f1 = f(x1);
}
}
if f1 >= f2 {
x1
} else {
x2
}
}
pub(crate) fn split_amount(total: &BigUint, fraction: f64) -> (BigUint, BigUint) {
let clamped = fraction.clamp(0.0, 1.0);
let scale: u64 = 1_000_000_000_000_000_000;
let numerator = (clamped * scale as f64) as u64;
let part = (total * BigUint::from(numerator)) / BigUint::from(scale);
let remainder = total - ∂
(part, remainder)
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub(crate) enum SplitMathError {
#[error("fractions slice must not be empty")]
EmptyFractions,
#[error("all fractions are zero, cannot normalize")]
AllZeroFractions,
#[error("fractions must not be negative")]
NegativeFraction,
}
pub(crate) fn normalize_fractions(fractions: &mut [f64]) -> Result<(), SplitMathError> {
if fractions.is_empty() {
return Err(SplitMathError::EmptyFractions);
}
if fractions.iter().any(|&f| f < 0.0) {
return Err(SplitMathError::NegativeFraction);
}
let sum: f64 = fractions.iter().sum();
if sum == 0.0 {
return Err(SplitMathError::AllZeroFractions);
}
for f in fractions.iter_mut() {
*f /= sum;
}
Ok(())
}
pub(crate) fn fractions_to_amounts(
total: &BigUint,
fractions: &[f64],
) -> Result<Vec<BigUint>, SplitMathError> {
if fractions.is_empty() {
return Err(SplitMathError::EmptyFractions);
}
let n = fractions.len();
let mut amounts = Vec::with_capacity(n);
let mut running_sum = BigUint::zero();
for &frac in &fractions[..n - 1] {
let (part, _) = split_amount(total, frac);
running_sum += ∂
amounts.push(part);
}
amounts.push(total - &running_sum);
Ok(amounts)
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[test]
fn test_split_amount_exact_sum() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
for fraction in [0.1, 0.5, 0.9, 0.999] {
let (part, remainder) = split_amount(&total, fraction);
assert_eq!(
&part + &remainder,
total,
"part + remainder must equal total for fraction={fraction}"
);
}
}
#[test]
fn test_split_amount_edge_fraction_zero() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, 0.0);
assert!(part.is_zero());
assert_eq!(remainder, total);
}
#[test]
fn test_split_amount_clamps_above_one() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, 1.5);
assert_eq!(part, total);
assert!(remainder.is_zero());
}
#[test]
fn test_split_amount_clamps_negative() {
let total = BigUint::from(1_000_000_000_000_000_000_u64);
let (part, remainder) = split_amount(&total, -0.5);
assert!(part.is_zero());
assert_eq!(remainder, total);
}
#[test]
fn test_fractions_to_amounts_exact_sum() {
let total = BigUint::from(999_999_999_999_999_999_u64);
let fractions = [0.3, 0.5, 0.2];
let amounts = fractions_to_amounts(&total, &fractions).unwrap();
assert_eq!(amounts.len(), 3);
let sum: BigUint = amounts.iter().sum();
assert_eq!(sum, total, "amounts must sum exactly to total");
}
#[test]
fn test_fractions_to_amounts_empty() {
let total = BigUint::from(1_000_u64);
let err = fractions_to_amounts(&total, &[]).unwrap_err();
assert_eq!(err, SplitMathError::EmptyFractions);
}
#[rstest]
#[case::already_normalized(&[0.3, 0.5, 0.2])]
#[case::drift(&[0.33, 0.33, 0.33])]
fn test_normalize_fractions(#[case] input: &[f64]) {
let mut fractions = input.to_vec();
normalize_fractions(&mut fractions).unwrap();
let sum: f64 = fractions.iter().sum();
assert!((sum - 1.0).abs() < f64::EPSILON);
}
#[rstest]
#[case::empty(&[], SplitMathError::EmptyFractions)]
#[case::all_zeros(&[0.0, 0.0, 0.0], SplitMathError::AllZeroFractions)]
#[case::negative(&[-0.5, 0.5], SplitMathError::NegativeFraction)]
fn test_normalize_fractions_invalid(#[case] input: &[f64], #[case] expected: SplitMathError) {
let mut fractions = input.to_vec();
let err = normalize_fractions(&mut fractions).unwrap_err();
assert_eq!(err, expected);
}
#[test]
fn test_golden_section_finds_maximum() {
let f = |x: f64| -(x - 0.3) * (x - 0.3);
let result = golden_section_search(f, 0.0, 1.0, 100);
assert!((result - 0.3).abs() < 1e-4, "expected ~0.3, got {result}");
}
}