alpaca_websocket/
client.rs

1//! WebSocket client for Alpaca streaming data.
2
3#![allow(missing_docs)]
4
5use crate::{messages::*, streams::*};
6use alpaca_base::types::Quote;
7use alpaca_base::{AlpacaError, Result, auth::Credentials, types::Environment};
8use futures_util::{
9    sink::SinkExt,
10    stream::{SplitSink, SplitStream, StreamExt},
11};
12use serde_json;
13use std::sync::Once;
14use std::time::Duration;
15use tokio::{
16    net::TcpStream,
17    sync::mpsc,
18    time::{interval, sleep},
19};
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
21use tracing::{debug, error, info, warn};
22
23static CRYPTO_PROVIDER_INIT: Once = Once::new();
24
25/// Initialize the rustls crypto provider (ring).
26/// This must be called before any TLS connections are made.
27fn init_crypto_provider() {
28    CRYPTO_PROVIDER_INIT.call_once(|| {
29        let _ = rustls::crypto::ring::default_provider().install_default();
30    });
31}
32
33type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
34type WsSink = SplitSink<WsStream, Message>;
35type WsReceiver = SplitStream<WsStream>;
36
37/// WebSocket client for Alpaca API
38#[derive(Debug)]
39pub struct AlpacaWebSocketClient {
40    credentials: Credentials,
41    environment: Environment,
42    url: String,
43}
44
45/// Data feed type for market data
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum DataFeed {
48    /// IEX exchange data (free, delayed)
49    Iex,
50    /// SIP data (paid, real-time)
51    Sip,
52    /// Crypto data
53    Crypto,
54}
55
56impl AlpacaWebSocketClient {
57    /// Create a new WebSocket client for stocks
58    pub fn new(credentials: Credentials, environment: Environment) -> Self {
59        let url = match environment {
60            Environment::Paper => "wss://stream.data.alpaca.markets/v2/iex",
61            Environment::Live => "wss://stream.data.alpaca.markets/v2/sip",
62        };
63
64        Self {
65            credentials,
66            environment,
67            url: url.to_string(),
68        }
69    }
70
71    /// Create a new client from environment variables
72    pub fn from_env(environment: Environment) -> Result<Self> {
73        let credentials = Credentials::from_env()?;
74        Ok(Self::new(credentials, environment))
75    }
76
77    /// Create a WebSocket client for a specific data feed
78    pub fn with_feed(credentials: Credentials, environment: Environment, feed: DataFeed) -> Self {
79        let url = match feed {
80            DataFeed::Iex => "wss://stream.data.alpaca.markets/v2/iex",
81            DataFeed::Sip => "wss://stream.data.alpaca.markets/v2/sip",
82            DataFeed::Crypto => "wss://stream.data.alpaca.markets/v1beta3/crypto/us",
83        };
84
85        Self {
86            credentials,
87            environment,
88            url: url.to_string(),
89        }
90    }
91
92    /// Create a crypto WebSocket client
93    pub fn crypto(credentials: Credentials, environment: Environment) -> Self {
94        Self::with_feed(credentials, environment, DataFeed::Crypto)
95    }
96
97    /// Create a crypto client from environment variables
98    pub fn crypto_from_env(environment: Environment) -> Result<Self> {
99        let credentials = Credentials::from_env()?;
100        Ok(Self::crypto(credentials, environment))
101    }
102
103    /// Create a trading WebSocket client
104    pub fn trading(credentials: Credentials, environment: Environment) -> Self {
105        let url = environment.websocket_url();
106        Self {
107            credentials,
108            environment,
109            url: url.to_string(),
110        }
111    }
112
113    /// Connect to the WebSocket and return a stream of messages
114    pub async fn connect(&self) -> Result<AlpacaStream> {
115        // Initialize crypto provider for TLS
116        init_crypto_provider();
117
118        let (sender, receiver) = mpsc::unbounded_channel();
119        info!("Connecting to WebSocket: {}", self.url);
120        let (ws_stream, _) = connect_async(&self.url).await?;
121        let (mut sink, mut stream) = ws_stream.split();
122
123        // Authenticate
124        self.authenticate(&mut sink).await?;
125
126        // Spawn message handler
127        let credentials = self.credentials.clone();
128        tokio::spawn(async move {
129            Self::handle_messages(&mut stream, sender, credentials).await;
130        });
131
132        Ok(AlpacaStream::new(receiver))
133    }
134
135    /// Connect with automatic reconnection
136    pub async fn connect_with_reconnect(&self, max_retries: u32) -> Result<AlpacaStream> {
137        let mut attempts = 0;
138        let mut delay = Duration::from_secs(1);
139
140        loop {
141            match self.connect().await {
142                Ok(stream) => {
143                    info!("Successfully connected to WebSocket");
144                    return Ok(stream);
145                }
146                Err(e) => {
147                    attempts += 1;
148                    if attempts >= max_retries {
149                        error!("Failed to connect after {} attempts", attempts);
150                        return Err(AlpacaError::WebSocket(format!(
151                            "Connection failed after {} attempts: {}",
152                            attempts, e
153                        )));
154                    }
155
156                    warn!(
157                        "Connection attempt {} failed: {}. Retrying in {:?}",
158                        attempts, e, delay
159                    );
160                    sleep(delay).await;
161                    delay = std::cmp::min(delay * 2, Duration::from_secs(60));
162                }
163            }
164        }
165    }
166
167    /// Subscribe to market data
168    pub async fn subscribe_market_data(
169        &self,
170        subscription: SubscribeMessage,
171    ) -> Result<MarketDataStream> {
172        // Initialize crypto provider for TLS
173        init_crypto_provider();
174
175        let (sender, receiver) = mpsc::unbounded_channel();
176        info!("Connecting to WebSocket: {}", self.url);
177        let (ws_stream, _) = connect_async(&self.url).await?;
178        let (mut sink, mut stream) = ws_stream.split();
179
180        // Wait for "connected" message from server
181        if let Some(Ok(Message::Text(text))) = stream.next().await {
182            debug!("Server: {}", text);
183        }
184
185        // Authenticate
186        self.authenticate(&mut sink).await?;
187
188        // Wait for authentication response
189        if let Some(Ok(Message::Text(text))) = stream.next().await {
190            debug!("Auth response: {}", text);
191        }
192
193        // Send subscription message - Alpaca uses {"action": "subscribe", ...}
194        let sub_msg = serde_json::json!({
195            "action": "subscribe",
196            "trades": subscription.trades.unwrap_or_default(),
197            "quotes": subscription.quotes.unwrap_or_default(),
198            "bars": subscription.bars.unwrap_or_default()
199        });
200        let sub_json = serde_json::to_string(&sub_msg)?;
201        debug!("Sending subscription: {}", sub_json);
202        sink.send(Message::Text(sub_json.into())).await?;
203
204        // Wait for subscription confirmation
205        if let Some(Ok(Message::Text(text))) = stream.next().await {
206            debug!("Subscription response: {}", text);
207        }
208
209        // Spawn message handler that converts to MarketDataUpdate
210        let credentials = self.credentials.clone();
211        tokio::spawn(async move {
212            let _ = credentials; // Keep credentials alive if needed
213            debug!("Handler started, waiting for messages...");
214            while let Some(message) = stream.next().await {
215                match message {
216                    Ok(Message::Text(text)) => {
217                        // Parse array of messages (Alpaca sends arrays)
218                        if let Ok(messages) = serde_json::from_str::<Vec<serde_json::Value>>(&text)
219                        {
220                            for msg_value in messages {
221                                if let Some(msg_type) = msg_value.get("T").and_then(|t| t.as_str())
222                                {
223                                    let update = match msg_type {
224                                        "t" => {
225                                            // Trade message
226                                            if let Ok(trade_msg) =
227                                                serde_json::from_value::<TradeMessage>(
228                                                    msg_value.clone(),
229                                                )
230                                            {
231                                                Some(MarketDataUpdate::Trade {
232                                                    symbol: trade_msg.symbol.clone(),
233                                                    trade: trade_msg.into(),
234                                                })
235                                            } else {
236                                                None
237                                            }
238                                        }
239                                        "q" => {
240                                            // Quote message - try crypto format first
241                                            if let Ok(quote_msg) =
242                                                serde_json::from_value::<CryptoQuoteMessage>(
243                                                    msg_value.clone(),
244                                                )
245                                            {
246                                                Some(MarketDataUpdate::Quote {
247                                                    symbol: quote_msg.symbol.clone(),
248                                                    quote: Quote {
249                                                        timestamp: quote_msg.timestamp,
250                                                        timeframe: "real-time".to_string(),
251                                                        bid_price: quote_msg.bid_price,
252                                                        bid_size: quote_msg.bid_size as u32,
253                                                        ask_price: quote_msg.ask_price,
254                                                        ask_size: quote_msg.ask_size as u32,
255                                                        bid_exchange: String::new(),
256                                                        ask_exchange: String::new(),
257                                                    },
258                                                })
259                                            } else if let Ok(quote_msg) =
260                                                serde_json::from_value::<QuoteMessage>(
261                                                    msg_value.clone(),
262                                                )
263                                            {
264                                                Some(MarketDataUpdate::Quote {
265                                                    symbol: quote_msg.symbol.clone(),
266                                                    quote: quote_msg.into(),
267                                                })
268                                            } else {
269                                                None
270                                            }
271                                        }
272                                        "b" => {
273                                            // Bar message
274                                            if let Ok(bar_msg) = serde_json::from_value::<BarMessage>(
275                                                msg_value.clone(),
276                                            ) {
277                                                Some(MarketDataUpdate::Bar {
278                                                    symbol: bar_msg.symbol.clone(),
279                                                    bar: bar_msg.into(),
280                                                })
281                                            } else {
282                                                None
283                                            }
284                                        }
285                                        _ => {
286                                            debug!("Ignoring message type: {}", msg_type);
287                                            None
288                                        }
289                                    };
290
291                                    if let Some(u) = update
292                                        && sender.send(u).is_err()
293                                    {
294                                        debug!("Channel closed");
295                                        break;
296                                    }
297                                }
298                            }
299                        }
300                    }
301                    Ok(Message::Close(_)) => {
302                        info!("WebSocket connection closed");
303                        break;
304                    }
305                    Err(e) => {
306                        error!("WebSocket error: {}", e);
307                        break;
308                    }
309                    _ => {}
310                }
311            }
312            info!("Market data handler exiting");
313        });
314
315        Ok(MarketDataStream::new(receiver))
316    }
317
318    /// Subscribe to trading updates
319    pub async fn subscribe_trading_updates(&self) -> Result<TradingStream> {
320        let stream = self.connect().await?;
321        let (sender, receiver) = mpsc::unbounded_channel();
322
323        tokio::spawn(async move {
324            let mut trading_stream = stream.trading_updates();
325            while let Some(update) = trading_stream.next().await {
326                if sender.send(update).is_err() {
327                    break;
328                }
329            }
330        });
331
332        Ok(TradingStream::new(receiver))
333    }
334
335    /// Authenticate with the WebSocket
336    async fn authenticate(&self, sink: &mut WsSink) -> Result<()> {
337        // Alpaca uses {"action": "auth", "key": "...", "secret": "..."}
338        let auth_msg = serde_json::json!({
339            "action": "auth",
340            "key": self.credentials.api_key,
341            "secret": self.credentials.secret_key
342        });
343
344        let auth_json = serde_json::to_string(&auth_msg)?;
345        debug!("Sending auth: {}", auth_json);
346        sink.send(Message::Text(auth_json.into())).await?;
347
348        debug!("Sent authentication message");
349        Ok(())
350    }
351
352    /// Handle incoming WebSocket messages
353    async fn handle_messages(
354        stream: &mut WsReceiver,
355        sender: mpsc::UnboundedSender<WebSocketMessage>,
356        _credentials: Credentials,
357    ) {
358        while let Some(message) = stream.next().await {
359            match message {
360                Ok(Message::Text(text)) => match Self::parse_message(&text) {
361                    Ok(msg) => {
362                        debug!("Received message: {:?}", msg);
363                        if sender.send(msg).is_err() {
364                            warn!("Failed to send message to channel");
365                            break;
366                        }
367                    }
368                    Err(e) => {
369                        warn!("Failed to parse message: {} - Raw: {}", e, text);
370                    }
371                },
372                Ok(Message::Close(_)) => {
373                    info!("WebSocket connection closed");
374                    break;
375                }
376                Ok(Message::Ping(_data)) => {
377                    debug!("Received ping, sending pong");
378                    // Note: tokio-tungstenite handles pong automatically
379                }
380                Ok(Message::Pong(_)) => {
381                    debug!("Received pong");
382                }
383                Ok(Message::Binary(_)) => {
384                    warn!("Received unexpected binary message");
385                }
386                Ok(Message::Frame(_)) => {
387                    debug!("Received frame message");
388                }
389                Err(e) => {
390                    error!("WebSocket error: {}", e);
391                    break;
392                }
393            }
394        }
395
396        info!("Message handler exiting");
397    }
398
399    /// Parse incoming WebSocket message
400    fn parse_message(text: &str) -> Result<WebSocketMessage> {
401        // Handle array of messages
402        if text.starts_with('[') {
403            let messages: Vec<serde_json::Value> = serde_json::from_str(text)?;
404            if let Some(first_msg) = messages.first() {
405                return serde_json::from_value(first_msg.clone())
406                    .map_err(|e| AlpacaError::Json(e.to_string()));
407            }
408        }
409
410        // Handle single message
411        serde_json::from_str(text).map_err(|e| AlpacaError::Json(e.to_string()))
412    }
413
414    /// Send subscription message
415    pub async fn send_subscription(&self, subscription: SubscribeMessage) -> Result<()> {
416        // This would need to be implemented with a persistent connection
417        // For now, this is a placeholder
418        debug!("Would send subscription: {:?}", subscription);
419        Ok(())
420    }
421
422    /// Send unsubscription message
423    pub async fn send_unsubscription(&self, unsubscription: UnsubscribeMessage) -> Result<()> {
424        // This would need to be implemented with a persistent connection
425        // For now, this is a placeholder
426        debug!("Would send unsubscription: {:?}", unsubscription);
427        Ok(())
428    }
429
430    /// Get the WebSocket URL
431    pub fn url(&self) -> &str {
432        &self.url
433    }
434
435    /// Get the environment
436    pub fn environment(&self) -> &Environment {
437        &self.environment
438    }
439}
440
441/// WebSocket connection manager with automatic reconnection
442pub struct WebSocketManager {
443    client: AlpacaWebSocketClient,
444    max_retries: u32,
445    heartbeat_interval: Duration,
446}
447
448impl WebSocketManager {
449    /// Create a new WebSocket manager
450    pub fn new(client: AlpacaWebSocketClient) -> Self {
451        Self {
452            client,
453            max_retries: 5,
454            heartbeat_interval: Duration::from_secs(30),
455        }
456    }
457
458    /// Set maximum retry attempts
459    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
460        self.max_retries = max_retries;
461        self
462    }
463
464    /// Set heartbeat interval
465    pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
466        self.heartbeat_interval = interval;
467        self
468    }
469
470    /// Start the managed connection
471    pub async fn start(&self) -> Result<AlpacaStream> {
472        let stream = self.client.connect_with_reconnect(self.max_retries).await?;
473
474        // Start heartbeat
475        self.start_heartbeat().await;
476
477        Ok(stream)
478    }
479
480    /// Start heartbeat to keep connection alive
481    async fn start_heartbeat(&self) {
482        let mut interval = interval(self.heartbeat_interval);
483
484        tokio::spawn(async move {
485            loop {
486                interval.tick().await;
487                debug!("Heartbeat tick");
488                // In a real implementation, you might send a ping message here
489            }
490        });
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use alpaca_base::types::Environment;
498
499    #[test]
500    fn test_client_creation() {
501        let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
502        let client = AlpacaWebSocketClient::new(credentials, Environment::Paper);
503
504        assert!(client.url().contains("stream.data.alpaca.markets"));
505    }
506
507    #[test]
508    fn test_trading_client() {
509        let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
510        let client = AlpacaWebSocketClient::trading(credentials, Environment::Paper);
511
512        assert!(client.url().contains("paper-api.alpaca.markets"));
513    }
514
515    #[test]
516    fn test_parse_message() {
517        let json = r#"{"T":"success","msg":"authenticated"}"#;
518        let result = AlpacaWebSocketClient::parse_message(json);
519        assert!(result.is_ok());
520    }
521}