use std::time::Duration;
use num_bigint::BigUint;
use num_traits::ToPrimitive;
use super::{
split_primitives::{split_amount, PathAllocation},
Algorithm, AlgorithmConfig, AlgorithmError, BellmanFordAlgorithm,
};
use crate::{
derived::{computation::ComputationRequirements, SharedDerivedDataRef},
feed::market_data::{MarketData, StateLabel},
graph::{petgraph::StableDiGraph, PetgraphStableDiGraphManager},
types::{quote::Order, RouteResult},
};
#[derive(Debug, Clone)]
pub struct PathFrankWolfeConfig {
pub max_paths: usize,
pub max_probe: f64,
pub min_split: f64,
pub line_search_evals: usize,
}
impl Default for PathFrankWolfeConfig {
fn default() -> Self {
Self { max_paths: 4, max_probe: 0.25, min_split: 0.05, line_search_evals: 12 }
}
}
pub struct PathFrankWolfeAlgorithm {
inner: BellmanFordAlgorithm,
#[allow(dead_code)]
config: PathFrankWolfeConfig,
}
impl PathFrankWolfeAlgorithm {
pub(crate) fn new(algorithm_config: AlgorithmConfig, config: PathFrankWolfeConfig) -> Self {
let inner = BellmanFordAlgorithm::with_config(algorithm_config);
Self { inner, config }
}
}
impl Default for PathFrankWolfeAlgorithm {
fn default() -> Self {
Self::new(AlgorithmConfig::default(), PathFrankWolfeConfig::default())
}
}
impl PathFrankWolfeAlgorithm {
#[allow(dead_code)]
fn compute_probe_amount(
&self,
total_amount: &BigUint,
price_impact: f64,
gas_cost_output_tokens: f64,
) -> Option<BigUint> {
if price_impact <= 0.0 {
return None;
}
let gas_floor = gas_cost_output_tokens / price_impact;
let probe_amount = BigUint::from(gas_floor.ceil() as u128);
let (max_probe_amount, _remainder) = split_amount(total_amount, self.config.max_probe);
if probe_amount > max_probe_amount {
return None;
}
Some(probe_amount)
}
#[allow(dead_code)]
fn compute_average_price_impact(paths: &[PathAllocation]) -> Result<f64, AlgorithmError> {
let mut weighted_price_impact = 0.0;
for path in paths {
let amount_in = path.amount_in.to_f64().ok_or_else(|| {
AlgorithmError::Other(format!("amount_in too large for f64: {}", path.amount_in))
})?;
let amount_out = path
.amount_out
.to_f64()
.ok_or_else(|| {
AlgorithmError::Other(format!(
"amount_out too large for f64: {}",
path.amount_out
))
})?;
if amount_in <= 0.0 {
return Err(AlgorithmError::Other(format!("non-positive amount_in ({amount_in})")));
}
if path.marginal_price_product <= 0.0 {
return Err(AlgorithmError::Other(format!(
"non-positive marginal_price_product ({})",
path.marginal_price_product
)));
}
let ideal_out = amount_in * path.marginal_price_product;
let price_impact = 1.0 - amount_out / ideal_out;
weighted_price_impact += path.flow_fraction * price_impact;
}
Ok(weighted_price_impact)
}
}
impl Algorithm for PathFrankWolfeAlgorithm {
type GraphType = StableDiGraph<()>;
type GraphManager = PetgraphStableDiGraphManager<()>;
fn name(&self) -> &str {
"path_frank_wolfe"
}
async fn find_best_route(
&self,
graph: &Self::GraphType,
market: MarketData,
label: Option<StateLabel>,
derived: Option<SharedDerivedDataRef>,
order: &Order,
) -> Result<RouteResult, AlgorithmError> {
self.inner
.find_best_route(graph, market, label, derived, order)
.await
}
fn computation_requirements(&self) -> ComputationRequirements {
ComputationRequirements::none()
.allow_stale("token_prices")
.expect("token_prices requirement conflicts (bug)")
.allow_stale("spot_prices")
.expect("spot_prices requirement conflicts (bug)")
}
fn timeout(&self) -> Duration {
self.inner.timeout()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algorithm::AlgorithmConfig;
impl PathFrankWolfeAlgorithm {
fn pfw_config(&self) -> &PathFrankWolfeConfig {
&self.config
}
}
#[test]
fn test_with_pfw_config_override() {
let pfw_config = PathFrankWolfeConfig {
max_paths: 8,
max_probe: 0.5,
min_split: 0.1,
line_search_evals: 24,
};
let algo = PathFrankWolfeAlgorithm::new(AlgorithmConfig::default(), pfw_config);
assert_eq!(algo.pfw_config().max_paths, 8);
assert!((algo.pfw_config().max_probe - 0.5).abs() < f64::EPSILON);
assert!((algo.pfw_config().min_split - 0.1).abs() < f64::EPSILON);
assert_eq!(algo.pfw_config().line_search_evals, 24);
}
#[test]
fn test_probe_amount_low_impact() {
let total = BigUint::from(1_000_000u64);
let algo = PathFrankWolfeAlgorithm::default();
let result = algo.compute_probe_amount(&total, 0.001, 100_000.0);
assert!(result.is_none());
}
#[test]
fn test_probe_amount_scaling() {
let total = BigUint::from(10_000_000u64);
let algo = PathFrankWolfeAlgorithm::default();
let gas_cost = 1000.0;
let probe_high_pi = algo
.compute_probe_amount(&total, 0.10, gas_cost)
.unwrap();
let probe_low_pi = algo
.compute_probe_amount(&total, 0.05, gas_cost)
.unwrap();
assert!(probe_high_pi < probe_low_pi);
let ratio = probe_high_pi.to_f64().unwrap() / probe_low_pi.to_f64().unwrap();
assert!(
(ratio - 0.5).abs() < 0.01,
"expected ratio ~0.5 (inverse proportionality), got {ratio}"
);
}
#[test]
fn test_probe_amount_within_cap() {
let total = BigUint::from(1_000_000u64);
let algo = PathFrankWolfeAlgorithm::default();
let probe_amount = algo
.compute_probe_amount(&total, 0.10, 1000.0)
.unwrap();
assert_eq!(probe_amount, BigUint::from(10_000u64));
}
#[test]
fn test_probe_amount_zero_price_impact() {
let total = BigUint::from(1_000_000u64);
let algo = PathFrankWolfeAlgorithm::default();
assert!(algo
.compute_probe_amount(&total, 0.0, 1000.0)
.is_none());
}
#[test]
fn test_average_price_impact_redistribution() {
let iter_0 = [PathAllocation {
hops: vec![],
flow_fraction: 1.0,
amount_in: BigUint::from(100_000u64),
amount_out: BigUint::from(181_818u64),
marginal_price_product: 2.0,
}];
let iter_1 = [
PathAllocation {
hops: vec![],
flow_fraction: 0.5,
amount_in: BigUint::from(50_000u64),
amount_out: BigUint::from(95_238u64),
marginal_price_product: 2.0,
},
PathAllocation {
hops: vec![],
flow_fraction: 0.5,
amount_in: BigUint::from(50_000u64),
amount_out: BigUint::from(95_238u64),
marginal_price_product: 2.0,
},
];
let third = 1.0 / 3.0;
let iter_2 = [
PathAllocation {
hops: vec![],
flow_fraction: third,
amount_in: BigUint::from(33_333u64),
amount_out: BigUint::from(64_514u64),
marginal_price_product: 2.0,
},
PathAllocation {
hops: vec![],
flow_fraction: third,
amount_in: BigUint::from(33_333u64),
amount_out: BigUint::from(64_514u64),
marginal_price_product: 2.0,
},
PathAllocation {
hops: vec![],
flow_fraction: third,
amount_in: BigUint::from(33_334u64),
amount_out: BigUint::from(64_516u64),
marginal_price_product: 2.0,
},
];
let pi_0 = PathFrankWolfeAlgorithm::compute_average_price_impact(&iter_0).unwrap();
let pi_1 = PathFrankWolfeAlgorithm::compute_average_price_impact(&iter_1).unwrap();
let pi_2 = PathFrankWolfeAlgorithm::compute_average_price_impact(&iter_2).unwrap();
assert!(pi_1 < pi_0, "price impact should decrease after first split: {pi_1} >= {pi_0}");
assert!(pi_2 < pi_1, "price impact should decrease after second split: {pi_2} >= {pi_1}");
assert!((pi_0 - 0.09091).abs() < 1e-5, "expected ~0.0909, got {pi_0}");
assert!((pi_1 - 0.04762).abs() < 1e-5, "expected ~0.0476, got {pi_1}");
assert!((pi_2 - 0.03228).abs() < 1e-5, "expected ~0.0323, got {pi_2}");
}
#[test]
fn test_average_price_impact_weighting() {
let allocations = [
PathAllocation {
hops: vec![],
flow_fraction: 0.9,
amount_in: BigUint::from(1000u64),
amount_out: BigUint::from(900u64),
marginal_price_product: 1.0,
},
PathAllocation {
hops: vec![],
flow_fraction: 0.1,
amount_in: BigUint::from(100u64),
amount_out: BigUint::from(50u64),
marginal_price_product: 1.0,
},
];
let pi = PathFrankWolfeAlgorithm::compute_average_price_impact(&allocations).unwrap();
assert!((pi - 0.14).abs() < 1e-10, "expected 0.14, got {pi}");
}
#[test]
#[ignore]
fn test_pi_exit_criterion_stops_loop_early() {}
}