webull_rs/streaming/
client.rs

1use crate::auth::{AuthManager, AccessToken};
2use crate::error::{WebullError, WebullResult};
3use crate::streaming::events::{Event, EventType, ConnectionState, ConnectionStatus, ErrorEvent, HeartbeatEvent};
4use crate::streaming::subscription::{SubscriptionRequest, UnsubscriptionRequest};
5use crate::utils::serialization::{from_json, to_json};
6use futures_util::{SinkExt, StreamExt};
7use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
8use serde_json::json;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::net::TcpStream;
12use tokio::sync::mpsc::{self, Receiver, Sender};
13use tokio::time::sleep;
14use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream};
15use url::Url;
16use uuid::Uuid;
17
18/// WebSocket client for streaming data from Webull.
19pub struct WebSocketClient {
20    /// Base URL for WebSocket connections
21    base_url: String,
22
23    /// Authentication manager
24    auth_manager: Arc<AuthManager>,
25
26    /// Connection state
27    connection_state: Arc<Mutex<ConnectionState>>,
28
29    /// Event sender
30    event_sender: Option<Sender<Event>>,
31
32    /// Last heartbeat time
33    last_heartbeat: Arc<Mutex<Instant>>,
34
35    /// Heartbeat interval in seconds
36    heartbeat_interval: u64,
37
38    /// Reconnect attempts
39    reconnect_attempts: Arc<Mutex<u32>>,
40
41    /// Maximum reconnect attempts
42    max_reconnect_attempts: u32,
43
44    /// Reconnect delay in seconds
45    reconnect_delay: u64,
46}
47
48impl WebSocketClient {
49    /// Create a new WebSocket client.
50    pub fn new(base_url: String, auth_manager: Arc<AuthManager>) -> Self {
51        Self {
52            base_url,
53            auth_manager,
54            connection_state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
55            event_sender: None,
56            last_heartbeat: Arc::new(Mutex::new(Instant::now())),
57            heartbeat_interval: 30,
58            reconnect_attempts: Arc::new(Mutex::new(0)),
59            max_reconnect_attempts: 5,
60            reconnect_delay: 5,
61        }
62    }
63
64    /// Connect to the WebSocket server.
65    pub async fn connect(&mut self) -> WebullResult<Receiver<Event>> {
66        // Create a channel for events
67        let (tx, rx) = mpsc::channel(100);
68        self.event_sender = Some(tx.clone());
69
70        // Set the connection state to reconnecting
71        *self.connection_state.lock().unwrap() = ConnectionState::Reconnecting;
72
73        // Reset reconnect attempts
74        *self.reconnect_attempts.lock().unwrap() = 0;
75
76        // Start the connection task
77        let base_url = self.base_url.clone();
78        let auth_manager = self.auth_manager.clone();
79        let connection_state = self.connection_state.clone();
80        let last_heartbeat = self.last_heartbeat.clone();
81        let heartbeat_interval = self.heartbeat_interval;
82        let reconnect_attempts = self.reconnect_attempts.clone();
83        let max_reconnect_attempts = self.max_reconnect_attempts;
84        let reconnect_delay = self.reconnect_delay;
85
86        tokio::spawn(async move {
87            loop {
88                // Check if we've exceeded the maximum reconnect attempts
89                let attempts = *reconnect_attempts.lock().unwrap();
90                if attempts > max_reconnect_attempts {
91                    // Send a connection failed event
92                    let event = Event {
93                        event_type: EventType::Connection,
94                        timestamp: chrono::Utc::now(),
95                        data: crate::streaming::events::EventData::Connection(ConnectionStatus {
96                            status: ConnectionState::Failed,
97                            connection_id: None,
98                            message: Some("Maximum reconnect attempts exceeded".to_string()),
99                        }),
100                    };
101
102                    let _ = tx.send(event).await;
103
104                    // Set the connection state to failed
105                    *connection_state.lock().unwrap() = ConnectionState::Failed;
106
107                    break;
108                }
109
110                // Increment reconnect attempts
111                *reconnect_attempts.lock().unwrap() = attempts + 1;
112
113                // Get the authentication token
114                let token = match auth_manager.get_token().await {
115                    Ok(token) => token,
116                    Err(e) => {
117                        // Send an error event
118                        let event = Event {
119                            event_type: EventType::Error,
120                            timestamp: chrono::Utc::now(),
121                            data: crate::streaming::events::EventData::Error(ErrorEvent {
122                                code: "AUTH_ERROR".to_string(),
123                                message: format!("Authentication error: {}", e),
124                            }),
125                        };
126
127                        let _ = tx.send(event).await;
128
129                        // Wait before retrying
130                        sleep(Duration::from_secs(reconnect_delay)).await;
131                        continue;
132                    }
133                };
134
135                // Connect to the WebSocket server
136                match Self::connect_websocket(&base_url, &token).await {
137                    Ok(ws_stream) => {
138                        // Set the connection state to connected
139                        *connection_state.lock().unwrap() = ConnectionState::Connected;
140
141                        // Reset reconnect attempts
142                        *reconnect_attempts.lock().unwrap() = 0;
143
144                        // Send a connection established event
145                        let connection_id = Uuid::new_v4().to_string();
146                        let event = Event {
147                            event_type: EventType::Connection,
148                            timestamp: chrono::Utc::now(),
149                            data: crate::streaming::events::EventData::Connection(ConnectionStatus {
150                                status: ConnectionState::Connected,
151                                connection_id: Some(connection_id.clone()),
152                                message: Some("Connection established".to_string()),
153                            }),
154                        };
155
156                        let _ = tx.send(event).await;
157
158                        // Handle the WebSocket connection
159                        if let Err(e) = Self::handle_websocket(ws_stream, tx.clone(), last_heartbeat.clone(), heartbeat_interval).await {
160                            // Send an error event
161                            let event = Event {
162                                event_type: EventType::Error,
163                                timestamp: chrono::Utc::now(),
164                                data: crate::streaming::events::EventData::Error(ErrorEvent {
165                                    code: "WS_ERROR".to_string(),
166                                    message: format!("WebSocket error: {}", e),
167                                }),
168                            };
169
170                            let _ = tx.send(event).await;
171                        }
172
173                        // Set the connection state to disconnected
174                        *connection_state.lock().unwrap() = ConnectionState::Disconnected;
175
176                        // Send a disconnection event
177                        let event = Event {
178                            event_type: EventType::Connection,
179                            timestamp: chrono::Utc::now(),
180                            data: crate::streaming::events::EventData::Connection(ConnectionStatus {
181                                status: ConnectionState::Disconnected,
182                                connection_id: Some(connection_id),
183                                message: Some("Connection closed".to_string()),
184                            }),
185                        };
186
187                        let _ = tx.send(event).await;
188                    }
189                    Err(e) => {
190                        // Send an error event
191                        let event = Event {
192                            event_type: EventType::Error,
193                            timestamp: chrono::Utc::now(),
194                            data: crate::streaming::events::EventData::Error(ErrorEvent {
195                                code: "WS_CONNECT_ERROR".to_string(),
196                                message: format!("WebSocket connection error: {}", e),
197                            }),
198                        };
199
200                        let _ = tx.send(event).await;
201                    }
202                }
203
204                // Wait before reconnecting
205                sleep(Duration::from_secs(reconnect_delay)).await;
206
207                // Set the connection state to reconnecting
208                *connection_state.lock().unwrap() = ConnectionState::Reconnecting;
209
210                // Send a reconnecting event
211                let event = Event {
212                    event_type: EventType::Connection,
213                    timestamp: chrono::Utc::now(),
214                    data: crate::streaming::events::EventData::Connection(ConnectionStatus {
215                        status: ConnectionState::Reconnecting,
216                        connection_id: None,
217                        message: Some("Reconnecting...".to_string()),
218                    }),
219                };
220
221                let _ = tx.send(event).await;
222            }
223        });
224
225        Ok(rx)
226    }
227
228    /// Disconnect from the WebSocket server.
229    pub async fn disconnect(&mut self) -> WebullResult<()> {
230        // Set the connection state to disconnected
231        *self.connection_state.lock().unwrap() = ConnectionState::Disconnected;
232
233        // Reset reconnect attempts
234        *self.reconnect_attempts.lock().unwrap() = self.max_reconnect_attempts + 1;
235
236        Ok(())
237    }
238
239    /// Subscribe to a topic.
240    pub async fn subscribe(&self, request: SubscriptionRequest) -> WebullResult<()> {
241        // Check if we're connected
242        if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
243            return Err(WebullError::InvalidRequest("Not connected to WebSocket server".to_string()));
244        }
245
246        // Send the subscription request
247        let message = json!({
248            "action": "SUBSCRIBE",
249            "request": request,
250        });
251
252        // Send the message
253        if let Some(tx) = &self.event_sender {
254            let _message_str = to_json(&message)?;
255
256            // Create a heartbeat event
257            let event = Event {
258                event_type: EventType::Heartbeat,
259                timestamp: chrono::Utc::now(),
260                data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
261                    id: Uuid::new_v4().to_string(),
262                }),
263            };
264
265            tx.send(event).await.map_err(|e| WebullError::InvalidRequest(format!("Failed to send message: {}", e)))?;
266        }
267
268        Ok(())
269    }
270
271    /// Unsubscribe from a topic.
272    pub async fn unsubscribe(&self, request: UnsubscriptionRequest) -> WebullResult<()> {
273        // Check if we're connected
274        if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
275            return Err(WebullError::InvalidRequest("Not connected to WebSocket server".to_string()));
276        }
277
278        // Send the unsubscription request
279        let message = json!({
280            "action": "UNSUBSCRIBE",
281            "request": request,
282        });
283
284        // Send the message
285        if let Some(tx) = &self.event_sender {
286            let _message_str = to_json(&message)?;
287
288            // Create a heartbeat event
289            let event = Event {
290                event_type: EventType::Heartbeat,
291                timestamp: chrono::Utc::now(),
292                data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
293                    id: Uuid::new_v4().to_string(),
294                }),
295            };
296
297            tx.send(event).await.map_err(|e| WebullError::InvalidRequest(format!("Failed to send message: {}", e)))?;
298        }
299
300        Ok(())
301    }
302
303    /// Connect to the WebSocket server.
304    async fn connect_websocket(base_url: &str, token: &AccessToken) -> WebullResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
305        // Create the WebSocket URL
306        let ws_url = format!("{}/ws", base_url.replace("http", "ws"));
307        let url = Url::parse(&ws_url).map_err(|e| WebullError::InvalidRequest(format!("Invalid WebSocket URL: {}", e)))?;
308
309        // Create the request headers
310        let mut headers = HeaderMap::new();
311        headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", token.token)).unwrap());
312
313        // Connect to the WebSocket server
314        let (ws_stream, _) = connect_async(url).await.map_err(|e| WebullError::InvalidRequest(format!("WebSocket connection error: {}", e)))?;
315
316        Ok(ws_stream)
317    }
318
319    /// Handle the WebSocket connection.
320    async fn handle_websocket(
321        mut ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
322        tx: Sender<Event>,
323        last_heartbeat: Arc<Mutex<Instant>>,
324        heartbeat_interval: u64,
325    ) -> WebullResult<()> {
326        // Start the heartbeat task
327        let tx_clone = tx.clone();
328        let last_heartbeat_clone = last_heartbeat.clone();
329
330        tokio::spawn(async move {
331            loop {
332                // Sleep for the heartbeat interval
333                sleep(Duration::from_secs(heartbeat_interval)).await;
334
335                // Check if we need to send a heartbeat
336                let now = Instant::now();
337                let last = *last_heartbeat_clone.lock().unwrap();
338
339                if now.duration_since(last).as_secs() >= heartbeat_interval {
340                    // Create a heartbeat message
341                    let heartbeat = json!({
342                        "type": "HEARTBEAT",
343                        "id": Uuid::new_v4().to_string(),
344                    });
345
346                    // Send the heartbeat message
347                    let _message = Message::Text(to_json(&heartbeat).unwrap());
348
349                    // Create a heartbeat event
350                    let event = Event {
351                        event_type: EventType::Heartbeat,
352                        timestamp: chrono::Utc::now(),
353                        data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
354                            id: Uuid::new_v4().to_string(),
355                        }),
356                    };
357
358                    // Send the heartbeat event
359                    if tx_clone.send(event).await.is_err() {
360                        // Channel closed, exit the task
361                        break;
362                    }
363
364                    // Update the last heartbeat time
365                    *last_heartbeat_clone.lock().unwrap() = now;
366                }
367            }
368        });
369
370        // Handle incoming messages
371        while let Some(message) = ws_stream.next().await {
372            match message {
373                Ok(Message::Text(text)) => {
374                    // Parse the message
375                    match from_json::<Event>(&text) {
376                        Ok(event) => {
377                            // Send the event
378                            if tx.send(event).await.is_err() {
379                                // Channel closed, exit the loop
380                                break;
381                            }
382                        }
383                        Err(e) => {
384                            // Send an error event
385                            let event = Event {
386                                event_type: EventType::Error,
387                                timestamp: chrono::Utc::now(),
388                                data: crate::streaming::events::EventData::Error(ErrorEvent {
389                                    code: "PARSE_ERROR".to_string(),
390                                    message: format!("Failed to parse message: {}", e),
391                                }),
392                            };
393
394                            if tx.send(event).await.is_err() {
395                                // Channel closed, exit the loop
396                                break;
397                            }
398                        }
399                    }
400                }
401                Ok(Message::Binary(_)) => {
402                    // Ignore binary messages
403                }
404                Ok(Message::Ping(data)) => {
405                    // Respond with a pong
406                    if let Err(e) = ws_stream.send(Message::Pong(data)).await {
407                        // Send an error event
408                        let event = Event {
409                            event_type: EventType::Error,
410                            timestamp: chrono::Utc::now(),
411                            data: crate::streaming::events::EventData::Error(ErrorEvent {
412                                code: "PONG_ERROR".to_string(),
413                                message: format!("Failed to send pong: {}", e),
414                            }),
415                        };
416
417                        if tx.send(event).await.is_err() {
418                            // Channel closed, exit the loop
419                            break;
420                        }
421                    }
422
423                    // Update the last heartbeat time
424                    *last_heartbeat.lock().unwrap() = Instant::now();
425                }
426                Ok(Message::Pong(_)) => {
427                    // Update the last heartbeat time
428                    *last_heartbeat.lock().unwrap() = Instant::now();
429                }
430                Ok(Message::Close(_)) => {
431                    // Connection closed
432                    break;
433                },
434                Ok(Message::Frame(_)) => {
435                    // Ignore frame messages
436                }
437                Err(e) => {
438                    // Send an error event
439                    let event = Event {
440                        event_type: EventType::Error,
441                        timestamp: chrono::Utc::now(),
442                        data: crate::streaming::events::EventData::Error(ErrorEvent {
443                            code: "WS_ERROR".to_string(),
444                            message: format!("WebSocket error: {}", e),
445                        }),
446                    };
447
448                    if tx.send(event).await.is_err() {
449                        // Channel closed, exit the loop
450                        break;
451                    }
452
453                    // Exit the loop on error
454                    break;
455                }
456            }
457        }
458
459        Ok(())
460    }
461}