1use futures_util::{SinkExt, StreamExt};
75use serde_json::Value;
76use std::collections::HashMap;
77use std::hash::Hash;
78use std::sync::Arc;
79use std::time::Duration;
80use tokio::sync::Mutex;
81use tokio::time::sleep;
82use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
83use tokio_util::sync::CancellationToken;
84use url::Url;
85
86use crate::utils::error::Error;
87
88#[derive(Debug, thiserror::Error)]
90pub enum WebSocketError {
91 #[error("连接失败: {0}")]
92 ConnectionFailed(String),
93 #[error("其他错误: {0}")]
94 Other(String),
95}
96
97#[derive(Debug, Clone)]
99pub enum WsBaseEvent {
100 Open,
101 Close(Option<String>),
102 Error(String),
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, Hash)]
107pub enum WsEventType {
108 Open,
109 Close,
110 Error,
111 All,
112}
113
114pub type EventListener = Arc<dyn Fn(WsBaseEvent) + Send + Sync + 'static>;
116pub type TypedListener<D> = Arc<dyn Fn(D) + Send + Sync + 'static>;
117pub type WsLogHook = Arc<dyn Fn(&str) + Send + Sync + 'static>;
118
119#[derive(Clone, Debug)]
121pub struct RetryPolicy {
122 pub max_attempts: u32,
124 pub initial_delay: Duration,
126 pub max_delay: Duration,
128 pub backoff_factor: f64,
130}
131
132impl Default for RetryPolicy {
133 fn default() -> Self {
134 Self {
135 max_attempts: 3,
136 initial_delay: Duration::from_millis(400),
137 max_delay: Duration::from_secs(8),
138 backoff_factor: 2.0,
139 }
140 }
141}
142
143#[derive(Clone)]
145pub struct EventBus<E, D>
146where
147 E: Eq + Hash + Clone,
148{
149 listeners: Arc<Mutex<HashMap<E, Vec<TypedListener<D>>>>>,
150}
151
152impl<E, D> Default for EventBus<E, D>
153where
154 E: Eq + Hash + Clone,
155{
156 fn default() -> Self {
157 Self {
158 listeners: Arc::new(Mutex::new(HashMap::new())),
159 }
160 }
161}
162
163impl<E, D> EventBus<E, D>
164where
165 E: Eq + Hash + Clone + Send + Sync + 'static,
166 D: Clone + Send + 'static,
167{
168 pub fn new() -> Self {
169 Self::default()
170 }
171
172 pub async fn add_listener<F>(&self, event: E, listener: F)
173 where
174 F: Fn(D) + Send + Sync + 'static,
175 {
176 let mut listeners = self.listeners.lock().await;
177 listeners
178 .entry(event)
179 .or_insert_with(Vec::new)
180 .push(Arc::new(listener));
181 }
182
183 pub async fn remove_listener(&self, event: Option<E>) {
184 let mut listeners = self.listeners.lock().await;
185 match event {
186 Some(e) => {
187 listeners.remove(&e);
188 }
189 None => {
190 listeners.clear();
191 }
192 }
193 }
194
195 pub async fn emit(&self, event: &E, data: D, all_event: Option<&E>) {
196 let event_listeners: Vec<TypedListener<D>> = {
197 let listeners_guard = self.listeners.lock().await;
198 listeners_guard.get(event).cloned().unwrap_or_default()
199 };
200
201 for listener in event_listeners {
202 let data = data.clone();
203 tokio::spawn(async move { listener(data) });
204 }
205
206 if let Some(all) = all_event {
207 if all == event {
208 return;
209 }
210
211 let all_listeners: Vec<TypedListener<D>> = {
212 let listeners_guard = self.listeners.lock().await;
213 listeners_guard.get(all).cloned().unwrap_or_default()
214 };
215
216 for listener in all_listeners {
217 let data = data.clone();
218 tokio::spawn(async move { listener(data) });
219 }
220 }
221 }
222}
223
224pub trait MessageHandler: Send + Sync {
226 fn handle_message(&self, msg: String);
228}
229
230pub struct WebSocketClient {
232 listeners: EventBus<WsEventType, WsBaseEvent>,
233 cancel_token: CancellationToken,
234 outbound_tx: tokio::sync::mpsc::UnboundedSender<Message>,
235 _handle: tokio::task::JoinHandle<()>,
236}
237
238pub fn build_ws_url(
240 domain: &str,
241 path: &str,
242 params: &[(&str, String)],
243) -> Result<String, WebSocketError> {
244 let mut url = Url::parse(&format!(
245 "wss://{}/{}",
246 domain,
247 path.trim_start_matches('/')
248 ))
249 .map_err(|e| WebSocketError::Other(format!("invalid ws url: {}", e)))?;
250
251 {
252 let mut query = url.query_pairs_mut();
253 for (k, v) in params {
254 query.append_pair(k, v);
255 }
256 }
257
258 Ok(url.to_string())
259}
260
261impl WebSocketClient {
262 pub async fn connect<H>(url: &str, message_handler: H) -> Result<Self, WebSocketError>
264 where
265 H: MessageHandler + 'static,
266 {
267 let listeners = EventBus::<WsEventType, WsBaseEvent>::new();
268 let cancel_token = CancellationToken::new();
269
270 let (ws_stream, _) = connect_async(url)
271 .await
272 .map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
273
274 let (mut write, mut read) = ws_stream.split();
275 let (outbound_tx, mut outbound_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
276
277 let listeners_clone = listeners.clone();
278 let cancel = cancel_token.clone();
279
280 let handle = tokio::spawn(async move {
282 tokio::select! {
283 _ = cancel.cancelled() => {}
284 _ = async {
285 listeners_clone
287 .emit(&WsEventType::Open, WsBaseEvent::Open, Some(&WsEventType::All))
288 .await;
289
290 loop {
291 tokio::select! {
292 _ = cancel.cancelled() => {
293 break;
294 }
295 outbound = outbound_rx.recv() => {
296 match outbound {
297 Some(msg) => {
298 if let Err(e) = write.send(msg).await {
299 listeners_clone
300 .emit(
301 &WsEventType::Error,
302 WsBaseEvent::Error(e.to_string()),
303 Some(&WsEventType::All),
304 )
305 .await;
306 break;
307 }
308 }
309 None => break,
310 }
311 }
312 incoming = read.next() => {
313 match incoming {
314 Some(Ok(Message::Text(text))) => {
315 message_handler.handle_message(text.to_string());
316 }
317 Some(Ok(Message::Close(frame))) => {
318 let reason = frame.map(|f| f.reason.to_string());
319 listeners_clone
320 .emit(
321 &WsEventType::Close,
322 WsBaseEvent::Close(reason),
323 Some(&WsEventType::All),
324 )
325 .await;
326 break;
327 }
328 Some(Err(e)) => {
329 listeners_clone
330 .emit(
331 &WsEventType::Error,
332 WsBaseEvent::Error(e.to_string()),
333 Some(&WsEventType::All),
334 )
335 .await;
336 break;
337 }
338 _ => {}
339 }
340 }
341 }
342 }
343 } => {}
344 }
345 });
346
347 Ok(Self {
348 listeners,
349 cancel_token,
350 outbound_tx,
351 _handle: handle,
352 })
353 }
354
355 pub async fn add_listener<F>(&self, event: WsEventType, listener: F)
357 where
358 F: Fn(WsBaseEvent) + Send + Sync + 'static,
359 {
360 self.listeners.add_listener(event, listener).await;
361 }
362
363 pub async fn on_open<F>(&self, listener: F)
365 where
366 F: Fn() + Send + Sync + 'static,
367 {
368 self.add_listener(WsEventType::Open, move |_| listener())
369 .await;
370 }
371
372 pub async fn on_close<F>(&self, listener: F)
374 where
375 F: Fn(Option<String>) + Send + Sync + 'static,
376 {
377 self.add_listener(WsEventType::Close, move |event| {
378 if let WsBaseEvent::Close(reason) = event {
379 listener(reason);
380 }
381 })
382 .await;
383 }
384
385 pub async fn on_error<F>(&self, listener: F)
387 where
388 F: Fn(String) + Send + Sync + 'static,
389 {
390 self.add_listener(WsEventType::Error, move |event| {
391 if let WsBaseEvent::Error(error) = event {
392 listener(error);
393 }
394 })
395 .await;
396 }
397
398 pub async fn remove_listener(&self, event: Option<WsEventType>) {
400 self.listeners.remove_listener(event).await;
401 }
402
403 pub fn disconnect(&self) {
405 self.cancel_token.cancel();
406 }
407
408 pub fn send_text(&self, text: &str) -> Result<(), WebSocketError> {
410 self.outbound_tx
411 .send(Message::Text(text.to_string().into()))
412 .map_err(|e| WebSocketError::Other(format!("send message failed: {}", e)))
413 }
414}
415
416impl Drop for WebSocketClient {
417 fn drop(&mut self) {
418 self.cancel_token.cancel();
419 }
420}
421
422#[derive(Default)]
424pub struct WsConnection {
425 client: Option<WebSocketClient>,
426 retry_policy: RetryPolicy,
427 log_hook: Option<WsLogHook>,
428}
429
430impl WsConnection {
431 pub fn new() -> Self {
432 Self {
433 client: None,
434 retry_policy: RetryPolicy::default(),
435 log_hook: None,
436 }
437 }
438
439 pub fn is_connected(&self) -> bool {
440 self.client.is_some()
441 }
442
443 pub fn set_retry_policy(&mut self, policy: RetryPolicy) {
444 self.retry_policy = policy;
445 }
446
447 pub fn set_log_hook<F>(&mut self, hook: F)
448 where
449 F: Fn(&str) + Send + Sync + 'static,
450 {
451 self.log_hook = Some(Arc::new(hook));
452 }
453
454 pub fn set_log_hook_arc(&mut self, hook: WsLogHook) {
455 self.log_hook = Some(hook);
456 }
457
458 fn log(&self, message: &str) {
459 if let Some(hook) = &self.log_hook {
460 hook(message);
461 }
462 }
463
464 pub async fn connect<H>(
465 &mut self,
466 reload: bool,
467 url: &str,
468 message_handler: H,
469 ) -> Result<(), WebSocketError>
470 where
471 H: MessageHandler + 'static,
472 {
473 if self.client.is_some() {
474 if !reload {
475 return Ok(());
476 }
477 self.disconnect();
478 }
479
480 let ws = WebSocketClient::connect(url, message_handler).await?;
481 self.client = Some(ws);
482 Ok(())
483 }
484
485 pub async fn reconnect<H>(&mut self, url: &str, message_handler: H) -> Result<(), WebSocketError>
486 where
487 H: MessageHandler + Clone + 'static,
488 {
489 self.disconnect();
490
491 let attempts = self.retry_policy.max_attempts.max(1);
492 let mut delay = self.retry_policy.initial_delay;
493 let mut last_err: Option<WebSocketError> = None;
494
495 for attempt in 1..=attempts {
496 match WebSocketClient::connect(url, message_handler.clone()).await {
497 Ok(ws) => {
498 self.client = Some(ws);
499 self.log(&format!(
500 "WebSocket reconnected on attempt {}/{}",
501 attempt, attempts
502 ));
503 return Ok(());
504 }
505 Err(err) => {
506 last_err = Some(err);
507 if attempt >= attempts {
508 break;
509 }
510
511 self.log(&format!(
512 "WebSocket reconnect attempt {}/{} failed, retrying in {:?}",
513 attempt, attempts, delay
514 ));
515 sleep(delay).await;
516
517 let next = (delay.as_secs_f64() * self.retry_policy.backoff_factor)
518 .max(self.retry_policy.initial_delay.as_secs_f64());
519 delay = Duration::from_secs_f64(next.min(self.retry_policy.max_delay.as_secs_f64()));
520 }
521 }
522 }
523
524 Err(last_err.unwrap_or_else(|| WebSocketError::Other("reconnect failed".to_string())))
525 }
526
527 pub fn disconnect(&mut self) {
528 if let Some(ws) = self.client.take() {
529 ws.disconnect();
530 }
531 }
532
533 pub fn send_text(&self, text: &str) -> Result<(), WebSocketError> {
534 match &self.client {
535 Some(ws) => ws.send_text(text),
536 None => Err(WebSocketError::Other(
537 "websocket is not connected".to_string(),
538 )),
539 }
540 }
541}
542
543#[derive(Clone)]
545pub struct ParsedMessageHandler<E, D>
546where
547 E: Eq + Hash + Clone + Send + Sync + 'static,
548 D: Clone + Send + 'static,
549{
550 emitter: EventBus<E, D>,
551 log_hook: Option<WsLogHook>,
552 parser: fn(&Value) -> Result<(E, D), Error>,
553 all_event: Option<E>,
554 error_context: &'static str,
555}
556
557impl<E, D> ParsedMessageHandler<E, D>
558where
559 E: Eq + Hash + Clone + Send + Sync + 'static,
560 D: Clone + Send + 'static,
561{
562 pub fn new(
563 parser: fn(&Value) -> Result<(E, D), Error>,
564 all_event: Option<E>,
565 error_context: &'static str,
566 ) -> Self {
567 Self {
568 emitter: EventBus::new(),
569 log_hook: None,
570 parser,
571 all_event,
572 error_context,
573 }
574 }
575
576 pub fn get_emitter(&self) -> EventBus<E, D> {
577 self.emitter.clone()
578 }
579
580 pub fn set_log_hook_arc(&mut self, hook: WsLogHook) {
581 self.log_hook = Some(hook);
582 }
583}
584
585impl<E, D> MessageHandler for ParsedMessageHandler<E, D>
586where
587 E: Eq + Hash + Clone + Send + Sync + 'static,
588 D: Clone + Send + 'static,
589{
590 fn handle_message(&self, text: String) {
591 if let Ok(json) = serde_json::from_str::<Value>(&text) {
592 let emitter = self.get_emitter();
593 let log_hook = self.log_hook.clone();
594 let parser = self.parser;
595 let all_event = self.all_event.clone();
596 let context = self.error_context;
597
598 tokio::spawn(async move {
599 match parser(&json) {
600 Ok((event_type, event)) => {
601 emitter.emit(&event_type, event, all_event.as_ref()).await;
602 }
603 Err(e) => {
604 if let Some(hook) = log_hook {
605 hook(&format!("Failed to parse {} message: {}", context, e));
606 }
607 }
608 }
609 });
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::{EventBus, RetryPolicy, WsEventType, build_ws_url};
617 use tokio::sync::mpsc;
618 use tokio::time::{Duration, timeout};
619
620 #[test]
621 fn retry_policy_defaults_are_reasonable() {
622 let p = RetryPolicy::default();
623 assert_eq!(p.max_attempts, 3);
624 assert_eq!(p.initial_delay, Duration::from_millis(400));
625 assert_eq!(p.max_delay, Duration::from_secs(8));
626 assert!((p.backoff_factor - 2.0).abs() < f64::EPSILON);
627 }
628
629 #[tokio::test]
630 async fn event_bus_emits_target_and_all() {
631 let bus = EventBus::<WsEventType, String>::new();
632 let (tx, mut rx) = mpsc::unbounded_channel::<String>();
633
634 let tx1 = tx.clone();
635 bus.add_listener(WsEventType::Open, move |msg| {
636 let _ = tx1.send(format!("open:{msg}"));
637 })
638 .await;
639
640 let tx2 = tx.clone();
641 bus.add_listener(WsEventType::All, move |msg| {
642 let _ = tx2.send(format!("all:{msg}"));
643 })
644 .await;
645
646 bus.emit(&WsEventType::Open, "hello".to_string(), Some(&WsEventType::All))
647 .await;
648
649 let first = timeout(Duration::from_secs(1), rx.recv())
650 .await
651 .expect("first recv timeout")
652 .expect("first message missing");
653 let second = timeout(Duration::from_secs(1), rx.recv())
654 .await
655 .expect("second recv timeout")
656 .expect("second message missing");
657
658 let got = [first, second];
659 assert!(got.iter().any(|s| s == "open:hello"));
660 assert!(got.iter().any(|s| s == "all:hello"));
661 }
662
663 #[test]
664 fn build_ws_url_encodes_query_params() {
665 let url = build_ws_url(
666 "fishpi.cn",
667 "chat-channel",
668 &[
669 ("apiKey", "token a+b".to_string()),
670 ("toUser", "alice/bob".to_string()),
671 ],
672 )
673 .expect("url build should succeed");
674
675 assert!(url.starts_with("wss://fishpi.cn/chat-channel?"));
676 assert!(url.contains("apiKey=token+a%2Bb"));
677 assert!(url.contains("toUser=alice%2Fbob"));
678 }
679}