onebot_api/communication/
ws_reverse.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use axum::Router;
5use axum::body::Body;
6use axum::extract::ws::{Message, WebSocket};
7use axum::extract::{State, WebSocketUpgrade};
8use axum::response::Response;
9use axum::routing::any;
10use futures::stream::{SplitSink, SplitStream};
11use futures::{SinkExt, StreamExt};
12use http::{HeaderMap, StatusCode};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use tokio::net::{TcpListener, ToSocketAddrs};
16use tokio::select;
17use tokio::sync::broadcast;
18
19pub struct WsReverseService<T: ToSocketAddrs + Clone + Send + Sync> {
20 api_receiver: Option<InternalAPIReceiver>,
21 event_sender: Option<InternalEventSender>,
22 close_signal_sender: broadcast::Sender<()>,
23 access_token: Option<String>,
24 addr: T,
25 is_running: Arc<AtomicBool>,
26}
27
28impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for WsReverseService<T> {
29 fn drop(&mut self) {
30 self.uninstall();
31 }
32}
33
34impl<T: ToSocketAddrs + Clone + Send + Sync> WsReverseService<T> {
35 pub fn new(addr: T, access_token: Option<String>) -> Self {
36 let (close_signal_sender, _) = broadcast::channel(1);
37 Self {
38 api_receiver: None,
39 event_sender: None,
40 close_signal_sender,
41 access_token,
42 addr,
43 is_running: Arc::new(AtomicBool::new(false)),
44 }
45 }
46}
47
48struct AppState {
49 access_token: Option<String>,
50 api_receiver: InternalAPIReceiver,
51 event_sender: InternalEventSender,
52 close_signal_sender: broadcast::Sender<()>,
53 connected: Arc<AtomicBool>,
54}
55
56async fn send_processor(
57 mut send_side: SplitSink<WebSocket, Message>,
58 api_receiver: InternalAPIReceiver,
59 mut close_signal: broadcast::Receiver<()>,
60 mut connection_close_signal: broadcast::Receiver<()>,
61) -> anyhow::Result<()> {
62 loop {
63 select! {
64 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
65 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
66 Ok(data) = api_receiver.recv_async() => {
67 let str = serde_json::to_string(&data);
68 if str.is_err() {
69 continue
70 }
71 let _ = send_side.send(Message::Text(str?.into())).await;
72 }
73 }
74 }
75}
76
77async fn read_processor(
78 mut read_side: SplitStream<WebSocket>,
79 event_sender: InternalEventSender,
80 mut close_signal: broadcast::Receiver<()>,
81 connection_close_signal_sender: broadcast::Sender<()>,
82 connected: Arc<AtomicBool>,
83) -> anyhow::Result<()> {
84 loop {
85 select! {
86 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
87 Some(Ok(msg)) = read_side.next() => {
88 match msg {
89 Message::Text(data) => {
90 let str = data.as_str();
91 let event = serde_json::from_str::<DeserializedEvent>(str);
92 if event.is_err() {
93 continue
94 }
95 let _ = event_sender.send_async(event?).await;
96 },
97 Message::Close(_) => {
98 let _ = connection_close_signal_sender.send(());
99 connected.store(false, Ordering::Relaxed);
100 return Err(anyhow::anyhow!("close"));
101 },
102 _ => ()
103 }
104 }
105 }
106 }
107}
108
109async fn handler(
110 headers: HeaderMap,
111 State(state): State<Arc<AppState>>,
112 ws: WebSocketUpgrade,
113) -> Response {
114 if state.connected.load(Ordering::Relaxed) {
115 return Response::builder()
116 .status(StatusCode::FORBIDDEN)
117 .body(Body::from(""))
118 .unwrap();
119 }
120 if state.access_token.is_some() {
121 let received_token = headers.get("Authorization").map(|v| v.to_str().unwrap());
122 if received_token.is_none() {
123 return Response::builder()
124 .status(StatusCode::UNAUTHORIZED)
125 .body(Body::from(""))
126 .unwrap();
127 }
128 let received_token = received_token.unwrap();
129 let access_token = state.access_token.clone().unwrap();
130 if received_token != "Bearer ".to_string() + &access_token {
131 return Response::builder()
132 .status(StatusCode::FORBIDDEN)
133 .body(Body::from(""))
134 .unwrap();
135 }
136 }
137 ws.on_upgrade(async move |socket: WebSocket| {
138 let (send_side, read_side) = socket.split();
139 let (connection_close_signal_sender, connection_close_signal) = broadcast::channel(1);
140 let api_receiver = state.api_receiver.clone();
141 let event_sender = state.event_sender.clone();
142 state.connected.store(true, Ordering::Relaxed);
143 let send_task = tokio::spawn(send_processor(
144 send_side,
145 api_receiver,
146 state.close_signal_sender.subscribe(),
147 connection_close_signal,
148 ));
149 let read_task = tokio::spawn(read_processor(
150 read_side,
151 event_sender,
152 state.close_signal_sender.subscribe(),
153 connection_close_signal_sender,
154 Arc::clone(&state.connected),
155 ));
156 let (r1, r2) = futures::try_join!(send_task, read_task).unwrap();
157 r1.and(r2).unwrap();
158 })
159}
160
161#[async_trait]
162impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for WsReverseService<T> {
163 fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
164 self.api_receiver = Some(api_receiver);
165 self.event_sender = Some(event_sender);
166 }
167
168 fn uninstall(&mut self) {
169 self.stop();
170 self.api_receiver = None;
171 self.event_sender = None;
172 }
173
174 fn stop(&self) {
175 let _ = self.close_signal_sender.send(());
176 self.is_running.store(false, Ordering::Relaxed);
177 }
178
179 async fn start(&self) -> ServiceStartResult<()> {
180 if self.is_running.load(Ordering::Relaxed) {
181 return Err(ServiceStartError::TaskIsRunning);
182 }
183
184 if self.api_receiver.is_none() && self.event_sender.is_none() {
185 return Err(ServiceStartError::NotInjected);
186 } else if self.event_sender.is_none() {
187 return Err(ServiceStartError::NotInjectedEventSender);
188 } else if self.api_receiver.is_none() {
189 return Err(ServiceStartError::NotInjectedAPIReceiver);
190 }
191
192 let api_receiver = self.api_receiver.clone().unwrap();
193 let event_sender = self.event_sender.clone().unwrap();
194
195 let state = Arc::new(AppState {
196 access_token: self.access_token.clone(),
197 api_receiver,
198 event_sender,
199 close_signal_sender: self.close_signal_sender.clone(),
200 connected: Arc::new(AtomicBool::new(false)),
201 });
202
203 let listener = TcpListener::bind(self.addr.clone()).await?;
204 let router = Router::new().fallback(any(handler)).with_state(state);
205 let mut close_signal = self.close_signal_sender.subscribe();
206
207 self.is_running.store(true, Ordering::Relaxed);
208 tokio::spawn(
209 axum::serve(listener, router)
210 .with_graceful_shutdown(async move {
211 let _ = close_signal.recv().await;
212 })
213 .into_future(),
214 );
215
216 Ok(())
217 }
218}