alpaca_websocket/
client.rs

1//! WebSocket client for Alpaca streaming data.
2
3#![allow(missing_docs)]
4
5use crate::{messages::*, streams::*};
6use alpaca_base::{AlpacaError, Result, auth::Credentials, types::Environment};
7use futures_util::{
8    sink::SinkExt,
9    stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json;
12use std::time::Duration;
13use tokio::{
14    net::TcpStream,
15    sync::mpsc,
16    time::{interval, sleep},
17};
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
19use tracing::{debug, error, info, warn};
20
21type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
22type WsSink = SplitSink<WsStream, Message>;
23type WsReceiver = SplitStream<WsStream>;
24
25/// WebSocket client for Alpaca API
26#[derive(Debug)]
27pub struct AlpacaWebSocketClient {
28    credentials: Credentials,
29    environment: Environment,
30    url: String,
31}
32
33impl AlpacaWebSocketClient {
34    /// Create a new WebSocket client
35    pub fn new(credentials: Credentials, environment: Environment) -> Self {
36        let url = match environment {
37            Environment::Paper => "wss://stream.data.alpaca.markets/v2/iex",
38            Environment::Live => "wss://stream.data.alpaca.markets/v2/sip",
39        };
40
41        Self {
42            credentials,
43            environment,
44            url: url.to_string(),
45        }
46    }
47
48    /// Create a new client from environment variables
49    pub fn from_env(environment: Environment) -> Result<Self> {
50        let credentials = Credentials::from_env()?;
51        Ok(Self::new(credentials, environment))
52    }
53
54    /// Create a trading WebSocket client
55    pub fn trading(credentials: Credentials, environment: Environment) -> Self {
56        let url = environment.websocket_url();
57        Self {
58            credentials,
59            environment,
60            url: url.to_string(),
61        }
62    }
63
64    /// Connect to the WebSocket and return a stream of messages
65    pub async fn connect(&self) -> Result<AlpacaStream> {
66        let (sender, receiver) = mpsc::unbounded_channel();
67        info!("Connecting to WebSocket: {}", self.url);
68        let (ws_stream, _) = connect_async(&self.url).await?;
69        let (mut sink, mut stream) = ws_stream.split();
70
71        // Authenticate
72        self.authenticate(&mut sink).await?;
73
74        // Spawn message handler
75        let credentials = self.credentials.clone();
76        tokio::spawn(async move {
77            Self::handle_messages(&mut stream, sender, credentials).await;
78        });
79
80        Ok(AlpacaStream::new(receiver))
81    }
82
83    /// Connect with automatic reconnection
84    pub async fn connect_with_reconnect(&self, max_retries: u32) -> Result<AlpacaStream> {
85        let mut attempts = 0;
86        let mut delay = Duration::from_secs(1);
87
88        loop {
89            match self.connect().await {
90                Ok(stream) => {
91                    info!("Successfully connected to WebSocket");
92                    return Ok(stream);
93                }
94                Err(e) => {
95                    attempts += 1;
96                    if attempts >= max_retries {
97                        error!("Failed to connect after {} attempts", attempts);
98                        return Err(AlpacaError::WebSocket(format!(
99                            "Connection failed after {} attempts: {}",
100                            attempts, e
101                        )));
102                    }
103
104                    warn!(
105                        "Connection attempt {} failed: {}. Retrying in {:?}",
106                        attempts, e, delay
107                    );
108                    sleep(delay).await;
109                    delay = std::cmp::min(delay * 2, Duration::from_secs(60));
110                }
111            }
112        }
113    }
114
115    /// Subscribe to market data
116    pub async fn subscribe_market_data(
117        &self,
118        _subscription: SubscribeMessage,
119    ) -> Result<MarketDataStream> {
120        let stream = self.connect().await?;
121        let (sender, receiver) = mpsc::unbounded_channel();
122
123        // Send subscription
124        // Note: In a real implementation, you'd need to send the subscription message
125        // through the WebSocket connection. This is simplified for the example.
126
127        tokio::spawn(async move {
128            let mut market_data_stream = stream.market_data();
129            while let Some(update) = market_data_stream.next().await {
130                if sender.send(update).is_err() {
131                    break;
132                }
133            }
134        });
135
136        Ok(MarketDataStream::new(receiver))
137    }
138
139    /// Subscribe to trading updates
140    pub async fn subscribe_trading_updates(&self) -> Result<TradingStream> {
141        let stream = self.connect().await?;
142        let (sender, receiver) = mpsc::unbounded_channel();
143
144        tokio::spawn(async move {
145            let mut trading_stream = stream.trading_updates();
146            while let Some(update) = trading_stream.next().await {
147                if sender.send(update).is_err() {
148                    break;
149                }
150            }
151        });
152
153        Ok(TradingStream::new(receiver))
154    }
155
156    /// Authenticate with the WebSocket
157    async fn authenticate(&self, sink: &mut WsSink) -> Result<()> {
158        let auth_msg = WebSocketMessage::Auth(AuthMessage {
159            key: self.credentials.api_key.clone(),
160            secret: self.credentials.secret_key.clone(),
161        });
162
163        let auth_json = serde_json::to_string(&auth_msg)?;
164        sink.send(Message::Text(auth_json.into())).await?;
165
166        debug!("Sent authentication message");
167        Ok(())
168    }
169
170    /// Handle incoming WebSocket messages
171    async fn handle_messages(
172        stream: &mut WsReceiver,
173        sender: mpsc::UnboundedSender<WebSocketMessage>,
174        _credentials: Credentials,
175    ) {
176        while let Some(message) = stream.next().await {
177            match message {
178                Ok(Message::Text(text)) => match Self::parse_message(&text) {
179                    Ok(msg) => {
180                        debug!("Received message: {:?}", msg);
181                        if sender.send(msg).is_err() {
182                            warn!("Failed to send message to channel");
183                            break;
184                        }
185                    }
186                    Err(e) => {
187                        warn!("Failed to parse message: {} - Raw: {}", e, text);
188                    }
189                },
190                Ok(Message::Close(_)) => {
191                    info!("WebSocket connection closed");
192                    break;
193                }
194                Ok(Message::Ping(_data)) => {
195                    debug!("Received ping, sending pong");
196                    // Note: tokio-tungstenite handles pong automatically
197                }
198                Ok(Message::Pong(_)) => {
199                    debug!("Received pong");
200                }
201                Ok(Message::Binary(_)) => {
202                    warn!("Received unexpected binary message");
203                }
204                Ok(Message::Frame(_)) => {
205                    debug!("Received frame message");
206                }
207                Err(e) => {
208                    error!("WebSocket error: {}", e);
209                    break;
210                }
211            }
212        }
213
214        info!("Message handler exiting");
215    }
216
217    /// Parse incoming WebSocket message
218    fn parse_message(text: &str) -> Result<WebSocketMessage> {
219        // Handle array of messages
220        if text.starts_with('[') {
221            let messages: Vec<serde_json::Value> = serde_json::from_str(text)?;
222            if let Some(first_msg) = messages.first() {
223                return serde_json::from_value(first_msg.clone())
224                    .map_err(|e| AlpacaError::Json(e.to_string()));
225            }
226        }
227
228        // Handle single message
229        serde_json::from_str(text).map_err(|e| AlpacaError::Json(e.to_string()))
230    }
231
232    /// Send subscription message
233    pub async fn send_subscription(&self, subscription: SubscribeMessage) -> Result<()> {
234        // This would need to be implemented with a persistent connection
235        // For now, this is a placeholder
236        debug!("Would send subscription: {:?}", subscription);
237        Ok(())
238    }
239
240    /// Send unsubscription message
241    pub async fn send_unsubscription(&self, unsubscription: UnsubscribeMessage) -> Result<()> {
242        // This would need to be implemented with a persistent connection
243        // For now, this is a placeholder
244        debug!("Would send unsubscription: {:?}", unsubscription);
245        Ok(())
246    }
247
248    /// Get the WebSocket URL
249    pub fn url(&self) -> &str {
250        &self.url
251    }
252
253    /// Get the environment
254    pub fn environment(&self) -> &Environment {
255        &self.environment
256    }
257}
258
259/// WebSocket connection manager with automatic reconnection
260pub struct WebSocketManager {
261    client: AlpacaWebSocketClient,
262    max_retries: u32,
263    heartbeat_interval: Duration,
264}
265
266impl WebSocketManager {
267    /// Create a new WebSocket manager
268    pub fn new(client: AlpacaWebSocketClient) -> Self {
269        Self {
270            client,
271            max_retries: 5,
272            heartbeat_interval: Duration::from_secs(30),
273        }
274    }
275
276    /// Set maximum retry attempts
277    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
278        self.max_retries = max_retries;
279        self
280    }
281
282    /// Set heartbeat interval
283    pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
284        self.heartbeat_interval = interval;
285        self
286    }
287
288    /// Start the managed connection
289    pub async fn start(&self) -> Result<AlpacaStream> {
290        let stream = self.client.connect_with_reconnect(self.max_retries).await?;
291
292        // Start heartbeat
293        self.start_heartbeat().await;
294
295        Ok(stream)
296    }
297
298    /// Start heartbeat to keep connection alive
299    async fn start_heartbeat(&self) {
300        let mut interval = interval(self.heartbeat_interval);
301
302        tokio::spawn(async move {
303            loop {
304                interval.tick().await;
305                debug!("Heartbeat tick");
306                // In a real implementation, you might send a ping message here
307            }
308        });
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use alpaca_base::types::Environment;
316
317    #[test]
318    fn test_client_creation() {
319        let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
320        let client = AlpacaWebSocketClient::new(credentials, Environment::Paper);
321
322        assert!(client.url().contains("stream.data.alpaca.markets"));
323    }
324
325    #[test]
326    fn test_trading_client() {
327        let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
328        let client = AlpacaWebSocketClient::trading(credentials, Environment::Paper);
329
330        assert!(client.url().contains("paper-api.alpaca.markets"));
331    }
332
333    #[test]
334    fn test_parse_message() {
335        let json = r#"{"T":"success","msg":"authenticated"}"#;
336        let result = AlpacaWebSocketClient::parse_message(json);
337        assert!(result.is_ok());
338    }
339}