Skip to main content

onebot_api/communication/
ws_reverse.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use axum::Router;
5use axum::body::Body;
6use axum::extract::ws::{Message, WebSocket};
7use axum::extract::{State, WebSocketUpgrade};
8use axum::response::Response;
9use axum::routing::any;
10use futures::stream::{SplitSink, SplitStream};
11use futures::{SinkExt, StreamExt};
12use http::{HeaderMap, StatusCode};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use tokio::net::{TcpListener, ToSocketAddrs};
16use tokio::select;
17use tokio::sync::broadcast;
18
19pub struct WsReverseService<T: ToSocketAddrs + Clone + Send + Sync> {
20	api_receiver: Option<InternalAPIReceiver>,
21	event_sender: Option<InternalEventSender>,
22	close_signal_sender: broadcast::Sender<()>,
23	access_token: Option<String>,
24	addr: T,
25	is_running: Arc<AtomicBool>,
26}
27
28impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for WsReverseService<T> {
29	fn drop(&mut self) {
30		self.uninstall();
31	}
32}
33
34impl<T: ToSocketAddrs + Clone + Send + Sync> WsReverseService<T> {
35	pub fn new(addr: T, access_token: Option<String>) -> Self {
36		let (close_signal_sender, _) = broadcast::channel(1);
37		Self {
38			api_receiver: None,
39			event_sender: None,
40			close_signal_sender,
41			access_token,
42			addr,
43			is_running: Arc::new(AtomicBool::new(false)),
44		}
45	}
46}
47
48struct AppState {
49	access_token: Option<String>,
50	api_receiver: InternalAPIReceiver,
51	event_sender: InternalEventSender,
52	close_signal_sender: broadcast::Sender<()>,
53	connected: Arc<AtomicBool>,
54}
55
56async fn send_processor(
57	mut send_side: SplitSink<WebSocket, Message>,
58	api_receiver: InternalAPIReceiver,
59	mut close_signal: broadcast::Receiver<()>,
60	mut connection_close_signal: broadcast::Receiver<()>,
61) -> anyhow::Result<()> {
62	loop {
63		select! {
64			_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
65			_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
66			Ok(data) = api_receiver.recv_async() => {
67				let str = serde_json::to_string(&data);
68				if str.is_err() {
69					continue
70				}
71				let _ = send_side.send(Message::Text(str?.into())).await;
72			}
73		}
74	}
75}
76
77async fn read_processor(
78	mut read_side: SplitStream<WebSocket>,
79	event_sender: InternalEventSender,
80	mut close_signal: broadcast::Receiver<()>,
81	connection_close_signal_sender: broadcast::Sender<()>,
82	connected: Arc<AtomicBool>,
83) -> anyhow::Result<()> {
84	loop {
85		select! {
86			_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
87			Some(Ok(msg)) = read_side.next() => {
88				match msg {
89					Message::Text(data) => {
90						let str = data.as_str();
91						let event = serde_json::from_str::<DeserializedEvent>(str);
92						if event.is_err() {
93							continue
94						}
95						let _ = event_sender.send_async(event?).await;
96					},
97					Message::Close(_) => {
98						let _ = connection_close_signal_sender.send(());
99						connected.store(false, Ordering::Relaxed);
100						return Err(anyhow::anyhow!("close"));
101					},
102					_ => ()
103				}
104			}
105		}
106	}
107}
108
109async fn handler(
110	headers: HeaderMap,
111	State(state): State<Arc<AppState>>,
112	ws: WebSocketUpgrade,
113) -> Response {
114	if state.connected.load(Ordering::Relaxed) {
115		return Response::builder()
116			.status(StatusCode::FORBIDDEN)
117			.body(Body::from(""))
118			.unwrap();
119	}
120	if state.access_token.is_some() {
121		let received_token = headers.get("Authorization").map(|v| v.to_str().unwrap());
122		if received_token.is_none() {
123			return Response::builder()
124				.status(StatusCode::UNAUTHORIZED)
125				.body(Body::from(""))
126				.unwrap();
127		}
128		let received_token = received_token.unwrap();
129		let access_token = state.access_token.clone().unwrap();
130		if received_token != "Bearer ".to_string() + &access_token {
131			return Response::builder()
132				.status(StatusCode::FORBIDDEN)
133				.body(Body::from(""))
134				.unwrap();
135		}
136	}
137	ws.on_upgrade(async move |socket: WebSocket| {
138		let (send_side, read_side) = socket.split();
139		let (connection_close_signal_sender, connection_close_signal) = broadcast::channel(1);
140		let api_receiver = state.api_receiver.clone();
141		let event_sender = state.event_sender.clone();
142		state.connected.store(true, Ordering::Relaxed);
143		let send_task = tokio::spawn(send_processor(
144			send_side,
145			api_receiver,
146			state.close_signal_sender.subscribe(),
147			connection_close_signal,
148		));
149		let read_task = tokio::spawn(read_processor(
150			read_side,
151			event_sender,
152			state.close_signal_sender.subscribe(),
153			connection_close_signal_sender,
154			Arc::clone(&state.connected),
155		));
156		let (r1, r2) = futures::try_join!(send_task, read_task).unwrap();
157		r1.and(r2).unwrap();
158	})
159}
160
161#[async_trait]
162impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for WsReverseService<T> {
163	fn install(&mut self, api_receiver: InternalAPIReceiver, event_sender: InternalEventSender) {
164		self.api_receiver = Some(api_receiver);
165		self.event_sender = Some(event_sender);
166	}
167
168	fn uninstall(&mut self) {
169		self.stop();
170		self.api_receiver = None;
171		self.event_sender = None;
172	}
173
174	fn stop(&self) {
175		let _ = self.close_signal_sender.send(());
176		self.is_running.store(false, Ordering::Relaxed);
177	}
178
179	async fn start(&self) -> ServiceStartResult<()> {
180		if self.is_running.load(Ordering::Relaxed) {
181			return Err(ServiceStartError::TaskIsRunning);
182		}
183
184		if self.api_receiver.is_none() && self.event_sender.is_none() {
185			return Err(ServiceStartError::NotInjected);
186		} else if self.event_sender.is_none() {
187			return Err(ServiceStartError::NotInjectedEventSender);
188		} else if self.api_receiver.is_none() {
189			return Err(ServiceStartError::NotInjectedAPIReceiver);
190		}
191
192		let api_receiver = self.api_receiver.clone().unwrap();
193		let event_sender = self.event_sender.clone().unwrap();
194
195		let state = Arc::new(AppState {
196			access_token: self.access_token.clone(),
197			api_receiver,
198			event_sender,
199			close_signal_sender: self.close_signal_sender.clone(),
200			connected: Arc::new(AtomicBool::new(false)),
201		});
202
203		let listener = TcpListener::bind(self.addr.clone()).await?;
204		let router = Router::new().fallback(any(handler)).with_state(state);
205		let mut close_signal = self.close_signal_sender.subscribe();
206
207		self.is_running.store(true, Ordering::Relaxed);
208		tokio::spawn(
209			axum::serve(listener, router)
210				.with_graceful_shutdown(async move {
211					let _ = close_signal.recv().await;
212				})
213				.into_future(),
214		);
215
216		Ok(())
217	}
218}