Skip to main content

onebot_api/communication/
ws.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::net::TcpStream;
11use tokio::select;
12use tokio::sync::broadcast;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15use url::Url;
16
17#[derive(Clone, Debug)]
18pub struct WsService {
19	url: Url,
20	access_token: Option<String>,
21	api_receiver: Option<APIReceiver>,
22	event_sender: Option<EventSender>,
23	close_signal_sender: broadcast::Sender<()>,
24	connection_close_signal_sender: broadcast::Sender<()>,
25	auto_reconnect: bool,
26	reconnect_interval: Duration,
27	max_reconnect_times: u32,
28}
29
30impl Drop for WsService {
31	fn drop(&mut self) {
32		let _ = self.close_signal_sender.send(());
33	}
34}
35
36impl WsService {
37	pub fn new(
38		url: impl IntoUrl,
39		access_token: Option<String>,
40		auto_reconnect: Option<bool>,
41		reconnect_interval: Option<Duration>,
42		max_reconnect_times: Option<u32>,
43	) -> reqwest::Result<Self> {
44		let (close_signal_sender, _) = broadcast::channel(1);
45		let (connection_close_signal_sender, _) = broadcast::channel(1);
46		Ok(Self {
47			url: url.into_url()?,
48			access_token,
49			api_receiver: None,
50			event_sender: None,
51			close_signal_sender,
52			connection_close_signal_sender,
53			auto_reconnect: auto_reconnect.unwrap_or(true),
54			reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
55			max_reconnect_times: max_reconnect_times.unwrap_or(5),
56		})
57	}
58}
59
60impl WsService {
61	pub fn get_url(&self) -> Url {
62		let mut url = self.url.clone();
63		if let Some(token) = &self.access_token {
64			let mut query_pairs = url.query_pairs_mut();
65			query_pairs.append_pair("access_token", token);
66		}
67		url
68	}
69
70	pub async fn connect(
71		url: impl ToString,
72	) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
73		let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
74		Ok(stream)
75	}
76
77	pub async fn send_processor(
78		mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
79		api_receiver: APIReceiver,
80		mut close_signal: broadcast::Receiver<()>,
81		mut connection_close_signal: broadcast::Receiver<()>,
82	) -> anyhow::Result<()> {
83		loop {
84			select! {
85				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
86				_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
87				Ok(data) = api_receiver.recv_async() => {
88					let str = serde_json::to_string(&data);
89					if str.is_err() {
90						continue
91					}
92					let _ = send_side.send(Message::Text(str?.into())).await;
93				}
94			}
95		}
96	}
97
98	pub async fn read_processor(
99		mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
100		event_sender: EventSender,
101		mut close_signal: broadcast::Receiver<()>,
102		connection_close_signal_sender: broadcast::Sender<()>,
103	) -> anyhow::Result<()> {
104		loop {
105			select! {
106				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
107				Some(Ok(msg)) = read_side.next() => {
108					match msg {
109						Message::Text(data) => {
110							let str = data.as_str();
111							let event = serde_json::from_str::<Event>(str);
112							if event.is_err() {
113								continue
114							}
115							let event = Arc::new(event?);
116							let _ = event_sender.send(event);
117						},
118						Message::Close(_) => {
119							let _ = connection_close_signal_sender.send(());
120							return Err(anyhow::anyhow!("close"));
121						},
122						_ => ()
123					}
124				}
125			}
126		}
127	}
128
129	pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
130		if self.api_receiver.is_none() && self.event_sender.is_none() {
131			return Err(ServiceStartError::NotInjected);
132		} else if self.event_sender.is_none() {
133			return Err(ServiceStartError::NotInjectedEventSender);
134		} else if self.api_receiver.is_none() {
135			return Err(ServiceStartError::NotInjectedAPIReceiver);
136		}
137
138		let api_receiver = self.api_receiver.clone().unwrap();
139		let event_sender = self.event_sender.clone().unwrap();
140
141		let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
142
143		tokio::spawn(Self::read_processor(
144			read_side,
145			event_sender,
146			self.close_signal_sender.subscribe(),
147			self.connection_close_signal_sender.clone(),
148		));
149		tokio::spawn(Self::send_processor(
150			send_side,
151			api_receiver,
152			self.close_signal_sender.subscribe(),
153			self.connection_close_signal_sender.subscribe(),
154		));
155		Ok(())
156	}
157
158	pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
159		if reconnect_times > self.max_reconnect_times {
160			return Err(anyhow::anyhow!("over max reconnect times"));
161		}
162		tokio::time::sleep(self.reconnect_interval).await;
163		if self.spawn_processor().await.is_err() {
164			Box::pin(self.reconnect(reconnect_times + 1)).await
165		} else {
166			Ok(())
167		}
168	}
169
170	pub async fn reconnect_processor(self) -> anyhow::Result<()> {
171		let mut close_signal = self.close_signal_sender.subscribe();
172		let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
173		loop {
174			select! {
175				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
176				_ = connection_close_signal.recv() => self.reconnect(1).await?
177			}
178		}
179	}
180}
181
182#[async_trait]
183impl CommunicationService for WsService {
184	fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
185		self.api_receiver = Some(api_receiver);
186		self.event_sender = Some(event_sender);
187	}
188
189	async fn start_service(&self) -> ServiceStartResult<()> {
190		self.spawn_processor().await?;
191		if self.auto_reconnect {
192			tokio::spawn(Self::reconnect_processor(self.clone()));
193		}
194		Ok(())
195	}
196}