1use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::websocket::error::WebSocketError;
10use crate::websocket::state::{LocalOrderbook, PriceHistory, UserState};
11use crate::websocket::state::price::PriceHistoryKey;
12use crate::websocket::types::{
13 BookUpdateData, ErrorData, MarketEventData, MessageType, PriceHistoryData, RawWsMessage,
14 TradeData, UserEventData, WsEvent,
15};
16
17#[derive(Debug)]
19pub struct MessageHandler {
20 orderbooks: Arc<RwLock<HashMap<String, LocalOrderbook>>>,
22 user_states: Arc<RwLock<HashMap<String, UserState>>>,
24 price_histories: Arc<RwLock<HashMap<PriceHistoryKey, PriceHistory>>>,
26 subscribed_user: Arc<RwLock<Option<String>>>,
28}
29
30impl MessageHandler {
31 pub fn new(
33 orderbooks: Arc<RwLock<HashMap<String, LocalOrderbook>>>,
34 user_states: Arc<RwLock<HashMap<String, UserState>>>,
35 price_histories: Arc<RwLock<HashMap<PriceHistoryKey, PriceHistory>>>,
36 subscribed_user: Arc<RwLock<Option<String>>>,
37 ) -> Self {
38 Self {
39 orderbooks,
40 user_states,
41 price_histories,
42 subscribed_user,
43 }
44 }
45
46 pub async fn handle_message(&self, text: &str) -> Vec<WsEvent> {
48 let raw_msg: RawWsMessage = match serde_json::from_str(text) {
50 Ok(msg) => msg,
51 Err(e) => {
52 tracing::warn!("Failed to parse WebSocket message: {}", e);
53 return vec![WsEvent::Error {
54 error: WebSocketError::MessageParseError(e.to_string()),
55 }];
56 }
57 };
58
59 let msg_type = MessageType::from(raw_msg.type_.as_str());
61 match msg_type {
62 MessageType::BookUpdate => self.handle_book_update(&raw_msg).await,
63 MessageType::Trades => self.handle_trade(&raw_msg).await,
64 MessageType::User => self.handle_user_event(&raw_msg).await,
65 MessageType::PriceHistory => self.handle_price_history(&raw_msg).await,
66 MessageType::Market => self.handle_market_event(&raw_msg).await,
67 MessageType::Error => self.handle_error(&raw_msg).await,
68 MessageType::Pong => vec![WsEvent::Pong],
69 MessageType::Unknown => {
70 tracing::warn!("Unknown message type: {}", raw_msg.type_);
71 vec![]
72 }
73 }
74 }
75
76 async fn handle_book_update(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
78 let data: BookUpdateData = match serde_json::from_value(raw_msg.data.clone()) {
79 Ok(data) => data,
80 Err(e) => {
81 tracing::warn!("Failed to parse book update: {}", e);
82 return vec![WsEvent::Error {
83 error: WebSocketError::MessageParseError(e.to_string()),
84 }];
85 }
86 };
87
88 if data.resync {
90 tracing::info!("Resync required for orderbook: {}", data.orderbook_id);
91 return vec![WsEvent::ResyncRequired {
92 orderbook_id: data.orderbook_id.clone(),
93 }];
94 }
95
96 let orderbook_id = data.orderbook_id.clone();
97 let is_snapshot = data.is_snapshot;
98
99 let mut orderbooks = self.orderbooks.write().await;
101 let book = orderbooks
102 .entry(orderbook_id.clone())
103 .or_insert_with(|| LocalOrderbook::new(orderbook_id.clone()));
104
105 match book.apply_update(&data) {
106 Ok(()) => {
107 vec![WsEvent::BookUpdate {
108 orderbook_id,
109 is_snapshot,
110 }]
111 }
112 Err(WebSocketError::SequenceGap { expected, received }) => {
113 tracing::warn!(
114 "Sequence gap in orderbook {}: expected {}, received {}",
115 orderbook_id,
116 expected,
117 received
118 );
119 book.clear();
121 vec![WsEvent::ResyncRequired { orderbook_id }]
122 }
123 Err(e) => {
124 vec![WsEvent::Error { error: e }]
125 }
126 }
127 }
128
129 async fn handle_trade(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
131 let data: TradeData = match serde_json::from_value(raw_msg.data.clone()) {
132 Ok(data) => data,
133 Err(e) => {
134 tracing::warn!("Failed to parse trade: {}", e);
135 return vec![WsEvent::Error {
136 error: WebSocketError::MessageParseError(e.to_string()),
137 }];
138 }
139 };
140
141 vec![WsEvent::Trade {
142 orderbook_id: data.orderbook_id.clone(),
143 trade: data,
144 }]
145 }
146
147 async fn handle_user_event(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
149 let data: UserEventData = match serde_json::from_value(raw_msg.data.clone()) {
150 Ok(data) => data,
151 Err(e) => {
152 tracing::warn!("Failed to parse user event: {}", e);
153 return vec![WsEvent::Error {
154 error: WebSocketError::MessageParseError(e.to_string()),
155 }];
156 }
157 };
158
159 let event_type = data.event_type.clone();
160
161 let user = {
163 let subscribed_user = self.subscribed_user.read().await;
164 subscribed_user
165 .clone()
166 .unwrap_or_else(|| "unknown".to_string())
167 };
168
169 let needs_warning = {
171 let user_states = self.user_states.read().await;
172 !user_states.contains_key(&user)
173 };
174
175 if !needs_warning {
177 let mut user_states = self.user_states.write().await;
178 if let Some(state) = user_states.get_mut(&user) {
179 state.apply_event(&data);
180 }
181 }
182
183 if needs_warning {
185 tracing::warn!(
186 "Received user event '{}' for user '{}' but no subscription exists. \
187 Call subscribe_user() before receiving events to avoid data loss.",
188 event_type,
189 user
190 );
191 }
192
193 vec![WsEvent::UserUpdate { event_type, user }]
194 }
195
196 async fn handle_price_history(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
198 let data: PriceHistoryData = match serde_json::from_value(raw_msg.data.clone()) {
199 Ok(data) => data,
200 Err(e) => {
201 tracing::warn!("Failed to parse price history: {}", e);
202 return vec![WsEvent::Error {
203 error: WebSocketError::MessageParseError(e.to_string()),
204 }];
205 }
206 };
207
208 if data.event_type == "heartbeat" {
210 let mut histories = self.price_histories.write().await;
212 for history in histories.values_mut() {
213 history.apply_heartbeat(&data);
214 }
215 return vec![];
216 }
217
218 let orderbook_id = match &data.orderbook_id {
219 Some(id) => id.clone(),
220 None => {
221 tracing::warn!("Price history message missing orderbook_id");
222 return vec![];
223 }
224 };
225
226 let resolution = data.resolution.clone().unwrap_or_else(|| "1m".to_string());
227
228 let mut histories = self.price_histories.write().await;
230 let key = PriceHistoryKey::new(orderbook_id.clone(), resolution.clone());
231
232 if let Some(history) = histories.get_mut(&key) {
233 history.apply_event(&data);
234 } else {
235 if data.event_type == "snapshot" {
237 let mut history = PriceHistory::new(
238 orderbook_id.clone(),
239 resolution.clone(),
240 data.include_ohlcv.unwrap_or(false),
241 );
242 history.apply_event(&data);
243 histories.insert(key, history);
244 } else {
245 tracing::warn!(
246 "Received price history event '{}' for orderbook '{}' resolution '{}' \
247 but no subscription exists. Event dropped.",
248 data.event_type,
249 orderbook_id,
250 resolution
251 );
252 }
253 }
254
255 vec![WsEvent::PriceUpdate {
256 orderbook_id,
257 resolution,
258 }]
259 }
260
261 async fn handle_market_event(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
263 let data: MarketEventData = match serde_json::from_value(raw_msg.data.clone()) {
264 Ok(data) => data,
265 Err(e) => {
266 tracing::warn!("Failed to parse market event: {}", e);
267 return vec![WsEvent::Error {
268 error: WebSocketError::MessageParseError(e.to_string()),
269 }];
270 }
271 };
272
273 vec![WsEvent::MarketEvent {
274 event_type: data.event_type,
275 market_pubkey: data.market_pubkey,
276 }]
277 }
278
279 async fn handle_error(&self, raw_msg: &RawWsMessage) -> Vec<WsEvent> {
281 let data: ErrorData = match serde_json::from_value(raw_msg.data.clone()) {
282 Ok(data) => data,
283 Err(e) => {
284 tracing::warn!("Failed to parse error: {}", e);
285 return vec![WsEvent::Error {
286 error: WebSocketError::MessageParseError(e.to_string()),
287 }];
288 }
289 };
290
291 tracing::error!("Server error: {} (code: {})", data.error, data.code);
292
293 vec![WsEvent::Error {
294 error: WebSocketError::ServerError {
295 code: data.code,
296 message: data.error,
297 },
298 }]
299 }
300
301 pub async fn init_orderbook(&self, orderbook_id: &str) {
311 let mut orderbooks = self.orderbooks.write().await;
312 orderbooks
313 .entry(orderbook_id.to_string())
314 .or_insert_with(|| LocalOrderbook::new(orderbook_id.to_string()));
315 }
316
317 pub async fn init_user_state(&self, user: &str) {
328 *self.subscribed_user.write().await = Some(user.to_string());
330
331 let mut user_states = self.user_states.write().await;
332 user_states
333 .entry(user.to_string())
334 .or_insert_with(|| UserState::new(user.to_string()));
335 }
336
337 pub async fn clear_subscribed_user(&self, user: &str) {
339 let mut subscribed = self.subscribed_user.write().await;
340 if subscribed.as_deref() == Some(user) {
341 *subscribed = None;
342 }
343 }
344
345 pub async fn init_price_history(
363 &self,
364 orderbook_id: &str,
365 resolution: &str,
366 include_ohlcv: bool,
367 ) {
368 let key = PriceHistoryKey::new(orderbook_id.to_string(), resolution.to_string());
369 let mut histories = self.price_histories.write().await;
370 histories.entry(key).or_insert_with(|| {
371 PriceHistory::new(orderbook_id.to_string(), resolution.to_string(), include_ohlcv)
372 });
373 }
374
375 pub async fn clear_orderbook(&self, orderbook_id: &str) {
377 let mut orderbooks = self.orderbooks.write().await;
378 if let Some(book) = orderbooks.get_mut(orderbook_id) {
379 book.clear();
380 }
381 }
382
383 pub async fn clear_user_state(&self, user: &str) {
385 let mut user_states = self.user_states.write().await;
386 if let Some(state) = user_states.get_mut(user) {
387 state.clear();
388 }
389 }
390
391 pub async fn clear_price_history(&self, orderbook_id: &str, resolution: &str) {
393 let key = PriceHistoryKey::new(orderbook_id.to_string(), resolution.to_string());
394 let mut histories = self.price_histories.write().await;
395 if let Some(history) = histories.get_mut(&key) {
396 history.clear();
397 }
398 }
399
400 pub async fn clear_all(&self) {
402 self.orderbooks.write().await.clear();
403 self.user_states.write().await.clear();
404 self.price_histories.write().await.clear();
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 async fn create_handler() -> MessageHandler {
413 MessageHandler::new(
414 Arc::new(RwLock::new(HashMap::new())),
415 Arc::new(RwLock::new(HashMap::new())),
416 Arc::new(RwLock::new(HashMap::new())),
417 Arc::new(RwLock::new(None)),
418 )
419 }
420
421 #[tokio::test]
422 async fn test_handle_book_update_snapshot() {
423 let handler = create_handler().await;
424
425 let msg = r#"{
426 "type": "book_update",
427 "version": 0.1,
428 "data": {
429 "orderbook_id": "ob1",
430 "timestamp": "2024-01-01T00:00:00.000Z",
431 "sequence": 0,
432 "bids": [{"side": "bid", "price": "0.500000", "size": "0.001000"}],
433 "asks": [{"side": "ask", "price": "0.510000", "size": "0.000500"}],
434 "is_snapshot": true
435 }
436 }"#;
437
438 let events = handler.handle_message(msg).await;
439 assert_eq!(events.len(), 1);
440
441 match &events[0] {
442 WsEvent::BookUpdate { orderbook_id, is_snapshot } => {
443 assert_eq!(orderbook_id, "ob1");
444 assert!(*is_snapshot);
445 }
446 _ => panic!("Expected BookUpdate event"),
447 }
448 }
449
450 #[tokio::test]
451 async fn test_handle_resync() {
452 let handler = create_handler().await;
453
454 let msg = r#"{
455 "type": "book_update",
456 "version": 0.1,
457 "data": {
458 "orderbook_id": "ob1",
459 "resync": true,
460 "message": "Please re-subscribe to get fresh snapshot"
461 }
462 }"#;
463
464 let events = handler.handle_message(msg).await;
465 assert_eq!(events.len(), 1);
466
467 match &events[0] {
468 WsEvent::ResyncRequired { orderbook_id } => {
469 assert_eq!(orderbook_id, "ob1");
470 }
471 _ => panic!("Expected ResyncRequired event"),
472 }
473 }
474
475 #[tokio::test]
476 async fn test_handle_trade() {
477 let handler = create_handler().await;
478
479 let msg = r#"{
480 "type": "trades",
481 "version": 0.1,
482 "data": {
483 "orderbook_id": "ob1",
484 "price": "0.505000",
485 "size": "0.000250",
486 "side": "bid",
487 "timestamp": "2024-01-01T00:00:00.000Z",
488 "trade_id": "trade123",
489 "sequence": 1
490 }
491 }"#;
492
493 let events = handler.handle_message(msg).await;
494 assert_eq!(events.len(), 1);
495
496 match &events[0] {
497 WsEvent::Trade { orderbook_id, trade } => {
498 assert_eq!(orderbook_id, "ob1");
499 assert_eq!(trade.price, "0.505000");
500 assert_eq!(trade.size, "0.000250");
501 }
502 _ => panic!("Expected Trade event"),
503 }
504 }
505
506 #[tokio::test]
507 async fn test_handle_pong() {
508 let handler = create_handler().await;
509
510 let msg = r#"{
511 "type": "pong",
512 "version": 0.1,
513 "data": {}
514 }"#;
515
516 let events = handler.handle_message(msg).await;
517 assert_eq!(events.len(), 1);
518 assert!(matches!(events[0], WsEvent::Pong));
519 }
520
521 #[tokio::test]
522 async fn test_handle_error() {
523 let handler = create_handler().await;
524
525 let msg = r#"{
526 "type": "error",
527 "version": 0.1,
528 "data": {
529 "error": "Engine unavailable",
530 "code": "ENGINE_UNAVAILABLE"
531 }
532 }"#;
533
534 let events = handler.handle_message(msg).await;
535 assert_eq!(events.len(), 1);
536
537 match &events[0] {
538 WsEvent::Error { error } => {
539 match error {
540 WebSocketError::ServerError { code, message } => {
541 assert_eq!(code, "ENGINE_UNAVAILABLE");
542 assert_eq!(message, "Engine unavailable");
543 }
544 _ => panic!("Expected ServerError"),
545 }
546 }
547 _ => panic!("Expected Error event"),
548 }
549 }
550
551 #[tokio::test]
552 async fn test_handle_invalid_json() {
553 let handler = create_handler().await;
554
555 let msg = "not valid json";
556
557 let events = handler.handle_message(msg).await;
558 assert_eq!(events.len(), 1);
559 assert!(matches!(events[0], WsEvent::Error { .. }));
560 }
561}