Skip to main content

onebot_api/communication/
ws_reverse.rs

1use super::utils::*;
2use async_trait::async_trait;
3use axum::Router;
4use axum::body::Body;
5use axum::extract::ws::{Message, WebSocket};
6use axum::extract::{State, WebSocketUpgrade};
7use axum::response::Response;
8use axum::routing::any;
9use futures::stream::{SplitSink, SplitStream};
10use futures::{SinkExt, StreamExt};
11use http::{HeaderMap, StatusCode};
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, Ordering};
14use tokio::net::{TcpListener, ToSocketAddrs};
15use tokio::select;
16use tokio::sync::broadcast;
17
18pub struct WsReverseService<T: ToSocketAddrs + Clone + Send + Sync> {
19	api_receiver: Option<APIReceiver>,
20	event_sender: Option<EventSender>,
21	close_signal_sender: broadcast::Sender<()>,
22	access_token: Option<String>,
23	addr: T,
24}
25
26impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for WsReverseService<T> {
27	fn drop(&mut self) {
28		let _ = self.close_signal_sender.send(());
29	}
30}
31
32impl<T: ToSocketAddrs + Clone + Send + Sync> WsReverseService<T> {
33	pub fn new(addr: T, access_token: Option<String>) -> Self {
34		let (close_signal_sender, _) = broadcast::channel(1);
35		Self {
36			api_receiver: None,
37			event_sender: None,
38			close_signal_sender,
39			access_token,
40			addr,
41		}
42	}
43}
44
45struct AppState {
46	access_token: Option<String>,
47	api_receiver: APIReceiver,
48	event_sender: EventSender,
49	close_signal_sender: broadcast::Sender<()>,
50	connected: Arc<AtomicBool>,
51}
52
53async fn send_processor(
54	mut send_side: SplitSink<WebSocket, Message>,
55	api_receiver: APIReceiver,
56	mut close_signal: broadcast::Receiver<()>,
57	mut connection_close_signal: broadcast::Receiver<()>,
58) -> anyhow::Result<()> {
59	loop {
60		select! {
61			_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
62			_ = connection_close_signal.recv() => return Err(anyhow::anyhow!("close")),
63			Ok(data) = api_receiver.recv_async() => {
64				let str = serde_json::to_string(&data);
65				if str.is_err() {
66					continue
67				}
68				let _ = send_side.send(Message::Text(str?.into())).await;
69			}
70		}
71	}
72}
73
74async fn read_processor(
75	mut read_side: SplitStream<WebSocket>,
76	event_sender: EventSender,
77	mut close_signal: broadcast::Receiver<()>,
78	connection_close_signal_sender: broadcast::Sender<()>,
79	connected: Arc<AtomicBool>,
80) -> anyhow::Result<()> {
81	loop {
82		select! {
83			_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
84			Some(Ok(msg)) = read_side.next() => {
85				match msg {
86					Message::Text(data) => {
87						let str = data.as_str();
88						let event = serde_json::from_str::<Event>(str);
89						if event.is_err() {
90							continue
91						}
92						let event = Arc::new(event?);
93						let _ = event_sender.send(event);
94					},
95					Message::Close(_) => {
96						let _ = connection_close_signal_sender.send(());
97						connected.store(false, Ordering::Relaxed);
98						return Err(anyhow::anyhow!("close"));
99					},
100					_ => ()
101				}
102			}
103		}
104	}
105}
106
107async fn handler(
108	headers: HeaderMap,
109	State(state): State<Arc<AppState>>,
110	ws: WebSocketUpgrade,
111) -> Response {
112	if state.connected.load(Ordering::Relaxed) {
113		return Response::builder()
114			.status(StatusCode::FORBIDDEN)
115			.body(Body::from(""))
116			.unwrap();
117	}
118	if state.access_token.is_some() {
119		let received_token = headers.get("Authorization").map(|v| v.to_str().unwrap());
120		if received_token.is_none() {
121			return Response::builder()
122				.status(StatusCode::UNAUTHORIZED)
123				.body(Body::from(""))
124				.unwrap();
125		}
126		let received_token = received_token.unwrap();
127		let access_token = state.access_token.clone().unwrap();
128		if received_token != "Bearer ".to_string() + &access_token {
129			return Response::builder()
130				.status(StatusCode::FORBIDDEN)
131				.body(Body::from(""))
132				.unwrap();
133		}
134	}
135	ws.on_upgrade(async move |socket: WebSocket| {
136		let (send_side, read_side) = socket.split();
137		let (connection_close_signal_sender, connection_close_signal) = broadcast::channel(1);
138		let api_receiver = state.api_receiver.clone();
139		let event_sender = state.event_sender.clone();
140		state.connected.store(true, Ordering::Relaxed);
141		let send_task = tokio::spawn(send_processor(
142			send_side,
143			api_receiver,
144			state.close_signal_sender.subscribe(),
145			connection_close_signal,
146		));
147		let read_task = tokio::spawn(read_processor(
148			read_side,
149			event_sender,
150			state.close_signal_sender.subscribe(),
151			connection_close_signal_sender,
152			Arc::clone(&state.connected),
153		));
154		let (r1, r2) = futures::try_join!(send_task, read_task).unwrap();
155		r1.and(r2).unwrap();
156	})
157}
158
159#[async_trait]
160impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for WsReverseService<T> {
161	fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
162		self.api_receiver = Some(api_receiver);
163		self.event_sender = Some(event_sender);
164	}
165
166	async fn start_service(&self) -> anyhow::Result<()> {
167		if self.api_receiver.is_none() || self.event_sender.is_none() {
168			return Err(anyhow::anyhow!("api receiver or event sender is none"));
169		}
170
171		let api_receiver = self.api_receiver.clone().unwrap();
172		let event_sender = self.event_sender.clone().unwrap();
173
174		let state = Arc::new(AppState {
175			access_token: self.access_token.clone(),
176			api_receiver,
177			event_sender,
178			close_signal_sender: self.close_signal_sender.clone(),
179			connected: Arc::new(AtomicBool::new(false)),
180		});
181
182		let listener = TcpListener::bind(self.addr.clone()).await?;
183		let router = Router::new().fallback(any(handler)).with_state(state);
184		let mut close_signal = self.close_signal_sender.subscribe();
185
186		tokio::spawn(
187			axum::serve(listener, router)
188				.with_graceful_shutdown(async move {
189					let _ = close_signal.recv().await;
190				})
191				.into_future(),
192		);
193
194		Ok(())
195	}
196}