Skip to main content

onebot_api/communication/
ws.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::time::Duration;
10use tokio::net::TcpStream;
11use tokio::select;
12use tokio::sync::broadcast;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15use url::Url;
16
17pub struct WsServiceBuilder {
18	url: Url,
19	access_token: Option<String>,
20	auto_reconnect: Option<bool>,
21	reconnect_interval: Option<Duration>,
22	max_reconnect_times: Option<u32>,
23}
24
25impl WsServiceBuilder {
26	pub fn new(url: impl IntoUrl) -> reqwest::Result<Self> {
27		Ok(Self {
28			url: url.into_url()?,
29			access_token: None,
30			auto_reconnect: None,
31			reconnect_interval: None,
32			max_reconnect_times: None,
33		})
34	}
35
36	pub fn build(self) -> reqwest::Result<WsService> {
37		WsService::new_with_options(
38			self.url,
39			self.access_token,
40			self.auto_reconnect,
41			self.reconnect_interval,
42			self.max_reconnect_times,
43		)
44	}
45
46	pub fn access_token(mut self, access_token: String) -> Self {
47		self.access_token = Some(access_token);
48		self
49	}
50
51	pub fn auto_reconnect(mut self, auto_reconnect: bool) -> Self {
52		self.auto_reconnect = Some(auto_reconnect);
53		self
54	}
55
56	pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
57		self.reconnect_interval = Some(reconnect_interval);
58		self
59	}
60
61	pub fn max_reconnect_times(mut self, max_reconnect_times: u32) -> Self {
62		self.max_reconnect_times = Some(max_reconnect_times);
63		self
64	}
65}
66
67#[derive(Clone, Debug)]
68pub struct WsService {
69	url: Url,
70	access_token: Option<String>,
71	api_receiver: Option<InternalAPIReceiver>,
72	event_sender: Option<InternalEventSender>,
73	close_signal_sender: broadcast::Sender<()>,
74	connection_close_signal_sender: broadcast::Sender<()>,
75	auto_reconnect: bool,
76	reconnect_interval: Duration,
77	max_reconnect_times: u32,
78	is_running: Arc<AtomicBool>,
79}
80
81impl Drop for WsService {
82	fn drop(&mut self) {
83		self.uninstall();
84	}
85}
86
87impl WsService {
88	pub fn new(url: impl IntoUrl, access_token: Option<String>) -> reqwest::Result<Self> {
89		Self::new_with_options(url, access_token, None, None, None)
90	}
91
92	pub fn new_with_options(
93		url: impl IntoUrl,
94		access_token: Option<String>,
95		auto_reconnect: Option<bool>,
96		reconnect_interval: Option<Duration>,
97		max_reconnect_times: Option<u32>,
98	) -> reqwest::Result<Self> {
99		let (close_signal_sender, _) = broadcast::channel(1);
100		let (connection_close_signal_sender, _) = broadcast::channel(1);
101		Ok(Self {
102			url: url.into_url()?,
103			access_token,
104			api_receiver: None,
105			event_sender: None,
106			close_signal_sender,
107			connection_close_signal_sender,
108			auto_reconnect: auto_reconnect.unwrap_or(true),
109			reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
110			max_reconnect_times: max_reconnect_times.unwrap_or(5),
111			is_running: Arc::new(AtomicBool::new(false)),
112		})
113	}
114
115	pub fn builder(url: impl IntoUrl) -> reqwest::Result<WsServiceBuilder> {
116		WsServiceBuilder::new(url)
117	}
118}
119
120impl WsService {
121	pub fn get_url(&self) -> Url {
122		let mut url = self.url.clone();
123		if let Some(token) = &self.access_token {
124			let mut query_pairs = url.query_pairs_mut();
125			query_pairs.append_pair("access_token", token);
126		}
127		url
128	}
129
130	pub async fn connect(
131		url: impl ToString,
132	) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
133		let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
134		Ok(stream)
135	}
136
137	pub async fn send_processor(
138		mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
139		api_receiver: InternalAPIReceiver,
140		mut close_signal: broadcast::Receiver<()>,
141		mut connection_close_signal: broadcast::Receiver<()>,
142	) -> anyhow::Result<()> {
143		loop {
144			select! {
145				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
146				_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
147				Ok(data) = api_receiver.recv_async() => {
148					let str = serde_json::to_string(&data);
149					if str.is_err() {
150						continue
151					}
152					let _ = send_side.send(Message::Text(str?.into())).await;
153				}
154			}
155		}
156	}
157
158	pub async fn read_processor(
159		mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
160		event_sender: InternalEventSender,
161		mut close_signal: broadcast::Receiver<()>,
162		connection_close_signal_sender: broadcast::Sender<()>,
163	) -> anyhow::Result<()> {
164		loop {
165			select! {
166				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
167				Some(Ok(msg)) = read_side.next() => {
168					match msg {
169						Message::Text(data) => {
170							let str = data.as_str();
171							let event = serde_json::from_str::<DeserializedEvent>(str);
172							if event.is_err() {
173								continue
174							}
175							let _ = event_sender.send(event?);
176						},
177						Message::Close(_) => {
178							let _ = connection_close_signal_sender.send(());
179							return Err(anyhow::anyhow!("close"));
180						},
181						_ => ()
182					}
183				}
184			}
185		}
186	}
187
188	pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
189		if self.api_receiver.is_none() && self.event_sender.is_none() {
190			return Err(ServiceStartError::NotInjected);
191		} else if self.event_sender.is_none() {
192			return Err(ServiceStartError::NotInjectedEventSender);
193		} else if self.api_receiver.is_none() {
194			return Err(ServiceStartError::NotInjectedAPIReceiver);
195		}
196
197		let api_receiver = self.api_receiver.clone().unwrap();
198		let event_sender = self.event_sender.clone().unwrap();
199
200		let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
201
202		tokio::spawn(Self::read_processor(
203			read_side,
204			event_sender,
205			self.close_signal_sender.subscribe(),
206			self.connection_close_signal_sender.clone(),
207		));
208		tokio::spawn(Self::send_processor(
209			send_side,
210			api_receiver,
211			self.close_signal_sender.subscribe(),
212			self.connection_close_signal_sender.subscribe(),
213		));
214		Ok(())
215	}
216
217	pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
218		if reconnect_times > self.max_reconnect_times {
219			return Err(anyhow::anyhow!("over max reconnect times"));
220		}
221		tokio::time::sleep(self.reconnect_interval).await;
222		if self.spawn_processor().await.is_err() {
223			Box::pin(self.reconnect(reconnect_times + 1)).await
224		} else {
225			Ok(())
226		}
227	}
228
229	pub async fn reconnect_processor(self) -> anyhow::Result<()> {
230		let mut close_signal = self.close_signal_sender.subscribe();
231		let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
232		loop {
233			select! {
234				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
235				_ = connection_close_signal.recv() => self.reconnect(1).await?
236			}
237		}
238	}
239}
240
241#[async_trait]
242impl CommunicationService for WsService {
243	fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
244		self.api_receiver = Some(api_receiver);
245		self.event_sender = Some(event_sender);
246	}
247
248	fn uninstall(&mut self) {
249		self.stop();
250		self.api_receiver = None;
251		self.event_sender = None;
252	}
253
254	fn stop(&self) {
255		let _ = self.close_signal_sender.send(());
256		self.is_running.store(false, Ordering::Relaxed);
257	}
258
259	async fn start(&self) -> ServiceStartResult<()> {
260		if self.is_running.load(Ordering::Relaxed) {
261			return Err(ServiceStartError::TaskIsRunning);
262		}
263
264		self.spawn_processor().await?;
265		self.is_running.store(true, Ordering::Relaxed);
266		if self.auto_reconnect {
267			tokio::spawn(Self::reconnect_processor(self.clone()));
268		}
269		Ok(())
270	}
271}