use super::error::MassiveError;
use super::transformer::{WsMessage, parse_ws_message};
use crate::event::{DataKind, MarketEvent};
use chrono::{DateTime, Utc};
use futures::{SinkExt, Stream, StreamExt};
use rustrade_instrument::exchange::ExchangeId;
use serde::Serialize;
use std::collections::HashMap;
use std::env;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time::{Instant, interval_at};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
use tracing::{debug, trace, warn};
const ENV_API_KEY: &str = "MASSIVE_API_KEY";
const PING_INTERVAL: Duration = Duration::from_secs(20);
const PONG_TIMEOUT: Duration = Duration::from_secs(19);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Market {
Stocks,
Crypto,
Forex,
Options,
}
impl Market {
fn ws_url(&self) -> &'static str {
match self {
Market::Stocks => "wss://socket.polygon.io/stocks",
Market::Crypto => "wss://socket.polygon.io/crypto",
Market::Forex => "wss://socket.polygon.io/forex",
Market::Options => "wss://socket.polygon.io/options",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ChannelType {
Trade,
Quote,
AggregateSecond,
AggregateMinute,
}
impl ChannelType {
fn prefix(self, market: Market) -> Option<&'static str> {
match (self, market) {
(ChannelType::Trade, Market::Stocks) => Some("T"),
(ChannelType::Quote, Market::Stocks) => Some("Q"),
(ChannelType::AggregateSecond, Market::Stocks) => Some("A"),
(ChannelType::AggregateMinute, Market::Stocks) => Some("AM"),
(ChannelType::Trade, Market::Crypto) => Some("XT"),
(ChannelType::Quote, Market::Crypto) => Some("XQ"),
(ChannelType::AggregateSecond, Market::Crypto) => Some("XA"),
(ChannelType::AggregateMinute, Market::Crypto) => Some("XAM"),
(ChannelType::Trade, Market::Forex) => None,
(ChannelType::Quote, Market::Forex) => Some("C"),
(ChannelType::AggregateSecond, Market::Forex) => Some("CA"),
(ChannelType::AggregateMinute, Market::Forex) => Some("CAM"),
(ChannelType::Trade, Market::Options) => Some("T"),
(ChannelType::Quote, Market::Options) => Some("Q"),
(ChannelType::AggregateSecond, Market::Options) => Some("A"),
(ChannelType::AggregateMinute, Market::Options) => Some("AM"),
}
}
fn channel_for(self, market: Market, symbol: &str) -> Option<String> {
let prefix = self.prefix(market)?;
Some(format!("{}.{}", prefix, symbol))
}
}
#[derive(Serialize)]
struct AuthMessage<'a> {
action: &'static str,
params: &'a str,
}
#[derive(Serialize)]
struct SubscribeMessage {
action: &'static str,
params: String,
}
pub struct MassiveLive<K> {
api_key: String,
market: Market,
instruments: HashMap<String, K>,
exchange: ExchangeId,
subscriptions: Vec<String>,
ws_url: String,
}
impl<K: std::fmt::Debug> std::fmt::Debug for MassiveLive<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MassiveLive")
.field("api_key", &"[REDACTED]")
.field("market", &self.market)
.field("instruments", &self.instruments)
.field("exchange", &self.exchange)
.field("subscriptions", &self.subscriptions)
.field("ws_url", &self.ws_url)
.finish()
}
}
impl<K> MassiveLive<K> {
pub fn new(
api_key: impl Into<String>,
market: Market,
exchange: ExchangeId,
instruments: HashMap<String, K>,
) -> Self {
Self {
api_key: api_key.into(),
market,
instruments,
exchange,
subscriptions: Vec::new(),
ws_url: market.ws_url().to_string(),
}
}
pub fn from_env(
market: Market,
exchange: ExchangeId,
instruments: HashMap<String, K>,
) -> Result<Self, MassiveError> {
let api_key =
env::var(ENV_API_KEY).map_err(|_| MassiveError::EnvVar { var: ENV_API_KEY })?;
Ok(Self::new(api_key, market, exchange, instruments))
}
#[must_use]
pub fn with_ws_url(mut self, url: impl Into<String>) -> Self {
self.ws_url = url.into();
self
}
pub fn market(&self) -> Market {
self.market
}
pub fn subscribe(&mut self, symbols: &[&str], channel_type: ChannelType) {
for symbol in symbols {
if let Some(channel) = channel_type.channel_for(self.market, symbol) {
debug!(channel = %channel, "Adding subscription");
self.subscriptions.push(channel);
} else {
warn!(
symbol = %symbol,
channel_type = ?channel_type,
market = ?self.market,
"Channel type not supported for this market, skipping"
);
}
}
}
pub fn subscriptions(&self) -> &[String] {
&self.subscriptions
}
}
impl<K: Clone + Send + 'static> MassiveLive<K> {
pub async fn start(
self,
) -> Result<impl Stream<Item = Result<MarketEvent<K, DataKind>, MassiveError>>, MassiveError>
{
debug!(url = %self.ws_url, market = ?self.market, "Connecting to Massive WebSocket");
let (ws_stream, _response) = connect_async(&self.ws_url).await?;
let (mut write, mut read) = ws_stream.split();
Self::consume_connected_frame(&mut read).await?;
debug!("Sending auth message");
let auth_msg = AuthMessage {
action: "auth",
params: &self.api_key,
};
#[allow(clippy::expect_used)]
let auth_json = serde_json::to_string(&auth_msg).expect("infallible");
write.send(Message::Text(auth_json.into())).await?;
let auth_response = tokio::time::timeout(Duration::from_secs(10), read.next())
.await
.map_err(|_| MassiveError::Auth {
message: "Auth response timeout".into(),
})?
.ok_or_else(|| MassiveError::Disconnected {
reason: "Connection closed before auth response".into(),
})??;
Self::verify_auth_response(&auth_response)?;
debug!("Authentication successful");
if self.subscriptions.is_empty() {
warn!("start() called with no subscriptions; stream will receive no market data");
}
if !self.subscriptions.is_empty() {
let params = self.subscriptions.join(",");
debug!(channels = %params, "Subscribing to channels");
let sub_msg = SubscribeMessage {
action: "subscribe",
params,
};
#[allow(clippy::expect_used)]
let sub_json = serde_json::to_string(&sub_msg).expect("infallible");
write.send(Message::Text(sub_json.into())).await?;
let sub_response = tokio::time::timeout(Duration::from_secs(10), read.next())
.await
.map_err(|_| MassiveError::Disconnected {
reason: "Subscription response timeout".into(),
})?
.ok_or_else(|| MassiveError::Disconnected {
reason: "Connection closed before subscription response".into(),
})??;
Self::verify_subscription_response(&sub_response)?;
debug!("Subscription successful");
}
let ws_stream = write
.reunite(read)
.map_err(|e| MassiveError::Disconnected {
reason: format!("Failed to reunite WebSocket stream: {}", e),
})?;
Ok(Self::create_event_stream(
ws_stream,
self.instruments,
self.exchange,
))
}
async fn consume_connected_frame(
read: &mut futures::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) -> Result<(), MassiveError> {
let msg = tokio::time::timeout(Duration::from_secs(10), read.next())
.await
.map_err(|_| MassiveError::Disconnected {
reason: "Timeout waiting for connected frame".into(),
})?
.ok_or_else(|| MassiveError::Disconnected {
reason: "Connection closed before connected frame".into(),
})??;
match &msg {
Message::Text(text) => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text)
&& let Some(arr) = parsed.as_array()
{
for m in arr {
if m.get("ev").and_then(|v| v.as_str()) == Some("status")
&& m.get("status").and_then(|v| v.as_str()) == Some("connected")
{
debug!("Received connected frame");
return Ok(());
}
}
}
debug!(msg = %text, "Unexpected initial frame (expected connected status)");
Ok(())
}
Message::Close(frame) => Err(MassiveError::Disconnected {
reason: frame
.as_ref()
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Connection closed".into()),
}),
_ => {
debug!(?msg, "Unexpected message type before auth");
Ok(())
}
}
}
fn verify_auth_response(msg: &Message) -> Result<(), MassiveError> {
match msg {
Message::Text(text) => {
let parsed: serde_json::Value =
serde_json::from_str(text).map_err(|e| MassiveError::Auth {
message: format!("Failed to parse auth response: {}", e),
})?;
let messages = parsed.as_array().ok_or_else(|| MassiveError::Auth {
message: "Auth response is not an array".into(),
})?;
for msg in messages {
if msg.get("ev").and_then(|v| v.as_str()) == Some("status") {
let status = msg.get("status").and_then(|v| v.as_str());
match status {
Some("auth_success") => return Ok(()),
Some("auth_failed") => {
let message = msg
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("Authentication failed");
return Err(MassiveError::Auth {
message: message.into(),
});
}
_ => {}
}
}
}
Err(MassiveError::Auth {
message: format!("Unexpected auth response: {}", text),
})
}
Message::Close(frame) => Err(MassiveError::Disconnected {
reason: frame
.as_ref()
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Connection closed".into()),
}),
_ => Err(MassiveError::Auth {
message: format!("Unexpected message type during auth: {:?}", msg),
}),
}
}
fn verify_subscription_response(msg: &Message) -> Result<(), MassiveError> {
match msg {
Message::Text(text) => {
let parsed: serde_json::Value =
serde_json::from_str(text).map_err(|e| MassiveError::Disconnected {
reason: format!("Failed to parse subscription response: {}", e),
})?;
let messages = parsed
.as_array()
.ok_or_else(|| MassiveError::Disconnected {
reason: "Subscription response is not an array".into(),
})?;
for msg in messages {
if msg.get("ev").and_then(|v| v.as_str()) == Some("status") {
let status = msg.get("status").and_then(|v| v.as_str());
if status == Some("success") {
return Ok(());
}
}
}
Err(MassiveError::Disconnected {
reason: format!("Subscription failed: {}", text),
})
}
Message::Close(frame) => Err(MassiveError::Disconnected {
reason: frame
.as_ref()
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Connection closed".into()),
}),
_ => {
Err(MassiveError::Disconnected {
reason: format!("Unexpected message type during subscription: {:?}", msg),
})
}
}
}
fn create_event_stream(
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
instruments: HashMap<String, K>,
exchange: ExchangeId,
) -> impl Stream<Item = Result<MarketEvent<K, DataKind>, MassiveError>> {
let (mut write, mut read) = ws_stream.split();
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut ping_interval = interval_at(Instant::now() + PING_INTERVAL, PING_INTERVAL);
let mut ping_sent_at: Option<Instant> = None;
loop {
tokio::select! {
_ = ping_interval.tick() => {
if let Some(sent_at) = ping_sent_at
&& sent_at.elapsed() > PONG_TIMEOUT
{
let _ = tx.send(Err(MassiveError::Disconnected {
reason: "Pong timeout".into(),
}));
break;
}
trace!("Sending ping");
if let Err(e) = write.send(Message::Ping(vec![].into())).await {
let _ = tx.send(Err(MassiveError::Disconnected {
reason: format!("Failed to send ping: {}", e),
}));
break;
}
ping_sent_at = Some(Instant::now());
}
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let time_received = Utc::now();
match parse_ws_message(&text) {
Ok(messages) => {
for ws_msg in messages {
if let Some(event) = Self::ws_message_to_event(
ws_msg,
&instruments,
exchange,
time_received,
) && tx.send(Ok(event)).is_err() {
return;
}
}
}
Err(e) => {
warn!(error = %e, "Failed to parse WebSocket message");
}
}
}
Some(Ok(Message::Pong(_))) => {
trace!("Received pong");
ping_sent_at = None;
}
Some(Ok(Message::Ping(data))) => {
trace!("Received ping, sending pong");
if let Err(e) = write.send(Message::Pong(data)).await {
let _ = tx.send(Err(MassiveError::Disconnected {
reason: format!("Failed to send pong: {}", e),
}));
break;
}
}
Some(Ok(Message::Close(frame))) => {
let reason = frame
.as_ref()
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Connection closed".into());
let _ = tx.send(Err(MassiveError::Disconnected { reason }));
break;
}
Some(Ok(Message::Binary(_))) => {
trace!("Received unexpected binary message");
}
Some(Ok(Message::Frame(_))) => {
}
Some(Err(e)) => {
let _ = tx.send(Err(MassiveError::Disconnected {
reason: e.to_string(),
}));
break;
}
None => {
let _ = tx.send(Err(MassiveError::Disconnected {
reason: "WebSocket stream ended".into(),
}));
break;
}
}
}
}
}
});
futures::stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|item| (item, rx))
})
}
fn ws_message_to_event(
msg: WsMessage,
instruments: &HashMap<String, K>,
exchange: ExchangeId,
time_received: DateTime<Utc>,
) -> Option<MarketEvent<K, DataKind>> {
match msg {
WsMessage::TradeStocks(trade) | WsMessage::TradeCrypto(trade) => {
let instrument = instruments.get(&trade.symbol)?.clone();
let (time_exchange, public_trade) = trade.into_public_trade();
Some(MarketEvent {
time_exchange,
time_received,
exchange,
instrument,
kind: DataKind::Trade(public_trade),
})
}
WsMessage::QuoteStocks(quote)
| WsMessage::QuoteCrypto(quote)
| WsMessage::QuoteForex(quote) => {
let instrument = instruments.get("e.symbol)?.clone();
let (time_exchange, l1) = quote.into_order_book_l1();
Some(MarketEvent {
time_exchange,
time_received,
exchange,
instrument,
kind: DataKind::OrderBookL1(l1),
})
}
WsMessage::AggSecondStocks(agg)
| WsMessage::AggMinuteStocks(agg)
| WsMessage::AggSecondCrypto(agg)
| WsMessage::AggMinuteCrypto(agg)
| WsMessage::AggSecondForex(agg)
| WsMessage::AggMinuteForex(agg) => {
let instrument = instruments.get(&agg.symbol)?.clone();
let (time_exchange, candle) = agg.into_candle();
Some(MarketEvent {
time_exchange,
time_received,
exchange,
instrument,
kind: DataKind::Candle(candle),
})
}
WsMessage::Status(_) => {
None
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_market_ws_url() {
assert_eq!(Market::Stocks.ws_url(), "wss://socket.polygon.io/stocks");
assert_eq!(Market::Crypto.ws_url(), "wss://socket.polygon.io/crypto");
assert_eq!(Market::Forex.ws_url(), "wss://socket.polygon.io/forex");
assert_eq!(Market::Options.ws_url(), "wss://socket.polygon.io/options");
}
#[test]
fn test_channel_type_prefix_stocks() {
assert_eq!(ChannelType::Trade.prefix(Market::Stocks), Some("T"));
assert_eq!(ChannelType::Quote.prefix(Market::Stocks), Some("Q"));
assert_eq!(
ChannelType::AggregateSecond.prefix(Market::Stocks),
Some("A")
);
assert_eq!(
ChannelType::AggregateMinute.prefix(Market::Stocks),
Some("AM")
);
}
#[test]
fn test_channel_type_prefix_crypto() {
assert_eq!(ChannelType::Trade.prefix(Market::Crypto), Some("XT"));
assert_eq!(ChannelType::Quote.prefix(Market::Crypto), Some("XQ"));
assert_eq!(
ChannelType::AggregateSecond.prefix(Market::Crypto),
Some("XA")
);
assert_eq!(
ChannelType::AggregateMinute.prefix(Market::Crypto),
Some("XAM")
);
}
#[test]
fn test_channel_type_prefix_forex() {
assert_eq!(ChannelType::Trade.prefix(Market::Forex), None);
assert_eq!(ChannelType::Quote.prefix(Market::Forex), Some("C"));
assert_eq!(
ChannelType::AggregateSecond.prefix(Market::Forex),
Some("CA")
);
assert_eq!(
ChannelType::AggregateMinute.prefix(Market::Forex),
Some("CAM")
);
}
#[test]
fn test_channel_for() {
assert_eq!(
ChannelType::Trade.channel_for(Market::Crypto, "BTC-USD"),
Some("XT.BTC-USD".to_string())
);
assert_eq!(
ChannelType::Quote.channel_for(Market::Crypto, "BTC-USD"),
Some("XQ.BTC-USD".to_string())
);
assert_eq!(
ChannelType::AggregateMinute.channel_for(Market::Crypto, "BTC-USD"),
Some("XAM.BTC-USD".to_string())
);
assert_eq!(
ChannelType::Trade.channel_for(Market::Stocks, "AAPL"),
Some("T.AAPL".to_string())
);
assert_eq!(
ChannelType::Quote.channel_for(Market::Stocks, "AAPL"),
Some("Q.AAPL".to_string())
);
assert_eq!(
ChannelType::AggregateMinute.channel_for(Market::Stocks, "AAPL"),
Some("AM.AAPL".to_string())
);
assert_eq!(
ChannelType::Trade.channel_for(Market::Forex, "EUR-USD"),
None
);
assert_eq!(
ChannelType::Quote.channel_for(Market::Forex, "EUR-USD"),
Some("C.EUR-USD".to_string())
);
}
#[test]
fn test_subscribe_accumulates() {
let instruments: HashMap<String, String> = HashMap::new();
let mut client =
MassiveLive::new("test_key", Market::Crypto, ExchangeId::Massive, instruments);
client.subscribe(&["BTC-USD", "ETH-USD"], ChannelType::Trade);
client.subscribe(&["BTC-USD"], ChannelType::Quote);
assert_eq!(client.subscriptions().len(), 3);
assert!(client.subscriptions().contains(&"XT.BTC-USD".to_string()));
assert!(client.subscriptions().contains(&"XT.ETH-USD".to_string()));
assert!(client.subscriptions().contains(&"XQ.BTC-USD".to_string()));
}
#[test]
fn test_subscribe_skips_unsupported() {
let instruments: HashMap<String, String> = HashMap::new();
let mut client =
MassiveLive::new("test_key", Market::Forex, ExchangeId::Massive, instruments);
client.subscribe(&["EUR-USD"], ChannelType::Trade);
client.subscribe(&["EUR-USD"], ChannelType::Quote);
assert_eq!(client.subscriptions().len(), 1);
assert!(client.subscriptions().contains(&"C.EUR-USD".to_string()));
}
#[test]
fn test_from_env_missing() {
temp_env::with_var_unset(ENV_API_KEY, || {
let result: Result<MassiveLive<String>, _> =
MassiveLive::from_env(Market::Crypto, ExchangeId::Massive, HashMap::new());
assert!(matches!(result, Err(MassiveError::EnvVar { .. })));
});
}
#[test]
fn test_with_ws_url_override() {
let instruments: HashMap<String, String> = HashMap::new();
let client = MassiveLive::new("test_key", Market::Crypto, ExchangeId::Massive, instruments)
.with_ws_url("wss://test.example.com/crypto");
assert_eq!(client.ws_url, "wss://test.example.com/crypto");
}
#[test]
fn test_verify_auth_response_success() {
let msg = Message::Text(
r#"[{"ev":"status","status":"auth_success","message":"authenticated"}]"#.into(),
);
assert!(MassiveLive::<String>::verify_auth_response(&msg).is_ok());
}
#[test]
fn test_verify_auth_response_failed() {
let msg = Message::Text(
r#"[{"ev":"status","status":"auth_failed","message":"invalid api key"}]"#.into(),
);
let result = MassiveLive::<String>::verify_auth_response(&msg);
assert!(matches!(result, Err(MassiveError::Auth { .. })));
}
#[test]
fn test_verify_auth_response_close_frame() {
let msg = Message::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
reason: "server shutdown".into(),
}));
let result = MassiveLive::<String>::verify_auth_response(&msg);
assert!(matches!(result, Err(MassiveError::Disconnected { .. })));
}
#[test]
fn test_verify_auth_response_connected_status_is_unexpected() {
let msg = Message::Text(
r#"[{"ev":"status","status":"connected","message":"Connected Successfully"}]"#.into(),
);
let result = MassiveLive::<String>::verify_auth_response(&msg);
assert!(matches!(result, Err(MassiveError::Auth { .. })));
}
#[test]
fn test_verify_subscription_response_success() {
let msg = Message::Text(
r#"[{"ev":"status","status":"success","message":"subscribed to: XT.BTC-USD"}]"#.into(),
);
assert!(MassiveLive::<String>::verify_subscription_response(&msg).is_ok());
}
#[test]
fn test_verify_subscription_response_failure() {
let msg = Message::Text(
r#"[{"ev":"status","status":"error","message":"invalid symbol"}]"#.into(),
);
let result = MassiveLive::<String>::verify_subscription_response(&msg);
assert!(matches!(result, Err(MassiveError::Disconnected { .. })));
}
#[test]
fn test_verify_subscription_response_close_frame() {
let msg = Message::Close(None);
let result = MassiveLive::<String>::verify_subscription_response(&msg);
assert!(matches!(result, Err(MassiveError::Disconnected { .. })));
}
}