endpoint_libs/libs/ws/
server.rs

1use eyre::{bail, eyre, ContextCompat, Result};
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::fs::File;
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::pin::Pin;
9use std::sync::atomic::AtomicU32;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14use tokio::task::LocalSet;
15use tokio_tungstenite::tungstenite::Error as WsError;
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::WebSocketStream;
18use tracing::*;
19
20use crate::libs::error_code::ErrorCode;
21use crate::libs::handler::{RequestHandler, RequestHandlerErased};
22use crate::libs::listener::{ConnectionListener, TcpListener, TlsListener};
23use crate::libs::toolbox::{ArcToolbox, RequestContext, Toolbox, TOOLBOX};
24use crate::libs::utils::{get_conn_id, get_log_id};
25use crate::libs::ws::client::WsRequest;
26use crate::libs::ws::{VerifyProtocol, WsClientSession, WsConnection};
27use crate::model::EndpointSchema;
28
29use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint};
30
31pub struct WebsocketServer {
32    pub auth_controller: Arc<dyn AuthController>,
33    pub handlers: HashMap<u32, WsEndpoint>,
34    pub allowed_roles: HashMap<u32, Option<HashSet<u32>>>,
35    pub message_receiver: Option<mpsc::Receiver<ConnectionId>>,
36    pub toolbox: ArcToolbox,
37    pub config: WsServerConfig,
38}
39
40// Helper to combine read bytes + original stream
41struct BufferedStream<S> {
42    buffer: Box<[u8]>,
43    stream: S,
44    pos: usize,
45}
46
47impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
48    fn poll_read(
49        mut self: Pin<&mut Self>,
50        cx: &mut Context<'_>,
51        buf: &mut ReadBuf<'_>,
52    ) -> Poll<tokio::io::Result<()>> {
53        if self.pos < self.buffer.len() {
54            // Serve from owned buffer
55            let remaining = self.buffer.len() - self.pos;
56            let len = std::cmp::min(remaining, buf.remaining());
57            buf.put_slice(&self.buffer[self.pos..self.pos + len]);
58            self.pos += len;
59            Poll::Ready(Ok(()))
60        } else {
61            // Delegate to underlying stream
62            Pin::new(&mut self.stream).poll_read(cx, buf)
63        }
64    }
65}
66
67impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
68    fn poll_write(
69        mut self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &[u8],
72    ) -> Poll<Result<usize, std::io::Error>> {
73        Pin::new(&mut self.stream).poll_write(cx, buf)
74    }
75
76    fn poll_flush(
77        mut self: Pin<&mut Self>,
78        cx: &mut Context<'_>,
79    ) -> Poll<Result<(), std::io::Error>> {
80        Pin::new(&mut self.stream).poll_flush(cx)
81    }
82
83    fn poll_shutdown(
84        mut self: Pin<&mut Self>,
85        cx: &mut Context<'_>,
86    ) -> Poll<Result<(), std::io::Error>> {
87        Pin::new(&mut self.stream).poll_shutdown(cx)
88    }
89}
90
91impl WebsocketServer {
92    pub fn new(config: WsServerConfig) -> Self {
93        Self {
94            auth_controller: Arc::new(SimpleAuthController),
95            allowed_roles: HashMap::new(),
96            handlers: Default::default(),
97            message_receiver: None,
98            toolbox: Toolbox::new(),
99            config,
100        }
101    }
102    pub fn set_auth_controller(&mut self, controller: impl AuthController + 'static) {
103        self.auth_controller = Arc::new(controller);
104    }
105    pub fn add_handler<T: RequestHandler + 'static>(&mut self, handler: T) {
106        let schema = serde_json::from_str(T::Request::SCHEMA).expect("Invalid schema");
107        let roles: Option<&[u32]> = T::Request::ROLES;
108        check_handler::<T>(&schema).expect("Invalid handler");
109        self.add_handler_erased(schema, roles, Arc::new(handler))
110    }
111    pub fn add_handler_erased(
112        &mut self,
113        schema: EndpointSchema,
114        roles: Option<&[u32]>,
115        handler: Arc<dyn RequestHandlerErased>,
116    ) {
117        let roles_set = roles.map(|roles| roles.iter().cloned().collect::<HashSet<u32>>());
118
119        let _old_roles = self.allowed_roles.insert(schema.code, roles_set);
120
121        let old = self
122            .handlers
123            .insert(schema.code, WsEndpoint { schema, handler });
124        if let Some(old) = old {
125            panic!(
126                "Overwriting handler for endpoint {} {}",
127                old.schema.code, old.schema.name
128            );
129        }
130    }
131    async fn handle_ws_handshake_and_connection<
132        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
133    >(
134        self: Arc<Self>,
135        addr: SocketAddr,
136        states: Arc<WebsocketStates>,
137        mut stream: S,
138    ) -> Result<()> {
139        let mut buffer = vec![0u8; 1024];
140        let n = stream.read(&mut buffer).await?; // Use AsyncReadExt::read
141
142        tracing::debug!("Raw request bytes: {:?}", &buffer[..n]);
143
144        let stream = BufferedStream {
145            buffer: buffer[..n].to_vec().into_boxed_slice(),
146            stream,
147            pos: 0,
148        };
149
150        let (tx, mut rx) = mpsc::channel(1);
151        let hs = tokio_tungstenite::accept_hdr_async(
152            stream,
153            VerifyProtocol {
154                addr,
155                tx,
156                allow_cors_domains: &self.config.allow_cors_urls,
157            },
158        )
159        .await;
160
161        // TODO remove below after tracing log issue
162        tracing::warn!("handle new WS connection");
163
164        let stream = wrap_ws_error(hs)?;
165        let conn = Arc::new(WsConnection {
166            connection_id: get_conn_id(),
167            user_id: Default::default(),
168            role: AtomicU32::new(0),
169            address: addr,
170            log_id: get_log_id(),
171        });
172        debug!(?addr, "New connection handshaken {:?}", conn);
173        let headers = rx
174            .recv()
175            .await
176            .ok_or_else(|| eyre!("Failed to receive ws headers"))?;
177
178        let (tx, rx) = mpsc::channel(100);
179        let conn = Arc::clone(&conn);
180        states.insert(conn.connection_id, tx, conn.clone());
181
182        let auth_result = Arc::clone(&self.auth_controller)
183            .auth(&self.toolbox, headers, Arc::clone(&conn))
184            .await;
185        let raw_ctx = RequestContext::from_conn(&conn);
186        if let Err(err) = auth_result {
187            self.toolbox.send_request_error(
188                &raw_ctx,
189                ErrorCode::new(100400), // BadRequest
190                err.to_string(),
191            );
192            return Err(err);
193        }
194        self.handle_session_connection(conn, states, stream, rx)
195            .await;
196
197        Ok(())
198    }
199
200    pub async fn handle_session_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
201        self: Arc<Self>,
202        conn: Arc<WsConnection>,
203        states: Arc<WebsocketStates>,
204        stream: WebSocketStream<S>,
205        rx: mpsc::Receiver<Message>,
206    ) {
207        let addr = conn.address;
208        let context = RequestContext::from_conn(&conn);
209
210        let session = WsClientSession::new(conn, stream, rx, self);
211        session.run().await;
212
213        states.remove(context.connection_id);
214        debug!(?addr, "Connection closed");
215    }
216
217    pub async fn listen(self) -> Result<()> {
218        info!("Listening on {}", self.config.address);
219
220        // Resolve the address and get the socket address
221        let addr = tokio::net::lookup_host(&self.config.address)
222            .await?
223            .next()
224            .with_context(|| format!("Failed to lookup host to bind: {}", self.config.address))?;
225
226        let listener = TcpListener::bind(addr).await?;
227        if self.config.insecure {
228            self.listen_impl(Arc::new(listener)).await
229        } else if self.config.pub_certs.is_some() && self.config.priv_key.is_some() {
230            // Proceed with binding the listener for secure mode
231            let listener = TlsListener::bind(
232                listener,
233                self.config.pub_certs.clone().unwrap(),
234                self.config.priv_key.clone().unwrap(),
235            )
236            .await?;
237            self.listen_impl(Arc::new(listener)).await
238        } else {
239            bail!("pub_certs and priv_key should be set")
240        }
241    }
242
243    async fn listen_impl<T: ConnectionListener + 'static>(self, listener: Arc<T>) -> Result<()> {
244        let states = Arc::new(WebsocketStates::new());
245        self.toolbox
246            .set_ws_states(states.clone_states(), self.config.header_only);
247        let this = Arc::new(self);
248        let local_set = LocalSet::new();
249        let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?;
250        local_set
251            .run_until(async {
252                loop {
253                    tokio::select! {
254                        _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break,
255                        accepted = listener.accept() => {
256                            let (stream, addr) = match accepted {
257                                Ok(x) => x,
258                                Err(err) => {
259                                    error!("Error while accepting stream: {:?}", err);
260                                    continue;
261                                }
262                            };
263                            let listener = Arc::clone(&listener);
264                            let this = Arc::clone(&this);
265                            let states = Arc::clone(&states);
266                            local_set.spawn_local(async move {
267                                let stream = match listener.handshake(stream).await {
268                                    Ok(channel) => {
269                                        info!("Accepted stream from {}", addr);
270                                        channel
271                                    }
272                                    Err(err) => {
273                                        error!("Error while handshaking stream: {:?}", err);
274                                        return;
275                                    }
276                                };
277
278                                let future = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream));
279                                if let Err(err) = future.await {
280                                    error!("Error while handling connection: {:?}", err);
281                                }
282                            });
283                        }
284                    }
285                }
286                Ok(())
287            })
288            .await
289    }
290
291    pub fn dump_schemas(&self) -> Result<()> {
292        let _ = std::fs::create_dir_all("docs");
293        let file = format!("docs/{}_alive_endpoints.json", self.config.name);
294        let available_schemas: Vec<String> = self
295            .handlers
296            .values()
297            .map(|x| x.schema.name.clone())
298            .sorted()
299            .collect();
300        info!(
301            "Dumping {} endpoint names to {}",
302            available_schemas.len(),
303            file
304        );
305        serde_json::to_writer_pretty(File::create(file)?, &available_schemas)?;
306        Ok(())
307    }
308}
309
310pub fn wrap_ws_error<T>(err: Result<T, WsError>) -> Result<T> {
311    err.map_err(|x| eyre!(x))
312}
313
314pub fn check_name(cat: &str, be_name: &str, should_name: &str) -> Result<()> {
315    if !be_name.contains(should_name) {
316        bail!("{} name should be {} but got {}", cat, should_name, be_name);
317    } else {
318        Ok(())
319    }
320}
321
322pub fn check_handler<T: RequestHandler + 'static>(schema: &EndpointSchema) -> Result<()> {
323    let handler_name = std::any::type_name::<T>();
324    let should_handler_name = format!("Method{}", schema.name);
325    check_name("Method", handler_name, &should_handler_name)?;
326    let request_name = std::any::type_name::<T::Request>();
327    let should_req_name = format!("{}Request", schema.name);
328    check_name("Request", request_name, &should_req_name)?;
329
330    Ok(())
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, Default)]
334pub struct WsServerConfig {
335    #[serde(default)]
336    pub name: String,
337    pub address: String,
338    #[serde(default)]
339    pub pub_certs: Option<Vec<PathBuf>>,
340    #[serde(default)]
341    pub priv_key: Option<PathBuf>,
342    #[serde(default)]
343    pub insecure: bool,
344    #[serde(default)]
345    pub debug: bool,
346    #[serde(skip)]
347    pub header_only: bool,
348    #[serde(skip)]
349    pub allow_cors_urls: Arc<Option<Vec<String>>>,
350}