1use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use crate::error::Result;
8use async_trait::async_trait;
9use rust_decimal::Decimal;
10use rust_decimal::prelude::*;
11use rust_decimal_macros::dec;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SimulationResult {
16 pub success: bool,
18 pub from_token: String,
20 pub to_token: String,
22 pub input_amount: Decimal,
24 pub output_amount: Decimal,
26 pub price_impact_percent: Decimal,
28 pub gas_cost_usd: Decimal,
30 pub min_output: Decimal,
32 pub exchange: String,
34 pub route: Vec<String>,
36 pub warnings: Vec<String>,
38}
39
40impl SimulationResult {
41 pub fn has_high_impact(&self, threshold: Decimal) -> bool {
43 self.price_impact_percent > threshold
44 }
45
46 pub fn total_cost_usd(&self, input_price_usd: Decimal) -> Decimal {
48 let impact_cost = self.input_amount * input_price_usd * (self.price_impact_percent / dec!(100.0));
49 self.gas_cost_usd + impact_cost
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SimulationRequest {
56 pub from_token: String,
58 pub to_token: String,
60 pub amount: Decimal,
62 pub slippage_tolerance: Decimal,
64 pub chain: String,
66 pub exchange: Option<String>,
68}
69
70#[async_trait]
72pub trait Simulator: Send + Sync {
73 async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult>;
75
76 fn supported_chains(&self) -> Vec<String>;
78}
79
80#[async_trait]
82pub trait PriceSource: Send + Sync {
83 async fn get_price_usd(&self, token: &str) -> Result<Decimal>;
85 async fn get_liquidity_usd(&self, token_a: &str, token_b: &str) -> Result<Decimal>;
87}
88
89pub struct MockPriceSource;
91#[async_trait]
92impl PriceSource for MockPriceSource {
93 async fn get_price_usd(&self, _token: &str) -> Result<Decimal> { Ok(Decimal::ONE) }
94 async fn get_liquidity_usd(&self, _token_a: &str, _token_b: &str) -> Result<Decimal> { Ok(dec!(10_000_000.0)) }
95}
96
97pub struct BasicSimulator {
99 default_gas_usd: Decimal,
101 price_source: Arc<dyn PriceSource>,
103}
104
105impl BasicSimulator {
106 pub fn new() -> Self {
108 Self {
109 default_gas_usd: dec!(0.5),
110 price_source: Arc::new(MockPriceSource),
111 }
112 }
113
114 pub fn with_source(source: Arc<dyn PriceSource>) -> Self {
116 Self {
117 default_gas_usd: dec!(0.5),
118 price_source: source,
119 }
120 }
121
122 fn estimate_price_impact(amount_usd: Decimal, liquidity_usd: Decimal) -> Decimal {
124 if liquidity_usd.is_zero() {
125 return dec!(100.0);
126 }
127 (amount_usd / liquidity_usd * dec!(100.0)).min(dec!(100.0))
128 }
129}
130
131impl Default for BasicSimulator {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[async_trait]
138impl Simulator for BasicSimulator {
139 async fn simulate(&self, request: &SimulationRequest) -> Result<SimulationResult> {
140 let price_from = self.price_source.get_price_usd(&request.from_token).await.unwrap_or(Decimal::ONE);
142 let amount_usd = request.amount * price_from;
143
144 let price_to = self.price_source.get_price_usd(&request.to_token).await.unwrap_or(Decimal::ONE);
145
146 let liquidity = self.price_source.get_liquidity_usd(&request.from_token, &request.to_token)
148 .await.unwrap_or(dec!(1000000.0));
149
150 let price_impact = Self::estimate_price_impact(amount_usd, liquidity);
151
152 let gross_output_tokens = (request.amount * price_from) / price_to;
154 let fee_rate = dec!(1.0) - dec!(0.003);
155 let impact_rate = dec!(1.0) - (price_impact / dec!(100.0));
156 let net_output_tokens = gross_output_tokens * fee_rate * impact_rate;
157
158 let min_output = net_output_tokens * (dec!(1.0) - request.slippage_tolerance / dec!(100.0));
159
160 let mut warnings = Vec::new();
161 if price_impact > Decimal::ONE {
162 warnings.push("High price impact detected".to_string());
163 }
164
165 Ok(SimulationResult {
166 success: true,
167 from_token: request.from_token.clone(),
168 to_token: request.to_token.clone(),
169 input_amount: request.amount,
170 output_amount: net_output_tokens,
171 price_impact_percent: price_impact,
172 gas_cost_usd: self.default_gas_usd,
173 min_output,
174 exchange: request.exchange.clone().unwrap_or_else(|| "Jupiter".to_string()),
175 route: vec![request.from_token.clone(), request.to_token.clone()],
176 warnings,
177 })
178 }
179
180 fn supported_chains(&self) -> Vec<String> {
181 vec!["solana".to_string(), "ethereum".to_string()]
182 }
183}
184
185pub struct MultiChainSimulator {
187 simulators: std::collections::HashMap<String, Box<dyn Simulator>>,
189}
190
191impl MultiChainSimulator {
192 pub fn new() -> Self {
194 Self {
195 simulators: std::collections::HashMap::new(),
196 }
197 }
198
199 pub fn add_chain(&mut self, chain: impl Into<String>, simulator: Box<dyn Simulator>) {
201 self.simulators.insert(chain.into(), simulator);
202 }
203
204 pub async fn simulate_on_chain(
206 &self,
207 chain: &str,
208 request: &SimulationRequest,
209 ) -> Result<SimulationResult> {
210 let simulator = self
211 .simulators
212 .get(chain)
213 .ok_or_else(|| crate::error::Error::Simulation(format!("Unsupported chain: {}", chain)))?;
214
215 simulator.simulate(request).await
216 }
217}
218
219impl Default for MultiChainSimulator {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[tokio::test]
230 async fn test_basic_simulation() {
231 let simulator = BasicSimulator::new();
232
233 let request = SimulationRequest {
234 from_token: "USDC".to_string(),
235 to_token: "SOL".to_string(),
236 amount: dec!(100.0),
237 slippage_tolerance: dec!(1.0),
238 chain: "solana".to_string(),
239 exchange: None,
240 };
241
242 let result = simulator.simulate(&request).await.expect("simulation should succeed");
243
244 assert!(result.success);
245 assert!(result.output_amount > Decimal::ZERO);
246 assert!(result.min_output < result.output_amount);
247 }
248}