onebot_api/communication/
ws.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::time::Duration;
10use tokio::net::TcpStream;
11use tokio::select;
12use tokio::sync::broadcast;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15use url::Url;
16
17pub struct WsServiceBuilder {
18 url: Url,
19 access_token: Option<String>,
20 auto_reconnect: Option<bool>,
21 reconnect_interval: Option<Duration>,
22 max_reconnect_times: Option<u32>,
23}
24
25impl WsServiceBuilder {
26 pub fn new(url: impl IntoUrl) -> reqwest::Result<Self> {
27 Ok(Self {
28 url: url.into_url()?,
29 access_token: None,
30 auto_reconnect: None,
31 reconnect_interval: None,
32 max_reconnect_times: None,
33 })
34 }
35
36 pub fn build(self) -> reqwest::Result<WsService> {
37 WsService::new_with_options(
38 self.url,
39 self.access_token,
40 self.auto_reconnect,
41 self.reconnect_interval,
42 self.max_reconnect_times,
43 )
44 }
45
46 pub fn access_token(mut self, access_token: String) -> Self {
47 self.access_token = Some(access_token);
48 self
49 }
50
51 pub fn auto_reconnect(mut self, auto_reconnect: bool) -> Self {
52 self.auto_reconnect = Some(auto_reconnect);
53 self
54 }
55
56 pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
57 self.reconnect_interval = Some(reconnect_interval);
58 self
59 }
60
61 pub fn max_reconnect_times(mut self, max_reconnect_times: u32) -> Self {
62 self.max_reconnect_times = Some(max_reconnect_times);
63 self
64 }
65}
66
67#[derive(Clone, Debug)]
68pub struct WsService {
69 url: Url,
70 access_token: Option<String>,
71 api_receiver: Option<InternalAPIReceiver>,
72 event_sender: Option<InternalEventSender>,
73 close_signal_sender: broadcast::Sender<()>,
74 connection_close_signal_sender: broadcast::Sender<()>,
75 auto_reconnect: bool,
76 reconnect_interval: Duration,
77 max_reconnect_times: u32,
78 is_running: Arc<AtomicBool>,
79}
80
81impl Drop for WsService {
82 fn drop(&mut self) {
83 self.uninstall();
84 }
85}
86
87impl WsService {
88 pub fn new(url: impl IntoUrl, access_token: Option<String>) -> reqwest::Result<Self> {
89 Self::new_with_options(url, access_token, None, None, None)
90 }
91
92 pub fn new_with_options(
93 url: impl IntoUrl,
94 access_token: Option<String>,
95 auto_reconnect: Option<bool>,
96 reconnect_interval: Option<Duration>,
97 max_reconnect_times: Option<u32>,
98 ) -> reqwest::Result<Self> {
99 let (close_signal_sender, _) = broadcast::channel(1);
100 let (connection_close_signal_sender, _) = broadcast::channel(1);
101 Ok(Self {
102 url: url.into_url()?,
103 access_token,
104 api_receiver: None,
105 event_sender: None,
106 close_signal_sender,
107 connection_close_signal_sender,
108 auto_reconnect: auto_reconnect.unwrap_or(true),
109 reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
110 max_reconnect_times: max_reconnect_times.unwrap_or(5),
111 is_running: Arc::new(AtomicBool::new(false)),
112 })
113 }
114
115 pub fn builder(url: impl IntoUrl) -> reqwest::Result<WsServiceBuilder> {
116 WsServiceBuilder::new(url)
117 }
118}
119
120impl WsService {
121 pub fn get_url(&self) -> Url {
122 let mut url = self.url.clone();
123 if let Some(token) = &self.access_token {
124 let mut query_pairs = url.query_pairs_mut();
125 query_pairs.append_pair("access_token", token);
126 }
127 url
128 }
129
130 pub async fn connect(
131 url: impl ToString,
132 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
133 let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
134 Ok(stream)
135 }
136
137 pub async fn send_processor(
138 mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
139 api_receiver: InternalAPIReceiver,
140 mut close_signal: broadcast::Receiver<()>,
141 mut connection_close_signal: broadcast::Receiver<()>,
142 ) -> anyhow::Result<()> {
143 loop {
144 select! {
145 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
146 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
147 Ok(data) = api_receiver.recv_async() => {
148 let str = serde_json::to_string(&data);
149 if str.is_err() {
150 continue
151 }
152 let _ = send_side.send(Message::Text(str?.into())).await;
153 }
154 }
155 }
156 }
157
158 pub async fn read_processor(
159 mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
160 event_sender: InternalEventSender,
161 mut close_signal: broadcast::Receiver<()>,
162 connection_close_signal_sender: broadcast::Sender<()>,
163 ) -> anyhow::Result<()> {
164 loop {
165 select! {
166 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
167 Some(Ok(msg)) = read_side.next() => {
168 match msg {
169 Message::Text(data) => {
170 let str = data.as_str();
171 let event = serde_json::from_str::<DeserializedEvent>(str);
172 if event.is_err() {
173 continue
174 }
175 let _ = event_sender.send(event?);
176 },
177 Message::Close(_) => {
178 let _ = connection_close_signal_sender.send(());
179 return Err(anyhow::anyhow!("close"));
180 },
181 _ => ()
182 }
183 }
184 }
185 }
186 }
187
188 pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
189 if self.api_receiver.is_none() && self.event_sender.is_none() {
190 return Err(ServiceStartError::NotInjected);
191 } else if self.event_sender.is_none() {
192 return Err(ServiceStartError::NotInjectedEventSender);
193 } else if self.api_receiver.is_none() {
194 return Err(ServiceStartError::NotInjectedAPIReceiver);
195 }
196
197 let api_receiver = self.api_receiver.clone().unwrap();
198 let event_sender = self.event_sender.clone().unwrap();
199
200 let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
201
202 tokio::spawn(Self::read_processor(
203 read_side,
204 event_sender,
205 self.close_signal_sender.subscribe(),
206 self.connection_close_signal_sender.clone(),
207 ));
208 tokio::spawn(Self::send_processor(
209 send_side,
210 api_receiver,
211 self.close_signal_sender.subscribe(),
212 self.connection_close_signal_sender.subscribe(),
213 ));
214 Ok(())
215 }
216
217 pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
218 if reconnect_times > self.max_reconnect_times {
219 return Err(anyhow::anyhow!("over max reconnect times"));
220 }
221 tokio::time::sleep(self.reconnect_interval).await;
222 if self.spawn_processor().await.is_err() {
223 Box::pin(self.reconnect(reconnect_times + 1)).await
224 } else {
225 Ok(())
226 }
227 }
228
229 pub async fn reconnect_processor(self) -> anyhow::Result<()> {
230 let mut close_signal = self.close_signal_sender.subscribe();
231 let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
232 loop {
233 select! {
234 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
235 _ = connection_close_signal.recv() => self.reconnect(1).await?
236 }
237 }
238 }
239}
240
241#[async_trait]
242impl CommunicationService for WsService {
243 fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
244 self.api_receiver = Some(api_receiver);
245 self.event_sender = Some(event_sender);
246 }
247
248 fn uninstall(&mut self) {
249 self.stop();
250 self.api_receiver = None;
251 self.event_sender = None;
252 }
253
254 fn stop(&self) {
255 let _ = self.close_signal_sender.send(());
256 self.is_running.store(false, Ordering::Relaxed);
257 }
258
259 async fn start(&self) -> ServiceStartResult<()> {
260 if self.is_running.load(Ordering::Relaxed) {
261 return Err(ServiceStartError::TaskIsRunning);
262 }
263
264 self.spawn_processor().await?;
265 self.is_running.store(true, Ordering::Relaxed);
266 if self.auto_reconnect {
267 tokio::spawn(Self::reconnect_processor(self.clone()));
268 }
269 Ok(())
270 }
271}