use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use futures_util::{SinkExt, Stream, StreamExt};
use serde_json::{json, Value};
use tokio::sync::{broadcast, RwLock};
use tokio::time::timeout;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::core::types::*;
use crate::core::traits::WebSocketConnector;
use super::auth::AlpacaAuth;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AlpacaChannel {
Bars(Vec<String>),
Quotes(Vec<String>),
Trades(Vec<String>),
Statuses(Vec<String>),
Lulds(Vec<String>),
TradeUpdates,
News(Vec<String>),
}
impl AlpacaChannel {
pub fn to_key_and_symbols(&self) -> (&'static str, Vec<String>) {
match self {
AlpacaChannel::Bars(s) => ("bars", s.clone()),
AlpacaChannel::Quotes(s) => ("quotes", s.clone()),
AlpacaChannel::Trades(s) => ("trades", s.clone()),
AlpacaChannel::Statuses(s) => ("statuses", s.clone()),
AlpacaChannel::Lulds(s) => ("lulds", s.clone()),
AlpacaChannel::TradeUpdates => ("trade_updates", vec!["*".to_string()]),
AlpacaChannel::News(s) => ("news", s.clone()),
}
}
pub fn bars_for(symbol: impl Into<String>) -> Self {
AlpacaChannel::Bars(vec![symbol.into()])
}
pub fn quotes_for(symbol: impl Into<String>) -> Self {
AlpacaChannel::Quotes(vec![symbol.into()])
}
pub fn trades_for(symbol: impl Into<String>) -> Self {
AlpacaChannel::Trades(vec![symbol.into()])
}
}
pub struct AlpacaWebSocket {
auth: AlpacaAuth,
ws_url: String,
status: Arc<RwLock<ConnectionStatus>>,
subscriptions: Arc<RwLock<Vec<SubscriptionRequest>>>,
broadcast_tx: Arc<std::sync::Mutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
}
impl AlpacaWebSocket {
pub fn new(auth: AlpacaAuth) -> Self {
Self {
auth,
ws_url: "wss://stream.data.alpaca.markets/v2/iex".to_string(),
status: Arc::new(RwLock::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(RwLock::new(Vec::new())),
broadcast_tx: Arc::new(std::sync::Mutex::new(None)),
}
}
pub fn live(auth: AlpacaAuth) -> Self {
let mut ws = Self::new(auth);
ws.ws_url = "wss://stream.data.alpaca.markets/v2/sip".to_string();
ws
}
pub fn test(auth: AlpacaAuth) -> Self {
let mut ws = Self::new(auth);
ws.ws_url = "wss://stream.data.alpaca.markets/v2/test".to_string();
ws
}
pub fn trading(auth: AlpacaAuth) -> Self {
let mut ws = Self::new(auth);
ws.ws_url = "wss://api.alpaca.markets/stream".to_string();
ws
}
pub fn crypto(auth: AlpacaAuth) -> Self {
let mut ws = Self::new(auth);
ws.ws_url = "wss://stream.data.alpaca.markets/v1beta3/crypto/us".to_string();
ws
}
pub fn build_subscribe_message(channels: &[AlpacaChannel]) -> serde_json::Value {
let mut msg = serde_json::json!({ "action": "subscribe" });
for channel in channels {
let (key, symbols) = channel.to_key_and_symbols();
msg[key] = serde_json::Value::Array(
symbols.iter().map(|s| serde_json::Value::String(s.clone())).collect(),
);
}
msg
}
pub fn build_unsubscribe_message(channels: &[AlpacaChannel]) -> serde_json::Value {
let mut msg = serde_json::json!({ "action": "unsubscribe" });
for channel in channels {
let (key, symbols) = channel.to_key_and_symbols();
msg[key] = serde_json::Value::Array(
symbols.iter().map(|s| serde_json::Value::String(s.clone())).collect(),
);
}
msg
}
async fn do_connect(&self) -> WebSocketResult<()> {
let (ws_stream, _response) = timeout(Duration::from_secs(15), connect_async(&self.ws_url))
.await
.map_err(|_| WebSocketError::Timeout)?
.map_err(|e| WebSocketError::ConnectionError(format!("WS connect failed: {}", e)))?;
let (mut write, mut read) = ws_stream.split();
Self::wait_for_message(&mut read, "connected", Duration::from_secs(10)).await?;
let key = self.auth.api_key_id.as_deref().unwrap_or_default();
let secret = self.auth.api_secret_key.as_deref().unwrap_or_default();
let auth_msg = json!({
"action": "auth",
"key": key,
"secret": secret
});
write
.send(Message::Text(auth_msg.to_string()))
.await
.map_err(|e| WebSocketError::Auth(format!("Failed to send auth: {}", e)))?;
Self::wait_for_message(&mut read, "authenticated", Duration::from_secs(10)).await?;
let (tx, _) = broadcast::channel::<WebSocketResult<StreamEvent>>(512);
{
let mut guard = self.broadcast_tx.lock().unwrap();
*guard = Some(tx.clone());
}
let broadcast_tx = self.broadcast_tx.clone();
let status = self.status.clone();
drop(write);
tokio::spawn(async move {
while let Some(msg_result) = read.next().await {
match msg_result {
Ok(Message::Text(text)) => {
if let Ok(value) = serde_json::from_str::<Value>(&text) {
let items: Vec<Value> = if let Some(arr) = value.as_array() {
arr.clone()
} else {
vec![value]
};
for raw in &items {
if let Some(event) = Self::parse_event(raw) {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Ok(event));
}
}
}
}
}
Ok(Message::Close(_)) | Err(_) => {
*status.write().await = ConnectionStatus::Disconnected;
break;
}
_ => {}
}
}
});
Ok(())
}
async fn wait_for_message<S>(
read: &mut S,
expected_msg: &str,
dur: Duration,
) -> WebSocketResult<()>
where
S: Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>> + Unpin,
{
let result = timeout(dur, async {
while let Some(msg_result) = read.next().await {
match msg_result {
Ok(Message::Text(text)) => {
if let Ok(value) = serde_json::from_str::<Value>(&text) {
let items: Vec<&Value> = if let Some(arr) = value.as_array() {
arr.iter().collect()
} else {
vec![&value]
};
for item in items {
let t = item.get("T").and_then(|v| v.as_str()).unwrap_or_default();
let msg = item.get("msg").and_then(|v| v.as_str()).unwrap_or_default();
if t == "error" {
return Err(WebSocketError::Auth(format!(
"Alpaca WS error: {}",
item.get("msg").and_then(|v| v.as_str()).unwrap_or("unknown")
)));
}
if t == "success" && msg == expected_msg {
return Ok(());
}
}
}
}
Ok(Message::Close(_)) => {
return Err(WebSocketError::ConnectionError(
"Connection closed before receiving expected message".to_string(),
));
}
Err(e) => {
return Err(WebSocketError::ConnectionError(format!(
"WS read error: {}", e
)));
}
_ => {}
}
}
Err(WebSocketError::ConnectionError(
"WebSocket stream ended unexpectedly".to_string(),
))
})
.await;
match result {
Ok(inner) => inner,
Err(_) => Err(WebSocketError::Timeout),
}
}
fn parse_event(value: &Value) -> Option<StreamEvent> {
let msg_type = value.get("T").and_then(|v| v.as_str())?;
match msg_type {
"t" => {
let symbol = value.get("S").and_then(|v| v.as_str()).unwrap_or_default();
let price = value.get("p").and_then(|v| v.as_f64()).unwrap_or_default();
let size = value.get("s").and_then(|v| v.as_f64()).unwrap_or_default();
let taker_side = value.get("tks").and_then(|v| v.as_str()).unwrap_or("B");
let trade = PublicTrade {
id: value
.get("i")
.and_then(|v| v.as_u64())
.map(|n| n.to_string())
.unwrap_or_default(),
symbol: symbol.to_string(),
price,
quantity: size,
side: if taker_side == "S" { TradeSide::Sell } else { TradeSide::Buy },
timestamp: crate::core::utils::timestamp_millis() as i64,
};
Some(StreamEvent::Trade(trade))
}
"q" => {
let symbol = value.get("S").and_then(|v| v.as_str()).unwrap_or_default();
let bid_price = value.get("bp").and_then(|v| v.as_f64()).unwrap_or_default();
let bid_size = value.get("bs").and_then(|v| v.as_f64()).unwrap_or_default();
let ask_price = value.get("ap").and_then(|v| v.as_f64()).unwrap_or_default();
let ask_size = value.get("as").and_then(|v| v.as_f64()).unwrap_or_default();
let ticker = Ticker {
symbol: symbol.to_string(),
last_price: (bid_price + ask_price) / 2.0,
bid_price: Some(bid_price),
ask_price: Some(ask_price),
high_24h: None,
low_24h: None,
volume_24h: Some(bid_size + ask_size),
quote_volume_24h: None,
price_change_24h: None,
price_change_percent_24h: None,
timestamp: crate::core::utils::timestamp_millis() as i64,
};
Some(StreamEvent::Ticker(ticker))
}
"b" => {
let symbol = value.get("S").and_then(|v| v.as_str()).unwrap_or_default();
let open = value.get("o").and_then(|v| v.as_f64()).unwrap_or_default();
let high = value.get("h").and_then(|v| v.as_f64()).unwrap_or_default();
let low = value.get("l").and_then(|v| v.as_f64()).unwrap_or_default();
let close = value.get("c").and_then(|v| v.as_f64()).unwrap_or_default();
let volume = value.get("v").and_then(|v| v.as_f64()).unwrap_or_default();
let _ = symbol; let bar = Kline {
open,
high,
low,
close,
volume,
quote_volume: None,
open_time: crate::core::utils::timestamp_millis() as i64,
close_time: Some(crate::core::utils::timestamp_millis() as i64),
trades: None,
};
Some(StreamEvent::Kline(bar))
}
"s" | "l" | "tu" => None,
_ => None,
}
}
}
#[async_trait]
impl WebSocketConnector for AlpacaWebSocket {
async fn connect(&self, _account_type: AccountType) -> WebSocketResult<()> {
*self.status.write().await = ConnectionStatus::Connecting;
match self.do_connect().await {
Ok(()) => {
*self.status.write().await = ConnectionStatus::Connected;
Ok(())
}
Err(e) => {
*self.status.write().await = ConnectionStatus::Disconnected;
Err(e)
}
}
}
async fn disconnect(&self) -> WebSocketResult<()> {
*self.status.write().await = ConnectionStatus::Disconnected;
let _ = self.broadcast_tx.lock().unwrap().take();
self.subscriptions.write().await.clear();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
match self.status.try_read() {
Ok(status) => *status,
Err(_) => ConnectionStatus::Disconnected,
}
}
async fn subscribe(&self, request: SubscriptionRequest) -> WebSocketResult<()> {
let status = self.status.read().await;
if *status != ConnectionStatus::Connected {
return Err(WebSocketError::NotConnected);
}
drop(status);
self.subscriptions.write().await.push(request);
Ok(())
}
async fn unsubscribe(&self, request: SubscriptionRequest) -> WebSocketResult<()> {
self.subscriptions.write().await.retain(|sub| sub != &request);
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let guard = self.broadcast_tx.lock().unwrap();
if let Some(tx) = guard.as_ref() {
let rx = tx.subscribe();
Box::pin(
tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|result| async move {
match result {
Ok(event) => Some(event),
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(_)) => {
Some(Err(WebSocketError::ConnectionError(
"Event stream lagged".to_string(),
)))
}
}
}),
)
} else {
Box::pin(futures_util::stream::empty())
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.subscriptions.try_read() {
Ok(subs) => subs.clone(),
Err(_) => Vec::new(),
}
}
}
impl AlpacaWebSocket {
pub async fn subscribe_news(&self, symbols: Vec<String>) -> WebSocketResult<()> {
let msg = Self::build_subscribe_message(&[AlpacaChannel::News(symbols)]);
self.subscriptions
.write()
.await
.push(SubscriptionRequest::ticker(Symbol::new("NEWS", "USD")));
let _ = msg; Ok(())
}
pub async fn subscribe_status(&self, symbols: Vec<String>) -> WebSocketResult<()> {
let msg = Self::build_subscribe_message(&[AlpacaChannel::Statuses(symbols)]);
let _ = msg;
Ok(())
}
pub async fn subscribe_luld(&self, symbols: Vec<String>) -> WebSocketResult<()> {
let msg = Self::build_subscribe_message(&[AlpacaChannel::Lulds(symbols)]);
let _ = msg;
Ok(())
}
pub async fn subscribe_trade_updates(&self) -> WebSocketResult<()> {
let _msg = serde_json::json!({
"action": "listen",
"data": { "streams": ["trade_updates"] }
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_websocket() {
let auth = AlpacaAuth::new("test_key", "test_secret");
let ws = AlpacaWebSocket::new(auth);
assert_eq!(ws.connection_status(), ConnectionStatus::Disconnected);
assert_eq!(ws.active_subscriptions().len(), 0);
}
#[tokio::test]
async fn test_subscribe_before_connect() {
let auth = AlpacaAuth::new("test_key", "test_secret");
let ws = AlpacaWebSocket::new(auth);
let request = SubscriptionRequest::ticker(Symbol::new("AAPL", "USD"));
let result = ws.subscribe(request).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WebSocketError::NotConnected));
}
#[test]
fn test_parse_trade_event() {
let raw = serde_json::json!({
"T": "t",
"S": "AAPL",
"p": 185.50,
"s": 100.0,
"tks": "B",
"i": 12345678
});
let event = AlpacaWebSocket::parse_event(&raw);
assert!(event.is_some());
if let Some(StreamEvent::Trade(trade)) = event {
assert_eq!(trade.symbol, "AAPL");
assert_eq!(trade.price, 185.50);
assert_eq!(trade.quantity, 100.0);
assert!(matches!(trade.side, TradeSide::Buy));
}
}
}