use std::time::Duration;
use super::{Algorithm, AlgorithmConfig, AlgorithmError, BellmanFordAlgorithm};
use crate::{
derived::{computation::ComputationRequirements, SharedDerivedDataRef},
feed::market_data::{MarketData, StateLabel},
graph::{petgraph::StableDiGraph, PetgraphStableDiGraphManager},
types::{quote::Order, RouteResult},
};
#[derive(Debug, Clone)]
pub struct PathFrankWolfeConfig {
pub max_paths: usize,
pub max_probe: f64,
pub min_split: f64,
pub line_search_evals: usize,
}
impl Default for PathFrankWolfeConfig {
fn default() -> Self {
Self { max_paths: 4, max_probe: 0.25, min_split: 0.05, line_search_evals: 12 }
}
}
pub struct PathFrankWolfeAlgorithm {
inner: BellmanFordAlgorithm,
#[allow(dead_code)]
config: PathFrankWolfeConfig,
}
impl PathFrankWolfeAlgorithm {
pub(crate) fn new(algorithm_config: AlgorithmConfig, config: PathFrankWolfeConfig) -> Self {
let inner = BellmanFordAlgorithm::with_config(algorithm_config);
Self { inner, config }
}
}
impl Default for PathFrankWolfeAlgorithm {
fn default() -> Self {
Self::new(AlgorithmConfig::default(), PathFrankWolfeConfig::default())
}
}
impl Algorithm for PathFrankWolfeAlgorithm {
type GraphType = StableDiGraph<()>;
type GraphManager = PetgraphStableDiGraphManager<()>;
fn name(&self) -> &str {
"path_frank_wolfe"
}
async fn find_best_route(
&self,
_graph: &Self::GraphType,
_market: MarketData,
_label: Option<StateLabel>,
_derived: Option<SharedDerivedDataRef>,
_order: &Order,
) -> Result<RouteResult, AlgorithmError> {
unimplemented!("PathFrankWolfe split-routing loop not yet implemented")
}
fn computation_requirements(&self) -> ComputationRequirements {
ComputationRequirements::none()
.allow_stale("token_prices")
.expect("token_prices requirement conflicts (bug)")
.allow_stale("spot_prices")
.expect("spot_prices requirement conflicts (bug)")
}
fn timeout(&self) -> Duration {
self.inner.timeout()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algorithm::AlgorithmConfig;
impl PathFrankWolfeAlgorithm {
fn pfw_config(&self) -> &PathFrankWolfeConfig {
&self.config
}
}
#[test]
fn test_with_pfw_config_override() {
let pfw_config = PathFrankWolfeConfig {
max_paths: 8,
max_probe: 0.5,
min_split: 0.1,
line_search_evals: 24,
};
let algo = PathFrankWolfeAlgorithm::new(AlgorithmConfig::default(), pfw_config);
assert_eq!(algo.pfw_config().max_paths, 8);
assert!((algo.pfw_config().max_probe - 0.5).abs() < f64::EPSILON);
assert!((algo.pfw_config().min_split - 0.1).abs() < f64::EPSILON);
assert_eq!(algo.pfw_config().line_search_evals, 24);
}
}