Skip to main content

onebot_api/communication/
ws.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::time::Duration;
9use tokio::net::TcpStream;
10use tokio::select;
11use tokio::sync::broadcast;
12use tokio_tungstenite::tungstenite::Message;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
14use url::Url;
15
16pub struct WsServiceBuilder {
17	url: Url,
18	access_token: Option<String>,
19	auto_reconnect: Option<bool>,
20	reconnect_interval: Option<Duration>,
21	max_reconnect_times: Option<u32>,
22}
23
24impl WsServiceBuilder {
25	pub fn new(url: &str) -> Result<Self, url::ParseError> {
26		let url = Url::parse(url)?;
27		Ok(Self {
28			url,
29			access_token: None,
30			auto_reconnect: None,
31			reconnect_interval: None,
32			max_reconnect_times: None,
33		})
34	}
35
36	pub fn new_with_url(url: Url) -> Self {
37		Self {
38			url,
39			access_token: None,
40			auto_reconnect: None,
41			reconnect_interval: None,
42			max_reconnect_times: None,
43		}
44	}
45
46	pub fn build(self) -> WsService {
47		WsService::new_with_options(
48			self.url,
49			self.access_token,
50			self.auto_reconnect,
51			self.reconnect_interval,
52			self.max_reconnect_times,
53		)
54	}
55
56	pub fn access_token(mut self, access_token: String) -> Self {
57		self.access_token = Some(access_token);
58		self
59	}
60
61	pub fn auto_reconnect(mut self, auto_reconnect: bool) -> Self {
62		self.auto_reconnect = Some(auto_reconnect);
63		self
64	}
65
66	pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
67		self.reconnect_interval = Some(reconnect_interval);
68		self
69	}
70
71	pub fn max_reconnect_times(mut self, max_reconnect_times: u32) -> Self {
72		self.max_reconnect_times = Some(max_reconnect_times);
73		self
74	}
75}
76
77#[derive(Clone, Debug)]
78pub struct WsService {
79	url: Url,
80	access_token: Option<String>,
81	api_receiver: Option<InternalAPIReceiver>,
82	event_sender: Option<InternalEventSender>,
83	close_signal_sender: broadcast::Sender<()>,
84	connection_close_signal_sender: broadcast::Sender<()>,
85	auto_reconnect: bool,
86	reconnect_interval: Duration,
87	max_reconnect_times: u32,
88	is_running: Arc<AtomicBool>,
89}
90
91impl Drop for WsService {
92	fn drop(&mut self) {
93		self.uninstall();
94	}
95}
96
97impl WsService {
98	pub fn new(url: Url, access_token: Option<String>) -> Self {
99		Self::new_with_options(url, access_token, None, None, None)
100	}
101
102	pub fn new_with_options(
103		url: Url,
104		access_token: Option<String>,
105		auto_reconnect: Option<bool>,
106		reconnect_interval: Option<Duration>,
107		max_reconnect_times: Option<u32>,
108	) -> Self {
109		let (close_signal_sender, _) = broadcast::channel(1);
110		let (connection_close_signal_sender, _) = broadcast::channel(1);
111		Self {
112			url,
113			access_token,
114			api_receiver: None,
115			event_sender: None,
116			close_signal_sender,
117			connection_close_signal_sender,
118			auto_reconnect: auto_reconnect.unwrap_or(true),
119			reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
120			max_reconnect_times: max_reconnect_times.unwrap_or(5),
121			is_running: Arc::new(AtomicBool::new(false)),
122		}
123	}
124
125	pub fn builder(url: &str) -> Result<WsServiceBuilder, url::ParseError> {
126		WsServiceBuilder::new(url)
127	}
128}
129
130impl WsService {
131	pub fn get_url(&self) -> Url {
132		let mut url = self.url.clone();
133		if let Some(token) = &self.access_token {
134			let mut query_pairs = url.query_pairs_mut();
135			query_pairs.append_pair("access_token", token);
136		}
137		url
138	}
139
140	pub async fn connect(
141		url: impl ToString,
142	) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
143		let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
144		Ok(stream)
145	}
146
147	pub async fn send_processor(
148		mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
149		api_receiver: InternalAPIReceiver,
150		mut close_signal: broadcast::Receiver<()>,
151		mut connection_close_signal: broadcast::Receiver<()>,
152	) -> anyhow::Result<()> {
153		loop {
154			select! {
155				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
156				_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
157				Ok(data) = api_receiver.recv_async() => {
158					let str = serde_json::to_string(&data);
159					if str.is_err() {
160						continue
161					}
162					let _ = send_side.send(Message::Text(str?.into())).await;
163				}
164			}
165		}
166	}
167
168	pub async fn read_processor(
169		mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
170		event_sender: InternalEventSender,
171		mut close_signal: broadcast::Receiver<()>,
172		connection_close_signal_sender: broadcast::Sender<()>,
173	) -> anyhow::Result<()> {
174		loop {
175			select! {
176				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
177				Some(Ok(msg)) = read_side.next() => {
178					match msg {
179						Message::Text(data) => {
180							let str = data.as_str();
181							let event = serde_json::from_str::<DeserializedEvent>(str);
182							if event.is_err() {
183								continue
184							}
185							let _ = event_sender.send(event?);
186						},
187						Message::Close(_) => {
188							let _ = connection_close_signal_sender.send(());
189							return Err(anyhow::anyhow!("close"));
190						},
191						_ => ()
192					}
193				}
194			}
195		}
196	}
197
198	pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
199		if self.api_receiver.is_none() && self.event_sender.is_none() {
200			return Err(ServiceStartError::NotInjected);
201		} else if self.event_sender.is_none() {
202			return Err(ServiceStartError::NotInjectedEventSender);
203		} else if self.api_receiver.is_none() {
204			return Err(ServiceStartError::NotInjectedAPIReceiver);
205		}
206
207		let api_receiver = self.api_receiver.clone().unwrap();
208		let event_sender = self.event_sender.clone().unwrap();
209
210		let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
211
212		tokio::spawn(Self::read_processor(
213			read_side,
214			event_sender,
215			self.close_signal_sender.subscribe(),
216			self.connection_close_signal_sender.clone(),
217		));
218		tokio::spawn(Self::send_processor(
219			send_side,
220			api_receiver,
221			self.close_signal_sender.subscribe(),
222			self.connection_close_signal_sender.subscribe(),
223		));
224		Ok(())
225	}
226
227	pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
228		if reconnect_times > self.max_reconnect_times {
229			return Err(anyhow::anyhow!("over max reconnect times"));
230		}
231		tokio::time::sleep(self.reconnect_interval).await;
232		if self.spawn_processor().await.is_err() {
233			Box::pin(self.reconnect(reconnect_times + 1)).await
234		} else {
235			Ok(())
236		}
237	}
238
239	pub async fn reconnect_processor(self) -> anyhow::Result<()> {
240		let mut close_signal = self.close_signal_sender.subscribe();
241		let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
242		loop {
243			select! {
244				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
245				_ = connection_close_signal.recv() => self.reconnect(1).await?
246			}
247		}
248	}
249}
250
251#[async_trait]
252impl CommunicationService for WsService {
253	fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
254		self.api_receiver = Some(api_receiver);
255		self.event_sender = Some(event_sender);
256	}
257
258	fn uninstall(&mut self) {
259		self.stop();
260		self.api_receiver = None;
261		self.event_sender = None;
262	}
263
264	fn stop(&self) {
265		let _ = self.close_signal_sender.send(());
266		self.is_running.store(false, Ordering::Relaxed);
267	}
268
269	async fn start(&self) -> ServiceStartResult<()> {
270		if self.is_running.load(Ordering::Relaxed) {
271			return Err(ServiceStartError::TaskIsRunning);
272		}
273
274		self.spawn_processor().await?;
275		self.is_running.store(true, Ordering::Relaxed);
276		if self.auto_reconnect {
277			tokio::spawn(Self::reconnect_processor(self.clone()));
278		}
279		Ok(())
280	}
281}