Skip to main content

aagt_core/trading/
simulation.rs

1//! Trade simulation system
2//!
3//! Allows simulating trades before execution to estimate costs, slippage, etc.
4
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use crate::error::Result;
8use async_trait::async_trait;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11
12/// Result of a trade simulation
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SimulationResult {
15    /// Whether simulation was successful
16    pub success: bool,
17    /// Input token
18    pub from_token: String,
19    /// Output token
20    pub to_token: String,
21    /// Input amount
22    pub input_amount: Decimal,
23    /// Expected output amount
24    pub output_amount: Decimal,
25    /// Estimated price impact percentage
26    pub price_impact_percent: Decimal,
27    /// Estimated gas cost in USD
28    pub gas_cost_usd: Decimal,
29    /// Minimum output with slippage
30    pub min_output: Decimal,
31    /// Exchange/DEX being used
32    pub exchange: String,
33    /// Route taken (for multi-hop swaps)
34    pub route: Vec<String>,
35    /// Warnings if any
36    pub warnings: Vec<String>,
37}
38
39impl SimulationResult {
40    /// Check if this trade has high price impact
41    pub fn has_high_impact(&self, threshold: Decimal) -> bool {
42        self.price_impact_percent > threshold
43    }
44
45    /// Get total cost (gas + price impact)
46    pub fn total_cost_usd(&self, input_price_usd: Decimal) -> Decimal {
47        let impact_cost = self.input_amount * input_price_usd * (self.price_impact_percent / dec!(100.0));
48        self.gas_cost_usd + impact_cost
49    }
50}
51
52/// Request for trade simulation
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct SimulationRequest {
55    /// Token to sell
56    pub from_token: String,
57    /// Token to buy
58    pub to_token: String,
59    /// Amount to swap
60    pub amount: Decimal,
61    /// Slippage tolerance percentage
62    pub slippage_tolerance: Decimal,
63    /// Chain to simulate on
64    pub chain: String,
65    /// Optional: specific exchange to use
66    pub exchange: Option<String>,
67}
68
69/// Trait for implementing simulators
70#[async_trait]
71pub trait Simulator: Send + Sync {
72    /// Simulate a trade
73    async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult>;
74
75    /// Get supported chains
76    fn supported_chains(&self) -> Vec<String>;
77}
78
79/// Abstract pricing source for simulations
80#[async_trait]
81pub trait PriceSource: Send + Sync {
82    /// Get exact price in USD
83    async fn get_price_usd(&self, token: &str) -> Result<Decimal>;
84    /// Get liquidity in USD for a pair
85    async fn get_liquidity_usd(&self, token_a: &str, token_b: &str) -> Result<Decimal>;
86}
87
88/// Mock Price Source for testing/default
89pub struct MockPriceSource;
90#[async_trait]
91impl PriceSource for MockPriceSource {
92    async fn get_price_usd(&self, _token: &str) -> Result<Decimal> { Ok(Decimal::ONE) }
93    async fn get_liquidity_usd(&self, _token_a: &str, _token_b: &str) -> Result<Decimal> { Ok(dec!(10_000_000.0)) }
94}
95
96/// A basic simulator that estimates based on liquidity
97pub struct BasicSimulator {
98    /// Default gas cost per chain
99    default_gas_usd: Decimal,
100    /// Price source
101    price_source: Arc<dyn PriceSource>,
102}
103
104impl BasicSimulator {
105    /// Create with default settings
106    pub fn new() -> Self {
107        Self {
108            default_gas_usd: dec!(0.5),
109            price_source: Arc::new(MockPriceSource),
110        }
111    }
112
113    /// Create with custom price source
114    pub fn with_source(source: Arc<dyn PriceSource>) -> Self {
115        Self {
116            default_gas_usd: dec!(0.5),
117            price_source: source,
118        }
119    }
120
121    /// Estimate price impact based on amount and liquidity
122    fn estimate_price_impact(amount_usd: Decimal, liquidity_usd: Decimal) -> Decimal {
123        if liquidity_usd.is_zero() {
124            return dec!(100.0);
125        }
126        (amount_usd / liquidity_usd * dec!(100.0)).min(dec!(100.0))
127    }
128}
129
130impl Default for BasicSimulator {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136#[async_trait]
137impl Simulator for BasicSimulator {
138    async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult> {
139        // 1. Get Prices
140        let price_from = self.price_source.get_price_usd(&request.from_token).await.unwrap_or(Decimal::ONE);
141        let amount_usd = request.amount * price_from;
142        
143        let price_to = self.price_source.get_price_usd(&request.to_token).await.unwrap_or(Decimal::ONE);
144        
145        // 2. Get Liquidity and Impact
146        let liquidity = self.price_source.get_liquidity_usd(&request.from_token, &request.to_token)
147            .await.unwrap_or(dec!(1000000.0));
148            
149        let price_impact = Self::estimate_price_impact(amount_usd, liquidity);
150        
151        // 3. Calculate Output
152        let gross_output_tokens = (request.amount * price_from) / price_to;
153        let fee_rate = dec!(1.0) - dec!(0.003);
154        let impact_rate = dec!(1.0) - (price_impact / dec!(100.0));
155        let net_output_tokens = gross_output_tokens * fee_rate * impact_rate;
156        
157        let min_output = net_output_tokens * (dec!(1.0) - request.slippage_tolerance / dec!(100.0));
158
159        let mut warnings = Vec::new();
160        if price_impact > Decimal::ONE {
161            warnings.push("High price impact detected".to_string());
162        }
163
164        Ok(SimulationResult {
165            success: true,
166            from_token: request.from_token.clone(),
167            to_token: request.to_token.clone(),
168            input_amount: request.amount,
169            output_amount: net_output_tokens,
170            price_impact_percent: price_impact,
171            gas_cost_usd: self.default_gas_usd,
172            min_output,
173            exchange: request.exchange.clone().unwrap_or_else(|| "Jupiter".to_string()),
174            route: vec![request.from_token.clone(), request.to_token.clone()],
175            warnings,
176        })
177    }
178
179    fn supported_chains(&self) -> Vec<String> {
180        vec!["solana".to_string(), "ethereum".to_string()]
181    }
182}
183
184/// Multi-chain simulator that delegates to chain-specific simulators
185pub struct MultiChainSimulator {
186    /// Chain-specific simulators
187    simulators: std::collections::HashMap<String, Box<dyn Simulator>>,
188}
189
190impl MultiChainSimulator {
191    /// Create with no simulators
192    pub fn new() -> Self {
193        Self {
194            simulators: std::collections::HashMap::new(),
195        }
196    }
197
198    /// Add a chain-specific simulator
199    pub fn add_chain(&mut self, chain: impl Into<String>, simulator: Box<dyn Simulator>) {
200        self.simulators.insert(chain.into(), simulator);
201    }
202
203    /// Simulate on specific chain
204    pub async fn simulate_on_chain(
205        &self,
206        chain: &str,
207        request: &SimulationRequest,
208    ) -> Result<SimulationResult> {
209        let simulator = self
210            .simulators
211            .get(chain)
212            .ok_or_else(|| crate::error::Error::Simulation(format!("Unsupported chain: {}", chain)))?;
213
214        simulator.simulate(request).await
215    }
216}
217
218impl Default for MultiChainSimulator {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[tokio::test]
229    async fn test_basic_simulation() {
230        let simulator = BasicSimulator::new();
231
232        let request = SimulationRequest {
233            from_token: "USDC".to_string(),
234            to_token: "SOL".to_string(),
235            amount: dec!(100.0),
236            slippage_tolerance: dec!(1.0),
237            chain: "solana".to_string(),
238            exchange: None,
239        };
240
241        let result = simulator.simulate(&request).await.expect("simulation should succeed");
242        
243        assert!(result.success);
244        assert!(result.output_amount > Decimal::ZERO);
245        assert!(result.min_output < result.output_amount);
246    }
247}