onebot_api/communication/
ws.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::time::Duration;
10use tokio::net::TcpStream;
11use tokio::select;
12use tokio::sync::broadcast;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15use url::Url;
16
17pub struct WsServiceBuilder {
18 url: Url,
19 access_token: Option<String>,
20 auto_reconnect: Option<bool>,
21 reconnect_interval: Option<Duration>,
22 max_reconnect_times: Option<u32>,
23}
24
25impl WsServiceBuilder {
26 pub fn new(url: impl IntoUrl) -> reqwest::Result<Self> {
27 Ok(Self {
28 url: url.into_url()?,
29 access_token: None,
30 auto_reconnect: None,
31 reconnect_interval: None,
32 max_reconnect_times: None,
33 })
34 }
35
36 pub fn build(self) -> reqwest::Result<WsService> {
37 WsService::new(
38 self.url,
39 self.access_token,
40 self.auto_reconnect,
41 self.reconnect_interval,
42 self.max_reconnect_times,
43 )
44 }
45
46 pub fn access_token(mut self, access_token: String) -> Self {
47 self.access_token = Some(access_token);
48 self
49 }
50
51 pub fn auto_reconnect(mut self, auto_reconnect: bool) -> Self {
52 self.auto_reconnect = Some(auto_reconnect);
53 self
54 }
55
56 pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
57 self.reconnect_interval = Some(reconnect_interval);
58 self
59 }
60
61 pub fn max_reconnect_times(mut self, max_reconnect_times: u32) -> Self {
62 self.max_reconnect_times = Some(max_reconnect_times);
63 self
64 }
65}
66
67#[derive(Clone, Debug)]
68pub struct WsService {
69 url: Url,
70 access_token: Option<String>,
71 api_receiver: Option<InternalAPIReceiver>,
72 event_sender: Option<InternalEventSender>,
73 close_signal_sender: broadcast::Sender<()>,
74 connection_close_signal_sender: broadcast::Sender<()>,
75 auto_reconnect: bool,
76 reconnect_interval: Duration,
77 max_reconnect_times: u32,
78 is_running: Arc<AtomicBool>,
79}
80
81impl Drop for WsService {
82 fn drop(&mut self) {
83 self.uninstall();
84 }
85}
86
87impl WsService {
88 pub fn new(
89 url: impl IntoUrl,
90 access_token: Option<String>,
91 auto_reconnect: Option<bool>,
92 reconnect_interval: Option<Duration>,
93 max_reconnect_times: Option<u32>,
94 ) -> reqwest::Result<Self> {
95 let (close_signal_sender, _) = broadcast::channel(1);
96 let (connection_close_signal_sender, _) = broadcast::channel(1);
97 Ok(Self {
98 url: url.into_url()?,
99 access_token,
100 api_receiver: None,
101 event_sender: None,
102 close_signal_sender,
103 connection_close_signal_sender,
104 auto_reconnect: auto_reconnect.unwrap_or(true),
105 reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
106 max_reconnect_times: max_reconnect_times.unwrap_or(5),
107 is_running: Arc::new(AtomicBool::new(false)),
108 })
109 }
110
111 pub fn builder(url: impl IntoUrl) -> reqwest::Result<WsServiceBuilder> {
112 WsServiceBuilder::new(url)
113 }
114}
115
116impl WsService {
117 pub fn get_url(&self) -> Url {
118 let mut url = self.url.clone();
119 if let Some(token) = &self.access_token {
120 let mut query_pairs = url.query_pairs_mut();
121 query_pairs.append_pair("access_token", token);
122 }
123 url
124 }
125
126 pub async fn connect(
127 url: impl ToString,
128 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
129 let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
130 Ok(stream)
131 }
132
133 pub async fn send_processor(
134 mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
135 api_receiver: InternalAPIReceiver,
136 mut close_signal: broadcast::Receiver<()>,
137 mut connection_close_signal: broadcast::Receiver<()>,
138 ) -> anyhow::Result<()> {
139 loop {
140 select! {
141 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
142 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
143 Ok(data) = api_receiver.recv_async() => {
144 let str = serde_json::to_string(&data);
145 if str.is_err() {
146 continue
147 }
148 let _ = send_side.send(Message::Text(str?.into())).await;
149 }
150 }
151 }
152 }
153
154 pub async fn read_processor(
155 mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
156 event_sender: InternalEventSender,
157 mut close_signal: broadcast::Receiver<()>,
158 connection_close_signal_sender: broadcast::Sender<()>,
159 ) -> anyhow::Result<()> {
160 loop {
161 select! {
162 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
163 Some(Ok(msg)) = read_side.next() => {
164 match msg {
165 Message::Text(data) => {
166 let str = data.as_str();
167 let event = serde_json::from_str::<DeserializedEvent>(str);
168 if event.is_err() {
169 continue
170 }
171 let _ = event_sender.send(event?);
172 },
173 Message::Close(_) => {
174 let _ = connection_close_signal_sender.send(());
175 return Err(anyhow::anyhow!("close"));
176 },
177 _ => ()
178 }
179 }
180 }
181 }
182 }
183
184 pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
185 if self.api_receiver.is_none() && self.event_sender.is_none() {
186 return Err(ServiceStartError::NotInjected);
187 } else if self.event_sender.is_none() {
188 return Err(ServiceStartError::NotInjectedEventSender);
189 } else if self.api_receiver.is_none() {
190 return Err(ServiceStartError::NotInjectedAPIReceiver);
191 }
192
193 let api_receiver = self.api_receiver.clone().unwrap();
194 let event_sender = self.event_sender.clone().unwrap();
195
196 let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
197
198 tokio::spawn(Self::read_processor(
199 read_side,
200 event_sender,
201 self.close_signal_sender.subscribe(),
202 self.connection_close_signal_sender.clone(),
203 ));
204 tokio::spawn(Self::send_processor(
205 send_side,
206 api_receiver,
207 self.close_signal_sender.subscribe(),
208 self.connection_close_signal_sender.subscribe(),
209 ));
210 Ok(())
211 }
212
213 pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
214 if reconnect_times > self.max_reconnect_times {
215 return Err(anyhow::anyhow!("over max reconnect times"));
216 }
217 tokio::time::sleep(self.reconnect_interval).await;
218 if self.spawn_processor().await.is_err() {
219 Box::pin(self.reconnect(reconnect_times + 1)).await
220 } else {
221 Ok(())
222 }
223 }
224
225 pub async fn reconnect_processor(self) -> anyhow::Result<()> {
226 let mut close_signal = self.close_signal_sender.subscribe();
227 let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
228 loop {
229 select! {
230 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
231 _ = connection_close_signal.recv() => self.reconnect(1).await?
232 }
233 }
234 }
235}
236
237#[async_trait]
238impl CommunicationService for WsService {
239 fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
240 self.api_receiver = Some(api_receiver);
241 self.event_sender = Some(event_sender);
242 }
243
244 fn uninstall(&mut self) {
245 self.stop();
246 self.api_receiver = None;
247 self.event_sender = None;
248 }
249
250 fn stop(&self) {
251 let _ = self.close_signal_sender.send(());
252 self.is_running.store(false, Ordering::Relaxed);
253 }
254
255 async fn start(&self) -> ServiceStartResult<()> {
256 if self.is_running.load(Ordering::Relaxed) {
257 return Err(ServiceStartError::TaskIsRunning);
258 }
259
260 self.spawn_processor().await?;
261 self.is_running.store(true, Ordering::Relaxed);
262 if self.auto_reconnect {
263 tokio::spawn(Self::reconnect_processor(self.clone()));
264 }
265 Ok(())
266 }
267}