Skip to main content

onebot_api/communication/
ws_reverse.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use axum::Router;
5use axum::body::Body;
6use axum::extract::ws::{Message, WebSocket};
7use axum::extract::{State, WebSocketUpgrade};
8use axum::response::Response;
9use axum::routing::any;
10use futures::stream::{SplitSink, SplitStream};
11use futures::{SinkExt, StreamExt};
12use http::{HeaderMap, StatusCode};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use tokio::net::{TcpListener, ToSocketAddrs};
16use tokio::select;
17use tokio::sync::broadcast;
18
19pub struct WsReverseService<T: ToSocketAddrs + Clone + Send + Sync> {
20	api_receiver: Option<APIReceiver>,
21	event_sender: Option<EventSender>,
22	close_signal_sender: broadcast::Sender<()>,
23	access_token: Option<String>,
24	addr: T,
25}
26
27impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for WsReverseService<T> {
28	fn drop(&mut self) {
29		let _ = self.close_signal_sender.send(());
30	}
31}
32
33impl<T: ToSocketAddrs + Clone + Send + Sync> WsReverseService<T> {
34	pub fn new(addr: T, access_token: Option<String>) -> Self {
35		let (close_signal_sender, _) = broadcast::channel(1);
36		Self {
37			api_receiver: None,
38			event_sender: None,
39			close_signal_sender,
40			access_token,
41			addr,
42		}
43	}
44}
45
46struct AppState {
47	access_token: Option<String>,
48	api_receiver: APIReceiver,
49	event_sender: EventSender,
50	close_signal_sender: broadcast::Sender<()>,
51	connected: Arc<AtomicBool>,
52}
53
54async fn send_processor(
55	mut send_side: SplitSink<WebSocket, Message>,
56	api_receiver: APIReceiver,
57	mut close_signal: broadcast::Receiver<()>,
58	mut connection_close_signal: broadcast::Receiver<()>,
59) -> anyhow::Result<()> {
60	loop {
61		select! {
62			_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
63			_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
64			Ok(data) = api_receiver.recv_async() => {
65				let str = serde_json::to_string(&data);
66				if str.is_err() {
67					continue
68				}
69				let _ = send_side.send(Message::Text(str?.into())).await;
70			}
71		}
72	}
73}
74
75async fn read_processor(
76	mut read_side: SplitStream<WebSocket>,
77	event_sender: EventSender,
78	mut close_signal: broadcast::Receiver<()>,
79	connection_close_signal_sender: broadcast::Sender<()>,
80	connected: Arc<AtomicBool>,
81) -> anyhow::Result<()> {
82	loop {
83		select! {
84			_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
85			Some(Ok(msg)) = read_side.next() => {
86				match msg {
87					Message::Text(data) => {
88						let str = data.as_str();
89						let event = serde_json::from_str::<Event>(str);
90						if event.is_err() {
91							continue
92						}
93						let event = Arc::new(event?);
94						let _ = event_sender.send(event);
95					},
96					Message::Close(_) => {
97						let _ = connection_close_signal_sender.send(());
98						connected.store(false, Ordering::Relaxed);
99						return Err(anyhow::anyhow!("close"));
100					},
101					_ => ()
102				}
103			}
104		}
105	}
106}
107
108async fn handler(
109	headers: HeaderMap,
110	State(state): State<Arc<AppState>>,
111	ws: WebSocketUpgrade,
112) -> Response {
113	if state.connected.load(Ordering::Relaxed) {
114		return Response::builder()
115			.status(StatusCode::FORBIDDEN)
116			.body(Body::from(""))
117			.unwrap();
118	}
119	if state.access_token.is_some() {
120		let received_token = headers.get("Authorization").map(|v| v.to_str().unwrap());
121		if received_token.is_none() {
122			return Response::builder()
123				.status(StatusCode::UNAUTHORIZED)
124				.body(Body::from(""))
125				.unwrap();
126		}
127		let received_token = received_token.unwrap();
128		let access_token = state.access_token.clone().unwrap();
129		if received_token != "Bearer ".to_string() + &access_token {
130			return Response::builder()
131				.status(StatusCode::FORBIDDEN)
132				.body(Body::from(""))
133				.unwrap();
134		}
135	}
136	ws.on_upgrade(async move |socket: WebSocket| {
137		let (send_side, read_side) = socket.split();
138		let (connection_close_signal_sender, connection_close_signal) = broadcast::channel(1);
139		let api_receiver = state.api_receiver.clone();
140		let event_sender = state.event_sender.clone();
141		state.connected.store(true, Ordering::Relaxed);
142		let send_task = tokio::spawn(send_processor(
143			send_side,
144			api_receiver,
145			state.close_signal_sender.subscribe(),
146			connection_close_signal,
147		));
148		let read_task = tokio::spawn(read_processor(
149			read_side,
150			event_sender,
151			state.close_signal_sender.subscribe(),
152			connection_close_signal_sender,
153			Arc::clone(&state.connected),
154		));
155		let (r1, r2) = futures::try_join!(send_task, read_task).unwrap();
156		r1.and(r2).unwrap();
157	})
158}
159
160#[async_trait]
161impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for WsReverseService<T> {
162	fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
163		self.api_receiver = Some(api_receiver);
164		self.event_sender = Some(event_sender);
165	}
166
167	async fn start_service(&self) -> ServiceStartResult<()> {
168		if self.api_receiver.is_none() && self.event_sender.is_none() {
169			return Err(ServiceStartError::NotInjected);
170		} else if self.event_sender.is_none() {
171			return Err(ServiceStartError::NotInjectedEventSender);
172		} else if self.api_receiver.is_none() {
173			return Err(ServiceStartError::NotInjectedAPIReceiver);
174		}
175
176		let api_receiver = self.api_receiver.clone().unwrap();
177		let event_sender = self.event_sender.clone().unwrap();
178
179		let state = Arc::new(AppState {
180			access_token: self.access_token.clone(),
181			api_receiver,
182			event_sender,
183			close_signal_sender: self.close_signal_sender.clone(),
184			connected: Arc::new(AtomicBool::new(false)),
185		});
186
187		let listener = TcpListener::bind(self.addr.clone()).await?;
188		let router = Router::new().fallback(any(handler)).with_state(state);
189		let mut close_signal = self.close_signal_sender.subscribe();
190
191		tokio::spawn(
192			axum::serve(listener, router)
193				.with_graceful_shutdown(async move {
194					let _ = close_signal.recv().await;
195				})
196				.into_future(),
197		);
198
199		Ok(())
200	}
201}