1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::{broadcast, Mutex};
5use tokio::time::Duration;
6
7use crate::error::DrmError;
8use crate::exchange::Exchange;
9use crate::models::{Market, Order, OrderSide, Position};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum StrategyState {
13 Stopped,
14 Running,
15 Paused,
16}
17
18#[derive(Debug, Clone)]
19pub enum StrategyEvent {
20 Started,
21 Stopped,
22 Paused,
23 Resumed,
24 Order(Order),
25 Error(String),
26 Tick,
27}
28
29#[derive(Debug, Clone)]
30pub struct StrategyConfig {
31 pub tick_interval_ms: u64,
32 pub max_position_size: f64,
33 pub spread_bps: u32,
34 pub verbose: bool,
35}
36
37impl Default for StrategyConfig {
38 fn default() -> Self {
39 Self {
40 tick_interval_ms: 1000,
41 max_position_size: 100.0,
42 spread_bps: 100,
43 verbose: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
49pub struct MarketMakingConfig {
50 pub max_exposure: f64,
51 pub check_interval_ms: u64,
52 pub min_spread_bps: u32,
53 pub max_order_size: f64,
54 pub verbose: bool,
55}
56
57impl Default for MarketMakingConfig {
58 fn default() -> Self {
59 Self {
60 max_exposure: 1000.0,
61 check_interval_ms: 2000,
62 min_spread_bps: 50,
63 max_order_size: 100.0,
64 verbose: false,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct AccountState {
71 pub balance: HashMap<String, f64>,
72 pub positions: Vec<Position>,
73}
74
75#[async_trait]
76pub trait Strategy: Send + Sync {
77 fn name(&self) -> &str;
78 fn config(&self) -> &StrategyConfig;
79 fn state(&self) -> StrategyState;
80
81 async fn on_tick(&mut self) -> Result<(), DrmError>;
82
83 async fn start(&mut self) -> Result<(), DrmError>;
84 async fn stop(&mut self) -> Result<(), DrmError>;
85 fn pause(&mut self);
86 fn resume(&mut self);
87}
88
89pub struct BaseStrategy<E: Exchange + 'static> {
90 pub exchange: Arc<E>,
91 pub market_id: String,
92 pub market: Option<Market>,
93 pub state: StrategyState,
94 pub config: StrategyConfig,
95 pub positions: Vec<Position>,
96 pub open_orders: Vec<Order>,
97 pub event_tx: broadcast::Sender<StrategyEvent>,
98 tick_handle: Option<tokio::task::JoinHandle<()>>,
99 stop_signal: Arc<Mutex<bool>>,
100}
101
102impl<E: Exchange + 'static> BaseStrategy<E> {
103 pub fn new(exchange: Arc<E>, market_id: String, config: StrategyConfig) -> Self {
104 let (event_tx, _) = broadcast::channel(100);
105
106 Self {
107 exchange,
108 market_id,
109 market: None,
110 state: StrategyState::Stopped,
111 config,
112 positions: Vec::new(),
113 open_orders: Vec::new(),
114 event_tx,
115 tick_handle: None,
116 stop_signal: Arc::new(Mutex::new(false)),
117 }
118 }
119
120 pub fn subscribe(&self) -> broadcast::Receiver<StrategyEvent> {
121 self.event_tx.subscribe()
122 }
123
124 pub async fn refresh_state(&mut self) -> Result<(), DrmError> {
125 let (positions, orders) = tokio::try_join!(
126 self.exchange.fetch_positions(Some(&self.market_id)),
127 self.exchange.fetch_open_orders(None),
128 )?;
129
130 self.positions = positions;
131 self.open_orders = orders
132 .into_iter()
133 .filter(|o| o.market_id == self.market_id)
134 .collect();
135
136 Ok(())
137 }
138
139 pub async fn cancel_all_orders(&mut self) -> Result<(), DrmError> {
140 for order in self.open_orders.drain(..) {
141 let _ = self
142 .exchange
143 .cancel_order(&order.id, Some(&self.market_id))
144 .await;
145 }
146 Ok(())
147 }
148
149 pub fn get_position(&self, outcome: &str) -> Option<&Position> {
150 self.positions.iter().find(|p| p.outcome == outcome)
151 }
152
153 pub fn get_net_position(&self) -> f64 {
154 let market = match &self.market {
155 Some(m) if m.outcomes.len() == 2 => m,
156 _ => return 0.0,
157 };
158
159 let pos1 = self
160 .get_position(&market.outcomes[0])
161 .map(|p| p.size)
162 .unwrap_or(0.0);
163
164 let pos2 = self
165 .get_position(&market.outcomes[1])
166 .map(|p| p.size)
167 .unwrap_or(0.0);
168
169 pos1 - pos2
170 }
171
172 pub async fn place_order(
173 &mut self,
174 outcome: &str,
175 side: OrderSide,
176 price: f64,
177 size: f64,
178 token_id: Option<&str>,
179 ) -> Result<Order, DrmError> {
180 let mut params = HashMap::new();
181 if let Some(tid) = token_id {
182 params.insert("token_id".to_string(), tid.to_string());
183 }
184
185 let order = self
186 .exchange
187 .create_order(&self.market_id, outcome, side, price, size, params)
188 .await?;
189
190 self.open_orders.push(order.clone());
191 let _ = self.event_tx.send(StrategyEvent::Order(order.clone()));
192
193 Ok(order)
194 }
195
196 pub fn log(&self, message: &str) {
197 if self.config.verbose {
198 println!("[{}:{}] {}", self.exchange.id(), self.market_id, message);
199 }
200 }
201
202 pub fn is_running(&self) -> bool {
203 self.state == StrategyState::Running
204 }
205
206 pub async fn signal_stop(&self) {
207 let mut stop = self.stop_signal.lock().await;
208 *stop = true;
209 }
210
211 pub async fn should_stop(&self) -> bool {
212 *self.stop_signal.lock().await
213 }
214
215 pub async fn reset_stop_signal(&self) {
216 let mut stop = self.stop_signal.lock().await;
217 *stop = false;
218 }
219
220 pub async fn run_loop<F, Fut>(&mut self, mut on_tick: F) -> Result<(), DrmError>
221 where
222 F: FnMut(&mut Self) -> Fut + Send,
223 Fut: std::future::Future<Output = Result<(), DrmError>> + Send,
224 {
225 self.reset_stop_signal().await;
226 self.state = StrategyState::Running;
227 let _ = self.event_tx.send(StrategyEvent::Started);
228 self.log("Strategy started");
229
230 self.market = Some(self.exchange.fetch_market(&self.market_id).await?);
231 self.log(&format!("Loaded market: {}", self.market_id));
232
233 let tick_interval = Duration::from_millis(self.config.tick_interval_ms);
234
235 loop {
236 if self.should_stop().await {
237 break;
238 }
239
240 if self.state == StrategyState::Paused {
241 tokio::time::sleep(tick_interval).await;
242 continue;
243 }
244
245 if let Err(e) = self.refresh_state().await {
246 self.log(&format!("Failed to refresh state: {e}"));
247 let _ = self.event_tx.send(StrategyEvent::Error(e.to_string()));
248 }
249
250 if let Err(e) = on_tick(self).await {
251 self.log(&format!("Tick error: {e}"));
252 let _ = self.event_tx.send(StrategyEvent::Error(e.to_string()));
253 } else {
254 let _ = self.event_tx.send(StrategyEvent::Tick);
255 }
256
257 tokio::time::sleep(tick_interval).await;
258 }
259
260 self.state = StrategyState::Stopped;
261 let _ = self.event_tx.send(StrategyEvent::Stopped);
262 self.log("Strategy stopped");
263
264 Ok(())
265 }
266
267 pub fn pause(&mut self) {
268 if self.state == StrategyState::Running {
269 self.state = StrategyState::Paused;
270 let _ = self.event_tx.send(StrategyEvent::Paused);
271 self.log("Strategy paused");
272 }
273 }
274
275 pub fn resume(&mut self) {
276 if self.state == StrategyState::Paused {
277 self.state = StrategyState::Running;
278 let _ = self.event_tx.send(StrategyEvent::Resumed);
279 self.log("Strategy resumed");
280 }
281 }
282
283 pub async fn get_account_state(&self) -> Result<AccountState, DrmError> {
284 let balance = self.exchange.fetch_balance().await?;
285 let positions = self.exchange.fetch_positions(Some(&self.market_id)).await?;
286
287 if self.config.verbose {
288 let usdc_balance = balance.get("USDC").copied().unwrap_or(0.0);
289 self.log(&format!("USDC Balance: ${usdc_balance:.2}"));
290 self.log(&format!("Positions: {} open", positions.len()));
291 for pos in &positions {
292 self.log(&format!(
293 " {}: {} shares @ avg ${:.4}",
294 pos.outcome, pos.size, pos.average_price
295 ));
296 }
297 }
298
299 Ok(AccountState { balance, positions })
300 }
301
302 pub fn calculate_order_size(&self, price: f64, max_exposure: f64) -> f64 {
303 let market = match &self.market {
304 Some(m) => m,
305 None => return 5.0,
306 };
307
308 let base_size = if market.liquidity > 0.0 {
309 (20.0_f64).min(market.liquidity * 0.01)
310 } else {
311 5.0
312 };
313
314 let position_cost = base_size * price;
315 if position_cost > max_exposure {
316 base_size * (max_exposure / position_cost)
317 } else {
318 base_size
319 }
320 }
321
322 pub fn calculate_spread_prices(&self, mid_price: f64, spread_bps: u32) -> (f64, f64) {
323 let half_spread = mid_price * (spread_bps as f64 / 10000.0) / 2.0;
324 let bid = mid_price - half_spread;
325 let ask = mid_price + half_spread;
326 (bid, ask)
327 }
328}
329
330impl<E: Exchange + 'static> Drop for BaseStrategy<E> {
331 fn drop(&mut self) {
332 if let Some(handle) = self.tick_handle.take() {
333 handle.abort();
334 }
335 }
336}