onebot_api/communication/
ws_reverse.rs1use super::utils::*;
2use async_trait::async_trait;
3use axum::Router;
4use axum::body::Body;
5use axum::extract::ws::{Message, WebSocket};
6use axum::extract::{State, WebSocketUpgrade};
7use axum::response::Response;
8use axum::routing::any;
9use futures::stream::{SplitSink, SplitStream};
10use futures::{SinkExt, StreamExt};
11use http::{HeaderMap, StatusCode};
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, Ordering};
14use tokio::net::{TcpListener, ToSocketAddrs};
15use tokio::select;
16use tokio::sync::broadcast;
17
18pub struct WsReverseService<T: ToSocketAddrs + Clone + Send + Sync> {
19 api_receiver: Option<APIReceiver>,
20 event_sender: Option<EventSender>,
21 close_signal_sender: broadcast::Sender<()>,
22 access_token: Option<String>,
23 addr: T,
24}
25
26impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for WsReverseService<T> {
27 fn drop(&mut self) {
28 let _ = self.close_signal_sender.send(());
29 }
30}
31
32impl<T: ToSocketAddrs + Clone + Send + Sync> WsReverseService<T> {
33 pub fn new(addr: T, access_token: Option<String>) -> Self {
34 let (close_signal_sender, _) = broadcast::channel(1);
35 Self {
36 api_receiver: None,
37 event_sender: None,
38 close_signal_sender,
39 access_token,
40 addr,
41 }
42 }
43}
44
45struct AppState {
46 access_token: Option<String>,
47 api_receiver: APIReceiver,
48 event_sender: EventSender,
49 close_signal_sender: broadcast::Sender<()>,
50 connected: Arc<AtomicBool>,
51}
52
53async fn send_processor(
54 mut send_side: SplitSink<WebSocket, Message>,
55 api_receiver: APIReceiver,
56 mut close_signal: broadcast::Receiver<()>,
57 mut connection_close_signal: broadcast::Receiver<()>,
58) -> anyhow::Result<()> {
59 loop {
60 select! {
61 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
62 _ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
63 Ok(data) = api_receiver.recv_async() => {
64 let str = serde_json::to_string(&data);
65 if str.is_err() {
66 continue
67 }
68 let _ = send_side.send(Message::Text(str?.into())).await;
69 }
70 }
71 }
72}
73
74async fn read_processor(
75 mut read_side: SplitStream<WebSocket>,
76 event_sender: EventSender,
77 mut close_signal: broadcast::Receiver<()>,
78 connection_close_signal_sender: broadcast::Sender<()>,
79 connected: Arc<AtomicBool>,
80) -> anyhow::Result<()> {
81 loop {
82 select! {
83 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
84 Some(Ok(msg)) = read_side.next() => {
85 match msg {
86 Message::Text(data) => {
87 let str = data.as_str();
88 let event = serde_json::from_str::<Event>(str);
89 if event.is_err() {
90 continue
91 }
92 let event = Arc::new(event?);
93 let _ = event_sender.send(event);
94 },
95 Message::Close(_) => {
96 let _ = connection_close_signal_sender.send(());
97 connected.store(false, Ordering::Relaxed);
98 return Err(anyhow::anyhow!("close"));
99 },
100 _ => ()
101 }
102 }
103 }
104 }
105}
106
107async fn handler(
108 headers: HeaderMap,
109 State(state): State<Arc<AppState>>,
110 ws: WebSocketUpgrade,
111) -> Response {
112 if state.connected.load(Ordering::Relaxed) {
113 return Response::builder()
114 .status(StatusCode::FORBIDDEN)
115 .body(Body::from(""))
116 .unwrap();
117 }
118 if state.access_token.is_some() {
119 let received_token = headers.get("Authorization").map(|v| v.to_str().unwrap());
120 if received_token.is_none() {
121 return Response::builder()
122 .status(StatusCode::UNAUTHORIZED)
123 .body(Body::from(""))
124 .unwrap();
125 }
126 let received_token = received_token.unwrap();
127 let access_token = state.access_token.clone().unwrap();
128 if received_token != "Bearer ".to_string() + &access_token {
129 return Response::builder()
130 .status(StatusCode::FORBIDDEN)
131 .body(Body::from(""))
132 .unwrap();
133 }
134 }
135 ws.on_upgrade(async move |socket: WebSocket| {
136 let (send_side, read_side) = socket.split();
137 let (connection_close_signal_sender, connection_close_signal) = broadcast::channel(1);
138 let api_receiver = state.api_receiver.clone();
139 let event_sender = state.event_sender.clone();
140 state.connected.store(true, Ordering::Relaxed);
141 let send_task = tokio::spawn(send_processor(
142 send_side,
143 api_receiver,
144 state.close_signal_sender.subscribe(),
145 connection_close_signal,
146 ));
147 let read_task = tokio::spawn(read_processor(
148 read_side,
149 event_sender,
150 state.close_signal_sender.subscribe(),
151 connection_close_signal_sender,
152 Arc::clone(&state.connected),
153 ));
154 let (r1, r2) = futures::try_join!(send_task, read_task).unwrap();
155 r1.and(r2).unwrap();
156 })
157}
158
159#[async_trait]
160impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for WsReverseService<T> {
161 fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
162 self.api_receiver = Some(api_receiver);
163 self.event_sender = Some(event_sender);
164 }
165
166 async fn start_service(&self) -> anyhow::Result<()> {
167 if self.api_receiver.is_none() || self.event_sender.is_none() {
168 return Err(anyhow::anyhow!("api receiver or event sender is none"));
169 }
170
171 let api_receiver = self.api_receiver.clone().unwrap();
172 let event_sender = self.event_sender.clone().unwrap();
173
174 let state = Arc::new(AppState {
175 access_token: self.access_token.clone(),
176 api_receiver,
177 event_sender,
178 close_signal_sender: self.close_signal_sender.clone(),
179 connected: Arc::new(AtomicBool::new(false)),
180 });
181
182 let listener = TcpListener::bind(self.addr.clone()).await?;
183 let router = Router::new().fallback(any(handler)).with_state(state);
184 let mut close_signal = self.close_signal_sender.subscribe();
185
186 tokio::spawn(
187 axum::serve(listener, router)
188 .with_graceful_shutdown(async move {
189 let _ = close_signal.recv().await;
190 })
191 .into_future(),
192 );
193
194 Ok(())
195 }
196}