Skip to main content

onebot_api/communication/
ws.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::time::Duration;
10use tokio::net::TcpStream;
11use tokio::select;
12use tokio::sync::broadcast;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15use url::Url;
16
17pub struct WsServiceBuilder {
18	url: Url,
19	access_token: Option<String>,
20	auto_reconnect: Option<bool>,
21	reconnect_interval: Option<Duration>,
22	max_reconnect_times: Option<u32>,
23}
24
25impl WsServiceBuilder {
26	pub fn new(url: impl IntoUrl) -> reqwest::Result<Self> {
27		Ok(Self {
28			url: url.into_url()?,
29			access_token: None,
30			auto_reconnect: None,
31			reconnect_interval: None,
32			max_reconnect_times: None,
33		})
34	}
35
36	pub fn build(self) -> reqwest::Result<WsService> {
37		WsService::new(
38			self.url,
39			self.access_token,
40			self.auto_reconnect,
41			self.reconnect_interval,
42			self.max_reconnect_times,
43		)
44	}
45
46	pub fn access_token(mut self, access_token: String) -> Self {
47		self.access_token = Some(access_token);
48		self
49	}
50
51	pub fn auto_reconnect(mut self, auto_reconnect: bool) -> Self {
52		self.auto_reconnect = Some(auto_reconnect);
53		self
54	}
55
56	pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
57		self.reconnect_interval = Some(reconnect_interval);
58		self
59	}
60
61	pub fn max_reconnect_times(mut self, max_reconnect_times: u32) -> Self {
62		self.max_reconnect_times = Some(max_reconnect_times);
63		self
64	}
65}
66
67#[derive(Clone, Debug)]
68pub struct WsService {
69	url: Url,
70	access_token: Option<String>,
71	api_receiver: Option<InternalAPIReceiver>,
72	event_sender: Option<InternalEventSender>,
73	close_signal_sender: broadcast::Sender<()>,
74	connection_close_signal_sender: broadcast::Sender<()>,
75	auto_reconnect: bool,
76	reconnect_interval: Duration,
77	max_reconnect_times: u32,
78	is_running: Arc<AtomicBool>,
79}
80
81impl Drop for WsService {
82	fn drop(&mut self) {
83		self.uninstall();
84	}
85}
86
87impl WsService {
88	pub fn new(
89		url: impl IntoUrl,
90		access_token: Option<String>,
91		auto_reconnect: Option<bool>,
92		reconnect_interval: Option<Duration>,
93		max_reconnect_times: Option<u32>,
94	) -> reqwest::Result<Self> {
95		let (close_signal_sender, _) = broadcast::channel(1);
96		let (connection_close_signal_sender, _) = broadcast::channel(1);
97		Ok(Self {
98			url: url.into_url()?,
99			access_token,
100			api_receiver: None,
101			event_sender: None,
102			close_signal_sender,
103			connection_close_signal_sender,
104			auto_reconnect: auto_reconnect.unwrap_or(true),
105			reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
106			max_reconnect_times: max_reconnect_times.unwrap_or(5),
107			is_running: Arc::new(AtomicBool::new(false)),
108		})
109	}
110
111	pub fn builder(url: impl IntoUrl) -> reqwest::Result<WsServiceBuilder> {
112		WsServiceBuilder::new(url)
113	}
114}
115
116impl WsService {
117	pub fn get_url(&self) -> Url {
118		let mut url = self.url.clone();
119		if let Some(token) = &self.access_token {
120			let mut query_pairs = url.query_pairs_mut();
121			query_pairs.append_pair("access_token", token);
122		}
123		url
124	}
125
126	pub async fn connect(
127		url: impl ToString,
128	) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
129		let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
130		Ok(stream)
131	}
132
133	pub async fn send_processor(
134		mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
135		api_receiver: InternalAPIReceiver,
136		mut close_signal: broadcast::Receiver<()>,
137		mut connection_close_signal: broadcast::Receiver<()>,
138	) -> anyhow::Result<()> {
139		loop {
140			select! {
141				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
142				_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
143				Ok(data) = api_receiver.recv_async() => {
144					let str = serde_json::to_string(&data);
145					if str.is_err() {
146						continue
147					}
148					let _ = send_side.send(Message::Text(str?.into())).await;
149				}
150			}
151		}
152	}
153
154	pub async fn read_processor(
155		mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
156		event_sender: InternalEventSender,
157		mut close_signal: broadcast::Receiver<()>,
158		connection_close_signal_sender: broadcast::Sender<()>,
159	) -> anyhow::Result<()> {
160		loop {
161			select! {
162				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
163				Some(Ok(msg)) = read_side.next() => {
164					match msg {
165						Message::Text(data) => {
166							let str = data.as_str();
167							let event = serde_json::from_str::<DeserializedEvent>(str);
168							if event.is_err() {
169								continue
170							}
171							let _ = event_sender.send(event?);
172						},
173						Message::Close(_) => {
174							let _ = connection_close_signal_sender.send(());
175							return Err(anyhow::anyhow!("close"));
176						},
177						_ => ()
178					}
179				}
180			}
181		}
182	}
183
184	pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
185		if self.api_receiver.is_none() && self.event_sender.is_none() {
186			return Err(ServiceStartError::NotInjected);
187		} else if self.event_sender.is_none() {
188			return Err(ServiceStartError::NotInjectedEventSender);
189		} else if self.api_receiver.is_none() {
190			return Err(ServiceStartError::NotInjectedAPIReceiver);
191		}
192
193		let api_receiver = self.api_receiver.clone().unwrap();
194		let event_sender = self.event_sender.clone().unwrap();
195
196		let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
197
198		tokio::spawn(Self::read_processor(
199			read_side,
200			event_sender,
201			self.close_signal_sender.subscribe(),
202			self.connection_close_signal_sender.clone(),
203		));
204		tokio::spawn(Self::send_processor(
205			send_side,
206			api_receiver,
207			self.close_signal_sender.subscribe(),
208			self.connection_close_signal_sender.subscribe(),
209		));
210		Ok(())
211	}
212
213	pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
214		if reconnect_times > self.max_reconnect_times {
215			return Err(anyhow::anyhow!("over max reconnect times"));
216		}
217		tokio::time::sleep(self.reconnect_interval).await;
218		if self.spawn_processor().await.is_err() {
219			Box::pin(self.reconnect(reconnect_times + 1)).await
220		} else {
221			Ok(())
222		}
223	}
224
225	pub async fn reconnect_processor(self) -> anyhow::Result<()> {
226		let mut close_signal = self.close_signal_sender.subscribe();
227		let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
228		loop {
229			select! {
230				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
231				_ = connection_close_signal.recv() => self.reconnect(1).await?
232			}
233		}
234	}
235}
236
237#[async_trait]
238impl CommunicationService for WsService {
239	fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
240		self.api_receiver = Some(api_receiver);
241		self.event_sender = Some(event_sender);
242	}
243
244	fn uninstall(&mut self) {
245		self.stop();
246		self.api_receiver = None;
247		self.event_sender = None;
248	}
249
250	fn stop(&self) {
251		let _ = self.close_signal_sender.send(());
252		self.is_running.store(false, Ordering::Relaxed);
253	}
254
255	async fn start(&self) -> ServiceStartResult<()> {
256		if self.is_running.load(Ordering::Relaxed) {
257			return Err(ServiceStartError::TaskIsRunning);
258		}
259
260		self.spawn_processor().await?;
261		self.is_running.store(true, Ordering::Relaxed);
262		if self.auto_reconnect {
263			tokio::spawn(Self::reconnect_processor(self.clone()));
264		}
265		Ok(())
266	}
267}