onebot_api/communication/
ws.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::net::TcpStream;
11use tokio::select;
12use tokio::sync::broadcast;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15use url::Url;
16
17#[derive(Clone, Debug)]
18pub struct WsService {
19 url: Url,
20 access_token: Option<String>,
21 api_receiver: Option<APIReceiver>,
22 event_sender: Option<EventSender>,
23 close_signal_sender: broadcast::Sender<()>,
24 connection_close_signal_sender: broadcast::Sender<()>,
25 auto_reconnect: bool,
26 reconnect_interval: Duration,
27 max_reconnect_times: u32,
28}
29
30impl Drop for WsService {
31 fn drop(&mut self) {
32 let _ = self.close_signal_sender.send(());
33 }
34}
35
36impl WsService {
37 pub fn new(
38 url: impl IntoUrl,
39 access_token: Option<String>,
40 auto_reconnect: Option<bool>,
41 reconnect_interval: Option<Duration>,
42 max_reconnect_times: Option<u32>,
43 ) -> reqwest::Result<Self> {
44 let (close_signal_sender, _) = broadcast::channel(1);
45 let (connection_close_signal_sender, _) = broadcast::channel(1);
46 Ok(Self {
47 url: url.into_url()?,
48 access_token,
49 api_receiver: None,
50 event_sender: None,
51 close_signal_sender,
52 connection_close_signal_sender,
53 auto_reconnect: auto_reconnect.unwrap_or(true),
54 reconnect_interval: reconnect_interval.unwrap_or(Duration::from_secs(10)),
55 max_reconnect_times: max_reconnect_times.unwrap_or(5),
56 })
57 }
58}
59
60impl WsService {
61 pub fn get_url(&self) -> Url {
62 let mut url = self.url.clone();
63 if let Some(token) = &self.access_token {
64 let mut query_pairs = url.query_pairs_mut();
65 query_pairs.append_pair("access_token", token);
66 }
67 url
68 }
69
70 pub async fn connect(
71 url: impl ToString,
72 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
73 let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
74 Ok(stream)
75 }
76
77 pub async fn send_processor(
78 mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
79 api_receiver: APIReceiver,
80 mut close_signal: broadcast::Receiver<()>,
81 mut connection_close_signal: broadcast::Receiver<()>,
82 ) -> anyhow::Result<()> {
83 loop {
84 select! {
85 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
86 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
87 Ok(data) = api_receiver.recv_async() => {
88 let str = serde_json::to_string(&data);
89 if str.is_err() {
90 continue
91 }
92 let _ = send_side.send(Message::Text(str?.into())).await;
93 }
94 }
95 }
96 }
97
98 pub async fn read_processor(
99 mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
100 event_sender: EventSender,
101 mut close_signal: broadcast::Receiver<()>,
102 connection_close_signal_sender: broadcast::Sender<()>,
103 ) -> anyhow::Result<()> {
104 loop {
105 select! {
106 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
107 Some(Ok(msg)) = read_side.next() => {
108 match msg {
109 Message::Text(data) => {
110 let str = data.as_str();
111 let event = serde_json::from_str::<Event>(str);
112 if event.is_err() {
113 continue
114 }
115 let event = Arc::new(event?);
116 let _ = event_sender.send(event);
117 },
118 Message::Close(_) => {
119 let _ = connection_close_signal_sender.send(());
120 return Err(anyhow::anyhow!("close"));
121 },
122 _ => ()
123 }
124 }
125 }
126 }
127 }
128
129 pub async fn spawn_processor(&self) -> ServiceStartResult<()> {
130 if self.api_receiver.is_none() && self.event_sender.is_none() {
131 return Err(ServiceStartError::NotInjected);
132 } else if self.event_sender.is_none() {
133 return Err(ServiceStartError::NotInjectedEventSender);
134 } else if self.api_receiver.is_none() {
135 return Err(ServiceStartError::NotInjectedAPIReceiver);
136 }
137
138 let api_receiver = self.api_receiver.clone().unwrap();
139 let event_sender = self.event_sender.clone().unwrap();
140
141 let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
142
143 tokio::spawn(Self::read_processor(
144 read_side,
145 event_sender,
146 self.close_signal_sender.subscribe(),
147 self.connection_close_signal_sender.clone(),
148 ));
149 tokio::spawn(Self::send_processor(
150 send_side,
151 api_receiver,
152 self.close_signal_sender.subscribe(),
153 self.connection_close_signal_sender.subscribe(),
154 ));
155 Ok(())
156 }
157
158 pub async fn reconnect(&self, reconnect_times: u32) -> anyhow::Result<()> {
159 if reconnect_times > self.max_reconnect_times {
160 return Err(anyhow::anyhow!("over max reconnect times"));
161 }
162 tokio::time::sleep(self.reconnect_interval).await;
163 if self.spawn_processor().await.is_err() {
164 Box::pin(self.reconnect(reconnect_times + 1)).await
165 } else {
166 Ok(())
167 }
168 }
169
170 pub async fn reconnect_processor(self) -> anyhow::Result<()> {
171 let mut close_signal = self.close_signal_sender.subscribe();
172 let mut connection_close_signal = self.connection_close_signal_sender.subscribe();
173 loop {
174 select! {
175 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
176 _ = connection_close_signal.recv() => self.reconnect(1).await?
177 }
178 }
179 }
180}
181
182#[async_trait]
183impl CommunicationService for WsService {
184 fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
185 self.api_receiver = Some(api_receiver);
186 self.event_sender = Some(event_sender);
187 }
188
189 async fn start_service(&self) -> ServiceStartResult<()> {
190 self.spawn_processor().await?;
191 if self.auto_reconnect {
192 tokio::spawn(Self::reconnect_processor(self.clone()));
193 }
194 Ok(())
195 }
196}