1use serde::{Deserialize, Serialize};
4
5use crate::models::symbols::Symbols;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "lowercase")]
12pub enum Channel {
13 Trades,
15 Candles,
17 Books,
19 Aggregates,
21 Indices,
23}
24
25impl Channel {
26 pub fn as_str(&self) -> &'static str {
28 match self {
29 Channel::Trades => "trades",
30 Channel::Candles => "candles",
31 Channel::Books => "books",
32 Channel::Aggregates => "aggregates",
33 Channel::Indices => "indices",
34 }
35 }
36}
37
38#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, bon::Builder)]
51pub struct SubscribeRequest {
52 pub channel: String,
54
55 #[serde(skip_serializing_if = "Option::is_none")]
59 pub symbol: Option<String>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
64 pub symbols: Option<Vec<String>>,
65
66 #[serde(rename = "afterHours", skip_serializing_if = "Option::is_none")]
69 pub after_hours: Option<bool>,
70
71 #[serde(rename = "intradayOddLot", skip_serializing_if = "Option::is_none")]
74 pub intraday_odd_lot: Option<bool>,
75}
76
77impl SubscribeRequest {
78 pub fn new(channel: Channel, symbol: impl Into<String>) -> Self {
82 Self {
83 channel: channel.as_str().to_string(),
84 symbol: Some(symbol.into()),
85 ..Default::default()
86 }
87 }
88
89 pub fn with_symbols(channel: Channel, symbols: impl Into<Symbols>) -> Self {
101 let spec = symbols.into().normalized();
102 let mut req = Self {
103 channel: channel.as_str().to_string(),
104 ..Default::default()
105 };
106 match spec {
107 Symbols::Single(s) => req.symbol = Some(s),
108 Symbols::Many(v) => {
109 if !v.is_empty() {
110 req.symbols = Some(v);
111 }
112 }
113 }
114 req
115 }
116
117 pub fn expand(self) -> Vec<SubscribeRequest> {
130 match self.symbols {
131 Some(symbols) => symbols
132 .into_iter()
133 .map(|s| SubscribeRequest {
134 channel: self.channel.clone(),
135 symbol: Some(s),
136 symbols: None,
137 after_hours: self.after_hours,
138 intraday_odd_lot: self.intraday_odd_lot,
139 })
140 .collect(),
141 None => vec![self],
142 }
143 }
144
145 pub fn key(&self) -> String {
157 let base = match &self.symbol {
158 Some(symbol) => format!("{}:{}", self.channel, symbol),
159 None => self.channel.clone(),
160 };
161 if self.after_hours == Some(true) {
162 format!("{base}:afterhours")
163 } else if self.intraday_odd_lot == Some(true) {
164 format!("{base}:oddlot")
165 } else {
166 base
167 }
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct UnsubscribeRequest {
174 #[serde(skip_serializing_if = "Option::is_none")]
176 pub id: Option<String>,
177
178 #[serde(skip_serializing_if = "Option::is_none")]
180 pub ids: Option<Vec<String>>,
181}
182
183impl UnsubscribeRequest {
184 pub fn by_id(id: impl Into<String>) -> Self {
186 Self {
187 id: Some(id.into()),
188 ids: None,
189 }
190 }
191
192 pub fn by_ids(ids: Vec<String>) -> Self {
194 Self {
195 id: None,
196 ids: Some(ids),
197 }
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WebSocketMessage {
204 pub event: String,
206
207 #[serde(default)]
209 pub data: Option<serde_json::Value>,
210
211 #[serde(default)]
213 pub channel: Option<String>,
214
215 #[serde(default)]
217 pub symbol: Option<String>,
218
219 #[serde(default)]
221 pub id: Option<String>,
222}
223
224impl WebSocketMessage {
225 pub fn is_authenticated(&self) -> bool {
227 self.event == "authenticated"
228 }
229
230 pub fn is_error(&self) -> bool {
232 self.event == "error"
233 }
234
235 pub fn is_data(&self) -> bool {
237 self.event == "data"
238 }
239
240 pub fn is_pong(&self) -> bool {
245 self.event == "pong"
246 }
247
248 pub fn is_heartbeat(&self) -> bool {
253 self.event == "heartbeat"
254 }
255
256 pub fn is_subscribed(&self) -> bool {
258 self.event == "subscribed"
259 }
260
261 pub fn error_message(&self) -> Option<String> {
263 if !self.is_error() {
264 return None;
265 }
266 self.data
267 .as_ref()
268 .and_then(|d| d.get("message"))
269 .and_then(|m| m.as_str())
270 .map(|s| s.to_string())
271 }
272
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct AuthRequest {
278 #[serde(skip_serializing_if = "Option::is_none")]
280 pub apikey: Option<String>,
281
282 #[serde(skip_serializing_if = "Option::is_none")]
284 pub token: Option<String>,
285
286 #[serde(rename = "sdkToken", skip_serializing_if = "Option::is_none")]
288 pub sdk_token: Option<String>,
289
290 #[serde(rename = "heartbeatIntervalMs", skip_serializing_if = "Option::is_none")]
300 pub heartbeat_interval_ms: Option<u64>,
301}
302
303impl AuthRequest {
304 pub fn with_api_key(api_key: impl Into<String>) -> Self {
306 Self {
307 apikey: Some(api_key.into()),
308 token: None,
309 sdk_token: None,
310 heartbeat_interval_ms: None,
311 }
312 }
313
314 pub fn with_token(token: impl Into<String>) -> Self {
316 Self {
317 apikey: None,
318 token: Some(token.into()),
319 sdk_token: None,
320 heartbeat_interval_ms: None,
321 }
322 }
323
324 pub fn with_sdk_token(sdk_token: impl Into<String>) -> Self {
326 Self {
327 apikey: None,
328 token: None,
329 sdk_token: Some(sdk_token.into()),
330 heartbeat_interval_ms: None,
331 }
332 }
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct WebSocketRequest {
338 pub event: String,
340
341 #[serde(skip_serializing_if = "Option::is_none")]
343 pub data: Option<serde_json::Value>,
344}
345
346impl WebSocketRequest {
347 pub fn auth(auth: AuthRequest) -> Self {
349 Self {
350 event: "auth".to_string(),
351 data: Some(serde_json::to_value(auth).unwrap()),
352 }
353 }
354
355 pub fn subscribe(sub: SubscribeRequest) -> Self {
357 Self {
358 event: "subscribe".to_string(),
359 data: Some(serde_json::to_value(sub).unwrap()),
360 }
361 }
362
363 pub fn unsubscribe(unsub: UnsubscribeRequest) -> Self {
365 Self {
366 event: "unsubscribe".to_string(),
367 data: Some(serde_json::to_value(unsub).unwrap()),
368 }
369 }
370
371 pub fn ping(state: Option<String>) -> Self {
373 Self {
374 event: "ping".to_string(),
375 data: state.map(|s| serde_json::json!({"state": s})),
376 }
377 }
378
379 pub fn subscriptions() -> Self {
381 Self {
382 event: "subscriptions".to_string(),
383 data: None,
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_channel_serialization() {
394 let channel = Channel::Trades;
395 let json = serde_json::to_string(&channel).unwrap();
396 assert_eq!(json, "\"trades\"");
397 }
398
399 #[test]
400 fn test_channel_deserialization() {
401 let channel: Channel = serde_json::from_str("\"candles\"").unwrap();
402 assert_eq!(channel, Channel::Candles);
403 }
404
405 #[test]
406 fn test_subscribe_request() {
407 let req = SubscribeRequest::new(Channel::Trades, "2330");
408 assert_eq!(req.channel, "trades");
409 assert_eq!(req.symbol.as_deref(), Some("2330"));
410 assert_eq!(req.key(), "trades:2330");
411 }
412
413 #[test]
414 fn test_subscribe_request_serialization() {
415 let req = SubscribeRequest::new(Channel::Trades, "2330");
416 let json = serde_json::to_string(&req).unwrap();
417 assert!(json.contains("\"channel\":\"trades\""));
418 assert!(json.contains("\"symbol\":\"2330\""));
419 assert!(!json.contains("afterHours"));
422 assert!(!json.contains("intradayOddLot"));
423 assert!(!json.contains("\"symbols\""));
425 }
426
427 #[test]
428 fn symbol_spec_accepts_common_input_shapes() {
429 let s1: Symbols = "2330".into();
431 let s2: Symbols = "2330".to_string().into();
432 let owned = "2330".to_string();
433 let s3: Symbols = (&owned).into();
434 assert!(matches!(s1, Symbols::Single(ref v) if v == "2330"));
435 assert!(matches!(s2, Symbols::Single(ref v) if v == "2330"));
436 assert!(matches!(s3, Symbols::Single(ref v) if v == "2330"));
437
438 let m1: Symbols = vec!["A".to_string(), "B".to_string()].into();
440 let m2: Symbols = vec!["A", "B"].into();
441 let m3: Symbols = ["A", "B"].into();
442 let m4: Symbols = ["A".to_string(), "B".to_string()].into();
443 let arr: &[&str] = &["A", "B"];
444 let m5: Symbols = arr.into();
445 for v in [m1, m2, m3, m4, m5] {
446 assert!(matches!(v, Symbols::Many(ref x) if x == &["A", "B"]));
447 }
448 }
449
450 #[test]
451 fn subscribe_request_with_symbols_serializes_batch() {
452 let req = SubscribeRequest::with_symbols(Channel::Aggregates, vec!["2330", "0050", "2603"]);
453 let json = serde_json::to_value(&req).unwrap();
454 assert_eq!(json["channel"], "aggregates");
455 assert_eq!(json["symbols"], serde_json::json!(["2330", "0050", "2603"]));
456 assert!(json.get("symbol").is_none());
458 }
459
460 #[test]
461 fn subscribe_request_with_symbols_single_routes_to_symbol_field() {
462 let req = SubscribeRequest::with_symbols(Channel::Trades, "2330");
464 let json = serde_json::to_value(&req).unwrap();
465 assert_eq!(json["symbol"], "2330");
466 assert!(json.get("symbols").is_none());
467 }
468
469 #[test]
470 fn bon_builder_round_trips_through_new() {
471 let via_builder = SubscribeRequest::builder()
473 .channel("trades".to_string())
474 .symbol("2330".to_string())
475 .build();
476 let via_new = SubscribeRequest::new(Channel::Trades, "2330");
477 assert_eq!(via_builder, via_new);
478 }
479
480 #[test]
481 fn with_symbols_dedups_duplicates() {
482 let req = SubscribeRequest::with_symbols(Channel::Trades, vec!["2330", "2330"]);
483 assert_eq!(req.symbol.as_deref(), Some("2330"));
486 assert!(req.symbols.is_none());
487 assert_eq!(req.expand().len(), 1);
488 }
489
490 #[test]
491 fn with_symbols_collapses_whitespace_differences() {
492 let req = SubscribeRequest::with_symbols(Channel::Trades, vec!["2330", " 2330 ", "2330\n"]);
493 assert_eq!(req.symbol.as_deref(), Some("2330"));
494 assert!(req.symbols.is_none());
495 assert_eq!(req.expand().len(), 1);
496 }
497
498 #[test]
499 fn with_symbols_keeps_distinct_in_insertion_order() {
500 let req =
501 SubscribeRequest::with_symbols(Channel::Trades, vec!["2330", "2454", "2317"]);
502 assert_eq!(
503 req.symbols.as_deref(),
504 Some(&["2330".to_string(), "2454".to_string(), "2317".to_string()][..])
505 );
506 assert_eq!(req.expand().len(), 3);
507 }
508
509 #[test]
510 fn with_symbols_empty_input_yields_no_symbol_field() {
511 let req = SubscribeRequest::with_symbols(Channel::Trades, Vec::<String>::new());
515 assert!(req.symbol.is_none());
516 assert!(req.symbols.is_none());
517 }
518
519 #[test]
520 fn expand_batch_into_per_symbol_requests() {
521 let batch = SubscribeRequest::with_symbols(Channel::Aggregates, vec!["A", "B", "C"]);
522 let expanded = batch.expand();
523 assert_eq!(expanded.len(), 3);
524 for (i, sym) in ["A", "B", "C"].iter().enumerate() {
525 assert_eq!(expanded[i].channel, "aggregates");
526 assert_eq!(expanded[i].symbol.as_deref(), Some(*sym));
527 assert!(expanded[i].symbols.is_none());
528 }
529 }
530
531 #[test]
532 fn expand_preserves_modifier_flags_per_entry() {
533 let mut batch = SubscribeRequest::with_symbols(Channel::Trades, ["2330", "2454"]);
534 batch.intraday_odd_lot = Some(true);
535 let expanded = batch.expand();
536 for entry in &expanded {
537 assert_eq!(entry.intraday_odd_lot, Some(true));
538 assert_eq!(entry.key().contains("oddlot"), true);
539 }
540 }
541
542 #[test]
543 fn expand_single_symbol_passes_through() {
544 let single = SubscribeRequest::new(Channel::Trades, "2330");
545 let expanded = single.expand();
546 assert_eq!(expanded.len(), 1);
547 assert_eq!(expanded[0].symbol.as_deref(), Some("2330"));
548 }
549
550 #[test]
551 fn test_subscribe_request_after_hours_key_and_wire() {
552 let req = SubscribeRequest {
553 channel: "trades".to_string(),
554 symbol: Some("TXF1!".to_string()),
555 after_hours: Some(true),
556 ..Default::default()
557 };
558 assert_eq!(req.key(), "trades:TXF1!:afterhours");
560 let json = serde_json::to_string(&req).unwrap();
561 assert!(json.contains("\"afterHours\":true"));
562 }
563
564 #[test]
565 fn test_subscribe_request_oddlot_key_and_wire() {
566 let req = SubscribeRequest {
567 channel: "trades".to_string(),
568 symbol: Some("2330".to_string()),
569 intraday_odd_lot: Some(true),
570 ..Default::default()
571 };
572 assert_eq!(req.key(), "trades:2330:oddlot");
573 let json = serde_json::to_string(&req).unwrap();
574 assert!(json.contains("\"intradayOddLot\":true"));
575 }
576
577 #[test]
578 fn test_subscribe_request_deserialize_without_modifiers() {
579 let json = r#"{"channel":"trades","symbol":"2330"}"#;
581 let req: SubscribeRequest = serde_json::from_str(json).unwrap();
582 assert_eq!(req.after_hours, None);
583 assert_eq!(req.intraday_odd_lot, None);
584 assert_eq!(req.key(), "trades:2330");
585 }
586
587 #[test]
588 fn test_unsubscribe_request() {
589 let req = UnsubscribeRequest::by_id("sub-123");
590 let json = serde_json::to_string(&req).unwrap();
591 assert!(json.contains("\"id\":\"sub-123\""));
592 }
593
594 #[test]
595 fn test_websocket_message_deserialization() {
596 let json = r#"{
597 "event": "data",
598 "channel": "trades",
599 "symbol": "2330",
600 "data": {"price": 583.0, "size": 1000}
601 }"#;
602 let msg: WebSocketMessage = serde_json::from_str(json).unwrap();
603 assert!(msg.is_data());
604 assert_eq!(msg.channel.as_deref(), Some("trades"));
605 assert_eq!(msg.symbol.as_deref(), Some("2330"));
606 }
607
608 #[test]
609 fn test_websocket_error_message() {
610 let json = r#"{
611 "event": "error",
612 "data": {"message": "Unauthorized"}
613 }"#;
614 let msg: WebSocketMessage = serde_json::from_str(json).unwrap();
615 assert!(msg.is_error());
616 assert_eq!(msg.error_message(), Some("Unauthorized".to_string()));
617 }
618
619 #[test]
620 fn test_websocket_authenticated() {
621 let json = r#"{"event": "authenticated"}"#;
622 let msg: WebSocketMessage = serde_json::from_str(json).unwrap();
623 assert!(msg.is_authenticated());
624 }
625
626 #[test]
627 fn test_auth_request_api_key() {
628 let req = AuthRequest::with_api_key("my-api-key");
629 let json = serde_json::to_string(&req).unwrap();
630 assert!(json.contains("\"apikey\":\"my-api-key\""));
631 assert!(!json.contains("token"));
632 assert!(!json.contains("sdkToken"));
633 }
634
635 #[test]
636 fn test_auth_request_sdk_token() {
637 let req = AuthRequest::with_sdk_token("my-sdk-token");
638 let json = serde_json::to_string(&req).unwrap();
639 assert!(json.contains("\"sdkToken\":\"my-sdk-token\""));
640 }
641
642 #[test]
643 fn test_auth_request_heartbeat_interval_omitted_by_default() {
644 let req = AuthRequest::with_api_key("k");
646 let json = serde_json::to_string(&req).unwrap();
647 assert!(!json.contains("heartbeatIntervalMs"));
648 }
649
650 #[test]
651 fn test_auth_request_heartbeat_interval_serialized_when_set() {
652 let mut req = AuthRequest::with_api_key("k");
653 req.heartbeat_interval_ms = Some(30_000);
654 let json: serde_json::Value = serde_json::from_str(
655 &serde_json::to_string(&req).unwrap(),
656 )
657 .unwrap();
658 assert_eq!(json["heartbeatIntervalMs"], 30_000);
659 assert_eq!(json["apikey"], "k");
660 }
661
662 #[test]
663 fn test_websocket_request_auth() {
664 let req = WebSocketRequest::auth(AuthRequest::with_api_key("test"));
665 let json = serde_json::to_string(&req).unwrap();
666 assert!(json.contains("\"event\":\"auth\""));
667 assert!(json.contains("\"apikey\":\"test\""));
668 }
669
670 #[test]
671 fn test_websocket_request_subscribe() {
672 let req = WebSocketRequest::subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
673 let json = serde_json::to_string(&req).unwrap();
674 assert!(json.contains("\"event\":\"subscribe\""));
675 assert!(json.contains("\"channel\":\"trades\""));
676 }
677
678 #[test]
679 fn test_websocket_request_ping() {
680 let req = WebSocketRequest::ping(Some("test-state".to_string()));
681 let json = serde_json::to_string(&req).unwrap();
682 assert!(json.contains("\"event\":\"ping\""));
683 assert!(json.contains("\"state\":\"test-state\""));
684 }
685}