1use crate::datamanager::DataManager;
10use crate::errors::{Result, TqError};
11use futures::{SinkExt, StreamExt};
12use reqwest::header::HeaderMap;
13use serde::Serialize;
14use serde_json::Value;
15use std::sync::{Arc, RwLock};
16use std::time::Duration;
17use tokio::sync::Mutex;
18use tokio::time::sleep;
19use tracing::{debug, error, info, trace, warn};
20use yawc::frame::{FrameView, OpCode};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum WebSocketStatus {
25 Connecting,
27 Open,
29 Closing,
31 Closed,
33}
34
35#[derive(Debug, Clone)]
37pub struct WebSocketConfig {
38 pub headers: HeaderMap,
40 pub reconnect_interval: Duration,
42 pub reconnect_max_times: usize,
44}
45
46impl Default for WebSocketConfig {
47 fn default() -> Self {
48 WebSocketConfig {
49 headers: HeaderMap::new(),
50 reconnect_interval: Duration::from_secs(3),
51 reconnect_max_times: 2,
52 }
53 }
54}
55
56pub struct TqWebsocket {
60 url: String,
61 config: WebSocketConfig,
62 status: Arc<RwLock<WebSocketStatus>>,
63 queue: Arc<Mutex<Vec<String>>>,
64 reconnect_times: Arc<RwLock<usize>>,
65 should_reconnect: Arc<RwLock<bool>>,
66
67 ws: Arc<Mutex<Option<yawc::WebSocket>>>,
69
70 on_message: Arc<RwLock<Option<Box<dyn Fn(Value) + Send + Sync>>>>,
72 on_open: Arc<RwLock<Option<Box<dyn Fn() + Send + Sync>>>>,
73 on_close: Arc<RwLock<Option<Box<dyn Fn() + Send + Sync>>>>,
74 on_error: Arc<RwLock<Option<Box<dyn Fn(String) + Send + Sync>>>>,
75}
76
77impl TqWebsocket {
78 pub fn new(url: String, config: WebSocketConfig) -> Self {
80 TqWebsocket {
81 url,
82 config,
83 status: Arc::new(RwLock::new(WebSocketStatus::Closed)),
84 queue: Arc::new(Mutex::new(Vec::new())),
85 reconnect_times: Arc::new(RwLock::new(0)),
86 should_reconnect: Arc::new(RwLock::new(true)),
87 ws: Arc::new(Mutex::new(None)),
88 on_message: Arc::new(RwLock::new(None)),
89 on_open: Arc::new(RwLock::new(None)),
90 on_close: Arc::new(RwLock::new(None)),
91 on_error: Arc::new(RwLock::new(None)),
92 }
93 }
94
95 pub async fn init(&self, is_reconnection: bool) -> Result<()> {
97 info!(
98 "正在连接 WebSocket: {} (重连: {})",
99 self.url, is_reconnection
100 );
101 *self.status.write().unwrap() = WebSocketStatus::Connecting;
102
103 let options = yawc::Options::default()
105 .client_no_context_takeover()
106 .server_no_context_takeover();
107
108 let parsed_url = url::Url::parse(&self.url)
109 .map_err(|e| TqError::WebSocketError(format!("Invalid URL: {}", e)))?;
110
111 let mut http_builder = yawc::HttpRequestBuilder::new();
113 for (key, value) in self.config.headers.iter() {
114 if let Ok(value_str) = value.to_str() {
115 http_builder = http_builder.header(key.as_str(), value_str);
116 }
117 }
118
119 let ws = match yawc::WebSocket::connect(parsed_url)
121 .with_options(options)
122 .with_request(http_builder)
123 .await
124 {
125 Ok(ws) => ws,
126 Err(e) => {
127 error!(url = %self.url, error = %e, "WebSocket 连接失败");
128
129 if let Some(callback) = self.on_error.read().unwrap().as_ref() {
131 callback(format!("Connection failed: {}", e));
132 }
133
134 if *self.should_reconnect.read().unwrap() {
136 self.handle_reconnect().await;
137 }
138
139 return Err(TqError::WebSocketError(format!("Connection failed: {}", e)));
140 }
141 };
142
143 {
145 let mut ws_guard = self.ws.lock().await;
146 *ws_guard = Some(ws);
147 }
148
149 *self.status.write().unwrap() = WebSocketStatus::Open;
150
151 if !is_reconnection {
153 *self.reconnect_times.write().unwrap() = 0;
154 }
155
156 if let Some(callback) = self.on_open.read().unwrap().as_ref() {
158 callback();
159 }
160
161 self.start_receive_loop().await;
163 self.flush_queue().await;
165
166 info!("WebSocket 连接成功");
167 Ok(())
168 }
169
170 pub async fn send<T: Serialize>(&self, obj: &T) -> Result<()> {
172 let json_str = serde_json::to_string(obj)?;
173
174 if self.is_ready() {
175 debug!("WebSocket 发送消息: {}", json_str);
176
177 let mut ws_guard = self.ws.lock().await;
178 if let Some(ws) = ws_guard.as_mut() {
179 let frame = FrameView::text(json_str.into_bytes());
181 match ws.send(frame).await {
182 Ok(_) => {
183 Ok(())
184 }
185 Err(e) => {
186 error!("消息发送失败: {}", e);
187 Err(TqError::WebSocketError(format!("Send failed: {}", e)))
188 }
189 }
190 } else {
191 debug!("WebSocket 连接不存在,消息加入队列");
192 drop(ws_guard);
193 self.queue.lock().await.push(json_str);
194 Ok(())
195 }
196 } else {
197 debug!("WebSocket 未就绪,消息加入队列: {}", json_str);
198 self.queue.lock().await.push(json_str);
199 Ok(())
200 }
201 }
202
203 pub fn is_ready(&self) -> bool {
205 *self.status.read().unwrap() == WebSocketStatus::Open
206 }
207
208 pub async fn close(&self) -> Result<()> {
210 info!("正在关闭 WebSocket 连接");
211 *self.should_reconnect.write().unwrap() = false;
212
213 *self.status.write().unwrap() = WebSocketStatus::Closing;
215
216 let mut ws_guard = self.ws.lock().await;
218 if let Some(ws) = ws_guard.take() {
219 drop(ws);
221 info!("WebSocket 连接已关闭");
222 }
223 drop(ws_guard);
224
225 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
227
228 *self.status.write().unwrap() = WebSocketStatus::Closed;
229
230 if let Some(callback) = self.on_close.read().unwrap().as_ref() {
232 callback();
233 }
234
235 Ok(())
236 }
237
238 async fn flush_queue(&self) {
240 let mut queue = self.queue.lock().await;
241 if queue.is_empty() {
242 return;
243 }
244
245 debug!("发送队列中的 {} 条消息", queue.len());
246
247 let mut ws_guard = self.ws.lock().await;
248 if let Some(ws) = ws_guard.as_mut() {
249 for msg in queue.drain(..) {
250 debug!("发送队列消息: {}", msg);
251 let frame = FrameView::text(msg.into_bytes());
252 match ws.send(frame).await {
253 Ok(_) => {
254 debug!("队列消息发送成功");
255 }
256 Err(e) => {
257 error!("队列消息发送失败: {}", e);
258 }
259 }
260 }
261 }
262 }
263
264 pub fn on_message<F>(&self, callback: F)
266 where
267 F: Fn(Value) + Send + Sync + 'static,
268 {
269 *self.on_message.write().unwrap() = Some(Box::new(callback));
270 }
271
272 pub fn on_open<F>(&self, callback: F)
274 where
275 F: Fn() + Send + Sync + 'static,
276 {
277 *self.on_open.write().unwrap() = Some(Box::new(callback));
278 }
279
280 pub fn on_close<F>(&self, callback: F)
282 where
283 F: Fn() + Send + Sync + 'static,
284 {
285 *self.on_close.write().unwrap() = Some(Box::new(callback));
286 }
287
288 pub fn on_error<F>(&self, callback: F)
290 where
291 F: Fn(String) + Send + Sync + 'static,
292 {
293 *self.on_error.write().unwrap() = Some(Box::new(callback));
294 }
295
296 async fn start_receive_loop(&self) {
298 let status = Arc::clone(&self.status);
299 let ws = Arc::clone(&self.ws);
300 let _on_message = Arc::clone(&self.on_message);
301 let _on_error = Arc::clone(&self.on_error);
302 let _on_close = Arc::clone(&self.on_close);
303 let _should_reconnect = Arc::clone(&self.should_reconnect);
304 let _url = self.url.clone();
305
306 tokio::spawn(async move {
307 debug!("启动 WebSocket 消息接收循环");
308
309 loop {
310 let current_status = *status.read().unwrap();
312 if current_status != WebSocketStatus::Open {
313 debug!("WebSocket 状态不是 Open,退出接收循环");
314 break;
315 }
316
317 let mut ws_guard = ws.lock().await;
319 if let Some(ws_instance) = ws_guard.as_mut() {
320 let timeout_duration = tokio::time::Duration::from_secs(1);
323 let next_result =
324 tokio::time::timeout(timeout_duration, ws_instance.next()).await;
325
326 match next_result {
327 Ok(Some(frame)) => {
328 match frame.opcode {
330 OpCode::Text | OpCode::Binary => {
331 match String::from_utf8(frame.payload.to_vec()) {
333 Ok(text) => {
334 debug!("WebSocket Recv Text: {}", text);
335
336 match serde_json::from_str::<Value>(&text) {
338 Ok(json_value) => {
339 if let Some(callback) =
341 _on_message.read().unwrap().as_ref()
342 {
343 callback(json_value);
344 }
345 }
346 Err(e) => {
347 warn!("解析 JSON 失败: {}", e);
348 }
349 }
350 }
351 Err(e) => {
352 warn!("消息不是有效的 UTF-8: {}", e);
353 }
354 }
355
356 let frame =
357 FrameView::text(r#"{"aid": "peek_message"}"#.as_bytes());
358 match ws_instance.send(frame).await {
359 Ok(_) => {
360 debug!("Websocket Send -> peek_message");
361 }
362 Err(e) => {
363 error!("Websocket Send `peek_message` failed: {}", e);
364 break;
365 }
366 }
367 }
368 OpCode::Close => {
369 info!("WebSocket 收到关闭帧");
370 *status.write().unwrap() = WebSocketStatus::Closed;
371
372 if let Some(callback) = _on_close.read().unwrap().as_ref() {
374 callback();
375 }
376
377 drop(ws_guard);
378 break;
379 }
380 OpCode::Ping => {
381 debug!("WebSocket 收到 Ping(自动处理)");
382 }
383 OpCode::Pong => {
384 debug!("WebSocket 收到 Pong");
385 }
386 OpCode::Continuation => {
387 debug!("WebSocket 收到 Continuation 帧");
388 }
389 }
390 }
391 Ok(None) => {
392 info!("WebSocket Stream 结束,连接已关闭");
394 *status.write().unwrap() = WebSocketStatus::Closed;
395
396 if let Some(callback) = _on_close.read().unwrap().as_ref() {
397 callback();
398 }
399
400 drop(ws_guard);
401 break;
402 }
403 Err(_) => {
404 trace!("WebSocket 接收超时,继续等待");
406 }
407 }
408 } else {
409 trace!("WebSocket 实例不存在,退出接收循环");
410 break;
411 }
412
413 drop(ws_guard);
414 }
415
416 info!("WebSocket 消息接收循环结束");
417 });
418 }
419
420 async fn handle_reconnect(&self) {
422 let (should_reconnect, times) = {
424 let mut reconnect_times = self.reconnect_times.write().unwrap();
425
426 if *reconnect_times >= self.config.reconnect_max_times {
427 error!(
428 "已达到最大重连次数 {},停止重连",
429 self.config.reconnect_max_times
430 );
431 return;
432 }
433
434 *reconnect_times += 1;
435 (true, *reconnect_times)
436 }; if should_reconnect {
439 info!(
440 "第 {} 次尝试重连(最多 {} 次)",
441 times, self.config.reconnect_max_times
442 );
443
444 sleep(self.config.reconnect_interval).await;
446
447 info!("重连等待完成,请外部调用 init(true) 进行重连");
450 }
451 }
452}
453
454pub struct TqQuoteWebsocket {
456 base: Arc<TqWebsocket>,
457 _dm: Arc<DataManager>,
458 subscribe_quote: Arc<RwLock<Option<Value>>>,
459 charts: Arc<RwLock<std::collections::HashMap<String, Value>>>,
460}
461
462impl TqQuoteWebsocket {
463 pub fn new(url: String, dm: Arc<DataManager>, config: WebSocketConfig) -> Self {
465 let base = Arc::new(TqWebsocket::new(url, config));
466 let dm_clone = Arc::clone(&dm);
467 let subscribe_quote: Arc<RwLock<Option<Value>>> = Arc::new(RwLock::new(None));
468 let charts: Arc<RwLock<std::collections::HashMap<String, Value>>> =
469 Arc::new(RwLock::new(std::collections::HashMap::new()));
470
471 base.on_message({
473 let dm = Arc::clone(&dm_clone);
474 move |data: Value| {
475 if let Some(aid) = data.get("aid").and_then(|v| v.as_str()) {
476 if aid == "rtn_data" {
477 if let Some(payload) = data.get("data") {
478 dm.merge_data(payload.clone(), true, true);
479 }
480 }
481 }
482 }
483 });
484
485 {
487 let subscribe_quote_clone = Arc::clone(&subscribe_quote);
488 let charts_clone = Arc::clone(&charts);
489 let base_clone = Arc::clone(&base);
490
491 base.on_open(move || {
492 if let Some(sub) = subscribe_quote_clone.read().unwrap().as_ref() {
494 debug!("重连后重发订阅请求");
495 let base_for_send = Arc::clone(&base_clone);
496 let sub_clone = sub.clone();
497 tokio::spawn(async move {
498 let _ = base_for_send.send(&sub_clone).await;
499 });
500 }
501
502 let charts_guard = charts_clone.read().unwrap();
504 for (chart_id, chart) in charts_guard.iter() {
505 if let Some(view_width) = chart.get("view_width").and_then(|v| v.as_f64()) {
506 if view_width > 0.0 {
507 debug!("重连后重发图表请求: {}", chart_id);
508 let base_for_send = Arc::clone(&base_clone);
509 let chart_clone = chart.clone();
510 tokio::spawn(async move {
511 let _ = base_for_send.send(&chart_clone).await;
512 });
513 }
514 }
515 }
516 });
517 }
518
519 TqQuoteWebsocket {
520 base,
521 _dm: dm_clone,
522 subscribe_quote,
523 charts,
524 }
525 }
526
527 pub async fn init(&self, is_reconnection: bool) -> Result<()> {
529 self.base.init(is_reconnection).await
530 }
531
532 pub async fn send<T: Serialize>(&self, obj: &T) -> Result<()> {
534 let json_str = serde_json::to_string(obj)?;
536 let value: Value = serde_json::from_str(&json_str)?;
537
538 if let Some(aid) = value.get("aid").and_then(|v| v.as_str()) {
539 match aid {
540 "subscribe_quote" => {
541 let mut should_send = false;
543 let mut subscribe_guard = self.subscribe_quote.write().unwrap();
544
545 if let Some(old_sub) = subscribe_guard.as_ref() {
546 let old_list = old_sub.get("ins_list");
548 let new_list = value.get("ins_list");
549
550 if old_list != new_list {
551 debug!("订阅列表变化,更新订阅");
552 *subscribe_guard = Some(value.clone());
553 should_send = true;
554 } else {
555 debug!("订阅列表未变化,跳过");
556 }
557 } else {
558 debug!("首次订阅");
559 *subscribe_guard = Some(value.clone());
560 should_send = true;
561 }
562
563 drop(subscribe_guard);
564
565 if should_send {
566 return self.base.send(&value).await;
567 } else {
568 return Ok(());
569 }
570 }
571 "set_chart" => {
572 if let Some(chart_id) = value.get("chart_id").and_then(|v| v.as_str()) {
574 let mut charts_guard = self.charts.write().unwrap();
575
576 if let Some(view_width) = value.get("view_width").and_then(|v| v.as_f64()) {
577 if view_width == 0.0 {
578 trace!("删除图表: {}", chart_id);
579 charts_guard.remove(chart_id);
580 } else {
581 trace!("保存图表请求: {}", chart_id);
582 charts_guard.insert(chart_id.to_string(), value.clone());
583 }
584 }
585
586 drop(charts_guard);
587 return self.base.send(&value).await;
588 }
589 }
590 _ => {}
591 }
592 }
593
594 self.base.send(obj).await
596 }
597
598 pub fn is_ready(&self) -> bool {
600 self.base.is_ready()
601 }
602
603 pub async fn close(&self) -> Result<()> {
605 self.base.close().await
606 }
607}
608
609pub struct TqTradeWebsocket {
611 base: Arc<TqWebsocket>,
612 _dm: Arc<DataManager>,
613 req_login: Arc<RwLock<Option<Value>>>,
614 on_notify: Arc<RwLock<Option<Box<dyn Fn(crate::types::Notification) + Send + Sync>>>>,
615}
616
617impl TqTradeWebsocket {
618 pub fn new(url: String, dm: Arc<DataManager>, config: WebSocketConfig) -> Self {
620 let base = Arc::new(TqWebsocket::new(url, config));
621 let dm_clone = Arc::clone(&dm);
622 let req_login: Arc<RwLock<Option<Value>>> = Arc::new(RwLock::new(None));
623 let on_notify: Arc<RwLock<Option<Box<dyn Fn(crate::types::Notification) + Send + Sync>>>> =
624 Arc::new(RwLock::new(None));
625
626 {
628 let dm = Arc::clone(&dm_clone);
629 let on_notify_clone = Arc::clone(&on_notify);
630
631 base.on_message(move |data: Value| {
632 if let Some(aid) = data.get("aid").and_then(|v| v.as_str()) {
633 match aid {
634 "rtn_data" => {
635 if let Some(payload) = data.get("data") {
636 if let Some(array) = payload.as_array() {
638 let (notifies, cleaned_data) =
639 Self::separate_notifies(array.clone());
640 debug!("notifies: {:?}", notifies);
641
642 if let Some(callback) = on_notify_clone.read().unwrap().as_ref()
644 {
645 for notify in notifies {
646 callback(notify);
647 }
648 }
649
650 dm.merge_data(Value::Array(cleaned_data), true, true);
652 } else {
653 dm.merge_data(payload.clone(), true, true);
654 }
655 }
656 }
657 "rtn_brokers" => {
658 debug!("收到期货公司列表");
660 }
661 "qry_settlement_info" => {
662 if let (Some(settlement_info), Some(user_name), Some(trading_day)) = (
664 data.get("settlement_info").and_then(|v| v.as_str()),
665 data.get("user_name").and_then(|v| v.as_str()),
666 data.get("trading_day").and_then(|v| v.as_str()),
667 ) {
668 debug!(
669 "收到结算单: user={}, trading_day={}",
670 user_name, trading_day
671 );
672
673 let settlement = Self::parse_settlement_content(settlement_info);
675
676 let settlement_data = serde_json::json!({
678 "trade": {
679 user_name: {
680 "his_settlements": {
681 trading_day: settlement
682 }
683 }
684 }
685 });
686
687 dm.merge_data(settlement_data, true, true);
688 }
689 }
690 _ => {}
691 }
692 }
693 });
694 }
695
696 {
698 let req_login_clone = Arc::clone(&req_login);
699 let base_clone = Arc::clone(&base);
700
701 base.on_open(move || {
702 if let Some(login) = req_login_clone.read().unwrap().as_ref() {
703 debug!("重连后重发登录请求");
704 let base_for_send = Arc::clone(&base_clone);
705 let login_clone = login.clone();
706 tokio::spawn(async move {
707 let _ = base_for_send.send(&login_clone).await;
708 });
709 }
710 });
711 }
712
713 TqTradeWebsocket {
714 base,
715 _dm: dm_clone,
716 req_login,
717 on_notify,
718 }
719 }
720
721 fn separate_notifies(data: Vec<Value>) -> (Vec<crate::types::Notification>, Vec<Value>) {
725 let mut notifies = Vec::new();
726 let mut cleaned_data = Vec::new();
727
728 for mut item in data {
729 if let Some(obj) = item.as_object_mut() {
730 if let Some(notify_data) = obj.remove("notify") {
732 if let Some(notify_map) = notify_data.as_object() {
733 for (_key, notify_value) in notify_map {
734 if let Some(n) = notify_value.as_object() {
735 let notification = crate::types::Notification {
736 code: n
737 .get("code")
738 .and_then(|v| v.as_str())
739 .unwrap_or("")
740 .to_string(),
741 level: n
742 .get("level")
743 .and_then(|v| v.as_str())
744 .unwrap_or("")
745 .to_string(),
746 r#type: n
747 .get("type")
748 .and_then(|v| v.as_str())
749 .unwrap_or("")
750 .to_string(),
751 content: n
752 .get("content")
753 .and_then(|v| v.as_str())
754 .unwrap_or("")
755 .to_string(),
756 bid: n
757 .get("bid")
758 .and_then(|v| v.as_str())
759 .unwrap_or("")
760 .to_string(),
761 user_id: n
762 .get("user_id")
763 .and_then(|v| v.as_str())
764 .unwrap_or("")
765 .to_string(),
766 };
767 notifies.push(notification);
768 }
769 }
770 }
771 }
772 }
773
774 cleaned_data.push(item);
776 }
777
778 (notifies, cleaned_data)
779 }
780
781 fn parse_settlement_content(content: &str) -> serde_json::Value {
785 serde_json::json!({
786 "content": content,
787 "parsed": false })
789 }
790
791 pub fn on_notify<F>(&self, callback: F)
793 where
794 F: Fn(crate::types::Notification) + Send + Sync + 'static,
795 {
796 *self.on_notify.write().unwrap() = Some(Box::new(callback));
797 }
798
799 pub async fn init(&self, is_reconnection: bool) -> Result<()> {
801 self.base.init(is_reconnection).await
802 }
803
804 pub async fn send<T: Serialize>(&self, obj: &T) -> Result<()> {
806 let json_str = serde_json::to_string(obj)?;
808 let value: Value = serde_json::from_str(&json_str)?;
809
810 if let Some(aid) = value.get("aid").and_then(|v| v.as_str()) {
812 if aid == "req_login" {
813 debug!("记录登录请求 {:?}", value);
814 *self.req_login.write().unwrap() = Some(value.clone());
815 }
816 }
817
818 self.base.send(&value).await
820 }
821
822 pub fn is_ready(&self) -> bool {
824 self.base.is_ready()
825 }
826
827 pub async fn close(&self) -> Result<()> {
829 self.base.close().await
830 }
831}
832
833