onebot_api/communication/
ws_reverse.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use axum::Router;
5use axum::body::Body;
6use axum::extract::ws::{Message, WebSocket};
7use axum::extract::{State, WebSocketUpgrade};
8use axum::response::Response;
9use axum::routing::any;
10use futures::stream::{SplitSink, SplitStream};
11use futures::{SinkExt, StreamExt};
12use http::{HeaderMap, StatusCode};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use tokio::net::{TcpListener, ToSocketAddrs};
16use tokio::select;
17use tokio::sync::broadcast;
18
19pub struct WsReverseService<T: ToSocketAddrs + Clone + Send + Sync> {
20 api_receiver: Option<APIReceiver>,
21 event_sender: Option<EventSender>,
22 close_signal_sender: broadcast::Sender<()>,
23 access_token: Option<String>,
24 addr: T,
25}
26
27impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for WsReverseService<T> {
28 fn drop(&mut self) {
29 let _ = self.close_signal_sender.send(());
30 }
31}
32
33impl<T: ToSocketAddrs + Clone + Send + Sync> WsReverseService<T> {
34 pub fn new(addr: T, access_token: Option<String>) -> Self {
35 let (close_signal_sender, _) = broadcast::channel(1);
36 Self {
37 api_receiver: None,
38 event_sender: None,
39 close_signal_sender,
40 access_token,
41 addr,
42 }
43 }
44}
45
46struct AppState {
47 access_token: Option<String>,
48 api_receiver: APIReceiver,
49 event_sender: EventSender,
50 close_signal_sender: broadcast::Sender<()>,
51 connected: Arc<AtomicBool>,
52}
53
54async fn send_processor(
55 mut send_side: SplitSink<WebSocket, Message>,
56 api_receiver: APIReceiver,
57 mut close_signal: broadcast::Receiver<()>,
58 mut connection_close_signal: broadcast::Receiver<()>,
59) -> anyhow::Result<()> {
60 loop {
61 select! {
62 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
63 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
64 Ok(data) = api_receiver.recv_async() => {
65 let str = serde_json::to_string(&data);
66 if str.is_err() {
67 continue
68 }
69 let _ = send_side.send(Message::Text(str?.into())).await;
70 }
71 }
72 }
73}
74
75async fn read_processor(
76 mut read_side: SplitStream<WebSocket>,
77 event_sender: EventSender,
78 mut close_signal: broadcast::Receiver<()>,
79 connection_close_signal_sender: broadcast::Sender<()>,
80 connected: Arc<AtomicBool>,
81) -> anyhow::Result<()> {
82 loop {
83 select! {
84 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
85 Some(Ok(msg)) = read_side.next() => {
86 match msg {
87 Message::Text(data) => {
88 let str = data.as_str();
89 let event = serde_json::from_str::<Event>(str);
90 if event.is_err() {
91 continue
92 }
93 let event = Arc::new(event?);
94 let _ = event_sender.send(event);
95 },
96 Message::Close(_) => {
97 let _ = connection_close_signal_sender.send(());
98 connected.store(false, Ordering::Relaxed);
99 return Err(anyhow::anyhow!("close"));
100 },
101 _ => ()
102 }
103 }
104 }
105 }
106}
107
108async fn handler(
109 headers: HeaderMap,
110 State(state): State<Arc<AppState>>,
111 ws: WebSocketUpgrade,
112) -> Response {
113 if state.connected.load(Ordering::Relaxed) {
114 return Response::builder()
115 .status(StatusCode::FORBIDDEN)
116 .body(Body::from(""))
117 .unwrap();
118 }
119 if state.access_token.is_some() {
120 let received_token = headers.get("Authorization").map(|v| v.to_str().unwrap());
121 if received_token.is_none() {
122 return Response::builder()
123 .status(StatusCode::UNAUTHORIZED)
124 .body(Body::from(""))
125 .unwrap();
126 }
127 let received_token = received_token.unwrap();
128 let access_token = state.access_token.clone().unwrap();
129 if received_token != "Bearer ".to_string() + &access_token {
130 return Response::builder()
131 .status(StatusCode::FORBIDDEN)
132 .body(Body::from(""))
133 .unwrap();
134 }
135 }
136 ws.on_upgrade(async move |socket: WebSocket| {
137 let (send_side, read_side) = socket.split();
138 let (connection_close_signal_sender, connection_close_signal) = broadcast::channel(1);
139 let api_receiver = state.api_receiver.clone();
140 let event_sender = state.event_sender.clone();
141 state.connected.store(true, Ordering::Relaxed);
142 let send_task = tokio::spawn(send_processor(
143 send_side,
144 api_receiver,
145 state.close_signal_sender.subscribe(),
146 connection_close_signal,
147 ));
148 let read_task = tokio::spawn(read_processor(
149 read_side,
150 event_sender,
151 state.close_signal_sender.subscribe(),
152 connection_close_signal_sender,
153 Arc::clone(&state.connected),
154 ));
155 let (r1, r2) = futures::try_join!(send_task, read_task).unwrap();
156 r1.and(r2).unwrap();
157 })
158}
159
160#[async_trait]
161impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for WsReverseService<T> {
162 fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
163 self.api_receiver = Some(api_receiver);
164 self.event_sender = Some(event_sender);
165 }
166
167 async fn start_service(&self) -> ServiceStartResult<()> {
168 if self.api_receiver.is_none() && self.event_sender.is_none() {
169 return Err(ServiceStartError::NotInjected);
170 } else if self.event_sender.is_none() {
171 return Err(ServiceStartError::NotInjectedEventSender);
172 } else if self.api_receiver.is_none() {
173 return Err(ServiceStartError::NotInjectedAPIReceiver);
174 }
175
176 let api_receiver = self.api_receiver.clone().unwrap();
177 let event_sender = self.event_sender.clone().unwrap();
178
179 let state = Arc::new(AppState {
180 access_token: self.access_token.clone(),
181 api_receiver,
182 event_sender,
183 close_signal_sender: self.close_signal_sender.clone(),
184 connected: Arc::new(AtomicBool::new(false)),
185 });
186
187 let listener = TcpListener::bind(self.addr.clone()).await?;
188 let router = Router::new().fallback(any(handler)).with_state(state);
189 let mut close_signal = self.close_signal_sender.subscribe();
190
191 tokio::spawn(
192 axum::serve(listener, router)
193 .with_graceful_shutdown(async move {
194 let _ = close_signal.recv().await;
195 })
196 .into_future(),
197 );
198
199 Ok(())
200 }
201}