Skip to main content

lightcone_sdk/websocket/
subscriptions.rs

1//! Subscription management for WebSocket channels.
2//!
3//! Tracks active subscriptions and supports re-subscribing after reconnect.
4
5use std::collections::{HashMap, HashSet};
6
7use crate::websocket::types::SubscribeParams;
8
9/// Represents a subscription to a specific channel
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub enum Subscription {
12    /// Book update subscription for orderbook IDs
13    BookUpdate { orderbook_ids: Vec<String> },
14    /// Trades subscription for orderbook IDs
15    Trades { orderbook_ids: Vec<String> },
16    /// User subscription for a public key
17    User { user: String },
18    /// Price history subscription
19    PriceHistory {
20        orderbook_id: String,
21        resolution: String,
22        include_ohlcv: bool,
23    },
24    /// Market events subscription
25    Market { market_pubkey: String },
26}
27
28impl Subscription {
29    /// Convert to SubscribeParams for sending
30    pub fn to_params(&self) -> SubscribeParams {
31        match self {
32            Self::BookUpdate { orderbook_ids } => SubscribeParams::book_update(orderbook_ids.clone()),
33            Self::Trades { orderbook_ids } => SubscribeParams::trades(orderbook_ids.clone()),
34            Self::User { user } => SubscribeParams::user(user.clone()),
35            Self::PriceHistory {
36                orderbook_id,
37                resolution,
38                include_ohlcv,
39            } => SubscribeParams::price_history(
40                orderbook_id.clone(),
41                resolution.clone(),
42                *include_ohlcv,
43            ),
44            Self::Market { market_pubkey } => SubscribeParams::market(market_pubkey.clone()),
45        }
46    }
47
48    /// Get the subscription type as a string
49    pub fn subscription_type(&self) -> &'static str {
50        match self {
51            Self::BookUpdate { .. } => "book_update",
52            Self::Trades { .. } => "trades",
53            Self::User { .. } => "user",
54            Self::PriceHistory { .. } => "price_history",
55            Self::Market { .. } => "market",
56        }
57    }
58}
59
60/// Manages active subscriptions
61#[derive(Debug, Default)]
62pub struct SubscriptionManager {
63    /// Book update subscriptions (orderbook_id -> subscription)
64    book_updates: HashSet<String>,
65    /// Trades subscriptions (orderbook_id -> subscription)
66    trades: HashSet<String>,
67    /// User subscriptions (user pubkey -> subscription)
68    users: HashSet<String>,
69    /// Price history subscriptions (orderbook_id:resolution -> params)
70    price_history: HashMap<String, (String, String, bool)>, // key -> (orderbook_id, resolution, include_ohlcv)
71    /// Market subscriptions (market_pubkey -> subscription)
72    markets: HashSet<String>,
73}
74
75impl SubscriptionManager {
76    /// Create a new subscription manager
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Add a book update subscription
82    pub fn add_book_update(&mut self, orderbook_ids: Vec<String>) {
83        for id in orderbook_ids {
84            self.book_updates.insert(id);
85        }
86    }
87
88    /// Remove a book update subscription
89    pub fn remove_book_update(&mut self, orderbook_ids: &[String]) {
90        for id in orderbook_ids {
91            self.book_updates.remove(id);
92        }
93    }
94
95    /// Check if subscribed to book updates for an orderbook
96    pub fn is_subscribed_book_update(&self, orderbook_id: &str) -> bool {
97        self.book_updates.contains(orderbook_id)
98    }
99
100    /// Add a trades subscription
101    pub fn add_trades(&mut self, orderbook_ids: Vec<String>) {
102        for id in orderbook_ids {
103            self.trades.insert(id);
104        }
105    }
106
107    /// Remove a trades subscription
108    pub fn remove_trades(&mut self, orderbook_ids: &[String]) {
109        for id in orderbook_ids {
110            self.trades.remove(id);
111        }
112    }
113
114    /// Check if subscribed to trades for an orderbook
115    pub fn is_subscribed_trades(&self, orderbook_id: &str) -> bool {
116        self.trades.contains(orderbook_id)
117    }
118
119    /// Add a user subscription
120    pub fn add_user(&mut self, user: String) {
121        self.users.insert(user);
122    }
123
124    /// Remove a user subscription
125    pub fn remove_user(&mut self, user: &str) {
126        self.users.remove(user);
127    }
128
129    /// Check if subscribed to a user
130    pub fn is_subscribed_user(&self, user: &str) -> bool {
131        self.users.contains(user)
132    }
133
134    /// Add a price history subscription
135    pub fn add_price_history(&mut self, orderbook_id: String, resolution: String, include_ohlcv: bool) {
136        let key = format!("{}:{}", orderbook_id, resolution);
137        self.price_history
138            .insert(key, (orderbook_id, resolution, include_ohlcv));
139    }
140
141    /// Remove a price history subscription
142    pub fn remove_price_history(&mut self, orderbook_id: &str, resolution: &str) {
143        let key = format!("{}:{}", orderbook_id, resolution);
144        self.price_history.remove(&key);
145    }
146
147    /// Check if subscribed to price history for an orderbook/resolution
148    pub fn is_subscribed_price_history(&self, orderbook_id: &str, resolution: &str) -> bool {
149        let key = format!("{}:{}", orderbook_id, resolution);
150        self.price_history.contains_key(&key)
151    }
152
153    /// Add a market subscription
154    pub fn add_market(&mut self, market_pubkey: String) {
155        self.markets.insert(market_pubkey);
156    }
157
158    /// Remove a market subscription
159    pub fn remove_market(&mut self, market_pubkey: &str) {
160        self.markets.remove(market_pubkey);
161    }
162
163    /// Check if subscribed to market events
164    pub fn is_subscribed_market(&self, market_pubkey: &str) -> bool {
165        self.markets.contains(market_pubkey) || self.markets.contains("all")
166    }
167
168    /// Get all subscriptions for re-subscribing after reconnect
169    pub fn get_all_subscriptions(&self) -> Vec<Subscription> {
170        let mut subs = Vec::new();
171
172        // Group book updates
173        if !self.book_updates.is_empty() {
174            subs.push(Subscription::BookUpdate {
175                orderbook_ids: self.book_updates.iter().cloned().collect(),
176            });
177        }
178
179        // Group trades
180        if !self.trades.is_empty() {
181            subs.push(Subscription::Trades {
182                orderbook_ids: self.trades.iter().cloned().collect(),
183            });
184        }
185
186        // Users
187        for user in &self.users {
188            subs.push(Subscription::User { user: user.clone() });
189        }
190
191        // Price history
192        for (orderbook_id, resolution, include_ohlcv) in self.price_history.values() {
193            subs.push(Subscription::PriceHistory {
194                orderbook_id: orderbook_id.clone(),
195                resolution: resolution.clone(),
196                include_ohlcv: *include_ohlcv,
197            });
198        }
199
200        // Markets
201        for market_pubkey in &self.markets {
202            subs.push(Subscription::Market {
203                market_pubkey: market_pubkey.clone(),
204            });
205        }
206
207        subs
208    }
209
210    /// Clear all subscriptions
211    pub fn clear(&mut self) {
212        self.book_updates.clear();
213        self.trades.clear();
214        self.users.clear();
215        self.price_history.clear();
216        self.markets.clear();
217    }
218
219    /// Check if there are any active subscriptions
220    pub fn has_subscriptions(&self) -> bool {
221        !self.book_updates.is_empty()
222            || !self.trades.is_empty()
223            || !self.users.is_empty()
224            || !self.price_history.is_empty()
225            || !self.markets.is_empty()
226    }
227
228    /// Get count of active subscriptions
229    pub fn subscription_count(&self) -> usize {
230        self.book_updates.len()
231            + self.trades.len()
232            + self.users.len()
233            + self.price_history.len()
234            + self.markets.len()
235    }
236
237    /// Get all subscribed orderbook IDs (for book updates)
238    pub fn book_update_orderbooks(&self) -> Vec<String> {
239        self.book_updates.iter().cloned().collect()
240    }
241
242    /// Get all subscribed orderbook IDs (for trades)
243    pub fn trade_orderbooks(&self) -> Vec<String> {
244        self.trades.iter().cloned().collect()
245    }
246
247    /// Get all subscribed users
248    pub fn subscribed_users(&self) -> Vec<String> {
249        self.users.iter().cloned().collect()
250    }
251
252    /// Get all subscribed markets
253    pub fn subscribed_markets(&self) -> Vec<String> {
254        self.markets.iter().cloned().collect()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_book_update_subscriptions() {
264        let mut manager = SubscriptionManager::new();
265
266        manager.add_book_update(vec!["ob1".to_string(), "ob2".to_string()]);
267        assert!(manager.is_subscribed_book_update("ob1"));
268        assert!(manager.is_subscribed_book_update("ob2"));
269        assert!(!manager.is_subscribed_book_update("ob3"));
270
271        manager.remove_book_update(&["ob1".to_string()]);
272        assert!(!manager.is_subscribed_book_update("ob1"));
273        assert!(manager.is_subscribed_book_update("ob2"));
274    }
275
276    #[test]
277    fn test_user_subscriptions() {
278        let mut manager = SubscriptionManager::new();
279
280        manager.add_user("user1".to_string());
281        assert!(manager.is_subscribed_user("user1"));
282        assert!(!manager.is_subscribed_user("user2"));
283
284        manager.remove_user("user1");
285        assert!(!manager.is_subscribed_user("user1"));
286    }
287
288    #[test]
289    fn test_price_history_subscriptions() {
290        let mut manager = SubscriptionManager::new();
291
292        manager.add_price_history("ob1".to_string(), "1m".to_string(), true);
293        assert!(manager.is_subscribed_price_history("ob1", "1m"));
294        assert!(!manager.is_subscribed_price_history("ob1", "5m"));
295
296        manager.remove_price_history("ob1", "1m");
297        assert!(!manager.is_subscribed_price_history("ob1", "1m"));
298    }
299
300    #[test]
301    fn test_market_subscriptions() {
302        let mut manager = SubscriptionManager::new();
303
304        manager.add_market("market1".to_string());
305        assert!(manager.is_subscribed_market("market1"));
306
307        // Test "all" markets
308        manager.add_market("all".to_string());
309        assert!(manager.is_subscribed_market("any_market"));
310    }
311
312    #[test]
313    fn test_get_all_subscriptions() {
314        let mut manager = SubscriptionManager::new();
315
316        manager.add_book_update(vec!["ob1".to_string()]);
317        manager.add_user("user1".to_string());
318        manager.add_price_history("ob1".to_string(), "1m".to_string(), true);
319
320        let subs = manager.get_all_subscriptions();
321        assert_eq!(subs.len(), 3);
322    }
323
324    #[test]
325    fn test_subscription_count() {
326        let mut manager = SubscriptionManager::new();
327
328        assert_eq!(manager.subscription_count(), 0);
329        assert!(!manager.has_subscriptions());
330
331        manager.add_book_update(vec!["ob1".to_string(), "ob2".to_string()]);
332        manager.add_user("user1".to_string());
333
334        assert_eq!(manager.subscription_count(), 3);
335        assert!(manager.has_subscriptions());
336    }
337
338    #[test]
339    fn test_clear() {
340        let mut manager = SubscriptionManager::new();
341
342        manager.add_book_update(vec!["ob1".to_string()]);
343        manager.add_user("user1".to_string());
344
345        manager.clear();
346
347        assert!(!manager.has_subscriptions());
348        assert_eq!(manager.subscription_count(), 0);
349    }
350
351    #[test]
352    fn test_subscription_to_params() {
353        let sub = Subscription::BookUpdate {
354            orderbook_ids: vec!["ob1".to_string()],
355        };
356
357        let params = sub.to_params();
358        let json = serde_json::to_string(&params).unwrap();
359        assert!(json.contains("book_update"));
360        assert!(json.contains("ob1"));
361    }
362}