drm_core/strategy/
traits.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::{broadcast, Mutex};
5use tokio::time::Duration;
6
7use crate::error::DrmError;
8use crate::exchange::Exchange;
9use crate::models::{Market, Order, OrderSide, Position};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum StrategyState {
13    Stopped,
14    Running,
15    Paused,
16}
17
18#[derive(Debug, Clone)]
19pub enum StrategyEvent {
20    Started,
21    Stopped,
22    Paused,
23    Resumed,
24    Order(Order),
25    Error(String),
26    Tick,
27}
28
29#[derive(Debug, Clone)]
30pub struct StrategyConfig {
31    pub tick_interval_ms: u64,
32    pub max_position_size: f64,
33    pub spread_bps: u32,
34    pub verbose: bool,
35}
36
37impl Default for StrategyConfig {
38    fn default() -> Self {
39        Self {
40            tick_interval_ms: 1000,
41            max_position_size: 100.0,
42            spread_bps: 100,
43            verbose: false,
44        }
45    }
46}
47
48#[derive(Debug, Clone)]
49pub struct MarketMakingConfig {
50    pub max_exposure: f64,
51    pub check_interval_ms: u64,
52    pub min_spread_bps: u32,
53    pub max_order_size: f64,
54    pub verbose: bool,
55}
56
57impl Default for MarketMakingConfig {
58    fn default() -> Self {
59        Self {
60            max_exposure: 1000.0,
61            check_interval_ms: 2000,
62            min_spread_bps: 50,
63            max_order_size: 100.0,
64            verbose: false,
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct AccountState {
71    pub balance: HashMap<String, f64>,
72    pub positions: Vec<Position>,
73}
74
75#[async_trait]
76pub trait Strategy: Send + Sync {
77    fn name(&self) -> &str;
78    fn config(&self) -> &StrategyConfig;
79    fn state(&self) -> StrategyState;
80
81    async fn on_tick(&mut self) -> Result<(), DrmError>;
82
83    async fn start(&mut self) -> Result<(), DrmError>;
84    async fn stop(&mut self) -> Result<(), DrmError>;
85    fn pause(&mut self);
86    fn resume(&mut self);
87}
88
89pub struct BaseStrategy<E: Exchange + 'static> {
90    pub exchange: Arc<E>,
91    pub market_id: String,
92    pub market: Option<Market>,
93    pub state: StrategyState,
94    pub config: StrategyConfig,
95    pub positions: Vec<Position>,
96    pub open_orders: Vec<Order>,
97    pub event_tx: broadcast::Sender<StrategyEvent>,
98    tick_handle: Option<tokio::task::JoinHandle<()>>,
99    stop_signal: Arc<Mutex<bool>>,
100}
101
102impl<E: Exchange + 'static> BaseStrategy<E> {
103    pub fn new(exchange: Arc<E>, market_id: String, config: StrategyConfig) -> Self {
104        let (event_tx, _) = broadcast::channel(100);
105
106        Self {
107            exchange,
108            market_id,
109            market: None,
110            state: StrategyState::Stopped,
111            config,
112            positions: Vec::new(),
113            open_orders: Vec::new(),
114            event_tx,
115            tick_handle: None,
116            stop_signal: Arc::new(Mutex::new(false)),
117        }
118    }
119
120    pub fn subscribe(&self) -> broadcast::Receiver<StrategyEvent> {
121        self.event_tx.subscribe()
122    }
123
124    pub async fn refresh_state(&mut self) -> Result<(), DrmError> {
125        let (positions, orders) = tokio::try_join!(
126            self.exchange.fetch_positions(Some(&self.market_id)),
127            self.exchange.fetch_open_orders(None),
128        )?;
129
130        self.positions = positions;
131        self.open_orders = orders
132            .into_iter()
133            .filter(|o| o.market_id == self.market_id)
134            .collect();
135
136        Ok(())
137    }
138
139    pub async fn cancel_all_orders(&mut self) -> Result<(), DrmError> {
140        for order in self.open_orders.drain(..) {
141            let _ = self
142                .exchange
143                .cancel_order(&order.id, Some(&self.market_id))
144                .await;
145        }
146        Ok(())
147    }
148
149    pub fn get_position(&self, outcome: &str) -> Option<&Position> {
150        self.positions.iter().find(|p| p.outcome == outcome)
151    }
152
153    pub fn get_net_position(&self) -> f64 {
154        let market = match &self.market {
155            Some(m) if m.outcomes.len() == 2 => m,
156            _ => return 0.0,
157        };
158
159        let pos1 = self
160            .get_position(&market.outcomes[0])
161            .map(|p| p.size)
162            .unwrap_or(0.0);
163
164        let pos2 = self
165            .get_position(&market.outcomes[1])
166            .map(|p| p.size)
167            .unwrap_or(0.0);
168
169        pos1 - pos2
170    }
171
172    pub async fn place_order(
173        &mut self,
174        outcome: &str,
175        side: OrderSide,
176        price: f64,
177        size: f64,
178        token_id: Option<&str>,
179    ) -> Result<Order, DrmError> {
180        let mut params = HashMap::new();
181        if let Some(tid) = token_id {
182            params.insert("token_id".to_string(), tid.to_string());
183        }
184
185        let order = self
186            .exchange
187            .create_order(&self.market_id, outcome, side, price, size, params)
188            .await?;
189
190        self.open_orders.push(order.clone());
191        let _ = self.event_tx.send(StrategyEvent::Order(order.clone()));
192
193        Ok(order)
194    }
195
196    pub fn log(&self, message: &str) {
197        if self.config.verbose {
198            println!("[{}:{}] {}", self.exchange.id(), self.market_id, message);
199        }
200    }
201
202    pub fn is_running(&self) -> bool {
203        self.state == StrategyState::Running
204    }
205
206    pub async fn signal_stop(&self) {
207        let mut stop = self.stop_signal.lock().await;
208        *stop = true;
209    }
210
211    pub async fn should_stop(&self) -> bool {
212        *self.stop_signal.lock().await
213    }
214
215    pub async fn reset_stop_signal(&self) {
216        let mut stop = self.stop_signal.lock().await;
217        *stop = false;
218    }
219
220    pub async fn run_loop<F, Fut>(&mut self, mut on_tick: F) -> Result<(), DrmError>
221    where
222        F: FnMut(&mut Self) -> Fut + Send,
223        Fut: std::future::Future<Output = Result<(), DrmError>> + Send,
224    {
225        self.reset_stop_signal().await;
226        self.state = StrategyState::Running;
227        let _ = self.event_tx.send(StrategyEvent::Started);
228        self.log("Strategy started");
229
230        self.market = Some(self.exchange.fetch_market(&self.market_id).await?);
231        self.log(&format!("Loaded market: {}", self.market_id));
232
233        let tick_interval = Duration::from_millis(self.config.tick_interval_ms);
234
235        loop {
236            if self.should_stop().await {
237                break;
238            }
239
240            if self.state == StrategyState::Paused {
241                tokio::time::sleep(tick_interval).await;
242                continue;
243            }
244
245            if let Err(e) = self.refresh_state().await {
246                self.log(&format!("Failed to refresh state: {e}"));
247                let _ = self.event_tx.send(StrategyEvent::Error(e.to_string()));
248            }
249
250            if let Err(e) = on_tick(self).await {
251                self.log(&format!("Tick error: {e}"));
252                let _ = self.event_tx.send(StrategyEvent::Error(e.to_string()));
253            } else {
254                let _ = self.event_tx.send(StrategyEvent::Tick);
255            }
256
257            tokio::time::sleep(tick_interval).await;
258        }
259
260        self.state = StrategyState::Stopped;
261        let _ = self.event_tx.send(StrategyEvent::Stopped);
262        self.log("Strategy stopped");
263
264        Ok(())
265    }
266
267    pub fn pause(&mut self) {
268        if self.state == StrategyState::Running {
269            self.state = StrategyState::Paused;
270            let _ = self.event_tx.send(StrategyEvent::Paused);
271            self.log("Strategy paused");
272        }
273    }
274
275    pub fn resume(&mut self) {
276        if self.state == StrategyState::Paused {
277            self.state = StrategyState::Running;
278            let _ = self.event_tx.send(StrategyEvent::Resumed);
279            self.log("Strategy resumed");
280        }
281    }
282
283    pub async fn get_account_state(&self) -> Result<AccountState, DrmError> {
284        let balance = self.exchange.fetch_balance().await?;
285        let positions = self.exchange.fetch_positions(Some(&self.market_id)).await?;
286
287        if self.config.verbose {
288            let usdc_balance = balance.get("USDC").copied().unwrap_or(0.0);
289            self.log(&format!("USDC Balance: ${usdc_balance:.2}"));
290            self.log(&format!("Positions: {} open", positions.len()));
291            for pos in &positions {
292                self.log(&format!(
293                    "  {}: {} shares @ avg ${:.4}",
294                    pos.outcome, pos.size, pos.average_price
295                ));
296            }
297        }
298
299        Ok(AccountState { balance, positions })
300    }
301
302    pub fn calculate_order_size(&self, price: f64, max_exposure: f64) -> f64 {
303        let market = match &self.market {
304            Some(m) => m,
305            None => return 5.0,
306        };
307
308        let base_size = if market.liquidity > 0.0 {
309            (20.0_f64).min(market.liquidity * 0.01)
310        } else {
311            5.0
312        };
313
314        let position_cost = base_size * price;
315        if position_cost > max_exposure {
316            base_size * (max_exposure / position_cost)
317        } else {
318            base_size
319        }
320    }
321
322    pub fn calculate_spread_prices(&self, mid_price: f64, spread_bps: u32) -> (f64, f64) {
323        let half_spread = mid_price * (spread_bps as f64 / 10000.0) / 2.0;
324        let bid = mid_price - half_spread;
325        let ask = mid_price + half_spread;
326        (bid, ask)
327    }
328}
329
330impl<E: Exchange + 'static> Drop for BaseStrategy<E> {
331    fn drop(&mut self) {
332        if let Some(handle) = self.tick_handle.take() {
333            handle.abort();
334        }
335    }
336}