1use eyre::{bail, eyre, ContextCompat, Result};
2use itertools::Itertools;
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::fs::File;
7use std::net::SocketAddr;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14use tokio::task::LocalSet;
15use tokio_tungstenite::tungstenite::Error as WsError;
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::WebSocketStream;
18use tracing::*;
19
20use crate::libs::error_code::ErrorCode;
21use crate::libs::handler::{RequestHandler, RequestHandlerErased};
22use crate::libs::listener::{ConnectionListener, TcpListener, TlsListener};
23use crate::libs::toolbox::{ArcToolbox, RequestContext, Toolbox, TOOLBOX};
24use crate::libs::utils::{get_conn_id, get_log_id};
25use crate::libs::ws::client::WsRequest;
26use crate::libs::ws::{VerifyProtocol, WsClientSession, WsConnection};
27use crate::model::EndpointSchema;
28
29use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint};
30
31pub struct WebsocketServer {
32 pub auth_controller: Arc<dyn AuthController>,
33 pub handlers: HashMap<u32, WsEndpoint>,
34 pub allowed_roles: HashMap<u32, HashSet<u32>>,
35 pub message_receiver: Option<mpsc::Receiver<ConnectionId>>,
36 pub toolbox: ArcToolbox,
37 pub config: WsServerConfig,
38}
39
40struct BufferedStream<S> {
42 buffer: Box<[u8]>,
43 stream: S,
44 pos: usize,
45}
46
47impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
48 fn poll_read(
49 mut self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 buf: &mut ReadBuf<'_>,
52 ) -> Poll<tokio::io::Result<()>> {
53 if self.pos < self.buffer.len() {
54 let remaining = self.buffer.len() - self.pos;
56 let len = std::cmp::min(remaining, buf.remaining());
57 buf.put_slice(&self.buffer[self.pos..self.pos + len]);
58 self.pos += len;
59 Poll::Ready(Ok(()))
60 } else {
61 Pin::new(&mut self.stream).poll_read(cx, buf)
63 }
64 }
65}
66
67impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
68 fn poll_write(
69 mut self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &[u8],
72 ) -> Poll<Result<usize, std::io::Error>> {
73 Pin::new(&mut self.stream).poll_write(cx, buf)
74 }
75
76 fn poll_flush(
77 mut self: Pin<&mut Self>,
78 cx: &mut Context<'_>,
79 ) -> Poll<Result<(), std::io::Error>> {
80 Pin::new(&mut self.stream).poll_flush(cx)
81 }
82
83 fn poll_shutdown(
84 mut self: Pin<&mut Self>,
85 cx: &mut Context<'_>,
86 ) -> Poll<Result<(), std::io::Error>> {
87 Pin::new(&mut self.stream).poll_shutdown(cx)
88 }
89}
90
91impl WebsocketServer {
92 pub fn new(config: WsServerConfig) -> Self {
93 Self {
94 auth_controller: Arc::new(SimpleAuthController),
95 allowed_roles: HashMap::new(),
96 handlers: Default::default(),
97 message_receiver: None,
98 toolbox: Toolbox::new(),
99 config,
100 }
101 }
102 pub fn set_auth_controller(&mut self, controller: impl AuthController + 'static) {
103 self.auth_controller = Arc::new(controller);
104 }
105 pub fn add_handler<T: RequestHandler + 'static>(&mut self, handler: T) {
106 let schema = serde_json::from_str(T::Request::SCHEMA).expect("Invalid schema");
107 let roles: &[u32] = T::Request::ROLES;
108 check_handler::<T>(&schema).expect("Invalid handler");
109 self.add_handler_erased(schema, roles, Arc::new(handler))
110 }
111 pub fn add_handler_erased(
112 &mut self,
113 schema: EndpointSchema,
114 roles: &[u32],
115 handler: Arc<dyn RequestHandlerErased>,
116 ) {
117 let roles_set = roles.iter().cloned().collect::<HashSet<u32>>();
118
119 let _old_roles = self.allowed_roles.insert(schema.code, roles_set);
120
121 let old = self
122 .handlers
123 .insert(schema.code, WsEndpoint { schema, handler });
124 if let Some(old) = old {
125 panic!(
126 "Overwriting handler for endpoint {} {}",
127 old.schema.code, old.schema.name
128 );
129 }
130 }
131 async fn handle_ws_handshake_and_connection<
132 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
133 >(
134 self: Arc<Self>,
135 addr: SocketAddr,
136 states: Arc<WebsocketStates>,
137 mut stream: S,
138 ) -> Result<()> {
139 let mut buffer = vec![0u8; 1024];
140 let n = stream.read(&mut buffer).await?; tracing::debug!("Raw request bytes: {:?}", &buffer[..n]);
143
144 let stream = BufferedStream {
145 buffer: buffer[..n].to_vec().into_boxed_slice(),
146 stream,
147 pos: 0,
148 };
149
150 let (tx, mut rx) = mpsc::channel(1);
151 let hs = tokio_tungstenite::accept_hdr_async(
152 stream,
153 VerifyProtocol {
154 addr,
155 tx,
156 allow_cors_domains: &self.config.allow_cors_urls,
157 },
158 )
159 .await;
160
161 tracing::warn!("handle new WS connection");
163
164 let stream = wrap_ws_error(hs)?;
165 let conn = Arc::new(WsConnection {
166 connection_id: get_conn_id(),
167 user_id: Default::default(),
168 roles: Arc::new(RwLock::new(Arc::new(Vec::new()))),
169 address: addr,
170 log_id: get_log_id(),
171 });
172 debug!(?addr, "New connection handshaken {:?}", conn);
173 let headers = rx
174 .recv()
175 .await
176 .ok_or_else(|| eyre!("Failed to receive ws headers"))?;
177
178 let (tx, rx) = mpsc::channel(100);
179 let conn = Arc::clone(&conn);
180 states.insert(conn.connection_id, tx, conn.clone());
181
182 let auth_result = Arc::clone(&self.auth_controller)
183 .auth(&self.toolbox, headers, Arc::clone(&conn))
184 .await;
185 let raw_ctx = RequestContext::from_conn(&conn);
186 if let Err(err) = auth_result {
187 self.toolbox
188 .send_request_error(&raw_ctx, ErrorCode::BAD_REQUEST, err.to_string());
189 return Err(err);
190 }
191 self.handle_session_connection(conn, states, stream, rx)
192 .await;
193
194 Ok(())
195 }
196
197 pub async fn handle_session_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
198 self: Arc<Self>,
199 conn: Arc<WsConnection>,
200 states: Arc<WebsocketStates>,
201 stream: WebSocketStream<S>,
202 rx: mpsc::Receiver<Message>,
203 ) {
204 let addr = conn.address;
205 let context = RequestContext::from_conn(&conn);
206
207 let session = WsClientSession::new(conn, stream, rx, self);
208 session.run().await;
209
210 states.remove(context.connection_id);
211 debug!(?addr, "Connection closed");
212 }
213
214 pub async fn listen(self) -> Result<()> {
215 info!("Listening on {}", self.config.address);
216
217 let addr = tokio::net::lookup_host(&self.config.address)
219 .await?
220 .next()
221 .with_context(|| format!("Failed to lookup host to bind: {}", self.config.address))?;
222
223 let listener = TcpListener::bind(addr).await?;
224 if self.config.insecure {
225 self.listen_impl(Arc::new(listener)).await
226 } else if self.config.pub_certs.is_some() && self.config.priv_key.is_some() {
227 let listener = TlsListener::bind(
229 listener,
230 self.config.pub_certs.clone().unwrap(),
231 self.config.priv_key.clone().unwrap(),
232 )
233 .await?;
234 self.listen_impl(Arc::new(listener)).await
235 } else {
236 bail!("pub_certs and priv_key should be set")
237 }
238 }
239
240 async fn listen_impl<T: ConnectionListener + 'static>(self, listener: Arc<T>) -> Result<()> {
241 let states = Arc::new(WebsocketStates::new());
242 self.toolbox
243 .set_ws_states(states.clone_states(), self.config.header_only);
244 let this = Arc::new(self);
245 let local_set = LocalSet::new();
246 let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?;
247 local_set
248 .run_until(async {
249 loop {
250 tokio::select! {
251 _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break,
252 accepted = listener.accept() => {
253 let (stream, addr) = match accepted {
254 Ok(x) => x,
255 Err(err) => {
256 error!("Error while accepting stream: {:?}", err);
257 continue;
258 }
259 };
260 let listener = Arc::clone(&listener);
261 let this = Arc::clone(&this);
262 let states = Arc::clone(&states);
263 local_set.spawn_local(async move {
264 let stream = match listener.handshake(stream).await {
265 Ok(channel) => {
266 info!("Accepted stream from {}", addr);
267 channel
268 }
269 Err(err) => {
270 error!("Error while handshaking stream: {:?}", err);
271 return;
272 }
273 };
274
275 let future = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream));
276 if let Err(err) = future.await {
277 error!("Error while handling connection: {:?}", err);
278 }
279 });
280 }
281 }
282 }
283 Ok(())
284 })
285 .await
286 }
287
288 pub fn dump_schemas(&self) -> Result<()> {
289 let _ = std::fs::create_dir_all("docs");
290 let file = format!("docs/{}_alive_endpoints.json", self.config.name);
291 let available_schemas: Vec<String> = self
292 .handlers
293 .values()
294 .map(|x| x.schema.name.clone())
295 .sorted()
296 .collect();
297 info!(
298 "Dumping {} endpoint names to {}",
299 available_schemas.len(),
300 file
301 );
302 serde_json::to_writer_pretty(File::create(file)?, &available_schemas)?;
303 Ok(())
304 }
305}
306
307pub fn wrap_ws_error<T>(err: Result<T, WsError>) -> Result<T> {
308 err.map_err(|x| eyre!(x))
309}
310
311pub fn check_name(cat: &str, be_name: &str, should_name: &str) -> Result<()> {
312 if !be_name.contains(should_name) {
313 bail!("{} name should be {} but got {}", cat, should_name, be_name);
314 } else {
315 Ok(())
316 }
317}
318
319pub fn check_handler<T: RequestHandler + 'static>(schema: &EndpointSchema) -> Result<()> {
320 let handler_name = std::any::type_name::<T>();
321 let should_handler_name = format!("Method{}", schema.name);
322 check_name("Method", handler_name, &should_handler_name)?;
323 let request_name = std::any::type_name::<T::Request>();
324 let should_req_name = format!("{}Request", schema.name);
325 check_name("Request", request_name, &should_req_name)?;
326
327 Ok(())
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, Default)]
331pub struct WsServerConfig {
332 #[serde(default)]
333 pub name: String,
334 pub address: String,
335 #[serde(default)]
336 pub pub_certs: Option<Vec<PathBuf>>,
337 #[serde(default)]
338 pub priv_key: Option<PathBuf>,
339 #[serde(default)]
340 pub insecure: bool,
341 #[serde(default)]
342 pub debug: bool,
343 #[serde(skip)]
344 pub header_only: bool,
345 #[serde(skip)]
346 pub allow_cors_urls: Arc<Option<Vec<String>>>,
347}