Skip to main content

aagt_core/
simulation.rs

1//! Trade simulation system
2//!
3//! Allows simulating trades before execution to estimate costs, slippage, etc.
4
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use crate::error::Result;
8use async_trait::async_trait;
9use rust_decimal::Decimal;
10use rust_decimal::prelude::*;
11use rust_decimal_macros::dec;
12
13/// Result of a trade simulation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SimulationResult {
16    /// Whether simulation was successful
17    pub success: bool,
18    /// Input token
19    pub from_token: String,
20    /// Output token
21    pub to_token: String,
22    /// Input amount
23    pub input_amount: Decimal,
24    /// Expected output amount
25    pub output_amount: Decimal,
26    /// Estimated price impact percentage
27    pub price_impact_percent: Decimal,
28    /// Estimated gas cost in USD
29    pub gas_cost_usd: Decimal,
30    /// Minimum output with slippage
31    pub min_output: Decimal,
32    /// Exchange/DEX being used
33    pub exchange: String,
34    /// Route taken (for multi-hop swaps)
35    pub route: Vec<String>,
36    /// Warnings if any
37    pub warnings: Vec<String>,
38}
39
40impl SimulationResult {
41    /// Check if this trade has high price impact
42    pub fn has_high_impact(&self, threshold: Decimal) -> bool {
43        self.price_impact_percent > threshold
44    }
45
46    /// Get total cost (gas + price impact)
47    pub fn total_cost_usd(&self, input_price_usd: Decimal) -> Decimal {
48        let impact_cost = self.input_amount * input_price_usd * (self.price_impact_percent / dec!(100.0));
49        self.gas_cost_usd + impact_cost
50    }
51}
52
53/// Request for trade simulation
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SimulationRequest {
56    /// Token to sell
57    pub from_token: String,
58    /// Token to buy
59    pub to_token: String,
60    /// Amount to swap
61    pub amount: Decimal,
62    /// Slippage tolerance percentage
63    pub slippage_tolerance: Decimal,
64    /// Chain to simulate on
65    pub chain: String,
66    /// Optional: specific exchange to use
67    pub exchange: Option<String>,
68}
69
70/// Trait for implementing simulators
71#[async_trait]
72pub trait Simulator: Send + Sync {
73    /// Simulate a trade
74    async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult>;
75
76    /// Get supported chains
77    fn supported_chains(&self) -> Vec<String>;
78}
79
80/// Abstract pricing source for simulations
81#[async_trait]
82pub trait PriceSource: Send + Sync {
83    /// Get exact price in USD
84    async fn get_price_usd(&self, token: &str) -> Result<Decimal>;
85    /// Get liquidity in USD for a pair
86    async fn get_liquidity_usd(&self, token_a: &str, token_b: &str) -> Result<Decimal>;
87}
88
89/// Mock Price Source for testing/default
90pub struct MockPriceSource;
91#[async_trait]
92impl PriceSource for MockPriceSource {
93    async fn get_price_usd(&self, _token: &str) -> Result<Decimal> { Ok(Decimal::ONE) }
94    async fn get_liquidity_usd(&self, _token_a: &str, _token_b: &str) -> Result<Decimal> { Ok(dec!(10_000_000.0)) }
95}
96
97/// A basic simulator that estimates based on liquidity
98pub struct BasicSimulator {
99    /// Default gas cost per chain
100    default_gas_usd: Decimal,
101    /// Price source
102    price_source: Arc<dyn PriceSource>,
103}
104
105impl BasicSimulator {
106    /// Create with default settings
107    pub fn new() -> Self {
108        Self {
109            default_gas_usd: dec!(0.5),
110            price_source: Arc::new(MockPriceSource),
111        }
112    }
113
114    /// Create with custom price source
115    pub fn with_source(source: Arc<dyn PriceSource>) -> Self {
116        Self {
117            default_gas_usd: dec!(0.5),
118            price_source: source,
119        }
120    }
121
122    /// Estimate price impact based on amount and liquidity
123    fn estimate_price_impact(amount_usd: Decimal, liquidity_usd: Decimal) -> Decimal {
124        if liquidity_usd.is_zero() {
125            return dec!(100.0);
126        }
127        (amount_usd / liquidity_usd * dec!(100.0)).min(dec!(100.0))
128    }
129}
130
131impl Default for BasicSimulator {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137#[async_trait]
138impl Simulator for BasicSimulator {
139    async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult> {
140        // 1. Get Prices
141        let price_from = self.price_source.get_price_usd(&request.from_token).await.unwrap_or(Decimal::ONE);
142        let amount_usd = request.amount * price_from;
143        
144        let price_to = self.price_source.get_price_usd(&request.to_token).await.unwrap_or(Decimal::ONE);
145        
146        // 2. Get Liquidity and Impact
147        let liquidity = self.price_source.get_liquidity_usd(&request.from_token, &request.to_token)
148            .await.unwrap_or(dec!(1000000.0));
149            
150        let price_impact = Self::estimate_price_impact(amount_usd, liquidity);
151        
152        // 3. Calculate Output
153        let gross_output_tokens = (request.amount * price_from) / price_to;
154        let fee_rate = dec!(1.0) - dec!(0.003);
155        let impact_rate = dec!(1.0) - (price_impact / dec!(100.0));
156        let net_output_tokens = gross_output_tokens * fee_rate * impact_rate;
157        
158        let min_output = net_output_tokens * (dec!(1.0) - request.slippage_tolerance / dec!(100.0));
159
160        let mut warnings = Vec::new();
161        if price_impact > Decimal::ONE {
162            warnings.push("High price impact detected".to_string());
163        }
164
165        Ok(SimulationResult {
166            success: true,
167            from_token: request.from_token.clone(),
168            to_token: request.to_token.clone(),
169            input_amount: request.amount,
170            output_amount: net_output_tokens,
171            price_impact_percent: price_impact,
172            gas_cost_usd: self.default_gas_usd,
173            min_output,
174            exchange: request.exchange.clone().unwrap_or_else(|| "Jupiter".to_string()),
175            route: vec![request.from_token.clone(), request.to_token.clone()],
176            warnings,
177        })
178    }
179
180    fn supported_chains(&self) -> Vec<String> {
181        vec!["solana".to_string(), "ethereum".to_string()]
182    }
183}
184
185/// Multi-chain simulator that delegates to chain-specific simulators
186pub struct MultiChainSimulator {
187    /// Chain-specific simulators
188    simulators: std::collections::HashMap<String, Box<dyn Simulator>>,
189}
190
191impl MultiChainSimulator {
192    /// Create with no simulators
193    pub fn new() -> Self {
194        Self {
195            simulators: std::collections::HashMap::new(),
196        }
197    }
198
199    /// Add a chain-specific simulator
200    pub fn add_chain(&mut self, chain: impl Into<String>, simulator: Box<dyn Simulator>) {
201        self.simulators.insert(chain.into(), simulator);
202    }
203
204    /// Simulate on specific chain
205    pub async fn simulate_on_chain(
206        &self,
207        chain: &str,
208        request: &SimulationRequest,
209    ) -> Result<SimulationResult> {
210        let simulator = self
211            .simulators
212            .get(chain)
213            .ok_or_else(|| crate::error::Error::Simulation(format!("Unsupported chain: {}", chain)))?;
214
215        simulator.simulate(request).await
216    }
217}
218
219impl Default for MultiChainSimulator {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[tokio::test]
230    async fn test_basic_simulation() {
231        let simulator = BasicSimulator::new();
232
233        let request = SimulationRequest {
234            from_token: "USDC".to_string(),
235            to_token: "SOL".to_string(),
236            amount: dec!(100.0),
237            slippage_tolerance: dec!(1.0),
238            chain: "solana".to_string(),
239            exchange: None,
240        };
241
242        let result = simulator.simulate(&request).await.expect("simulation should succeed");
243        
244        assert!(result.success);
245        assert!(result.output_amount > Decimal::ZERO);
246        assert!(result.min_output < result.output_amount);
247    }
248}