onebot_api/communication/
ws.rs

1use super::ws_utils::*;
2use anyhow::anyhow;
3use async_trait::async_trait;
4use flume::{Receiver, Sender};
5use futures::stream::{SplitSink, SplitStream};
6use futures::{SinkExt, StreamExt};
7use reqwest::IntoUrl;
8use tokio::net::TcpStream;
9use tokio::sync::broadcast;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
12use url::Url;
13
14pub struct WsService {
15	close_sender: broadcast::Sender<()>,
16	api_receiver: Option<Receiver<String>>,
17	msg_sender: Option<Sender<String>>,
18	url: Url,
19}
20
21impl Drop for WsService {
22	fn drop(&mut self) {
23		let _ = self.close_sender.send(());
24	}
25}
26
27impl WsService {
28	pub fn new(url: impl IntoUrl) -> reqwest::Result<Self> {
29		let (close_sender, _) = broadcast::channel(4);
30		Ok(Self {
31			close_sender,
32			api_receiver: None,
33			msg_sender: None,
34			url: url.into_url()?,
35		})
36	}
37
38	pub fn new_with_token(url: impl IntoUrl, token: Option<String>) -> reqwest::Result<Self> {
39		if let Some(token) = token {
40			let mut url = url.into_url()?;
41			url.set_query(Some(&format!("access_token={}", token)));
42			Self::new(url)
43		} else {
44			Self::new(url)
45		}
46	}
47}
48
49impl WsService {
50	async fn connect(&self) -> anyhow::Result<()> {
51		if self.msg_sender.is_none() || self.api_receiver.is_none() {
52			return Err(anyhow!("msg_sender or api_receiver is none"));
53		}
54		let api_receiver = self.api_receiver.clone().unwrap();
55		let msg_sender = self.msg_sender.clone().unwrap();
56		let (ws_stream, _) = connect_async(self.url.as_str()).await?;
57		let (write_half, read_half) = ws_stream.split();
58		let write_half_close_receiver = self.close_sender.subscribe();
59		let read_half_close_receiver = self.close_sender.subscribe();
60		tokio::spawn(async move {
61			Self::ws_stream_writer(write_half_close_receiver, api_receiver, write_half).await
62		});
63		tokio::spawn(async move {
64			Self::ws_stream_reader(read_half_close_receiver, msg_sender, read_half).await
65		});
66		Ok(())
67	}
68
69	async fn ws_stream_writer(
70		mut close_receiver: broadcast::Receiver<()>,
71		api_receiver: Receiver<String>,
72		mut write_half: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
73	) {
74		loop {
75			tokio::select! {
76				msg = api_receiver.recv_async() => {
77					if let Ok(msg) = msg {
78						let _ = write_half.send(Message::Text(msg.into())).await;
79					}
80				}
81				_ = close_receiver.recv() => {
82					return
83				}
84			}
85		}
86	}
87
88	async fn ws_stream_reader(
89		mut close_receiver: broadcast::Receiver<()>,
90		msg_sender: Sender<String>,
91		mut read_half: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
92	) {
93		loop {
94			tokio::select! {
95				msg = read_half.next() => {
96					if let Some(Ok(Message::Text(data))) = msg {
97						let _ = msg_sender.send_async(data.to_string()).await;
98					}
99				}
100				_ = close_receiver.recv() => {
101					return
102				}
103			}
104		}
105	}
106}
107
108#[async_trait]
109impl WebSocketService for WsService {
110	fn register_api_receiver(&mut self, api_receiver: Receiver<String>) {
111		self.api_receiver = Some(api_receiver)
112	}
113
114	fn register_msg_sender(&mut self, msg_sender: Sender<String>) {
115		self.msg_sender = Some(msg_sender)
116	}
117
118	async fn start(&self) -> anyhow::Result<()> {
119		self.connect().await
120	}
121}