Skip to main content

fishpi_sdk/api/
ws.rs

1//! WebSocket API 模块
2//!
3//! 这个模块提供了基础的 WebSocket 客户端功能,包括连接、监听事件、处理消息等。
4//! 主要结构体是 `WebSocketClient`,用于管理 WebSocket 连接和事件监听。
5//! 事件通过 `WsBaseEvent` 枚举表示,支持连接、断开和错误事件。
6//!
7//! # 主要组件
8//!
9//! - [`WebSocketClient`] - WebSocket 客户端结构体,负责连接和管理监听器。
10//! - [`MessageHandler`] - WebSocket 消息处理器 trait,用于处理接收到的文本消息。
11//! - [`WsBaseEvent`] - WebSocket 基础事件枚举,包装连接、断开和错误事件。
12//! - [`EventListener`] - 事件监听器类型别名,定义监听器函数的签名。
13//! - [`WebSocketError`] - WebSocket 错误类型,用于连接和操作错误。
14//!
15//! # 方法列表
16//!
17//! - [`WebSocketClient::connect`] - 创建并连接 WebSocket。
18//! - [`WebSocketClient::add_listener`] - 添加事件监听器。
19//! - [`WebSocketClient::on_open`] - 监听连接成功事件。
20//! - [`WebSocketClient::on_close`] - 监听连接断开事件。
21//! - [`WebSocketClient::on_error`] - 监听连接错误事件。
22//! - [`WebSocketClient::remove_listener`] - 移除事件监听器。
23//! - [`WebSocketClient::disconnect`] - 断开连接。
24//!
25//! # 示例
26//!
27//! ```rust,no_run
28//! use fishpi_sdk::api::ws::{MessageHandler, WebSocketClient};
29//!
30//! struct MyHandler;
31//!
32//! impl MessageHandler for MyHandler {
33//!     fn handle_message(&self, msg: String) {
34//!         println!("Received: {}", msg);
35//!     }
36//! }
37//!
38//! #[tokio::main]
39//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
40//!     let handler = MyHandler;
41//!
42//!     // 连接 WebSocket
43//!     let ws = WebSocketClient::connect("ws://example.com", handler).await?;
44//!
45//!     // 添加事件监听器
46//!     ws.on_open(|| {
47//!         println!("Connected!");
48//!     }).await;
49//!
50//!     ws.on_close(|reason| {
51//!         println!("Disconnected: {:?}", reason);
52//!     }).await;
53//!
54//!     ws.on_error(|error| {
55//!         println!("Error: {}", error);
56//!     }).await;
57//!
58//!     // 断开连接
59//!     ws.disconnect();
60//!
61//!     Ok(())
62//! }
63//! ```
64//!
65//! # 注意事项
66//!
67//! - 连接前需要有效的 WebSocket URL。
68//! - 监听器函数必须是 `Send + Sync + 'static`,以支持异步环境。
69//! - `MessageHandler` 实现必须处理文本消息,其他消息类型被忽略。
70//! - 断开连接后,客户端会自动清理资源。
71//! - 事件监听器支持 "open"、"close"、"error" 和 "all" 事件。
72//! - 错误处理使用 `WebSocketError`,连接失败或操作错误。
73
74use futures_util::{SinkExt, StreamExt};
75use serde_json::Value;
76use std::collections::HashMap;
77use std::hash::Hash;
78use std::sync::Arc;
79use std::time::Duration;
80use tokio::sync::Mutex;
81use tokio::time::sleep;
82use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
83use tokio_util::sync::CancellationToken;
84use url::Url;
85
86use crate::utils::error::Error;
87
88/// WebSocket 错误类型
89#[derive(Debug, thiserror::Error)]
90pub enum WebSocketError {
91    #[error("连接失败: {0}")]
92    ConnectionFailed(String),
93    #[error("其他错误: {0}")]
94    Other(String),
95}
96
97/// 基础 WebSocket 事件
98#[derive(Debug, Clone)]
99pub enum WsBaseEvent {
100    Open,
101    Close(Option<String>),
102    Error(String),
103}
104
105/// WebSocket 事件类型枚举
106#[derive(Debug, Clone, PartialEq, Eq, Hash)]
107pub enum WsEventType {
108    Open,
109    Close,
110    Error,
111    All,
112}
113
114/// 事件监听器类型
115pub type EventListener = Arc<dyn Fn(WsBaseEvent) + Send + Sync + 'static>;
116pub type TypedListener<D> = Arc<dyn Fn(D) + Send + Sync + 'static>;
117pub type WsLogHook = Arc<dyn Fn(&str) + Send + Sync + 'static>;
118
119/// 自动重连策略
120#[derive(Clone, Debug)]
121pub struct RetryPolicy {
122    /// 最大尝试次数(包含首次)
123    pub max_attempts: u32,
124    /// 首次重试延迟
125    pub initial_delay: Duration,
126    /// 最大重试延迟
127    pub max_delay: Duration,
128    /// 退避倍率(每次失败后 delay *= backoff_factor)
129    pub backoff_factor: f64,
130}
131
132impl Default for RetryPolicy {
133    fn default() -> Self {
134        Self {
135            max_attempts: 3,
136            initial_delay: Duration::from_millis(400),
137            max_delay: Duration::from_secs(8),
138            backoff_factor: 2.0,
139        }
140    }
141}
142
143/// 通用事件总线
144#[derive(Clone)]
145pub struct EventBus<E, D>
146where
147    E: Eq + Hash + Clone,
148{
149    listeners: Arc<Mutex<HashMap<E, Vec<TypedListener<D>>>>>,
150}
151
152impl<E, D> Default for EventBus<E, D>
153where
154    E: Eq + Hash + Clone,
155{
156    fn default() -> Self {
157        Self {
158            listeners: Arc::new(Mutex::new(HashMap::new())),
159        }
160    }
161}
162
163impl<E, D> EventBus<E, D>
164where
165    E: Eq + Hash + Clone + Send + Sync + 'static,
166    D: Clone + Send + 'static,
167{
168    pub fn new() -> Self {
169        Self::default()
170    }
171
172    pub async fn add_listener<F>(&self, event: E, listener: F)
173    where
174        F: Fn(D) + Send + Sync + 'static,
175    {
176        let mut listeners = self.listeners.lock().await;
177        listeners
178            .entry(event)
179            .or_insert_with(Vec::new)
180            .push(Arc::new(listener));
181    }
182
183    pub async fn remove_listener(&self, event: Option<E>) {
184        let mut listeners = self.listeners.lock().await;
185        match event {
186            Some(e) => {
187                listeners.remove(&e);
188            }
189            None => {
190                listeners.clear();
191            }
192        }
193    }
194
195    pub async fn emit(&self, event: &E, data: D, all_event: Option<&E>) {
196        let event_listeners: Vec<TypedListener<D>> = {
197            let listeners_guard = self.listeners.lock().await;
198            listeners_guard.get(event).cloned().unwrap_or_default()
199        };
200
201        for listener in event_listeners {
202            let data = data.clone();
203            tokio::spawn(async move { listener(data) });
204        }
205
206        if let Some(all) = all_event {
207            if all == event {
208                return;
209            }
210
211            let all_listeners: Vec<TypedListener<D>> = {
212                let listeners_guard = self.listeners.lock().await;
213                listeners_guard.get(all).cloned().unwrap_or_default()
214            };
215
216            for listener in all_listeners {
217                let data = data.clone();
218                tokio::spawn(async move { listener(data) });
219            }
220        }
221    }
222}
223
224/// WebSocket 消息处理器 trait
225pub trait MessageHandler: Send + Sync {
226    /// 处理接收到的文本消息
227    fn handle_message(&self, msg: String);
228}
229
230/// 基础 WebSocket 客户端
231pub struct WebSocketClient {
232    listeners: EventBus<WsEventType, WsBaseEvent>,
233    cancel_token: CancellationToken,
234    outbound_tx: tokio::sync::mpsc::UnboundedSender<Message>,
235    _handle: tokio::task::JoinHandle<()>,
236}
237
238/// 构造带查询参数的 WebSocket URL,自动进行 query 编码
239pub fn build_ws_url(
240    domain: &str,
241    path: &str,
242    params: &[(&str, String)],
243) -> Result<String, WebSocketError> {
244    let mut url = Url::parse(&format!(
245        "wss://{}/{}",
246        domain,
247        path.trim_start_matches('/')
248    ))
249    .map_err(|e| WebSocketError::Other(format!("invalid ws url: {}", e)))?;
250
251    {
252        let mut query = url.query_pairs_mut();
253        for (k, v) in params {
254            query.append_pair(k, v);
255        }
256    }
257
258    Ok(url.to_string())
259}
260
261impl WebSocketClient {
262    /// 创建并连接 WebSocket
263    pub async fn connect<H>(url: &str, message_handler: H) -> Result<Self, WebSocketError>
264    where
265        H: MessageHandler + 'static,
266    {
267        let listeners = EventBus::<WsEventType, WsBaseEvent>::new();
268        let cancel_token = CancellationToken::new();
269
270        let (ws_stream, _) = connect_async(url)
271            .await
272            .map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
273
274        let (mut write, mut read) = ws_stream.split();
275        let (outbound_tx, mut outbound_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
276
277        let listeners_clone = listeners.clone();
278        let cancel = cancel_token.clone();
279
280        // 主接收任务
281        let handle = tokio::spawn(async move {
282            tokio::select! {
283                _ = cancel.cancelled() => {}
284                _ = async {
285                    // 发送 Open 事件
286                    listeners_clone
287                        .emit(&WsEventType::Open, WsBaseEvent::Open, Some(&WsEventType::All))
288                        .await;
289
290                    loop {
291                        tokio::select! {
292                            _ = cancel.cancelled() => {
293                                break;
294                            }
295                            outbound = outbound_rx.recv() => {
296                                match outbound {
297                                    Some(msg) => {
298                                        if let Err(e) = write.send(msg).await {
299                                            listeners_clone
300                                                .emit(
301                                                    &WsEventType::Error,
302                                                    WsBaseEvent::Error(e.to_string()),
303                                                    Some(&WsEventType::All),
304                                                )
305                                                .await;
306                                            break;
307                                        }
308                                    }
309                                    None => break,
310                                }
311                            }
312                            incoming = read.next() => {
313                                match incoming {
314                                    Some(Ok(Message::Text(text))) => {
315                                        message_handler.handle_message(text.to_string());
316                                    }
317                                    Some(Ok(Message::Close(frame))) => {
318                                        let reason = frame.map(|f| f.reason.to_string());
319                                        listeners_clone
320                                            .emit(
321                                                &WsEventType::Close,
322                                                WsBaseEvent::Close(reason),
323                                                Some(&WsEventType::All),
324                                            )
325                                            .await;
326                                        break;
327                                    }
328                                    Some(Err(e)) => {
329                                        listeners_clone
330                                            .emit(
331                                                &WsEventType::Error,
332                                                WsBaseEvent::Error(e.to_string()),
333                                                Some(&WsEventType::All),
334                                            )
335                                            .await;
336                                        break;
337                                    }
338                                    _ => {}
339                                }
340                            }
341                        }
342                    }
343                } => {}
344            }
345        });
346
347        Ok(Self {
348            listeners,
349            cancel_token,
350            outbound_tx,
351            _handle: handle,
352        })
353    }
354
355    /// 添加事件监听器
356    pub async fn add_listener<F>(&self, event: WsEventType, listener: F)
357    where
358        F: Fn(WsBaseEvent) + Send + Sync + 'static,
359    {
360        self.listeners.add_listener(event, listener).await;
361    }
362
363    /// 监听 open 事件
364    pub async fn on_open<F>(&self, listener: F)
365    where
366        F: Fn() + Send + Sync + 'static,
367    {
368        self.add_listener(WsEventType::Open, move |_| listener())
369            .await;
370    }
371
372    /// 监听 close 事件
373    pub async fn on_close<F>(&self, listener: F)
374    where
375        F: Fn(Option<String>) + Send + Sync + 'static,
376    {
377        self.add_listener(WsEventType::Close, move |event| {
378            if let WsBaseEvent::Close(reason) = event {
379                listener(reason);
380            }
381        })
382        .await;
383    }
384
385    /// 监听 error 事件
386    pub async fn on_error<F>(&self, listener: F)
387    where
388        F: Fn(String) + Send + Sync + 'static,
389    {
390        self.add_listener(WsEventType::Error, move |event| {
391            if let WsBaseEvent::Error(error) = event {
392                listener(error);
393            }
394        })
395        .await;
396    }
397
398    /// 移除监听器
399    pub async fn remove_listener(&self, event: Option<WsEventType>) {
400        self.listeners.remove_listener(event).await;
401    }
402
403    /// 断开连接
404    pub fn disconnect(&self) {
405        self.cancel_token.cancel();
406    }
407
408    /// 发送文本消息
409    pub fn send_text(&self, text: &str) -> Result<(), WebSocketError> {
410        self.outbound_tx
411            .send(Message::Text(text.to_string().into()))
412            .map_err(|e| WebSocketError::Other(format!("send message failed: {}", e)))
413    }
414}
415
416impl Drop for WebSocketClient {
417    fn drop(&mut self) {
418        self.cancel_token.cancel();
419    }
420}
421
422/// 通用 WebSocket 连接生命周期管理器
423#[derive(Default)]
424pub struct WsConnection {
425    client: Option<WebSocketClient>,
426    retry_policy: RetryPolicy,
427    log_hook: Option<WsLogHook>,
428}
429
430impl WsConnection {
431    pub fn new() -> Self {
432        Self {
433            client: None,
434            retry_policy: RetryPolicy::default(),
435            log_hook: None,
436        }
437    }
438
439    pub fn is_connected(&self) -> bool {
440        self.client.is_some()
441    }
442
443    pub fn set_retry_policy(&mut self, policy: RetryPolicy) {
444        self.retry_policy = policy;
445    }
446
447    pub fn set_log_hook<F>(&mut self, hook: F)
448    where
449        F: Fn(&str) + Send + Sync + 'static,
450    {
451        self.log_hook = Some(Arc::new(hook));
452    }
453
454    pub fn set_log_hook_arc(&mut self, hook: WsLogHook) {
455        self.log_hook = Some(hook);
456    }
457
458    fn log(&self, message: &str) {
459        if let Some(hook) = &self.log_hook {
460            hook(message);
461        }
462    }
463
464    pub async fn connect<H>(
465        &mut self,
466        reload: bool,
467        url: &str,
468        message_handler: H,
469    ) -> Result<(), WebSocketError>
470    where
471        H: MessageHandler + 'static,
472    {
473        if self.client.is_some() {
474            if !reload {
475                return Ok(());
476            }
477            self.disconnect();
478        }
479
480        let ws = WebSocketClient::connect(url, message_handler).await?;
481        self.client = Some(ws);
482        Ok(())
483    }
484
485    pub async fn reconnect<H>(&mut self, url: &str, message_handler: H) -> Result<(), WebSocketError>
486    where
487        H: MessageHandler + Clone + 'static,
488    {
489        self.disconnect();
490
491        let attempts = self.retry_policy.max_attempts.max(1);
492        let mut delay = self.retry_policy.initial_delay;
493        let mut last_err: Option<WebSocketError> = None;
494
495        for attempt in 1..=attempts {
496            match WebSocketClient::connect(url, message_handler.clone()).await {
497                Ok(ws) => {
498                    self.client = Some(ws);
499                    self.log(&format!(
500                        "WebSocket reconnected on attempt {}/{}",
501                        attempt, attempts
502                    ));
503                    return Ok(());
504                }
505                Err(err) => {
506                    last_err = Some(err);
507                    if attempt >= attempts {
508                        break;
509                    }
510
511                    self.log(&format!(
512                        "WebSocket reconnect attempt {}/{} failed, retrying in {:?}",
513                        attempt, attempts, delay
514                    ));
515                    sleep(delay).await;
516
517                    let next = (delay.as_secs_f64() * self.retry_policy.backoff_factor)
518                        .max(self.retry_policy.initial_delay.as_secs_f64());
519                    delay = Duration::from_secs_f64(next.min(self.retry_policy.max_delay.as_secs_f64()));
520                }
521            }
522        }
523
524        Err(last_err.unwrap_or_else(|| WebSocketError::Other("reconnect failed".to_string())))
525    }
526
527    pub fn disconnect(&mut self) {
528        if let Some(ws) = self.client.take() {
529            ws.disconnect();
530        }
531    }
532
533    pub fn send_text(&self, text: &str) -> Result<(), WebSocketError> {
534        match &self.client {
535            Some(ws) => ws.send_text(text),
536            None => Err(WebSocketError::Other(
537                "websocket is not connected".to_string(),
538            )),
539        }
540    }
541}
542
543/// 通用“解析后分发”消息处理器
544#[derive(Clone)]
545pub struct ParsedMessageHandler<E, D>
546where
547    E: Eq + Hash + Clone + Send + Sync + 'static,
548    D: Clone + Send + 'static,
549{
550    emitter: EventBus<E, D>,
551    log_hook: Option<WsLogHook>,
552    parser: fn(&Value) -> Result<(E, D), Error>,
553    all_event: Option<E>,
554    error_context: &'static str,
555}
556
557impl<E, D> ParsedMessageHandler<E, D>
558where
559    E: Eq + Hash + Clone + Send + Sync + 'static,
560    D: Clone + Send + 'static,
561{
562    pub fn new(
563        parser: fn(&Value) -> Result<(E, D), Error>,
564        all_event: Option<E>,
565        error_context: &'static str,
566    ) -> Self {
567        Self {
568            emitter: EventBus::new(),
569            log_hook: None,
570            parser,
571            all_event,
572            error_context,
573        }
574    }
575
576    pub fn get_emitter(&self) -> EventBus<E, D> {
577        self.emitter.clone()
578    }
579
580    pub fn set_log_hook_arc(&mut self, hook: WsLogHook) {
581        self.log_hook = Some(hook);
582    }
583}
584
585impl<E, D> MessageHandler for ParsedMessageHandler<E, D>
586where
587    E: Eq + Hash + Clone + Send + Sync + 'static,
588    D: Clone + Send + 'static,
589{
590    fn handle_message(&self, text: String) {
591        if let Ok(json) = serde_json::from_str::<Value>(&text) {
592            let emitter = self.get_emitter();
593            let log_hook = self.log_hook.clone();
594            let parser = self.parser;
595            let all_event = self.all_event.clone();
596            let context = self.error_context;
597
598            tokio::spawn(async move {
599                match parser(&json) {
600                    Ok((event_type, event)) => {
601                        emitter.emit(&event_type, event, all_event.as_ref()).await;
602                    }
603                    Err(e) => {
604                        if let Some(hook) = log_hook {
605                            hook(&format!("Failed to parse {} message: {}", context, e));
606                        }
607                    }
608                }
609            });
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::{EventBus, RetryPolicy, WsEventType, build_ws_url};
617    use tokio::sync::mpsc;
618    use tokio::time::{Duration, timeout};
619
620    #[test]
621    fn retry_policy_defaults_are_reasonable() {
622        let p = RetryPolicy::default();
623        assert_eq!(p.max_attempts, 3);
624        assert_eq!(p.initial_delay, Duration::from_millis(400));
625        assert_eq!(p.max_delay, Duration::from_secs(8));
626        assert!((p.backoff_factor - 2.0).abs() < f64::EPSILON);
627    }
628
629    #[tokio::test]
630    async fn event_bus_emits_target_and_all() {
631        let bus = EventBus::<WsEventType, String>::new();
632        let (tx, mut rx) = mpsc::unbounded_channel::<String>();
633
634        let tx1 = tx.clone();
635        bus.add_listener(WsEventType::Open, move |msg| {
636            let _ = tx1.send(format!("open:{msg}"));
637        })
638        .await;
639
640        let tx2 = tx.clone();
641        bus.add_listener(WsEventType::All, move |msg| {
642            let _ = tx2.send(format!("all:{msg}"));
643        })
644        .await;
645
646        bus.emit(&WsEventType::Open, "hello".to_string(), Some(&WsEventType::All))
647            .await;
648
649        let first = timeout(Duration::from_secs(1), rx.recv())
650            .await
651            .expect("first recv timeout")
652            .expect("first message missing");
653        let second = timeout(Duration::from_secs(1), rx.recv())
654            .await
655            .expect("second recv timeout")
656            .expect("second message missing");
657
658        let got = [first, second];
659        assert!(got.iter().any(|s| s == "open:hello"));
660        assert!(got.iter().any(|s| s == "all:hello"));
661    }
662
663    #[test]
664    fn build_ws_url_encodes_query_params() {
665        let url = build_ws_url(
666            "fishpi.cn",
667            "chat-channel",
668            &[
669                ("apiKey", "token a+b".to_string()),
670                ("toUser", "alice/bob".to_string()),
671            ],
672        )
673        .expect("url build should succeed");
674
675        assert!(url.starts_with("wss://fishpi.cn/chat-channel?"));
676        assert!(url.contains("apiKey=token+a%2Bb"));
677        assert!(url.contains("toUser=alice%2Fbob"));
678    }
679}