Skip to main content

lightcone_sdk/websocket/
handlers.rs

1//! Message handlers for WebSocket events.
2//!
3//! Routes incoming messages to appropriate handlers and emits events.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::websocket::error::WebSocketError;
10use crate::websocket::state::{LocalOrderbook, PriceHistory, UserState};
11use crate::websocket::state::price::PriceHistoryKey;
12use crate::websocket::types::{
13    BookUpdateData, ErrorData, MarketEventData, MessageType, PriceHistoryData, RawWsMessage,
14    TradeData, UserEventData, WsEvent,
15};
16
17/// Handles incoming WebSocket messages
18#[derive(Debug)]
19pub struct MessageHandler {
20    /// Local orderbook state
21    orderbooks: Arc<RwLock<HashMap<String, LocalOrderbook>>>,
22    /// Local user state
23    user_states: Arc<RwLock<HashMap<String, UserState>>>,
24    /// Price history state
25    price_histories: Arc<RwLock<HashMap<PriceHistoryKey, PriceHistory>>>,
26    /// Currently subscribed user (single user per connection)
27    subscribed_user: Arc<RwLock<Option<String>>>,
28}
29
30impl MessageHandler {
31    /// Create a new message handler with shared state
32    pub fn new(
33        orderbooks: Arc<RwLock<HashMap<String, LocalOrderbook>>>,
34        user_states: Arc<RwLock<HashMap<String, UserState>>>,
35        price_histories: Arc<RwLock<HashMap<PriceHistoryKey, PriceHistory>>>,
36        subscribed_user: Arc<RwLock<Option<String>>>,
37    ) -> Self {
38        Self {
39            orderbooks,
40            user_states,
41            price_histories,
42            subscribed_user,
43        }
44    }
45
46    /// Handle an incoming message and return events
47    pub async fn handle_message(&self, text: &str) -> Vec<WsEvent> {
48        // Parse the raw message first
49        let raw_msg: RawWsMessage = match serde_json::from_str(text) {
50            Ok(msg) => msg,
51            Err(e) => {
52                tracing::warn!("Failed to parse WebSocket message: {}", e);
53                return vec![WsEvent::Error {
54                    error: WebSocketError::MessageParseError(e.to_string()),
55                }];
56            }
57        };
58
59        // Route by message type
60        let msg_type = MessageType::from(raw_msg.type_.as_str());
61        match msg_type {
62            MessageType::BookUpdate => self.handle_book_update(&raw_msg).await,
63            MessageType::Trades => self.handle_trade(&raw_msg).await,
64            MessageType::User => self.handle_user_event(&raw_msg).await,
65            MessageType::PriceHistory => self.handle_price_history(&raw_msg).await,
66            MessageType::Market => self.handle_market_event(&raw_msg).await,
67            MessageType::Error => self.handle_error(&raw_msg).await,
68            MessageType::Pong => vec![WsEvent::Pong],
69            MessageType::Unknown => {
70                tracing::warn!("Unknown message type: {}", raw_msg.type_);
71                vec![]
72            }
73        }
74    }
75
76    /// Handle book update message
77    async fn handle_book_update(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
78        let data: BookUpdateData = match serde_json::from_value(raw_msg.data.clone()) {
79            Ok(data) => data,
80            Err(e) => {
81                tracing::warn!("Failed to parse book update: {}", e);
82                return vec![WsEvent::Error {
83                    error: WebSocketError::MessageParseError(e.to_string()),
84                }];
85            }
86        };
87
88        // Check for resync signal
89        if data.resync {
90            tracing::info!("Resync required for orderbook: {}", data.orderbook_id);
91            return vec![WsEvent::ResyncRequired {
92                orderbook_id: data.orderbook_id.clone(),
93            }];
94        }
95
96        let orderbook_id = data.orderbook_id.clone();
97        let is_snapshot = data.is_snapshot;
98
99        // Update local state
100        let mut orderbooks = self.orderbooks.write().await;
101        let book = orderbooks
102            .entry(orderbook_id.clone())
103            .or_insert_with(|| LocalOrderbook::new(orderbook_id.clone()));
104
105        match book.apply_update(&data) {
106            Ok(()) => {
107                vec![WsEvent::BookUpdate {
108                    orderbook_id,
109                    is_snapshot,
110                }]
111            }
112            Err(WebSocketError::SequenceGap { expected, received }) => {
113                tracing::warn!(
114                    "Sequence gap in orderbook {}: expected {}, received {}",
115                    orderbook_id,
116                    expected,
117                    received
118                );
119                // Clear the orderbook state on sequence gap
120                book.clear();
121                vec![WsEvent::ResyncRequired { orderbook_id }]
122            }
123            Err(e) => {
124                vec![WsEvent::Error { error: e }]
125            }
126        }
127    }
128
129    /// Handle trade message
130    async fn handle_trade(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
131        let data: TradeData = match serde_json::from_value(raw_msg.data.clone()) {
132            Ok(data) => data,
133            Err(e) => {
134                tracing::warn!("Failed to parse trade: {}", e);
135                return vec![WsEvent::Error {
136                    error: WebSocketError::MessageParseError(e.to_string()),
137                }];
138            }
139        };
140
141        vec![WsEvent::Trade {
142            orderbook_id: data.orderbook_id.clone(),
143            trade: data,
144        }]
145    }
146
147    /// Handle user event message
148    async fn handle_user_event(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
149        let data: UserEventData = match serde_json::from_value(raw_msg.data.clone()) {
150            Ok(data) => data,
151            Err(e) => {
152                tracing::warn!("Failed to parse user event: {}", e);
153                return vec![WsEvent::Error {
154                    error: WebSocketError::MessageParseError(e.to_string()),
155                }];
156            }
157        };
158
159        let event_type = data.event_type.clone();
160
161        // Use the tracked subscribed user (single user per connection)
162        let user = {
163            let subscribed_user = self.subscribed_user.read().await;
164            subscribed_user
165                .clone()
166                .unwrap_or_else(|| "unknown".to_string())
167        };
168
169        // Check if user state exists (read lock, released quickly)
170        let needs_warning = {
171            let user_states = self.user_states.read().await;
172            !user_states.contains_key(&user)
173        };
174
175        // Update local state for the subscribed user (write lock only if needed)
176        if !needs_warning {
177            let mut user_states = self.user_states.write().await;
178            if let Some(state) = user_states.get_mut(&user) {
179                state.apply_event(&data);
180            }
181        }
182
183        // Log AFTER releasing lock to avoid holding lock during I/O
184        if needs_warning {
185            tracing::warn!(
186                "Received user event '{}' for user '{}' but no subscription exists. \
187                 Call subscribe_user() before receiving events to avoid data loss.",
188                event_type,
189                user
190            );
191        }
192
193        vec![WsEvent::UserUpdate { event_type, user }]
194    }
195
196    /// Handle price history message
197    async fn handle_price_history(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
198        let data: PriceHistoryData = match serde_json::from_value(raw_msg.data.clone()) {
199            Ok(data) => data,
200            Err(e) => {
201                tracing::warn!("Failed to parse price history: {}", e);
202                return vec![WsEvent::Error {
203                    error: WebSocketError::MessageParseError(e.to_string()),
204                }];
205            }
206        };
207
208        // Heartbeats don't have orderbook_id
209        if data.event_type == "heartbeat" {
210            // Update all price histories with heartbeat
211            let mut histories = self.price_histories.write().await;
212            for history in histories.values_mut() {
213                history.apply_heartbeat(&data);
214            }
215            return vec![];
216        }
217
218        let orderbook_id = match &data.orderbook_id {
219            Some(id) => id.clone(),
220            None => {
221                tracing::warn!("Price history message missing orderbook_id");
222                return vec![];
223            }
224        };
225
226        let resolution = data.resolution.clone().unwrap_or_else(|| "1m".to_string());
227
228        // Update local state
229        let mut histories = self.price_histories.write().await;
230        let key = PriceHistoryKey::new(orderbook_id.clone(), resolution.clone());
231
232        if let Some(history) = histories.get_mut(&key) {
233            history.apply_event(&data);
234        } else {
235            // Create new history if this is a snapshot
236            if data.event_type == "snapshot" {
237                let mut history = PriceHistory::new(
238                    orderbook_id.clone(),
239                    resolution.clone(),
240                    data.include_ohlcv.unwrap_or(false),
241                );
242                history.apply_event(&data);
243                histories.insert(key, history);
244            } else {
245                tracing::warn!(
246                    "Received price history event '{}' for orderbook '{}' resolution '{}' \
247                     but no subscription exists. Event dropped.",
248                    data.event_type,
249                    orderbook_id,
250                    resolution
251                );
252            }
253        }
254
255        vec![WsEvent::PriceUpdate {
256            orderbook_id,
257            resolution,
258        }]
259    }
260
261    /// Handle market event message
262    async fn handle_market_event(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
263        let data: MarketEventData = match serde_json::from_value(raw_msg.data.clone()) {
264            Ok(data) => data,
265            Err(e) => {
266                tracing::warn!("Failed to parse market event: {}", e);
267                return vec![WsEvent::Error {
268                    error: WebSocketError::MessageParseError(e.to_string()),
269                }];
270            }
271        };
272
273        vec![WsEvent::MarketEvent {
274            event_type: data.event_type,
275            market_pubkey: data.market_pubkey,
276        }]
277    }
278
279    /// Handle error message from server
280    async fn handle_error(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
281        let data: ErrorData = match serde_json::from_value(raw_msg.data.clone()) {
282            Ok(data) => data,
283            Err(e) => {
284                tracing::warn!("Failed to parse error: {}", e);
285                return vec![WsEvent::Error {
286                    error: WebSocketError::MessageParseError(e.to_string()),
287                }];
288            }
289        };
290
291        tracing::error!("Server error: {} (code: {})", data.error, data.code);
292
293        vec![WsEvent::Error {
294            error: WebSocketError::ServerError {
295                code: data.code,
296                message: data.error,
297            },
298        }]
299    }
300
301    /// Initialize orderbook state for a subscription.
302    ///
303    /// This must be called before subscribing to orderbook updates to ensure
304    /// the local state exists when the first snapshot arrives. If not called,
305    /// the handler will create the state on first message, but there may be
306    /// a brief window where `get_orderbook()` returns `None`.
307    ///
308    /// Uses atomic entry API to avoid race conditions with message handlers.
309    /// Thread-safe: multiple concurrent calls are safe due to interior write lock.
310    pub async fn init_orderbook(&self, orderbook_id: &str) {
311        let mut orderbooks = self.orderbooks.write().await;
312        orderbooks
313            .entry(orderbook_id.to_string())
314            .or_insert_with(|| LocalOrderbook::new(orderbook_id.to_string()));
315    }
316
317    /// Initialize user state for a subscription.
318    ///
319    /// This must be called before subscribing to user events to ensure state
320    /// is ready when the first event arrives. Also sets the tracked user for
321    /// this connection (single user per connection model).
322    ///
323    /// Thread-safe: uses interior write locks for both the subscribed user
324    /// tracking and the user state map. Multiple concurrent calls are safe
325    /// but will serialize on lock acquisition. State mutations from incoming
326    /// events are applied atomically via the same lock.
327    pub async fn init_user_state(&self, user: &str) {
328        // Track the subscribed user
329        *self.subscribed_user.write().await = Some(user.to_string());
330
331        let mut user_states = self.user_states.write().await;
332        user_states
333            .entry(user.to_string())
334            .or_insert_with(|| UserState::new(user.to_string()));
335    }
336
337    /// Clear the subscribed user
338    pub async fn clear_subscribed_user(&self, user: &str) {
339        let mut subscribed = self.subscribed_user.write().await;
340        if subscribed.as_deref() == Some(user) {
341            *subscribed = None;
342        }
343    }
344
345    /// Initialize price history state for a subscription.
346    ///
347    /// Creates an empty price history container for the given orderbook and
348    /// resolution. The actual price data will be populated when the snapshot
349    /// message arrives from the server.
350    ///
351    /// # Arguments
352    ///
353    /// * `orderbook_id` - The orderbook identifier (e.g., "BTC-USD:main")
354    /// * `resolution` - The candle resolution (e.g., "1m", "5m", "1h", "1d")
355    /// * `include_ohlcv` - Whether OHLCV candle data should be tracked
356    ///
357    /// # Example
358    ///
359    /// ```ignore
360    /// handler.init_price_history("BTC-USD:main", "1m", true).await;
361    /// ```
362    pub async fn init_price_history(
363        &self,
364        orderbook_id: &str,
365        resolution: &str,
366        include_ohlcv: bool,
367    ) {
368        let key = PriceHistoryKey::new(orderbook_id.to_string(), resolution.to_string());
369        let mut histories = self.price_histories.write().await;
370        histories.entry(key).or_insert_with(|| {
371            PriceHistory::new(orderbook_id.to_string(), resolution.to_string(), include_ohlcv)
372        });
373    }
374
375    /// Clear orderbook state
376    pub async fn clear_orderbook(&self, orderbook_id: &str) {
377        let mut orderbooks = self.orderbooks.write().await;
378        if let Some(book) = orderbooks.get_mut(orderbook_id) {
379            book.clear();
380        }
381    }
382
383    /// Clear user state
384    pub async fn clear_user_state(&self, user: &str) {
385        let mut user_states = self.user_states.write().await;
386        if let Some(state) = user_states.get_mut(user) {
387            state.clear();
388        }
389    }
390
391    /// Clear price history state
392    pub async fn clear_price_history(&self, orderbook_id: &str, resolution: &str) {
393        let key = PriceHistoryKey::new(orderbook_id.to_string(), resolution.to_string());
394        let mut histories = self.price_histories.write().await;
395        if let Some(history) = histories.get_mut(&key) {
396            history.clear();
397        }
398    }
399
400    /// Clear all state
401    pub async fn clear_all(&self) {
402        self.orderbooks.write().await.clear();
403        self.user_states.write().await.clear();
404        self.price_histories.write().await.clear();
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    async fn create_handler() -> MessageHandler {
413        MessageHandler::new(
414            Arc::new(RwLock::new(HashMap::new())),
415            Arc::new(RwLock::new(HashMap::new())),
416            Arc::new(RwLock::new(HashMap::new())),
417            Arc::new(RwLock::new(None)),
418        )
419    }
420
421    #[tokio::test]
422    async fn test_handle_book_update_snapshot() {
423        let handler = create_handler().await;
424
425        let msg = r#"{
426            "type": "book_update",
427            "version": 0.1,
428            "data": {
429                "orderbook_id": "ob1",
430                "timestamp": "2024-01-01T00:00:00.000Z",
431                "sequence": 0,
432                "bids": [{"side": "bid", "price": "0.500000", "size": "0.001000"}],
433                "asks": [{"side": "ask", "price": "0.510000", "size": "0.000500"}],
434                "is_snapshot": true
435            }
436        }"#;
437
438        let events = handler.handle_message(msg).await;
439        assert_eq!(events.len(), 1);
440
441        match &events[0] {
442            WsEvent::BookUpdate { orderbook_id, is_snapshot } => {
443                assert_eq!(orderbook_id, "ob1");
444                assert!(*is_snapshot);
445            }
446            _ => panic!("Expected BookUpdate event"),
447        }
448    }
449
450    #[tokio::test]
451    async fn test_handle_resync() {
452        let handler = create_handler().await;
453
454        let msg = r#"{
455            "type": "book_update",
456            "version": 0.1,
457            "data": {
458                "orderbook_id": "ob1",
459                "resync": true,
460                "message": "Please re-subscribe to get fresh snapshot"
461            }
462        }"#;
463
464        let events = handler.handle_message(msg).await;
465        assert_eq!(events.len(), 1);
466
467        match &events[0] {
468            WsEvent::ResyncRequired { orderbook_id } => {
469                assert_eq!(orderbook_id, "ob1");
470            }
471            _ => panic!("Expected ResyncRequired event"),
472        }
473    }
474
475    #[tokio::test]
476    async fn test_handle_trade() {
477        let handler = create_handler().await;
478
479        let msg = r#"{
480            "type": "trades",
481            "version": 0.1,
482            "data": {
483                "orderbook_id": "ob1",
484                "price": "0.505000",
485                "size": "0.000250",
486                "side": "bid",
487                "timestamp": "2024-01-01T00:00:00.000Z",
488                "trade_id": "trade123",
489                "sequence": 1
490            }
491        }"#;
492
493        let events = handler.handle_message(msg).await;
494        assert_eq!(events.len(), 1);
495
496        match &events[0] {
497            WsEvent::Trade { orderbook_id, trade } => {
498                assert_eq!(orderbook_id, "ob1");
499                assert_eq!(trade.price, "0.505000");
500                assert_eq!(trade.size, "0.000250");
501            }
502            _ => panic!("Expected Trade event"),
503        }
504    }
505
506    #[tokio::test]
507    async fn test_handle_pong() {
508        let handler = create_handler().await;
509
510        let msg = r#"{
511            "type": "pong",
512            "version": 0.1,
513            "data": {}
514        }"#;
515
516        let events = handler.handle_message(msg).await;
517        assert_eq!(events.len(), 1);
518        assert!(matches!(events[0], WsEvent::Pong));
519    }
520
521    #[tokio::test]
522    async fn test_handle_error() {
523        let handler = create_handler().await;
524
525        let msg = r#"{
526            "type": "error",
527            "version": 0.1,
528            "data": {
529                "error": "Engine unavailable",
530                "code": "ENGINE_UNAVAILABLE"
531            }
532        }"#;
533
534        let events = handler.handle_message(msg).await;
535        assert_eq!(events.len(), 1);
536
537        match &events[0] {
538            WsEvent::Error { error } => {
539                match error {
540                    WebSocketError::ServerError { code, message } => {
541                        assert_eq!(code, "ENGINE_UNAVAILABLE");
542                        assert_eq!(message, "Engine unavailable");
543                    }
544                    _ => panic!("Expected ServerError"),
545                }
546            }
547            _ => panic!("Expected Error event"),
548        }
549    }
550
551    #[tokio::test]
552    async fn test_handle_invalid_json() {
553        let handler = create_handler().await;
554
555        let msg = "not valid json";
556
557        let events = handler.handle_message(msg).await;
558        assert_eq!(events.len(), 1);
559        assert!(matches!(events[0], WsEvent::Error { .. }));
560    }
561}