endpoint_libs/libs/ws/
server.rs

1use eyre::{bail, eyre, ContextCompat, Result};
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::File;
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::pin::Pin;
9use std::sync::atomic::AtomicU32;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14use tokio::task::LocalSet;
15use tokio_tungstenite::tungstenite::Error as WsError;
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::WebSocketStream;
18use tracing::*;
19
20use crate::libs::error_code::ErrorCode;
21use crate::libs::handler::{RequestHandler, RequestHandlerErased};
22use crate::libs::listener::{ConnectionListener, TcpListener, TlsListener};
23use crate::libs::toolbox::{ArcToolbox, RequestContext, Toolbox, TOOLBOX};
24use crate::libs::utils::{get_conn_id, get_log_id};
25use crate::libs::ws::client::WsRequest;
26use crate::libs::ws::{VerifyProtocol, WsClientSession, WsConnection};
27use crate::model::EndpointSchema;
28
29use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint};
30
31pub struct WebsocketServer {
32    pub auth_controller: Arc<dyn AuthController>,
33    pub handlers: HashMap<u32, WsEndpoint>,
34    pub message_receiver: Option<mpsc::Receiver<ConnectionId>>,
35    pub toolbox: ArcToolbox,
36    pub config: WsServerConfig,
37}
38
39// Helper to combine read bytes + original stream
40struct BufferedStream<S> {
41    buffer: Box<[u8]>,
42    stream: S,
43    pos: usize,
44}
45
46impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
47    fn poll_read(
48        mut self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50        buf: &mut ReadBuf<'_>,
51    ) -> Poll<tokio::io::Result<()>> {
52        if self.pos < self.buffer.len() {
53            // Serve from owned buffer
54            let remaining = self.buffer.len() - self.pos;
55            let len = std::cmp::min(remaining, buf.remaining());
56            buf.put_slice(&self.buffer[self.pos..self.pos + len]);
57            self.pos += len;
58            Poll::Ready(Ok(()))
59        } else {
60            // Delegate to underlying stream
61            Pin::new(&mut self.stream).poll_read(cx, buf)
62        }
63    }
64}
65
66impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
67    fn poll_write(
68        mut self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        buf: &[u8],
71    ) -> Poll<Result<usize, std::io::Error>> {
72        Pin::new(&mut self.stream).poll_write(cx, buf)
73    }
74
75    fn poll_flush(
76        mut self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78    ) -> Poll<Result<(), std::io::Error>> {
79        Pin::new(&mut self.stream).poll_flush(cx)
80    }
81
82    fn poll_shutdown(
83        mut self: Pin<&mut Self>,
84        cx: &mut Context<'_>,
85    ) -> Poll<Result<(), std::io::Error>> {
86        Pin::new(&mut self.stream).poll_shutdown(cx)
87    }
88}
89
90impl WebsocketServer {
91    pub fn new(config: WsServerConfig) -> Self {
92        Self {
93            auth_controller: Arc::new(SimpleAuthController),
94            handlers: Default::default(),
95            message_receiver: None,
96            toolbox: Toolbox::new(),
97            config,
98        }
99    }
100    pub fn set_auth_controller(&mut self, controller: impl AuthController + 'static) {
101        self.auth_controller = Arc::new(controller);
102    }
103    pub fn add_handler<T: RequestHandler + 'static>(&mut self, handler: T) {
104        let schema = serde_json::from_str(T::Request::SCHEMA).expect("Invalid schema");
105        check_handler::<T>(&schema).expect("Invalid handler");
106        self.add_handler_erased(schema, Arc::new(handler))
107    }
108    pub fn add_handler_erased(
109        &mut self,
110        schema: EndpointSchema,
111        handler: Arc<dyn RequestHandlerErased>,
112    ) {
113        let old = self
114            .handlers
115            .insert(schema.code, WsEndpoint { schema, handler });
116        if let Some(old) = old {
117            panic!(
118                "Overwriting handler for endpoint {} {}",
119                old.schema.code, old.schema.name
120            );
121        }
122    }
123    async fn handle_ws_handshake_and_connection<
124        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
125    >(
126        self: Arc<Self>,
127        addr: SocketAddr,
128        states: Arc<WebsocketStates>,
129        mut stream: S,
130    ) -> Result<()> {
131        let mut buffer = vec![0u8; 1024];
132        let n = stream.read(&mut buffer).await?; // Use AsyncReadExt::read
133
134        tracing::debug!("Raw request bytes: {:?}", &buffer[..n]);
135
136        let stream = BufferedStream {
137            buffer: buffer[..n].to_vec().into_boxed_slice(),
138            stream,
139            pos: 0,
140        };
141
142        let (tx, mut rx) = mpsc::channel(1);
143        let hs = tokio_tungstenite::accept_hdr_async(
144            stream,
145            VerifyProtocol {
146                addr,
147                tx,
148                allow_cors_domains: &self.config.allow_cors_urls,
149            },
150        )
151        .await;
152
153        // TODO remove below after tracing log issue
154        tracing::warn!("handle new WS connection");
155
156        let stream = wrap_ws_error(hs)?;
157        let conn = Arc::new(WsConnection {
158            connection_id: get_conn_id(),
159            user_id: Default::default(),
160            role: AtomicU32::new(0),
161            address: addr,
162            log_id: get_log_id(),
163        });
164        debug!(?addr, "New connection handshaken {:?}", conn);
165        let headers = rx
166            .recv()
167            .await
168            .ok_or_else(|| eyre!("Failed to receive ws headers"))?;
169
170        let (tx, rx) = mpsc::channel(100);
171        let conn = Arc::clone(&conn);
172        states.insert(conn.connection_id, tx, conn.clone());
173
174        let auth_result = Arc::clone(&self.auth_controller)
175            .auth(&self.toolbox, headers, Arc::clone(&conn))
176            .await;
177        let raw_ctx = RequestContext::from_conn(&conn);
178        if let Err(err) = auth_result {
179            self.toolbox.send_request_error(
180                &raw_ctx,
181                ErrorCode::new(100400), // BadRequest
182                err.to_string(),
183            );
184            return Err(err);
185        }
186        self.handle_session_connection(conn, states, stream, rx)
187            .await;
188
189        Ok(())
190    }
191
192    pub async fn handle_session_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
193        self: Arc<Self>,
194        conn: Arc<WsConnection>,
195        states: Arc<WebsocketStates>,
196        stream: WebSocketStream<S>,
197        rx: mpsc::Receiver<Message>,
198    ) {
199        let addr = conn.address;
200        let context = RequestContext::from_conn(&conn);
201
202        let session = WsClientSession::new(conn, stream, rx, self);
203        session.run().await;
204
205        states.remove(context.connection_id);
206        debug!(?addr, "Connection closed");
207    }
208
209    pub async fn listen(self) -> Result<()> {
210        info!("Listening on {}", self.config.address);
211
212        // Resolve the address and get the socket address
213        let addr = tokio::net::lookup_host(&self.config.address)
214            .await?
215            .next()
216            .with_context(|| format!("Failed to lookup host to bind: {}", self.config.address))?;
217
218        let listener = TcpListener::bind(addr).await?;
219        if self.config.insecure {
220            self.listen_impl(Arc::new(listener)).await
221        } else if self.config.pub_certs.is_some() && self.config.priv_key.is_some() {
222            // Proceed with binding the listener for secure mode
223            let listener = TlsListener::bind(
224                listener,
225                self.config.pub_certs.clone().unwrap(),
226                self.config.priv_key.clone().unwrap(),
227            )
228            .await?;
229            self.listen_impl(Arc::new(listener)).await
230        } else {
231            bail!("pub_certs and priv_key should be set")
232        }
233    }
234
235    async fn listen_impl<T: ConnectionListener + 'static>(self, listener: Arc<T>) -> Result<()> {
236        let states = Arc::new(WebsocketStates::new());
237        self.toolbox
238            .set_ws_states(states.clone_states(), self.config.header_only);
239        let this = Arc::new(self);
240        let local_set = LocalSet::new();
241        let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?;
242        local_set
243            .run_until(async {
244                loop {
245                    tokio::select! {
246                        _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break,
247                        accepted = listener.accept() => {
248                            let (stream, addr) = match accepted {
249                                Ok(x) => x,
250                                Err(err) => {
251                                    error!("Error while accepting stream: {:?}", err);
252                                    continue;
253                                }
254                            };
255                            let listener = Arc::clone(&listener);
256                            let this = Arc::clone(&this);
257                            let states = Arc::clone(&states);
258                            local_set.spawn_local(async move {
259                                let stream = match listener.handshake(stream).await {
260                                    Ok(channel) => {
261                                        info!("Accepted stream from {}", addr);
262                                        channel
263                                    }
264                                    Err(err) => {
265                                        error!("Error while handshaking stream: {:?}", err);
266                                        return;
267                                    }
268                                };
269
270                                let future = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream));
271                                if let Err(err) = future.await {
272                                    error!("Error while handling connection: {:?}", err);
273                                }
274                            });
275                        }
276                    }
277                }
278                Ok(())
279            })
280            .await
281    }
282
283    pub fn dump_schemas(&self) -> Result<()> {
284        let _ = std::fs::create_dir_all("docs");
285        let file = format!("docs/{}_alive_endpoints.json", self.config.name);
286        let available_schemas: Vec<String> = self
287            .handlers
288            .values()
289            .map(|x| x.schema.name.clone())
290            .sorted()
291            .collect();
292        info!(
293            "Dumping {} endpoint names to {}",
294            available_schemas.len(),
295            file
296        );
297        serde_json::to_writer_pretty(File::create(file)?, &available_schemas)?;
298        Ok(())
299    }
300}
301
302pub fn wrap_ws_error<T>(err: Result<T, WsError>) -> Result<T> {
303    err.map_err(|x| eyre!(x))
304}
305
306pub fn check_name(cat: &str, be_name: &str, should_name: &str) -> Result<()> {
307    if !be_name.contains(should_name) {
308        bail!("{} name should be {} but got {}", cat, should_name, be_name);
309    } else {
310        Ok(())
311    }
312}
313
314pub fn check_handler<T: RequestHandler + 'static>(schema: &EndpointSchema) -> Result<()> {
315    let handler_name = std::any::type_name::<T>();
316    let should_handler_name = format!("Method{}", schema.name);
317    check_name("Method", handler_name, &should_handler_name)?;
318    let request_name = std::any::type_name::<T::Request>();
319    let should_req_name = format!("{}Request", schema.name);
320    check_name("Request", request_name, &should_req_name)?;
321
322    Ok(())
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize, Default)]
326pub struct WsServerConfig {
327    #[serde(default)]
328    pub name: String,
329    pub address: String,
330    #[serde(default)]
331    pub pub_certs: Option<Vec<PathBuf>>,
332    #[serde(default)]
333    pub priv_key: Option<PathBuf>,
334    #[serde(default)]
335    pub insecure: bool,
336    #[serde(default)]
337    pub debug: bool,
338    #[serde(skip)]
339    pub header_only: bool,
340    #[serde(skip)]
341    pub allow_cors_urls: Arc<Option<Vec<String>>>,
342}