1use crate::auth::{AuthManager, AccessToken};
2use crate::error::{WebullError, WebullResult};
3use crate::streaming::events::{Event, EventType, ConnectionState, ConnectionStatus, ErrorEvent, HeartbeatEvent};
4use crate::streaming::subscription::{SubscriptionRequest, UnsubscriptionRequest};
5use crate::utils::serialization::{from_json, to_json};
6use futures_util::{SinkExt, StreamExt};
7use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
8use serde_json::json;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::net::TcpStream;
12use tokio::sync::mpsc::{self, Receiver, Sender};
13use tokio::time::sleep;
14use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream};
15use url::Url;
16use uuid::Uuid;
17
18pub struct WebSocketClient {
20 base_url: String,
22
23 auth_manager: Arc<AuthManager>,
25
26 connection_state: Arc<Mutex<ConnectionState>>,
28
29 event_sender: Option<Sender<Event>>,
31
32 last_heartbeat: Arc<Mutex<Instant>>,
34
35 heartbeat_interval: u64,
37
38 reconnect_attempts: Arc<Mutex<u32>>,
40
41 max_reconnect_attempts: u32,
43
44 reconnect_delay: u64,
46}
47
48impl WebSocketClient {
49 pub fn new(base_url: String, auth_manager: Arc<AuthManager>) -> Self {
51 Self {
52 base_url,
53 auth_manager,
54 connection_state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
55 event_sender: None,
56 last_heartbeat: Arc::new(Mutex::new(Instant::now())),
57 heartbeat_interval: 30,
58 reconnect_attempts: Arc::new(Mutex::new(0)),
59 max_reconnect_attempts: 5,
60 reconnect_delay: 5,
61 }
62 }
63
64 pub async fn connect(&mut self) -> WebullResult<Receiver<Event>> {
66 let (tx, rx) = mpsc::channel(100);
68 self.event_sender = Some(tx.clone());
69
70 *self.connection_state.lock().unwrap() = ConnectionState::Reconnecting;
72
73 *self.reconnect_attempts.lock().unwrap() = 0;
75
76 let base_url = self.base_url.clone();
78 let auth_manager = self.auth_manager.clone();
79 let connection_state = self.connection_state.clone();
80 let last_heartbeat = self.last_heartbeat.clone();
81 let heartbeat_interval = self.heartbeat_interval;
82 let reconnect_attempts = self.reconnect_attempts.clone();
83 let max_reconnect_attempts = self.max_reconnect_attempts;
84 let reconnect_delay = self.reconnect_delay;
85
86 tokio::spawn(async move {
87 loop {
88 let attempts = *reconnect_attempts.lock().unwrap();
90 if attempts > max_reconnect_attempts {
91 let event = Event {
93 event_type: EventType::Connection,
94 timestamp: chrono::Utc::now(),
95 data: crate::streaming::events::EventData::Connection(ConnectionStatus {
96 status: ConnectionState::Failed,
97 connection_id: None,
98 message: Some("Maximum reconnect attempts exceeded".to_string()),
99 }),
100 };
101
102 let _ = tx.send(event).await;
103
104 *connection_state.lock().unwrap() = ConnectionState::Failed;
106
107 break;
108 }
109
110 *reconnect_attempts.lock().unwrap() = attempts + 1;
112
113 let token = match auth_manager.get_token().await {
115 Ok(token) => token,
116 Err(e) => {
117 let event = Event {
119 event_type: EventType::Error,
120 timestamp: chrono::Utc::now(),
121 data: crate::streaming::events::EventData::Error(ErrorEvent {
122 code: "AUTH_ERROR".to_string(),
123 message: format!("Authentication error: {}", e),
124 }),
125 };
126
127 let _ = tx.send(event).await;
128
129 sleep(Duration::from_secs(reconnect_delay)).await;
131 continue;
132 }
133 };
134
135 match Self::connect_websocket(&base_url, &token).await {
137 Ok(ws_stream) => {
138 *connection_state.lock().unwrap() = ConnectionState::Connected;
140
141 *reconnect_attempts.lock().unwrap() = 0;
143
144 let connection_id = Uuid::new_v4().to_string();
146 let event = Event {
147 event_type: EventType::Connection,
148 timestamp: chrono::Utc::now(),
149 data: crate::streaming::events::EventData::Connection(ConnectionStatus {
150 status: ConnectionState::Connected,
151 connection_id: Some(connection_id.clone()),
152 message: Some("Connection established".to_string()),
153 }),
154 };
155
156 let _ = tx.send(event).await;
157
158 if let Err(e) = Self::handle_websocket(ws_stream, tx.clone(), last_heartbeat.clone(), heartbeat_interval).await {
160 let event = Event {
162 event_type: EventType::Error,
163 timestamp: chrono::Utc::now(),
164 data: crate::streaming::events::EventData::Error(ErrorEvent {
165 code: "WS_ERROR".to_string(),
166 message: format!("WebSocket error: {}", e),
167 }),
168 };
169
170 let _ = tx.send(event).await;
171 }
172
173 *connection_state.lock().unwrap() = ConnectionState::Disconnected;
175
176 let event = Event {
178 event_type: EventType::Connection,
179 timestamp: chrono::Utc::now(),
180 data: crate::streaming::events::EventData::Connection(ConnectionStatus {
181 status: ConnectionState::Disconnected,
182 connection_id: Some(connection_id),
183 message: Some("Connection closed".to_string()),
184 }),
185 };
186
187 let _ = tx.send(event).await;
188 }
189 Err(e) => {
190 let event = Event {
192 event_type: EventType::Error,
193 timestamp: chrono::Utc::now(),
194 data: crate::streaming::events::EventData::Error(ErrorEvent {
195 code: "WS_CONNECT_ERROR".to_string(),
196 message: format!("WebSocket connection error: {}", e),
197 }),
198 };
199
200 let _ = tx.send(event).await;
201 }
202 }
203
204 sleep(Duration::from_secs(reconnect_delay)).await;
206
207 *connection_state.lock().unwrap() = ConnectionState::Reconnecting;
209
210 let event = Event {
212 event_type: EventType::Connection,
213 timestamp: chrono::Utc::now(),
214 data: crate::streaming::events::EventData::Connection(ConnectionStatus {
215 status: ConnectionState::Reconnecting,
216 connection_id: None,
217 message: Some("Reconnecting...".to_string()),
218 }),
219 };
220
221 let _ = tx.send(event).await;
222 }
223 });
224
225 Ok(rx)
226 }
227
228 pub async fn disconnect(&mut self) -> WebullResult<()> {
230 *self.connection_state.lock().unwrap() = ConnectionState::Disconnected;
232
233 *self.reconnect_attempts.lock().unwrap() = self.max_reconnect_attempts + 1;
235
236 Ok(())
237 }
238
239 pub async fn subscribe(&self, request: SubscriptionRequest) -> WebullResult<()> {
241 if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
243 return Err(WebullError::InvalidRequest("Not connected to WebSocket server".to_string()));
244 }
245
246 let message = json!({
248 "action": "SUBSCRIBE",
249 "request": request,
250 });
251
252 if let Some(tx) = &self.event_sender {
254 let _message_str = to_json(&message)?;
255
256 let event = Event {
258 event_type: EventType::Heartbeat,
259 timestamp: chrono::Utc::now(),
260 data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
261 id: Uuid::new_v4().to_string(),
262 }),
263 };
264
265 tx.send(event).await.map_err(|e| WebullError::InvalidRequest(format!("Failed to send message: {}", e)))?;
266 }
267
268 Ok(())
269 }
270
271 pub async fn unsubscribe(&self, request: UnsubscriptionRequest) -> WebullResult<()> {
273 if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
275 return Err(WebullError::InvalidRequest("Not connected to WebSocket server".to_string()));
276 }
277
278 let message = json!({
280 "action": "UNSUBSCRIBE",
281 "request": request,
282 });
283
284 if let Some(tx) = &self.event_sender {
286 let _message_str = to_json(&message)?;
287
288 let event = Event {
290 event_type: EventType::Heartbeat,
291 timestamp: chrono::Utc::now(),
292 data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
293 id: Uuid::new_v4().to_string(),
294 }),
295 };
296
297 tx.send(event).await.map_err(|e| WebullError::InvalidRequest(format!("Failed to send message: {}", e)))?;
298 }
299
300 Ok(())
301 }
302
303 async fn connect_websocket(base_url: &str, token: &AccessToken) -> WebullResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
305 let ws_url = format!("{}/ws", base_url.replace("http", "ws"));
307 let url = Url::parse(&ws_url).map_err(|e| WebullError::InvalidRequest(format!("Invalid WebSocket URL: {}", e)))?;
308
309 let mut headers = HeaderMap::new();
311 headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", token.token)).unwrap());
312
313 let (ws_stream, _) = connect_async(url).await.map_err(|e| WebullError::InvalidRequest(format!("WebSocket connection error: {}", e)))?;
315
316 Ok(ws_stream)
317 }
318
319 async fn handle_websocket(
321 mut ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
322 tx: Sender<Event>,
323 last_heartbeat: Arc<Mutex<Instant>>,
324 heartbeat_interval: u64,
325 ) -> WebullResult<()> {
326 let tx_clone = tx.clone();
328 let last_heartbeat_clone = last_heartbeat.clone();
329
330 tokio::spawn(async move {
331 loop {
332 sleep(Duration::from_secs(heartbeat_interval)).await;
334
335 let now = Instant::now();
337 let last = *last_heartbeat_clone.lock().unwrap();
338
339 if now.duration_since(last).as_secs() >= heartbeat_interval {
340 let heartbeat = json!({
342 "type": "HEARTBEAT",
343 "id": Uuid::new_v4().to_string(),
344 });
345
346 let _message = Message::Text(to_json(&heartbeat).unwrap());
348
349 let event = Event {
351 event_type: EventType::Heartbeat,
352 timestamp: chrono::Utc::now(),
353 data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
354 id: Uuid::new_v4().to_string(),
355 }),
356 };
357
358 if tx_clone.send(event).await.is_err() {
360 break;
362 }
363
364 *last_heartbeat_clone.lock().unwrap() = now;
366 }
367 }
368 });
369
370 while let Some(message) = ws_stream.next().await {
372 match message {
373 Ok(Message::Text(text)) => {
374 match from_json::<Event>(&text) {
376 Ok(event) => {
377 if tx.send(event).await.is_err() {
379 break;
381 }
382 }
383 Err(e) => {
384 let event = Event {
386 event_type: EventType::Error,
387 timestamp: chrono::Utc::now(),
388 data: crate::streaming::events::EventData::Error(ErrorEvent {
389 code: "PARSE_ERROR".to_string(),
390 message: format!("Failed to parse message: {}", e),
391 }),
392 };
393
394 if tx.send(event).await.is_err() {
395 break;
397 }
398 }
399 }
400 }
401 Ok(Message::Binary(_)) => {
402 }
404 Ok(Message::Ping(data)) => {
405 if let Err(e) = ws_stream.send(Message::Pong(data)).await {
407 let event = Event {
409 event_type: EventType::Error,
410 timestamp: chrono::Utc::now(),
411 data: crate::streaming::events::EventData::Error(ErrorEvent {
412 code: "PONG_ERROR".to_string(),
413 message: format!("Failed to send pong: {}", e),
414 }),
415 };
416
417 if tx.send(event).await.is_err() {
418 break;
420 }
421 }
422
423 *last_heartbeat.lock().unwrap() = Instant::now();
425 }
426 Ok(Message::Pong(_)) => {
427 *last_heartbeat.lock().unwrap() = Instant::now();
429 }
430 Ok(Message::Close(_)) => {
431 break;
433 },
434 Ok(Message::Frame(_)) => {
435 }
437 Err(e) => {
438 let event = Event {
440 event_type: EventType::Error,
441 timestamp: chrono::Utc::now(),
442 data: crate::streaming::events::EventData::Error(ErrorEvent {
443 code: "WS_ERROR".to_string(),
444 message: format!("WebSocket error: {}", e),
445 }),
446 };
447
448 if tx.send(event).await.is_err() {
449 break;
451 }
452
453 break;
455 }
456 }
457 }
458
459 Ok(())
460 }
461}