aagt_core/trading/
simulation.rs1use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use crate::error::Result;
8use async_trait::async_trait;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SimulationResult {
15 pub success: bool,
17 pub from_token: String,
19 pub to_token: String,
21 pub input_amount: Decimal,
23 pub output_amount: Decimal,
25 pub price_impact_percent: Decimal,
27 pub gas_cost_usd: Decimal,
29 pub min_output: Decimal,
31 pub exchange: String,
33 pub route: Vec<String>,
35 pub warnings: Vec<String>,
37}
38
39impl SimulationResult {
40 pub fn has_high_impact(&self, threshold: Decimal) -> bool {
42 self.price_impact_percent > threshold
43 }
44
45 pub fn total_cost_usd(&self, input_price_usd: Decimal) -> Decimal {
47 let impact_cost = self.input_amount * input_price_usd * (self.price_impact_percent / dec!(100.0));
48 self.gas_cost_usd + impact_cost
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct SimulationRequest {
55 pub from_token: String,
57 pub to_token: String,
59 pub amount: Decimal,
61 pub slippage_tolerance: Decimal,
63 pub chain: String,
65 pub exchange: Option<String>,
67}
68
69#[async_trait]
71pub trait Simulator: Send + Sync {
72 async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult>;
74
75 fn supported_chains(&self) -> Vec<String>;
77}
78
79#[async_trait]
81pub trait PriceSource: Send + Sync {
82 async fn get_price_usd(&self, token: &str) -> Result<Decimal>;
84 async fn get_liquidity_usd(&self, token_a: &str, token_b: &str) -> Result<Decimal>;
86}
87
88pub struct MockPriceSource;
90#[async_trait]
91impl PriceSource for MockPriceSource {
92 async fn get_price_usd(&self, _token: &str) -> Result<Decimal> { Ok(Decimal::ONE) }
93 async fn get_liquidity_usd(&self, _token_a: &str, _token_b: &str) -> Result<Decimal> { Ok(dec!(10_000_000.0)) }
94}
95
96pub struct BasicSimulator {
98 default_gas_usd: Decimal,
100 price_source: Arc<dyn PriceSource>,
102}
103
104impl BasicSimulator {
105 pub fn new() -> Self {
107 Self {
108 default_gas_usd: dec!(0.5),
109 price_source: Arc::new(MockPriceSource),
110 }
111 }
112
113 pub fn with_source(source: Arc<dyn PriceSource>) -> Self {
115 Self {
116 default_gas_usd: dec!(0.5),
117 price_source: source,
118 }
119 }
120
121 fn estimate_price_impact(amount_usd: Decimal, liquidity_usd: Decimal) -> Decimal {
123 if liquidity_usd.is_zero() {
124 return dec!(100.0);
125 }
126 (amount_usd / liquidity_usd * dec!(100.0)).min(dec!(100.0))
127 }
128}
129
130impl Default for BasicSimulator {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[async_trait]
137impl Simulator for BasicSimulator {
138 async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult> {
139 let price_from = self.price_source.get_price_usd(&request.from_token).await.unwrap_or(Decimal::ONE);
141 let amount_usd = request.amount * price_from;
142
143 let price_to = self.price_source.get_price_usd(&request.to_token).await.unwrap_or(Decimal::ONE);
144
145 let liquidity = self.price_source.get_liquidity_usd(&request.from_token, &request.to_token)
147 .await.unwrap_or(dec!(1000000.0));
148
149 let price_impact = Self::estimate_price_impact(amount_usd, liquidity);
150
151 let gross_output_tokens = (request.amount * price_from) / price_to;
153 let fee_rate = dec!(1.0) - dec!(0.003);
154 let impact_rate = dec!(1.0) - (price_impact / dec!(100.0));
155 let net_output_tokens = gross_output_tokens * fee_rate * impact_rate;
156
157 let min_output = net_output_tokens * (dec!(1.0) - request.slippage_tolerance / dec!(100.0));
158
159 let mut warnings = Vec::new();
160 if price_impact > Decimal::ONE {
161 warnings.push("High price impact detected".to_string());
162 }
163
164 Ok(SimulationResult {
165 success: true,
166 from_token: request.from_token.clone(),
167 to_token: request.to_token.clone(),
168 input_amount: request.amount,
169 output_amount: net_output_tokens,
170 price_impact_percent: price_impact,
171 gas_cost_usd: self.default_gas_usd,
172 min_output,
173 exchange: request.exchange.clone().unwrap_or_else(|| "Jupiter".to_string()),
174 route: vec![request.from_token.clone(), request.to_token.clone()],
175 warnings,
176 })
177 }
178
179 fn supported_chains(&self) -> Vec<String> {
180 vec!["solana".to_string(), "ethereum".to_string()]
181 }
182}
183
184pub struct MultiChainSimulator {
186 simulators: std::collections::HashMap<String, Box<dyn Simulator>>,
188}
189
190impl MultiChainSimulator {
191 pub fn new() -> Self {
193 Self {
194 simulators: std::collections::HashMap::new(),
195 }
196 }
197
198 pub fn add_chain(&mut self, chain: impl Into<String>, simulator: Box<dyn Simulator>) {
200 self.simulators.insert(chain.into(), simulator);
201 }
202
203 pub async fn simulate_on_chain(
205 &self,
206 chain: &str,
207 request: &SimulationRequest,
208 ) -> Result<SimulationResult> {
209 let simulator = self
210 .simulators
211 .get(chain)
212 .ok_or_else(|| crate::error::Error::Simulation(format!("Unsupported chain: {}", chain)))?;
213
214 simulator.simulate(request).await
215 }
216}
217
218impl Default for MultiChainSimulator {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[tokio::test]
229 async fn test_basic_simulation() {
230 let simulator = BasicSimulator::new();
231
232 let request = SimulationRequest {
233 from_token: "USDC".to_string(),
234 to_token: "SOL".to_string(),
235 amount: dec!(100.0),
236 slippage_tolerance: dec!(1.0),
237 chain: "solana".to_string(),
238 exchange: None,
239 };
240
241 let result = simulator.simulate(&request).await.expect("simulation should succeed");
242
243 assert!(result.success);
244 assert!(result.output_amount > Decimal::ZERO);
245 assert!(result.min_output < result.output_amount);
246 }
247}