onebot_api/communication/
ws.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::time::Duration;
9use tokio::net::TcpStream;
10use tokio::select;
11use tokio::sync::broadcast;
12use tokio_tungstenite::tungstenite::Message;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
14use url::Url;
15
16pub struct WsServiceBuilder {
17 url: Url,
18 access_token: Option<String>,
19 auto_reconnect: Option<bool>,
20 reconnect_interval: Option<Duration>,
21 max_reconnect_times: Option<u32>,
22}
23
24impl WsServiceBuilder {
25 pub fn new(url: &str) -> Result<Self, url::ParseError> {
26 let url = Url::parse(url)?;
27 Ok(Self {
28 url,
29 access_token: None,
30 auto_reconnect: None,
31 reconnect_interval: None,
32 max_reconnect_times: None,
33 })
34 }
35
36 pub fn new_with_url(url: Url) -> Self {
37 Self {
38 url,
39 access_token: None,
40 auto_reconnect: None,
41 reconnect_interval: None,
42 max_reconnect_times: None,
43 }
44 }
45
46 pub fn build(self) -> WsService {
47 WsService::new_with_options(
48 self.url,
49 self.access_token,
50 self.auto_reconnect,
51 self.reconnect_interval,
52 self.max_reconnect_times,
53 )
54 }
55
56 pub fn access_token(mut self, access_token: String) -> Self {
57 self.access_token = Some(access_token);
58 self
59 }
60
61 pub fn auto_reconnect(mut self, auto_reconnect: bool) -> Self {
62 self.auto_reconnect = Some(auto_reconnect);
63 self
64 }
65
66 pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
67 self.reconnect_interval = Some(reconnect_interval);
68 self
69 }
70
71 pub fn max_reconnect_times(mut self, max_reconnect_times: u32) -> Self {
72 self.max_reconnect_times = Some(max_reconnect_times);
73 self
74 }
75}
76
77#[derive(Clone, Debug)]
78pub struct WsService {
79 url: Url,
80 access_token: Option<String>,
81 api_receiver: Option<InternalAPIReceiver>,
82 event_sender: Option<InternalEventSender>,
83 close_signal_sender: broadcast::Sender<()>,
84 connection_close_signal_sender: broadcast::Sender<()>,
85 auto_reconnect: bool,
86 reconnect_interval: Duration,
87 max_reconnect_times: u32,
88 is_running: Arc<AtomicBool>,
89}
90
91impl Drop for WsService {
92 fn drop(&mut self) {
93 self.uninstall();
94 }
95}
96
97impl WsService {
98 pub fn new(url: Url, access_token: Option<String>) -> Self {
99 Self::new_with_options(url, access_token, None, None, None)
100 }
101
102 pub fn new_with_options(
103 url: Url,
104 access_token: Option<String>,
105 auto_reconnect: Option<bool>,
106 reconnect_interval: Option<Duration>,
107 max_reconnect_times: Option<u32>,
108 ) -> Self {
109 let (close_signal_sender, _) = broadcast::channel(1);
110 let (connection_close_signal_sender, _) = broadcast::channel(1);
111 Self {
112 url,
113 access_token,
114 api_receiver: None,
115 event_sender: None,
116 close_signal_sender,
117 connection_close_signal_sender,
118 auto_reconnect: auto_reconnect.unwrap_or(true),
119 reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
120 max_reconnect_times: max_reconnect_times.unwrap_or(5),
121 is_running: Arc::new(AtomicBool::new(false)),
122 }
123 }
124
125 pub fn builder(url: &str) -> Result<WsServiceBuilder, url::ParseError> {
126 WsServiceBuilder::new(url)
127 }
128}
129
130impl WsService {
131 pub fn get_url(&self) -> Url {
132 let mut url = self.url.clone();
133 if let Some(token) = &self.access_token {
134 let mut query_pairs = url.query_pairs_mut();
135 query_pairs.append_pair("access_token", token);
136 }
137 url
138 }
139
140 pub async fn connect(
141 url: impl ToString,
142 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
143 let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
144 Ok(stream)
145 }
146
147 pub async fn send_processor(
148 mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
149 api_receiver: InternalAPIReceiver,
150 mut close_signal: broadcast::Receiver<()>,
151 mut connection_close_signal: broadcast::Receiver<()>,
152 ) -> anyhow::Result<()> {
153 loop {
154 select! {
155 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
156 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
157 Ok(data) = api_receiver.recv_async() => {
158 let str = serde_json::to_string(&data);
159 if str.is_err() {
160 continue
161 }
162 let _ = send_side.send(Message::Text(str?.into())).await;
163 }
164 }
165 }
166 }
167
168 pub async fn read_processor(
169 mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
170 event_sender: InternalEventSender,
171 mut close_signal: broadcast::Receiver<()>,
172 connection_close_signal_sender: broadcast::Sender<()>,
173 ) -> anyhow::Result<()> {
174 loop {
175 select! {
176 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
177 Some(Ok(msg)) = read_side.next() => {
178 match msg {
179 Message::Text(data) => {
180 let str = data.as_str();
181 let event = serde_json::from_str::<DeserializedEvent>(str);
182 if event.is_err() {
183 continue
184 }
185 let _ = event_sender.send(event?);
186 },
187 Message::Close(_) => {
188 let _ = connection_close_signal_sender.send(());
189 return Err(anyhow::anyhow!("close"));
190 },
191 _ => ()
192 }
193 }
194 }
195 }
196 }
197
198 pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
199 if self.api_receiver.is_none() && self.event_sender.is_none() {
200 return Err(ServiceStartError::NotInjected);
201 } else if self.event_sender.is_none() {
202 return Err(ServiceStartError::NotInjectedEventSender);
203 } else if self.api_receiver.is_none() {
204 return Err(ServiceStartError::NotInjectedAPIReceiver);
205 }
206
207 let api_receiver = self.api_receiver.clone().unwrap();
208 let event_sender = self.event_sender.clone().unwrap();
209
210 let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
211
212 tokio::spawn(Self::read_processor(
213 read_side,
214 event_sender,
215 self.close_signal_sender.subscribe(),
216 self.connection_close_signal_sender.clone(),
217 ));
218 tokio::spawn(Self::send_processor(
219 send_side,
220 api_receiver,
221 self.close_signal_sender.subscribe(),
222 self.connection_close_signal_sender.subscribe(),
223 ));
224 Ok(())
225 }
226
227 pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
228 if reconnect_times > self.max_reconnect_times {
229 return Err(anyhow::anyhow!("over max reconnect times"));
230 }
231 tokio::time::sleep(self.reconnect_interval).await;
232 if self.spawn_processor().await.is_err() {
233 Box::pin(self.reconnect(reconnect_times + 1)).await
234 } else {
235 Ok(())
236 }
237 }
238
239 pub async fn reconnect_processor(self) -> anyhow::Result<()> {
240 let mut close_signal = self.close_signal_sender.subscribe();
241 let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
242 loop {
243 select! {
244 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
245 _ = connection_close_signal.recv() => self.reconnect(1).await?
246 }
247 }
248 }
249}
250
251#[async_trait]
252impl CommunicationService for WsService {
253 fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
254 self.api_receiver = Some(api_receiver);
255 self.event_sender = Some(event_sender);
256 }
257
258 fn uninstall(&mut self) {
259 self.stop();
260 self.api_receiver = None;
261 self.event_sender = None;
262 }
263
264 fn stop(&self) {
265 let _ = self.close_signal_sender.send(());
266 self.is_running.store(false, Ordering::Relaxed);
267 }
268
269 async fn start(&self) -> ServiceStartResult<()> {
270 if self.is_running.load(Ordering::Relaxed) {
271 return Err(ServiceStartError::TaskIsRunning);
272 }
273
274 self.spawn_processor().await?;
275 self.is_running.store(true, Ordering::Relaxed);
276 if self.auto_reconnect {
277 tokio::spawn(Self::reconnect_processor(self.clone()));
278 }
279 Ok(())
280 }
281}