alpaca_websocket/
client.rs1#![allow(missing_docs)]
4
5use crate::{messages::*, streams::*};
6use alpaca_base::{AlpacaError, Result, auth::Credentials, types::Environment};
7use futures_util::{
8 sink::SinkExt,
9 stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json;
12use std::time::Duration;
13use tokio::{
14 net::TcpStream,
15 sync::mpsc,
16 time::{interval, sleep},
17};
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
19use tracing::{debug, error, info, warn};
20
21type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
22type WsSink = SplitSink<WsStream, Message>;
23type WsReceiver = SplitStream<WsStream>;
24
25#[derive(Debug)]
27pub struct AlpacaWebSocketClient {
28 credentials: Credentials,
29 environment: Environment,
30 url: String,
31}
32
33impl AlpacaWebSocketClient {
34 pub fn new(credentials: Credentials, environment: Environment) -> Self {
36 let url = match environment {
37 Environment::Paper => "wss://stream.data.alpaca.markets/v2/iex",
38 Environment::Live => "wss://stream.data.alpaca.markets/v2/sip",
39 };
40
41 Self {
42 credentials,
43 environment,
44 url: url.to_string(),
45 }
46 }
47
48 pub fn from_env(environment: Environment) -> Result<Self> {
50 let credentials = Credentials::from_env()?;
51 Ok(Self::new(credentials, environment))
52 }
53
54 pub fn trading(credentials: Credentials, environment: Environment) -> Self {
56 let url = environment.websocket_url();
57 Self {
58 credentials,
59 environment,
60 url: url.to_string(),
61 }
62 }
63
64 pub async fn connect(&self) -> Result<AlpacaStream> {
66 let (sender, receiver) = mpsc::unbounded_channel();
67 info!("Connecting to WebSocket: {}", self.url);
68 let (ws_stream, _) = connect_async(&self.url).await?;
69 let (mut sink, mut stream) = ws_stream.split();
70
71 self.authenticate(&mut sink).await?;
73
74 let credentials = self.credentials.clone();
76 tokio::spawn(async move {
77 Self::handle_messages(&mut stream, sender, credentials).await;
78 });
79
80 Ok(AlpacaStream::new(receiver))
81 }
82
83 pub async fn connect_with_reconnect(&self, max_retries: u32) -> Result<AlpacaStream> {
85 let mut attempts = 0;
86 let mut delay = Duration::from_secs(1);
87
88 loop {
89 match self.connect().await {
90 Ok(stream) => {
91 info!("Successfully connected to WebSocket");
92 return Ok(stream);
93 }
94 Err(e) => {
95 attempts += 1;
96 if attempts >= max_retries {
97 error!("Failed to connect after {} attempts", attempts);
98 return Err(AlpacaError::WebSocket(format!(
99 "Connection failed after {} attempts: {}",
100 attempts, e
101 )));
102 }
103
104 warn!(
105 "Connection attempt {} failed: {}. Retrying in {:?}",
106 attempts, e, delay
107 );
108 sleep(delay).await;
109 delay = std::cmp::min(delay * 2, Duration::from_secs(60));
110 }
111 }
112 }
113 }
114
115 pub async fn subscribe_market_data(
117 &self,
118 _subscription: SubscribeMessage,
119 ) -> Result<MarketDataStream> {
120 let stream = self.connect().await?;
121 let (sender, receiver) = mpsc::unbounded_channel();
122
123 tokio::spawn(async move {
128 let mut market_data_stream = stream.market_data();
129 while let Some(update) = market_data_stream.next().await {
130 if sender.send(update).is_err() {
131 break;
132 }
133 }
134 });
135
136 Ok(MarketDataStream::new(receiver))
137 }
138
139 pub async fn subscribe_trading_updates(&self) -> Result<TradingStream> {
141 let stream = self.connect().await?;
142 let (sender, receiver) = mpsc::unbounded_channel();
143
144 tokio::spawn(async move {
145 let mut trading_stream = stream.trading_updates();
146 while let Some(update) = trading_stream.next().await {
147 if sender.send(update).is_err() {
148 break;
149 }
150 }
151 });
152
153 Ok(TradingStream::new(receiver))
154 }
155
156 async fn authenticate(&self, sink: &mut WsSink) -> Result<()> {
158 let auth_msg = WebSocketMessage::Auth(AuthMessage {
159 key: self.credentials.api_key.clone(),
160 secret: self.credentials.secret_key.clone(),
161 });
162
163 let auth_json = serde_json::to_string(&auth_msg)?;
164 sink.send(Message::Text(auth_json.into())).await?;
165
166 debug!("Sent authentication message");
167 Ok(())
168 }
169
170 async fn handle_messages(
172 stream: &mut WsReceiver,
173 sender: mpsc::UnboundedSender<WebSocketMessage>,
174 _credentials: Credentials,
175 ) {
176 while let Some(message) = stream.next().await {
177 match message {
178 Ok(Message::Text(text)) => match Self::parse_message(&text) {
179 Ok(msg) => {
180 debug!("Received message: {:?}", msg);
181 if sender.send(msg).is_err() {
182 warn!("Failed to send message to channel");
183 break;
184 }
185 }
186 Err(e) => {
187 warn!("Failed to parse message: {} - Raw: {}", e, text);
188 }
189 },
190 Ok(Message::Close(_)) => {
191 info!("WebSocket connection closed");
192 break;
193 }
194 Ok(Message::Ping(_data)) => {
195 debug!("Received ping, sending pong");
196 }
198 Ok(Message::Pong(_)) => {
199 debug!("Received pong");
200 }
201 Ok(Message::Binary(_)) => {
202 warn!("Received unexpected binary message");
203 }
204 Ok(Message::Frame(_)) => {
205 debug!("Received frame message");
206 }
207 Err(e) => {
208 error!("WebSocket error: {}", e);
209 break;
210 }
211 }
212 }
213
214 info!("Message handler exiting");
215 }
216
217 fn parse_message(text: &str) -> Result<WebSocketMessage> {
219 if text.starts_with('[') {
221 let messages: Vec<serde_json::Value> = serde_json::from_str(text)?;
222 if let Some(first_msg) = messages.first() {
223 return serde_json::from_value(first_msg.clone())
224 .map_err(|e| AlpacaError::Json(e.to_string()));
225 }
226 }
227
228 serde_json::from_str(text).map_err(|e| AlpacaError::Json(e.to_string()))
230 }
231
232 pub async fn send_subscription(&self, subscription: SubscribeMessage) -> Result<()> {
234 debug!("Would send subscription: {:?}", subscription);
237 Ok(())
238 }
239
240 pub async fn send_unsubscription(&self, unsubscription: UnsubscribeMessage) -> Result<()> {
242 debug!("Would send unsubscription: {:?}", unsubscription);
245 Ok(())
246 }
247
248 pub fn url(&self) -> &str {
250 &self.url
251 }
252
253 pub fn environment(&self) -> &Environment {
255 &self.environment
256 }
257}
258
259pub struct WebSocketManager {
261 client: AlpacaWebSocketClient,
262 max_retries: u32,
263 heartbeat_interval: Duration,
264}
265
266impl WebSocketManager {
267 pub fn new(client: AlpacaWebSocketClient) -> Self {
269 Self {
270 client,
271 max_retries: 5,
272 heartbeat_interval: Duration::from_secs(30),
273 }
274 }
275
276 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
278 self.max_retries = max_retries;
279 self
280 }
281
282 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
284 self.heartbeat_interval = interval;
285 self
286 }
287
288 pub async fn start(&self) -> Result<AlpacaStream> {
290 let stream = self.client.connect_with_reconnect(self.max_retries).await?;
291
292 self.start_heartbeat().await;
294
295 Ok(stream)
296 }
297
298 async fn start_heartbeat(&self) {
300 let mut interval = interval(self.heartbeat_interval);
301
302 tokio::spawn(async move {
303 loop {
304 interval.tick().await;
305 debug!("Heartbeat tick");
306 }
308 });
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use alpaca_base::types::Environment;
316
317 #[test]
318 fn test_client_creation() {
319 let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
320 let client = AlpacaWebSocketClient::new(credentials, Environment::Paper);
321
322 assert!(client.url().contains("stream.data.alpaca.markets"));
323 }
324
325 #[test]
326 fn test_trading_client() {
327 let credentials = Credentials::new("test_key".to_string(), "test_secret".to_string());
328 let client = AlpacaWebSocketClient::trading(credentials, Environment::Paper);
329
330 assert!(client.url().contains("paper-api.alpaca.markets"));
331 }
332
333 #[test]
334 fn test_parse_message() {
335 let json = r#"{"T":"success","msg":"authenticated"}"#;
336 let result = AlpacaWebSocketClient::parse_message(json);
337 assert!(result.is_ok());
338 }
339}