Skip to main content

bullet_rust_sdk/ws/
models.rs

1//! WebSocket message models for the Trading SDK.
2//!
3//! NOTE: This module is temporarily added here until we have a better solution.
4//! This enum is only used in the trading-sdk for deserializing server messages.
5//! It does NOT live in trading-api-types because:
6//! 1. The server uses optimized types with `&'static str` and `Arc<str>` that can't deserialize
7//! 2. This struct is only needed by SDK clients, not the server
8//!
9//! IMPORTANT: When new message types are added to the server, they must be manually
10//! added to the `ServerMessage` enum below.
11
12use serde::{Deserialize, Serialize};
13
14use crate::types::{
15    AggTradeMessage, BookTickerMessage, DepthUpdate, ErrorMessage, ForceOrderMessage,
16    MarkPriceMessage, OrderUpdateMessage, PongMessage, RequestId, StatusMessage,
17};
18
19/// Result message for subscribe/unsubscribe success
20#[derive(Serialize, Deserialize, Clone, Debug)]
21pub struct MethodResult {
22    #[serde(default)]
23    pub id: Option<RequestId>,
24    /// Event time (ms)
25    #[serde(rename = "E")]
26    pub event_time: u64,
27    pub result: String,
28}
29
30/// Result message for list_subscriptions
31#[derive(Serialize, Deserialize, Clone, Debug)]
32pub struct ListSubscriptionsResult {
33    #[serde(default)]
34    pub id: Option<RequestId>,
35    /// Event time (ms)
36    #[serde(rename = "E")]
37    pub event_time: u64,
38    pub result: Vec<String>,
39}
40
41/// Tagged messages from the server (have an "e" event type field)
42#[derive(Serialize, Deserialize, Clone, Debug, strum::AsRefStr)]
43#[strum(serialize_all = "camelCase")]
44#[serde(tag = "e", rename_all = "snake_case")]
45pub enum TaggedMessage {
46    Status(StatusMessage),
47    Pong(PongMessage),
48    Error(ErrorMessage),
49    Subscribe(MethodResult),
50    Unsubscribe(MethodResult),
51    ListSubscriptions(ListSubscriptionsResult),
52}
53
54/// All possible server messages.
55///
56/// Uses untagged deserialization - serde tries each variant in order until one matches.
57/// The `Unknown` variant captures any message that doesn't match known types.
58#[derive(Serialize, Deserialize, Clone, Debug, strum::AsRefStr)]
59#[strum(serialize_all = "camelCase")]
60#[serde(untagged)]
61pub enum ServerMessage {
62    // Tagged messages with "e" field - try these first
63    Tagged(TaggedMessage),
64
65    // Binance-style messages (identified by "e" event type field)
66    DepthUpdate(DepthUpdate),
67    AggTrade(AggTradeMessage),
68    BookTicker(BookTickerMessage),
69    MarkPrice(MarkPriceMessage),
70    ForceOrder(ForceOrderMessage),
71    OrderUpdate(OrderUpdateMessage),
72
73    // Untagged error response (e.g., order errors without "e" field)
74    Error(ErrorMessage),
75
76    /// Failed to parse message - contains (error message, raw text)
77    #[serde(skip)]
78    Unknown(String, String),
79}
80
81impl ServerMessage {
82    /// Returns true if this is an error message
83    pub fn is_error(&self) -> bool {
84        matches!(self, ServerMessage::Tagged(TaggedMessage::Error(_)) | ServerMessage::Error(_))
85    }
86
87    /// Returns the request ID if present
88    pub fn request_id(&self) -> Option<RequestId> {
89        match self {
90            ServerMessage::Tagged(msg) => match msg {
91                TaggedMessage::Pong(m) => m.id,
92                TaggedMessage::Error(m) => m.id,
93                TaggedMessage::Subscribe(m) => m.id,
94                TaggedMessage::Unsubscribe(m) => m.id,
95                TaggedMessage::ListSubscriptions(m) => m.id,
96                _ => None,
97            },
98            ServerMessage::Error(m) => m.id,
99            _ => None,
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_depth_update() {
110        let json = r#"{
111            "e": "depthUpdate",
112            "E": 1234567890,
113            "T": 1234567890,
114            "s": "BTCUSDT",
115            "U": 100,
116            "u": 200,
117            "pu": 99,
118            "b": [["50000.00", "1.5"]],
119            "a": [["50001.00", "2.0"]],
120            "mt": "s"
121        }"#;
122
123        let msg: ServerMessage = serde_json::from_str(json).unwrap();
124        assert!(matches!(msg, ServerMessage::DepthUpdate(_)));
125
126        if let ServerMessage::DepthUpdate(d) = msg {
127            assert_eq!(d.symbol, "BTCUSDT");
128            assert_eq!(d.bids.len(), 1);
129            assert_eq!(d.asks.len(), 1);
130        }
131    }
132
133    #[test]
134    fn test_agg_trade() {
135        let json = r#"{
136            "e": "aggTrade",
137            "E": 1234567890,
138            "s": "BTCUSDT",
139            "a": 12345,
140            "p": "50000.00",
141            "q": "1.5",
142            "f": 100,
143            "l": 105,
144            "T": 1234567890,
145            "m": true,
146            "th": "0xabc123",
147            "ua": "0xdef456",
148            "oi": 999,
149            "mk": true,
150            "ff": false,
151            "lq": false,
152            "fe": "0.001",
153            "nf": "0.001",
154            "fa": "USDT",
155            "sd": "BUY"
156        }"#;
157
158        let msg: ServerMessage = serde_json::from_str(json).unwrap();
159        assert!(matches!(msg, ServerMessage::AggTrade(_)));
160
161        if let ServerMessage::AggTrade(t) = msg {
162            assert_eq!(t.symbol, "BTCUSDT");
163            assert_eq!(t.price, "50000.00");
164            assert!(t.is_buyer_maker);
165        }
166    }
167
168    #[test]
169    fn test_book_ticker() {
170        let json = r#"{
171            "e": "bookTicker",
172            "u": 12345,
173            "E": 1234567890,
174            "T": 1234567890,
175            "s": "ETHUSDT",
176            "b": "3000.00",
177            "B": "10.5",
178            "a": "3001.00",
179            "A": "8.2",
180            "mt": "u"
181        }"#;
182
183        let msg: ServerMessage = serde_json::from_str(json).unwrap();
184        assert!(matches!(msg, ServerMessage::BookTicker(_)));
185
186        if let ServerMessage::BookTicker(b) = msg {
187            assert_eq!(b.symbol, "ETHUSDT");
188            assert_eq!(b.best_bid_price, "3000.00");
189            assert_eq!(b.best_ask_price, "3001.00");
190        }
191    }
192
193    #[test]
194    fn test_mark_price() {
195        let json = r#"{
196            "e": "markPriceUpdate",
197            "E": 1234567890,
198            "s": "BTCUSDT",
199            "p": "50000.00",
200            "i": "49999.00",
201            "r": "0.0001"
202        }"#;
203
204        let msg: ServerMessage = serde_json::from_str(json).unwrap();
205        assert!(matches!(msg, ServerMessage::MarkPrice(_)));
206
207        if let ServerMessage::MarkPrice(m) = msg {
208            assert_eq!(m.symbol, "BTCUSDT");
209            assert_eq!(m.mark_price, "50000.00");
210            assert_eq!(m.funding_rate, "0.0001");
211        }
212    }
213
214    #[test]
215    fn test_force_order() {
216        let json = r#"{
217            "e": "liquidation",
218            "E": 1234567890,
219            "o": {
220                "s": "BTCUSDT",
221                "S": "SELL",
222                "o": "LIMIT",
223                "f": "IOC",
224                "p": "49000.00",
225                "ap": "49000.00",
226                "X": "FILLED",
227                "l": "1.0",
228                "T": 1234567890,
229                "th": "0xabc",
230                "ua": "0xdef",
231                "oi": 123,
232                "ti": 456
233            }
234        }"#;
235
236        let msg: ServerMessage = serde_json::from_str(json).unwrap();
237        assert!(matches!(msg, ServerMessage::ForceOrder(_)));
238
239        if let ServerMessage::ForceOrder(f) = msg {
240            assert_eq!(f.order.symbol, "BTCUSDT");
241            assert_eq!(f.order.side, "SELL");
242        }
243    }
244
245    #[test]
246    fn test_order_update() {
247        let json = r#"{
248            "e": "orderTradeUpdate",
249            "E": 1234567890,
250            "o": {
251                "s": "BTCUSDT",
252                "i": 12345,
253                "X": "NEW",
254                "x": "NEW",
255                "T": 1234567890,
256                "th": "0xabc",
257                "ua": "0xdef",
258                "S": "BUY",
259                "o": "LIMIT",
260                "f": "GTC",
261                "p": "50000.00",
262                "q": "1.0"
263            }
264        }"#;
265
266        let msg: ServerMessage = serde_json::from_str(json).unwrap();
267        assert!(matches!(msg, ServerMessage::OrderUpdate(_)));
268
269        if let ServerMessage::OrderUpdate(o) = msg {
270            assert_eq!(o.event_time, 1234567890);
271        }
272    }
273
274    #[test]
275    fn test_status_message() {
276        let json = r#"{
277            "e": "status",
278            "E": 1234567890,
279            "status": "connected",
280            "clientId": "client-123"
281        }"#;
282
283        let msg: ServerMessage = serde_json::from_str(json).unwrap();
284        assert!(matches!(msg, ServerMessage::Tagged(TaggedMessage::Status(_))));
285
286        if let ServerMessage::Tagged(TaggedMessage::Status(s)) = msg {
287            assert_eq!(s.status, "connected");
288            assert_eq!(s.client_id, "client-123");
289            assert_eq!(s.event_time, 1234567890);
290        }
291    }
292
293    #[test]
294    fn test_pong_message() {
295        let json = r#"{
296            "e": "pong",
297            "id": 42,
298            "E": 1234567890
299        }"#;
300
301        let msg: ServerMessage = serde_json::from_str(json).unwrap();
302        assert!(matches!(msg, ServerMessage::Tagged(TaggedMessage::Pong(_))));
303        assert_eq!(msg.request_id(), Some(RequestId::from(42)));
304    }
305
306    #[test]
307    fn test_error_message() {
308        let json = r#"{
309            "e": "error",
310            "id": 1,
311            "E": 1234567890,
312            "error": {
313                "code": -1004,
314                "msg": "Invalid subscription format"
315            }
316        }"#;
317
318        let msg: ServerMessage = serde_json::from_str(json).unwrap();
319        assert!(msg.is_error(), "Expected error message, got: {msg:?}");
320        assert_eq!(msg.request_id(), Some(RequestId::from(1)));
321    }
322
323    #[test]
324    fn test_order_error() {
325        // Order errors come without the "e" tag
326        let json = r#"{
327            "id": 2,
328            "E": 1234567890,
329            "error": {
330                "code": -2010,
331                "msg": "Transaction execution unsuccessful"
332            }
333        }"#;
334
335        let msg: ServerMessage = serde_json::from_str(json).unwrap();
336        assert!(msg.is_error(), "Expected error message, got: {msg:?}");
337        assert_eq!(msg.request_id(), Some(RequestId::from(2)));
338        assert!(matches!(msg, ServerMessage::Error(_)));
339    }
340
341    #[test]
342    fn test_subscribe_success() {
343        let json = r#"{
344            "e": "subscribe",
345            "id": 5,
346            "E": 1234567890,
347            "result": "success"
348        }"#;
349
350        let msg: ServerMessage = serde_json::from_str(json).unwrap();
351        assert!(matches!(msg, ServerMessage::Tagged(TaggedMessage::Subscribe(_))));
352        assert_eq!(msg.request_id(), Some(RequestId::from(5)));
353
354        if let ServerMessage::Tagged(TaggedMessage::Subscribe(s)) = msg {
355            assert_eq!(s.result, "success");
356        }
357    }
358
359    #[test]
360    fn test_unsubscribe_success() {
361        let json = r#"{
362            "e": "unsubscribe",
363            "id": 6,
364            "E": 1234567890,
365            "result": "success"
366        }"#;
367
368        let msg: ServerMessage = serde_json::from_str(json).unwrap();
369        assert!(matches!(msg, ServerMessage::Tagged(TaggedMessage::Unsubscribe(_))));
370        assert_eq!(msg.request_id(), Some(RequestId::from(6)));
371    }
372
373    #[test]
374    fn test_list_subscriptions() {
375        let json = r#"{
376            "e": "list_subscriptions",
377            "id": 7,
378            "E": 1234567890,
379            "result": ["btcusdt@depth10", "ethusdt@aggTrade"]
380        }"#;
381
382        let msg: ServerMessage = serde_json::from_str(json).unwrap();
383        assert!(matches!(msg, ServerMessage::Tagged(TaggedMessage::ListSubscriptions(_))));
384
385        if let ServerMessage::Tagged(TaggedMessage::ListSubscriptions(l)) = msg {
386            assert_eq!(l.result.len(), 2);
387            assert_eq!(l.result[0], "btcusdt@depth10");
388        }
389    }
390}