1use eyre::{bail, eyre, ContextCompat, Result};
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::fs::File;
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::pin::Pin;
9use std::sync::atomic::AtomicU32;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14use tokio::task::LocalSet;
15use tokio_tungstenite::tungstenite::Error as WsError;
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::WebSocketStream;
18use tracing::*;
19
20use crate::libs::error_code::ErrorCode;
21use crate::libs::handler::{RequestHandler, RequestHandlerErased};
22use crate::libs::listener::{ConnectionListener, TcpListener, TlsListener};
23use crate::libs::toolbox::{ArcToolbox, RequestContext, Toolbox, TOOLBOX};
24use crate::libs::utils::{get_conn_id, get_log_id};
25use crate::libs::ws::client::WsRequest;
26use crate::libs::ws::{VerifyProtocol, WsClientSession, WsConnection};
27use crate::model::EndpointSchema;
28
29use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint};
30
31pub struct WebsocketServer {
32 pub auth_controller: Arc<dyn AuthController>,
33 pub handlers: HashMap<u32, WsEndpoint>,
34 pub allowed_roles: HashMap<u32, Option<HashSet<u32>>>,
35 pub message_receiver: Option<mpsc::Receiver<ConnectionId>>,
36 pub toolbox: ArcToolbox,
37 pub config: WsServerConfig,
38}
39
40struct BufferedStream<S> {
42 buffer: Box<[u8]>,
43 stream: S,
44 pos: usize,
45}
46
47impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
48 fn poll_read(
49 mut self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 buf: &mut ReadBuf<'_>,
52 ) -> Poll<tokio::io::Result<()>> {
53 if self.pos < self.buffer.len() {
54 let remaining = self.buffer.len() - self.pos;
56 let len = std::cmp::min(remaining, buf.remaining());
57 buf.put_slice(&self.buffer[self.pos..self.pos + len]);
58 self.pos += len;
59 Poll::Ready(Ok(()))
60 } else {
61 Pin::new(&mut self.stream).poll_read(cx, buf)
63 }
64 }
65}
66
67impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
68 fn poll_write(
69 mut self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &[u8],
72 ) -> Poll<Result<usize, std::io::Error>> {
73 Pin::new(&mut self.stream).poll_write(cx, buf)
74 }
75
76 fn poll_flush(
77 mut self: Pin<&mut Self>,
78 cx: &mut Context<'_>,
79 ) -> Poll<Result<(), std::io::Error>> {
80 Pin::new(&mut self.stream).poll_flush(cx)
81 }
82
83 fn poll_shutdown(
84 mut self: Pin<&mut Self>,
85 cx: &mut Context<'_>,
86 ) -> Poll<Result<(), std::io::Error>> {
87 Pin::new(&mut self.stream).poll_shutdown(cx)
88 }
89}
90
91impl WebsocketServer {
92 pub fn new(config: WsServerConfig) -> Self {
93 Self {
94 auth_controller: Arc::new(SimpleAuthController),
95 allowed_roles: HashMap::new(),
96 handlers: Default::default(),
97 message_receiver: None,
98 toolbox: Toolbox::new(),
99 config,
100 }
101 }
102 pub fn set_auth_controller(&mut self, controller: impl AuthController + 'static) {
103 self.auth_controller = Arc::new(controller);
104 }
105 pub fn add_handler<T: RequestHandler + 'static>(&mut self, handler: T) {
106 let schema = serde_json::from_str(T::Request::SCHEMA).expect("Invalid schema");
107 let roles: Option<&[u32]> = T::Request::ROLES;
108 check_handler::<T>(&schema).expect("Invalid handler");
109 self.add_handler_erased(schema, roles, Arc::new(handler))
110 }
111 pub fn add_handler_erased(
112 &mut self,
113 schema: EndpointSchema,
114 roles: Option<&[u32]>,
115 handler: Arc<dyn RequestHandlerErased>,
116 ) {
117 let roles_set = roles.map(|roles| roles.iter().cloned().collect::<HashSet<u32>>());
118
119 let _old_roles = self.allowed_roles.insert(schema.code, roles_set);
120
121 let old = self
122 .handlers
123 .insert(schema.code, WsEndpoint { schema, handler });
124 if let Some(old) = old {
125 panic!(
126 "Overwriting handler for endpoint {} {}",
127 old.schema.code, old.schema.name
128 );
129 }
130 }
131 async fn handle_ws_handshake_and_connection<
132 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
133 >(
134 self: Arc<Self>,
135 addr: SocketAddr,
136 states: Arc<WebsocketStates>,
137 mut stream: S,
138 ) -> Result<()> {
139 let mut buffer = vec![0u8; 1024];
140 let n = stream.read(&mut buffer).await?; tracing::debug!("Raw request bytes: {:?}", &buffer[..n]);
143
144 let stream = BufferedStream {
145 buffer: buffer[..n].to_vec().into_boxed_slice(),
146 stream,
147 pos: 0,
148 };
149
150 let (tx, mut rx) = mpsc::channel(1);
151 let hs = tokio_tungstenite::accept_hdr_async(
152 stream,
153 VerifyProtocol {
154 addr,
155 tx,
156 allow_cors_domains: &self.config.allow_cors_urls,
157 },
158 )
159 .await;
160
161 tracing::warn!("handle new WS connection");
163
164 let stream = wrap_ws_error(hs)?;
165 let conn = Arc::new(WsConnection {
166 connection_id: get_conn_id(),
167 user_id: Default::default(),
168 role: AtomicU32::new(0),
169 address: addr,
170 log_id: get_log_id(),
171 });
172 debug!(?addr, "New connection handshaken {:?}", conn);
173 let headers = rx
174 .recv()
175 .await
176 .ok_or_else(|| eyre!("Failed to receive ws headers"))?;
177
178 let (tx, rx) = mpsc::channel(100);
179 let conn = Arc::clone(&conn);
180 states.insert(conn.connection_id, tx, conn.clone());
181
182 let auth_result = Arc::clone(&self.auth_controller)
183 .auth(&self.toolbox, headers, Arc::clone(&conn))
184 .await;
185 let raw_ctx = RequestContext::from_conn(&conn);
186 if let Err(err) = auth_result {
187 self.toolbox.send_request_error(
188 &raw_ctx,
189 ErrorCode::new(100400), err.to_string(),
191 );
192 return Err(err);
193 }
194 self.handle_session_connection(conn, states, stream, rx)
195 .await;
196
197 Ok(())
198 }
199
200 pub async fn handle_session_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
201 self: Arc<Self>,
202 conn: Arc<WsConnection>,
203 states: Arc<WebsocketStates>,
204 stream: WebSocketStream<S>,
205 rx: mpsc::Receiver<Message>,
206 ) {
207 let addr = conn.address;
208 let context = RequestContext::from_conn(&conn);
209
210 let session = WsClientSession::new(conn, stream, rx, self);
211 session.run().await;
212
213 states.remove(context.connection_id);
214 debug!(?addr, "Connection closed");
215 }
216
217 pub async fn listen(self) -> Result<()> {
218 info!("Listening on {}", self.config.address);
219
220 let addr = tokio::net::lookup_host(&self.config.address)
222 .await?
223 .next()
224 .with_context(|| format!("Failed to lookup host to bind: {}", self.config.address))?;
225
226 let listener = TcpListener::bind(addr).await?;
227 if self.config.insecure {
228 self.listen_impl(Arc::new(listener)).await
229 } else if self.config.pub_certs.is_some() && self.config.priv_key.is_some() {
230 let listener = TlsListener::bind(
232 listener,
233 self.config.pub_certs.clone().unwrap(),
234 self.config.priv_key.clone().unwrap(),
235 )
236 .await?;
237 self.listen_impl(Arc::new(listener)).await
238 } else {
239 bail!("pub_certs and priv_key should be set")
240 }
241 }
242
243 async fn listen_impl<T: ConnectionListener + 'static>(self, listener: Arc<T>) -> Result<()> {
244 let states = Arc::new(WebsocketStates::new());
245 self.toolbox
246 .set_ws_states(states.clone_states(), self.config.header_only);
247 let this = Arc::new(self);
248 let local_set = LocalSet::new();
249 let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?;
250 local_set
251 .run_until(async {
252 loop {
253 tokio::select! {
254 _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break,
255 accepted = listener.accept() => {
256 let (stream, addr) = match accepted {
257 Ok(x) => x,
258 Err(err) => {
259 error!("Error while accepting stream: {:?}", err);
260 continue;
261 }
262 };
263 let listener = Arc::clone(&listener);
264 let this = Arc::clone(&this);
265 let states = Arc::clone(&states);
266 local_set.spawn_local(async move {
267 let stream = match listener.handshake(stream).await {
268 Ok(channel) => {
269 info!("Accepted stream from {}", addr);
270 channel
271 }
272 Err(err) => {
273 error!("Error while handshaking stream: {:?}", err);
274 return;
275 }
276 };
277
278 let future = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream));
279 if let Err(err) = future.await {
280 error!("Error while handling connection: {:?}", err);
281 }
282 });
283 }
284 }
285 }
286 Ok(())
287 })
288 .await
289 }
290
291 pub fn dump_schemas(&self) -> Result<()> {
292 let _ = std::fs::create_dir_all("docs");
293 let file = format!("docs/{}_alive_endpoints.json", self.config.name);
294 let available_schemas: Vec<String> = self
295 .handlers
296 .values()
297 .map(|x| x.schema.name.clone())
298 .sorted()
299 .collect();
300 info!(
301 "Dumping {} endpoint names to {}",
302 available_schemas.len(),
303 file
304 );
305 serde_json::to_writer_pretty(File::create(file)?, &available_schemas)?;
306 Ok(())
307 }
308}
309
310pub fn wrap_ws_error<T>(err: Result<T, WsError>) -> Result<T> {
311 err.map_err(|x| eyre!(x))
312}
313
314pub fn check_name(cat: &str, be_name: &str, should_name: &str) -> Result<()> {
315 if !be_name.contains(should_name) {
316 bail!("{} name should be {} but got {}", cat, should_name, be_name);
317 } else {
318 Ok(())
319 }
320}
321
322pub fn check_handler<T: RequestHandler + 'static>(schema: &EndpointSchema) -> Result<()> {
323 let handler_name = std::any::type_name::<T>();
324 let should_handler_name = format!("Method{}", schema.name);
325 check_name("Method", handler_name, &should_handler_name)?;
326 let request_name = std::any::type_name::<T::Request>();
327 let should_req_name = format!("{}Request", schema.name);
328 check_name("Request", request_name, &should_req_name)?;
329
330 Ok(())
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, Default)]
334pub struct WsServerConfig {
335 #[serde(default)]
336 pub name: String,
337 pub address: String,
338 #[serde(default)]
339 pub pub_certs: Option<Vec<PathBuf>>,
340 #[serde(default)]
341 pub priv_key: Option<PathBuf>,
342 #[serde(default)]
343 pub insecure: bool,
344 #[serde(default)]
345 pub debug: bool,
346 #[serde(skip)]
347 pub header_only: bool,
348 #[serde(skip)]
349 pub allow_cors_urls: Arc<Option<Vec<String>>>,
350}