1use eyre::{bail, eyre, ContextCompat, Result};
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::File;
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::pin::Pin;
9use std::sync::atomic::AtomicU32;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14use tokio::task::LocalSet;
15use tokio_tungstenite::tungstenite::Error as WsError;
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::WebSocketStream;
18use tracing::*;
19
20use crate::libs::error_code::ErrorCode;
21use crate::libs::handler::{RequestHandler, RequestHandlerErased};
22use crate::libs::listener::{ConnectionListener, TcpListener, TlsListener};
23use crate::libs::toolbox::{ArcToolbox, RequestContext, Toolbox, TOOLBOX};
24use crate::libs::utils::{get_conn_id, get_log_id};
25use crate::libs::ws::client::WsRequest;
26use crate::libs::ws::{VerifyProtocol, WsClientSession, WsConnection};
27use crate::model::EndpointSchema;
28
29use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint};
30
31pub struct WebsocketServer {
32 pub auth_controller: Arc<dyn AuthController>,
33 pub handlers: HashMap<u32, WsEndpoint>,
34 pub message_receiver: Option<mpsc::Receiver<ConnectionId>>,
35 pub toolbox: ArcToolbox,
36 pub config: WsServerConfig,
37}
38
39struct BufferedStream<S> {
41 buffer: Box<[u8]>,
42 stream: S,
43 pos: usize,
44}
45
46impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
47 fn poll_read(
48 mut self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 buf: &mut ReadBuf<'_>,
51 ) -> Poll<tokio::io::Result<()>> {
52 if self.pos < self.buffer.len() {
53 let remaining = self.buffer.len() - self.pos;
55 let len = std::cmp::min(remaining, buf.remaining());
56 buf.put_slice(&self.buffer[self.pos..self.pos + len]);
57 self.pos += len;
58 Poll::Ready(Ok(()))
59 } else {
60 Pin::new(&mut self.stream).poll_read(cx, buf)
62 }
63 }
64}
65
66impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
67 fn poll_write(
68 mut self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 buf: &[u8],
71 ) -> Poll<Result<usize, std::io::Error>> {
72 Pin::new(&mut self.stream).poll_write(cx, buf)
73 }
74
75 fn poll_flush(
76 mut self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 ) -> Poll<Result<(), std::io::Error>> {
79 Pin::new(&mut self.stream).poll_flush(cx)
80 }
81
82 fn poll_shutdown(
83 mut self: Pin<&mut Self>,
84 cx: &mut Context<'_>,
85 ) -> Poll<Result<(), std::io::Error>> {
86 Pin::new(&mut self.stream).poll_shutdown(cx)
87 }
88}
89
90impl WebsocketServer {
91 pub fn new(config: WsServerConfig) -> Self {
92 Self {
93 auth_controller: Arc::new(SimpleAuthController),
94 handlers: Default::default(),
95 message_receiver: None,
96 toolbox: Toolbox::new(),
97 config,
98 }
99 }
100 pub fn set_auth_controller(&mut self, controller: impl AuthController + 'static) {
101 self.auth_controller = Arc::new(controller);
102 }
103 pub fn add_handler<T: RequestHandler + 'static>(&mut self, handler: T) {
104 let schema = serde_json::from_str(T::Request::SCHEMA).expect("Invalid schema");
105 check_handler::<T>(&schema).expect("Invalid handler");
106 self.add_handler_erased(schema, Arc::new(handler))
107 }
108 pub fn add_handler_erased(
109 &mut self,
110 schema: EndpointSchema,
111 handler: Arc<dyn RequestHandlerErased>,
112 ) {
113 let old = self
114 .handlers
115 .insert(schema.code, WsEndpoint { schema, handler });
116 if let Some(old) = old {
117 panic!(
118 "Overwriting handler for endpoint {} {}",
119 old.schema.code, old.schema.name
120 );
121 }
122 }
123 async fn handle_ws_handshake_and_connection<
124 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
125 >(
126 self: Arc<Self>,
127 addr: SocketAddr,
128 states: Arc<WebsocketStates>,
129 mut stream: S,
130 ) -> Result<()> {
131 let mut buffer = vec![0u8; 1024];
132 let n = stream.read(&mut buffer).await?; tracing::debug!("Raw request bytes: {:?}", &buffer[..n]);
135
136 let stream = BufferedStream {
137 buffer: buffer[..n].to_vec().into_boxed_slice(),
138 stream,
139 pos: 0,
140 };
141
142 let (tx, mut rx) = mpsc::channel(1);
143 let hs = tokio_tungstenite::accept_hdr_async(
144 stream,
145 VerifyProtocol {
146 addr,
147 tx,
148 allow_cors_domains: &self.config.allow_cors_urls,
149 },
150 )
151 .await;
152
153 tracing::warn!("handle new WS connection");
155
156 let stream = wrap_ws_error(hs)?;
157 let conn = Arc::new(WsConnection {
158 connection_id: get_conn_id(),
159 user_id: Default::default(),
160 role: AtomicU32::new(0),
161 address: addr,
162 log_id: get_log_id(),
163 });
164 debug!(?addr, "New connection handshaken {:?}", conn);
165 let headers = rx
166 .recv()
167 .await
168 .ok_or_else(|| eyre!("Failed to receive ws headers"))?;
169
170 let (tx, rx) = mpsc::channel(100);
171 let conn = Arc::clone(&conn);
172 states.insert(conn.connection_id, tx, conn.clone());
173
174 let auth_result = Arc::clone(&self.auth_controller)
175 .auth(&self.toolbox, headers, Arc::clone(&conn))
176 .await;
177 let raw_ctx = RequestContext::from_conn(&conn);
178 if let Err(err) = auth_result {
179 self.toolbox.send_request_error(
180 &raw_ctx,
181 ErrorCode::new(100400), err.to_string(),
183 );
184 return Err(err);
185 }
186 self.handle_session_connection(conn, states, stream, rx)
187 .await;
188
189 Ok(())
190 }
191
192 pub async fn handle_session_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
193 self: Arc<Self>,
194 conn: Arc<WsConnection>,
195 states: Arc<WebsocketStates>,
196 stream: WebSocketStream<S>,
197 rx: mpsc::Receiver<Message>,
198 ) {
199 let addr = conn.address;
200 let context = RequestContext::from_conn(&conn);
201
202 let session = WsClientSession::new(conn, stream, rx, self);
203 session.run().await;
204
205 states.remove(context.connection_id);
206 debug!(?addr, "Connection closed");
207 }
208
209 pub async fn listen(self) -> Result<()> {
210 info!("Listening on {}", self.config.address);
211
212 let addr = tokio::net::lookup_host(&self.config.address)
214 .await?
215 .next()
216 .with_context(|| format!("Failed to lookup host to bind: {}", self.config.address))?;
217
218 let listener = TcpListener::bind(addr).await?;
219 if self.config.insecure {
220 self.listen_impl(Arc::new(listener)).await
221 } else if self.config.pub_certs.is_some() && self.config.priv_key.is_some() {
222 let listener = TlsListener::bind(
224 listener,
225 self.config.pub_certs.clone().unwrap(),
226 self.config.priv_key.clone().unwrap(),
227 )
228 .await?;
229 self.listen_impl(Arc::new(listener)).await
230 } else {
231 bail!("pub_certs and priv_key should be set")
232 }
233 }
234
235 async fn listen_impl<T: ConnectionListener + 'static>(self, listener: Arc<T>) -> Result<()> {
236 let states = Arc::new(WebsocketStates::new());
237 self.toolbox
238 .set_ws_states(states.clone_states(), self.config.header_only);
239 let this = Arc::new(self);
240 let local_set = LocalSet::new();
241 let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?;
242 local_set
243 .run_until(async {
244 loop {
245 tokio::select! {
246 _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break,
247 accepted = listener.accept() => {
248 let (stream, addr) = match accepted {
249 Ok(x) => x,
250 Err(err) => {
251 error!("Error while accepting stream: {:?}", err);
252 continue;
253 }
254 };
255 let listener = Arc::clone(&listener);
256 let this = Arc::clone(&this);
257 let states = Arc::clone(&states);
258 local_set.spawn_local(async move {
259 let stream = match listener.handshake(stream).await {
260 Ok(channel) => {
261 info!("Accepted stream from {}", addr);
262 channel
263 }
264 Err(err) => {
265 error!("Error while handshaking stream: {:?}", err);
266 return;
267 }
268 };
269
270 let future = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream));
271 if let Err(err) = future.await {
272 error!("Error while handling connection: {:?}", err);
273 }
274 });
275 }
276 }
277 }
278 Ok(())
279 })
280 .await
281 }
282
283 pub fn dump_schemas(&self) -> Result<()> {
284 let _ = std::fs::create_dir_all("docs");
285 let file = format!("docs/{}_alive_endpoints.json", self.config.name);
286 let available_schemas: Vec<String> = self
287 .handlers
288 .values()
289 .map(|x| x.schema.name.clone())
290 .sorted()
291 .collect();
292 info!(
293 "Dumping {} endpoint names to {}",
294 available_schemas.len(),
295 file
296 );
297 serde_json::to_writer_pretty(File::create(file)?, &available_schemas)?;
298 Ok(())
299 }
300}
301
302pub fn wrap_ws_error<T>(err: Result<T, WsError>) -> Result<T> {
303 err.map_err(|x| eyre!(x))
304}
305
306pub fn check_name(cat: &str, be_name: &str, should_name: &str) -> Result<()> {
307 if !be_name.contains(should_name) {
308 bail!("{} name should be {} but got {}", cat, should_name, be_name);
309 } else {
310 Ok(())
311 }
312}
313
314pub fn check_handler<T: RequestHandler + 'static>(schema: &EndpointSchema) -> Result<()> {
315 let handler_name = std::any::type_name::<T>();
316 let should_handler_name = format!("Method{}", schema.name);
317 check_name("Method", handler_name, &should_handler_name)?;
318 let request_name = std::any::type_name::<T::Request>();
319 let should_req_name = format!("{}Request", schema.name);
320 check_name("Request", request_name, &should_req_name)?;
321
322 Ok(())
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize, Default)]
326pub struct WsServerConfig {
327 #[serde(default)]
328 pub name: String,
329 pub address: String,
330 #[serde(default)]
331 pub pub_certs: Option<Vec<PathBuf>>,
332 #[serde(default)]
333 pub priv_key: Option<PathBuf>,
334 #[serde(default)]
335 pub insecure: bool,
336 #[serde(default)]
337 pub debug: bool,
338 #[serde(skip)]
339 pub header_only: bool,
340 #[serde(skip)]
341 pub allow_cors_urls: Arc<Option<Vec<String>>>,
342}