1#![allow(missing_docs)]
4
5use crate::{messages::*, streams::*};
6use alpaca_base::types::Quote;
7use alpaca_base::{AlpacaError, Result, auth::Credentials, types::Environment};
8use futures_util::{
9 sink::SinkExt,
10 stream::{SplitSink, SplitStream, StreamExt},
11};
12use serde_json;
13use std::sync::Once;
14use std::time::Duration;
15use tokio::{
16 net::TcpStream,
17 sync::mpsc,
18 time::{interval, sleep},
19};
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
21use tracing::{debug, error, info, warn};
22
23static CRYPTO_PROVIDER_INIT: Once = Once::new();
24
25fn init_crypto_provider() {
28 CRYPTO_PROVIDER_INIT.call_once(|| {
29 let _ = rustls::crypto::ring::default_provider().install_default();
30 });
31}
32
33type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
34type WsSink = SplitSink<WsStream, Message>;
35type WsReceiver = SplitStream<WsStream>;
36
37#[derive(Debug)]
39pub struct AlpacaWebSocketClient {
40 credentials: Credentials,
41 environment: Environment,
42 url: String,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum DataFeed {
48 Iex,
50 Sip,
52 Crypto,
54}
55
56impl AlpacaWebSocketClient {
57 pub fn new(credentials: Credentials, environment: Environment) -> Self {
59 let url = match environment {
60 Environment::Paper => "wss://stream.data.alpaca.markets/v2/iex",
61 Environment::Live => "wss://stream.data.alpaca.markets/v2/sip",
62 };
63
64 Self {
65 credentials,
66 environment,
67 url: url.to_string(),
68 }
69 }
70
71 pub fn from_env(environment: Environment) -> Result<Self> {
73 let credentials = Credentials::from_env()?;
74 Ok(Self::new(credentials, environment))
75 }
76
77 pub fn with_feed(credentials: Credentials, environment: Environment, feed: DataFeed) -> Self {
79 let url = match feed {
80 DataFeed::Iex => "wss://stream.data.alpaca.markets/v2/iex",
81 DataFeed::Sip => "wss://stream.data.alpaca.markets/v2/sip",
82 DataFeed::Crypto => "wss://stream.data.alpaca.markets/v1beta3/crypto/us",
83 };
84
85 Self {
86 credentials,
87 environment,
88 url: url.to_string(),
89 }
90 }
91
92 pub fn crypto(credentials: Credentials, environment: Environment) -> Self {
94 Self::with_feed(credentials, environment, DataFeed::Crypto)
95 }
96
97 pub fn crypto_from_env(environment: Environment) -> Result<Self> {
99 let credentials = Credentials::from_env()?;
100 Ok(Self::crypto(credentials, environment))
101 }
102
103 pub fn trading(credentials: Credentials, environment: Environment) -> Self {
105 let url = environment.websocket_url();
106 Self {
107 credentials,
108 environment,
109 url: url.to_string(),
110 }
111 }
112
113 pub async fn connect(&self) -> Result<AlpacaStream> {
115 init_crypto_provider();
117
118 let (sender, receiver) = mpsc::unbounded_channel();
119 info!("Connecting to WebSocket: {}", self.url);
120 let (ws_stream, _) = connect_async(&self.url).await?;
121 let (mut sink, mut stream) = ws_stream.split();
122
123 self.authenticate(&mut sink).await?;
125
126 let credentials = self.credentials.clone();
128 tokio::spawn(async move {
129 Self::handle_messages(&mut stream, sender, credentials).await;
130 });
131
132 Ok(AlpacaStream::new(receiver))
133 }
134
135 pub async fn connect_with_reconnect(&self, max_retries: u32) -> Result<AlpacaStream> {
137 let mut attempts = 0;
138 let mut delay = Duration::from_secs(1);
139
140 loop {
141 match self.connect().await {
142 Ok(stream) => {
143 info!("Successfully connected to WebSocket");
144 return Ok(stream);
145 }
146 Err(e) => {
147 attempts += 1;
148 if attempts >= max_retries {
149 error!("Failed to connect after {} attempts", attempts);
150 return Err(AlpacaError::WebSocket(format!(
151 "Connection failed after {} attempts: {}",
152 attempts, e
153 )));
154 }
155
156 warn!(
157 "Connection attempt {} failed: {}. Retrying in {:?}",
158 attempts, e, delay
159 );
160 sleep(delay).await;
161 delay = std::cmp::min(delay * 2, Duration::from_secs(60));
162 }
163 }
164 }
165 }
166
167 pub async fn subscribe_market_data(
169 &self,
170 subscription: SubscribeMessage,
171 ) -> Result<MarketDataStream> {
172 init_crypto_provider();
174
175 let (sender, receiver) = mpsc::unbounded_channel();
176 info!("Connecting to WebSocket: {}", self.url);
177 let (ws_stream, _) = connect_async(&self.url).await?;
178 let (mut sink, mut stream) = ws_stream.split();
179
180 if let Some(Ok(Message::Text(text))) = stream.next().await {
182 debug!("Server: {}", text);
183 }
184
185 self.authenticate(&mut sink).await?;
187
188 if let Some(Ok(Message::Text(text))) = stream.next().await {
190 debug!("Auth response: {}", text);
191 }
192
193 let sub_msg = serde_json::json!({
195 "action": "subscribe",
196 "trades": subscription.trades.unwrap_or_default(),
197 "quotes": subscription.quotes.unwrap_or_default(),
198 "bars": subscription.bars.unwrap_or_default()
199 });
200 let sub_json = serde_json::to_string(&sub_msg)?;
201 debug!("Sending subscription: {}", sub_json);
202 sink.send(Message::Text(sub_json.into())).await?;
203
204 if let Some(Ok(Message::Text(text))) = stream.next().await {
206 debug!("Subscription response: {}", text);
207 }
208
209 let credentials = self.credentials.clone();
211 tokio::spawn(async move {
212 let _ = credentials; debug!("Handler started, waiting for messages...");
214 while let Some(message) = stream.next().await {
215 match message {
216 Ok(Message::Text(text)) => {
217 if let Ok(messages) = serde_json::from_str::<Vec<serde_json::Value>>(&text)
219 {
220 for msg_value in messages {
221 if let Some(msg_type) = msg_value.get("T").and_then(|t| t.as_str())
222 {
223 let update = match msg_type {
224 "t" => {
225 if let Ok(trade_msg) =
227 serde_json::from_value::<TradeMessage>(
228 msg_value.clone(),
229 )
230 {
231 Some(MarketDataUpdate::Trade {
232 symbol: trade_msg.symbol.clone(),
233 trade: trade_msg.into(),
234 })
235 } else {
236 None
237 }
238 }
239 "q" => {
240 if let Ok(quote_msg) =
242 serde_json::from_value::<CryptoQuoteMessage>(
243 msg_value.clone(),
244 )
245 {
246 Some(MarketDataUpdate::Quote {
247 symbol: quote_msg.symbol.clone(),
248 quote: Quote {
249 timestamp: quote_msg.timestamp,
250 timeframe: "real-time".to_string(),
251 bid_price: quote_msg.bid_price,
252 bid_size: quote_msg.bid_size as u32,
253 ask_price: quote_msg.ask_price,
254 ask_size: quote_msg.ask_size as u32,
255 bid_exchange: String::new(),
256 ask_exchange: String::new(),
257 },
258 })
259 } else if let Ok(quote_msg) =
260 serde_json::from_value::<QuoteMessage>(
261 msg_value.clone(),
262 )
263 {
264 Some(MarketDataUpdate::Quote {
265 symbol: quote_msg.symbol.clone(),
266 quote: quote_msg.into(),
267 })
268 } else {
269 None
270 }
271 }
272 "b" => {
273 if let Ok(bar_msg) = serde_json::from_value::<BarMessage>(
275 msg_value.clone(),
276 ) {
277 Some(MarketDataUpdate::Bar {
278 symbol: bar_msg.symbol.clone(),
279 bar: bar_msg.into(),
280 })
281 } else {
282 None
283 }
284 }
285 _ => {
286 debug!("Ignoring message type: {}", msg_type);
287 None
288 }
289 };
290
291 if let Some(u) = update
292 && sender.send(u).is_err()
293 {
294 debug!("Channel closed");
295 break;
296 }
297 }
298 }
299 }
300 }
301 Ok(Message::Close(_)) => {
302 info!("WebSocket connection closed");
303 break;
304 }
305 Err(e) => {
306 error!("WebSocket error: {}", e);
307 break;
308 }
309 _ => {}
310 }
311 }
312 info!("Market data handler exiting");
313 });
314
315 Ok(MarketDataStream::new(receiver))
316 }
317
318 pub async fn subscribe_trading_updates(&self) -> Result<TradingStream> {
320 let stream = self.connect().await?;
321 let (sender, receiver) = mpsc::unbounded_channel();
322
323 tokio::spawn(async move {
324 let mut trading_stream = stream.trading_updates();
325 while let Some(update) = trading_stream.next().await {
326 if sender.send(update).is_err() {
327 break;
328 }
329 }
330 });
331
332 Ok(TradingStream::new(receiver))
333 }
334
335 async fn authenticate(&self, sink: &mut WsSink) -> Result<()> {
337 let auth_msg = serde_json::json!({
339 "action": "auth",
340 "key": self.credentials.api_key,
341 "secret": self.credentials.secret_key
342 });
343
344 let auth_json = serde_json::to_string(&auth_msg)?;
345 debug!("Sending auth: {}", auth_json);
346 sink.send(Message::Text(auth_json.into())).await?;
347
348 debug!("Sent authentication message");
349 Ok(())
350 }
351
352 async fn handle_messages(
354 stream: &mut WsReceiver,
355 sender: mpsc::UnboundedSender<WebSocketMessage>,
356 _credentials: Credentials,
357 ) {
358 while let Some(message) = stream.next().await {
359 match message {
360 Ok(Message::Text(text)) => match Self::parse_message(&text) {
361 Ok(msg) => {
362 debug!("Received message: {:?}", msg);
363 if sender.send(msg).is_err() {
364 warn!("Failed to send message to channel");
365 break;
366 }
367 }
368 Err(e) => {
369 warn!("Failed to parse message: {} - Raw: {}", e, text);
370 }
371 },
372 Ok(Message::Close(_)) => {
373 info!("WebSocket connection closed");
374 break;
375 }
376 Ok(Message::Ping(_data)) => {
377 debug!("Received ping, sending pong");
378 }
380 Ok(Message::Pong(_)) => {
381 debug!("Received pong");
382 }
383 Ok(Message::Binary(_)) => {
384 warn!("Received unexpected binary message");
385 }
386 Ok(Message::Frame(_)) => {
387 debug!("Received frame message");
388 }
389 Err(e) => {
390 error!("WebSocket error: {}", e);
391 break;
392 }
393 }
394 }
395
396 info!("Message handler exiting");
397 }
398
399 fn parse_message(text: &str) -> Result<WebSocketMessage> {
401 if text.starts_with('[') {
403 let messages: Vec<serde_json::Value> = serde_json::from_str(text)?;
404 if let Some(first_msg) = messages.first() {
405 return serde_json::from_value(first_msg.clone())
406 .map_err(|e| AlpacaError::Json(e.to_string()));
407 }
408 }
409
410 serde_json::from_str(text).map_err(|e| AlpacaError::Json(e.to_string()))
412 }
413
414 pub async fn send_subscription(&self, subscription: SubscribeMessage) -> Result<()> {
416 debug!("Would send subscription: {:?}", subscription);
419 Ok(())
420 }
421
422 pub async fn send_unsubscription(&self, unsubscription: UnsubscribeMessage) -> Result<()> {
424 debug!("Would send unsubscription: {:?}", unsubscription);
427 Ok(())
428 }
429
430 pub fn url(&self) -> &str {
432 &self.url
433 }
434
435 pub fn environment(&self) -> &Environment {
437 &self.environment
438 }
439}
440
441pub struct WebSocketManager {
443 client: AlpacaWebSocketClient,
444 max_retries: u32,
445 heartbeat_interval: Duration,
446}
447
448impl WebSocketManager {
449 pub fn new(client: AlpacaWebSocketClient) -> Self {
451 Self {
452 client,
453 max_retries: 5,
454 heartbeat_interval: Duration::from_secs(30),
455 }
456 }
457
458 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
460 self.max_retries = max_retries;
461 self
462 }
463
464 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
466 self.heartbeat_interval = interval;
467 self
468 }
469
470 pub async fn start(&self) -> Result<AlpacaStream> {
472 let stream = self.client.connect_with_reconnect(self.max_retries).await?;
473
474 self.start_heartbeat().await;
476
477 Ok(stream)
478 }
479
480 async fn start_heartbeat(&self) {
482 let mut interval = interval(self.heartbeat_interval);
483
484 tokio::spawn(async move {
485 loop {
486 interval.tick().await;
487 debug!("Heartbeat tick");
488 }
490 });
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use alpaca_base::types::Environment;
498
499 #[test]
500 fn test_client_creation() {
501 let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
502 let client = AlpacaWebSocketClient::new(credentials, Environment::Paper);
503
504 assert!(client.url().contains("stream.data.alpaca.markets"));
505 }
506
507 #[test]
508 fn test_trading_client() {
509 let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
510 let client = AlpacaWebSocketClient::trading(credentials, Environment::Paper);
511
512 assert!(client.url().contains("paper-api.alpaca.markets"));
513 }
514
515 #[test]
516 fn test_parse_message() {
517 let json = r#"{"T":"success","msg":"authenticated"}"#;
518 let result = AlpacaWebSocketClient::parse_message(json);
519 assert!(result.is_ok());
520 }
521}