use serde::{Deserialize, Serialize};
use crate::models::symbols::Symbols;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Channel {
Trades,
Candles,
Books,
Aggregates,
Indices,
}
impl Channel {
pub fn as_str(&self) -> &'static str {
match self {
Channel::Trades => "trades",
Channel::Candles => "candles",
Channel::Books => "books",
Channel::Aggregates => "aggregates",
Channel::Indices => "indices",
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, bon::Builder)]
pub struct SubscribeRequest {
pub channel: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub symbol: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub symbols: Option<Vec<String>>,
#[serde(rename = "afterHours", skip_serializing_if = "Option::is_none")]
pub after_hours: Option<bool>,
#[serde(rename = "intradayOddLot", skip_serializing_if = "Option::is_none")]
pub intraday_odd_lot: Option<bool>,
}
impl SubscribeRequest {
pub fn new(channel: Channel, symbol: impl Into<String>) -> Self {
Self {
channel: channel.as_str().to_string(),
symbol: Some(symbol.into()),
..Default::default()
}
}
pub fn with_symbols(channel: Channel, symbols: impl Into<Symbols>) -> Self {
let spec = symbols.into().normalized();
let mut req = Self {
channel: channel.as_str().to_string(),
..Default::default()
};
match spec {
Symbols::Single(s) => req.symbol = Some(s),
Symbols::Many(v) => {
if !v.is_empty() {
req.symbols = Some(v);
}
}
}
req
}
pub fn expand(self) -> Vec<SubscribeRequest> {
match self.symbols {
Some(symbols) => symbols
.into_iter()
.map(|s| SubscribeRequest {
channel: self.channel.clone(),
symbol: Some(s),
symbols: None,
after_hours: self.after_hours,
intraday_odd_lot: self.intraday_odd_lot,
})
.collect(),
None => vec![self],
}
}
pub fn key(&self) -> String {
let base = match &self.symbol {
Some(symbol) => format!("{}:{}", self.channel, symbol),
None => self.channel.clone(),
};
if self.after_hours == Some(true) {
format!("{base}:afterhours")
} else if self.intraday_odd_lot == Some(true) {
format!("{base}:oddlot")
} else {
base
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UnsubscribeRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ids: Option<Vec<String>>,
}
impl UnsubscribeRequest {
pub fn by_id(id: impl Into<String>) -> Self {
Self {
id: Some(id.into()),
ids: None,
}
}
pub fn by_ids(ids: Vec<String>) -> Self {
Self {
id: None,
ids: Some(ids),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSocketMessage {
pub event: String,
#[serde(default)]
pub data: Option<serde_json::Value>,
#[serde(default)]
pub channel: Option<String>,
#[serde(default)]
pub symbol: Option<String>,
#[serde(default)]
pub id: Option<String>,
}
impl WebSocketMessage {
pub fn is_authenticated(&self) -> bool {
self.event == "authenticated"
}
pub fn is_error(&self) -> bool {
self.event == "error"
}
pub fn is_data(&self) -> bool {
self.event == "data"
}
pub fn is_pong(&self) -> bool {
self.event == "pong"
}
pub fn is_heartbeat(&self) -> bool {
self.event == "heartbeat"
}
pub fn is_subscribed(&self) -> bool {
self.event == "subscribed"
}
pub fn error_message(&self) -> Option<String> {
if !self.is_error() {
return None;
}
self.data
.as_ref()
.and_then(|d| d.get("message"))
.and_then(|m| m.as_str())
.map(|s| s.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub apikey: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(rename = "sdkToken", skip_serializing_if = "Option::is_none")]
pub sdk_token: Option<String>,
#[serde(rename = "heartbeatIntervalMs", skip_serializing_if = "Option::is_none")]
pub heartbeat_interval_ms: Option<u64>,
}
impl AuthRequest {
pub fn with_api_key(api_key: impl Into<String>) -> Self {
Self {
apikey: Some(api_key.into()),
token: None,
sdk_token: None,
heartbeat_interval_ms: None,
}
}
pub fn with_token(token: impl Into<String>) -> Self {
Self {
apikey: None,
token: Some(token.into()),
sdk_token: None,
heartbeat_interval_ms: None,
}
}
pub fn with_sdk_token(sdk_token: impl Into<String>) -> Self {
Self {
apikey: None,
token: None,
sdk_token: Some(sdk_token.into()),
heartbeat_interval_ms: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSocketRequest {
pub event: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl WebSocketRequest {
pub fn auth(auth: AuthRequest) -> Self {
Self {
event: "auth".to_string(),
data: Some(serde_json::to_value(auth).unwrap()),
}
}
pub fn subscribe(sub: SubscribeRequest) -> Self {
Self {
event: "subscribe".to_string(),
data: Some(serde_json::to_value(sub).unwrap()),
}
}
pub fn unsubscribe(unsub: UnsubscribeRequest) -> Self {
Self {
event: "unsubscribe".to_string(),
data: Some(serde_json::to_value(unsub).unwrap()),
}
}
pub fn ping(state: Option<String>) -> Self {
Self {
event: "ping".to_string(),
data: state.map(|s| serde_json::json!({"state": s})),
}
}
pub fn subscriptions() -> Self {
Self {
event: "subscriptions".to_string(),
data: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_channel_serialization() {
let channel = Channel::Trades;
let json = serde_json::to_string(&channel).unwrap();
assert_eq!(json, "\"trades\"");
}
#[test]
fn test_channel_deserialization() {
let channel: Channel = serde_json::from_str("\"candles\"").unwrap();
assert_eq!(channel, Channel::Candles);
}
#[test]
fn test_subscribe_request() {
let req = SubscribeRequest::new(Channel::Trades, "2330");
assert_eq!(req.channel, "trades");
assert_eq!(req.symbol.as_deref(), Some("2330"));
assert_eq!(req.key(), "trades:2330");
}
#[test]
fn test_subscribe_request_serialization() {
let req = SubscribeRequest::new(Channel::Trades, "2330");
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"channel\":\"trades\""));
assert!(json.contains("\"symbol\":\"2330\""));
assert!(!json.contains("afterHours"));
assert!(!json.contains("intradayOddLot"));
assert!(!json.contains("\"symbols\""));
}
#[test]
fn symbol_spec_accepts_common_input_shapes() {
let s1: Symbols = "2330".into();
let s2: Symbols = "2330".to_string().into();
let owned = "2330".to_string();
let s3: Symbols = (&owned).into();
assert!(matches!(s1, Symbols::Single(ref v) if v == "2330"));
assert!(matches!(s2, Symbols::Single(ref v) if v == "2330"));
assert!(matches!(s3, Symbols::Single(ref v) if v == "2330"));
let m1: Symbols = vec!["A".to_string(), "B".to_string()].into();
let m2: Symbols = vec!["A", "B"].into();
let m3: Symbols = ["A", "B"].into();
let m4: Symbols = ["A".to_string(), "B".to_string()].into();
let arr: &[&str] = &["A", "B"];
let m5: Symbols = arr.into();
for v in [m1, m2, m3, m4, m5] {
assert!(matches!(v, Symbols::Many(ref x) if x == &["A", "B"]));
}
}
#[test]
fn subscribe_request_with_symbols_serializes_batch() {
let req = SubscribeRequest::with_symbols(Channel::Aggregates, vec!["2330", "0050", "2603"]);
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["channel"], "aggregates");
assert_eq!(json["symbols"], serde_json::json!(["2330", "0050", "2603"]));
assert!(json.get("symbol").is_none());
}
#[test]
fn subscribe_request_with_symbols_single_routes_to_symbol_field() {
let req = SubscribeRequest::with_symbols(Channel::Trades, "2330");
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["symbol"], "2330");
assert!(json.get("symbols").is_none());
}
#[test]
fn bon_builder_round_trips_through_new() {
let via_builder = SubscribeRequest::builder()
.channel("trades".to_string())
.symbol("2330".to_string())
.build();
let via_new = SubscribeRequest::new(Channel::Trades, "2330");
assert_eq!(via_builder, via_new);
}
#[test]
fn with_symbols_dedups_duplicates() {
let req = SubscribeRequest::with_symbols(Channel::Trades, vec!["2330", "2330"]);
assert_eq!(req.symbol.as_deref(), Some("2330"));
assert!(req.symbols.is_none());
assert_eq!(req.expand().len(), 1);
}
#[test]
fn with_symbols_collapses_whitespace_differences() {
let req = SubscribeRequest::with_symbols(Channel::Trades, vec!["2330", " 2330 ", "2330\n"]);
assert_eq!(req.symbol.as_deref(), Some("2330"));
assert!(req.symbols.is_none());
assert_eq!(req.expand().len(), 1);
}
#[test]
fn with_symbols_keeps_distinct_in_insertion_order() {
let req =
SubscribeRequest::with_symbols(Channel::Trades, vec!["2330", "2454", "2317"]);
assert_eq!(
req.symbols.as_deref(),
Some(&["2330".to_string(), "2454".to_string(), "2317".to_string()][..])
);
assert_eq!(req.expand().len(), 3);
}
#[test]
fn with_symbols_empty_input_yields_no_symbol_field() {
let req = SubscribeRequest::with_symbols(Channel::Trades, Vec::<String>::new());
assert!(req.symbol.is_none());
assert!(req.symbols.is_none());
}
#[test]
fn expand_batch_into_per_symbol_requests() {
let batch = SubscribeRequest::with_symbols(Channel::Aggregates, vec!["A", "B", "C"]);
let expanded = batch.expand();
assert_eq!(expanded.len(), 3);
for (i, sym) in ["A", "B", "C"].iter().enumerate() {
assert_eq!(expanded[i].channel, "aggregates");
assert_eq!(expanded[i].symbol.as_deref(), Some(*sym));
assert!(expanded[i].symbols.is_none());
}
}
#[test]
fn expand_preserves_modifier_flags_per_entry() {
let mut batch = SubscribeRequest::with_symbols(Channel::Trades, ["2330", "2454"]);
batch.intraday_odd_lot = Some(true);
let expanded = batch.expand();
for entry in &expanded {
assert_eq!(entry.intraday_odd_lot, Some(true));
assert_eq!(entry.key().contains("oddlot"), true);
}
}
#[test]
fn expand_single_symbol_passes_through() {
let single = SubscribeRequest::new(Channel::Trades, "2330");
let expanded = single.expand();
assert_eq!(expanded.len(), 1);
assert_eq!(expanded[0].symbol.as_deref(), Some("2330"));
}
#[test]
fn test_subscribe_request_after_hours_key_and_wire() {
let req = SubscribeRequest {
channel: "trades".to_string(),
symbol: Some("TXF1!".to_string()),
after_hours: Some(true),
..Default::default()
};
assert_eq!(req.key(), "trades:TXF1!:afterhours");
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"afterHours\":true"));
}
#[test]
fn test_subscribe_request_oddlot_key_and_wire() {
let req = SubscribeRequest {
channel: "trades".to_string(),
symbol: Some("2330".to_string()),
intraday_odd_lot: Some(true),
..Default::default()
};
assert_eq!(req.key(), "trades:2330:oddlot");
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"intradayOddLot\":true"));
}
#[test]
fn test_subscribe_request_deserialize_without_modifiers() {
let json = r#"{"channel":"trades","symbol":"2330"}"#;
let req: SubscribeRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.after_hours, None);
assert_eq!(req.intraday_odd_lot, None);
assert_eq!(req.key(), "trades:2330");
}
#[test]
fn test_unsubscribe_request() {
let req = UnsubscribeRequest::by_id("sub-123");
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"id\":\"sub-123\""));
}
#[test]
fn test_websocket_message_deserialization() {
let json = r#"{
"event": "data",
"channel": "trades",
"symbol": "2330",
"data": {"price": 583.0, "size": 1000}
}"#;
let msg: WebSocketMessage = serde_json::from_str(json).unwrap();
assert!(msg.is_data());
assert_eq!(msg.channel.as_deref(), Some("trades"));
assert_eq!(msg.symbol.as_deref(), Some("2330"));
}
#[test]
fn test_websocket_error_message() {
let json = r#"{
"event": "error",
"data": {"message": "Unauthorized"}
}"#;
let msg: WebSocketMessage = serde_json::from_str(json).unwrap();
assert!(msg.is_error());
assert_eq!(msg.error_message(), Some("Unauthorized".to_string()));
}
#[test]
fn test_websocket_authenticated() {
let json = r#"{"event": "authenticated"}"#;
let msg: WebSocketMessage = serde_json::from_str(json).unwrap();
assert!(msg.is_authenticated());
}
#[test]
fn test_auth_request_api_key() {
let req = AuthRequest::with_api_key("my-api-key");
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"apikey\":\"my-api-key\""));
assert!(!json.contains("token"));
assert!(!json.contains("sdkToken"));
}
#[test]
fn test_auth_request_sdk_token() {
let req = AuthRequest::with_sdk_token("my-sdk-token");
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"sdkToken\":\"my-sdk-token\""));
}
#[test]
fn test_auth_request_heartbeat_interval_omitted_by_default() {
let req = AuthRequest::with_api_key("k");
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("heartbeatIntervalMs"));
}
#[test]
fn test_auth_request_heartbeat_interval_serialized_when_set() {
let mut req = AuthRequest::with_api_key("k");
req.heartbeat_interval_ms = Some(30_000);
let json: serde_json::Value = serde_json::from_str(
&serde_json::to_string(&req).unwrap(),
)
.unwrap();
assert_eq!(json["heartbeatIntervalMs"], 30_000);
assert_eq!(json["apikey"], "k");
}
#[test]
fn test_websocket_request_auth() {
let req = WebSocketRequest::auth(AuthRequest::with_api_key("test"));
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"event\":\"auth\""));
assert!(json.contains("\"apikey\":\"test\""));
}
#[test]
fn test_websocket_request_subscribe() {
let req = WebSocketRequest::subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"event\":\"subscribe\""));
assert!(json.contains("\"channel\":\"trades\""));
}
#[test]
fn test_websocket_request_ping() {
let req = WebSocketRequest::ping(Some("test-state".to_string()));
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"event\":\"ping\""));
assert!(json.contains("\"state\":\"test-state\""));
}
}