atomic_websocket/helpers/
internal_server.rs

1//! WebSocket server implementation for atomic_websocket.
2//!
3//! This module provides the server-side functionality for accepting and managing
4//! WebSocket connections, including client tracking, message routing, and automatic
5//! ping/pong handling.
6
7use std::{net::SocketAddr, sync::Arc, time::Duration};
8
9use tokio::{
10    self,
11    net::{TcpListener, TcpStream},
12    sync::{mpsc::Receiver, RwLock},
13    time::{Instant, MissedTickBehavior},
14};
15use tokio_tungstenite::{tungstenite::Error, WebSocketStream};
16
17use crate::{
18    helpers::{
19        client_sender::ClientSendersTrait,
20        common::{get_data_schema, make_disconnect_message, make_pong_message},
21    },
22    log_debug, log_error,
23    schema::{Category, Ping},
24};
25use bebop::Record;
26use futures_util::{stream::SplitStream, SinkExt, StreamExt};
27use tokio::sync::mpsc::{self, Sender};
28use tokio_tungstenite::{
29    accept_async,
30    tungstenite::{self, Message},
31};
32
33use super::{client_sender::ClientSenders, types::RwClientSenders};
34
35/// WebSocket server implementation for accepting and managing client connections.
36///
37/// Manages WebSocket client connections and routes messages between clients.
38pub struct AtomicServer {
39    /// Collection of connected clients
40    pub client_senders: RwClientSenders,
41}
42
43/// Configuration options for the WebSocket server.
44///
45/// Controls aspects of server behavior such as ping handling.
46#[derive(Clone)]
47pub struct ServerOptions {
48    /// Whether to automatically respond to ping messages with pongs
49    pub use_ping: bool,
50
51    /// Category ID to use when proxying ping messages instead of responding directly
52    /// A value of -1 disables ping proxying
53    pub proxy_ping: i16,
54}
55
56impl Default for ServerOptions {
57    fn default() -> Self {
58        Self {
59            use_ping: true,
60            proxy_ping: -1,
61        }
62    }
63}
64
65impl AtomicServer {
66    /// Creates a new WebSocket server instance.
67    ///
68    /// Binds to the specified address and starts accepting connections.
69    /// Also spawns background tasks for connection handling and client tracking.
70    ///
71    /// # Arguments
72    ///
73    /// * `addr` - Address to bind the server to (e.g., "0.0.0.0:9000")
74    /// * `option` - Server configuration options
75    /// * `client_senders` - Optional existing ClientSenders instance
76    ///
77    /// # Returns
78    ///
79    /// A new AtomicServer instance
80    ///
81    /// # Panics
82    ///
83    /// Panics if the server cannot bind to the specified address
84    pub async fn new(
85        addr: &str,
86        option: ServerOptions,
87        client_senders: Option<RwClientSenders>,
88    ) -> Self {
89        let listener = TcpListener::bind(&addr).await.expect("Can't listen");
90        let client_senders = match client_senders {
91            Some(client_senders) => client_senders,
92            None => Arc::new(RwLock::new(ClientSenders::new())),
93        };
94        client_senders.write().await.options = option;
95        tokio::spawn(handle_accept(listener, client_senders.clone()));
96
97        tokio::spawn(loop_client_checker(client_senders.clone()));
98        Self { client_senders }
99    }
100
101    /// Gets a receiver for incoming messages from clients.
102    ///
103    /// # Returns
104    ///
105    /// A channel receiver for message data along with client identifiers
106    pub async fn get_handle_message_receiver(&self) -> Receiver<(Vec<u8>, String)> {
107        self.client_senders.get_handle_message_receiver().await
108    }
109}
110
111/// Periodically checks for and removes inactive clients.
112///
113/// # Arguments
114///
115/// * `server_sender` - Shared client senders collection
116pub async fn loop_client_checker(server_sender: RwClientSenders) {
117    let mut interval = tokio::time::interval_at(
118        Instant::now() + Duration::from_secs(15),
119        Duration::from_secs(15),
120    );
121    interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
122
123    loop {
124        interval.tick().await;
125        server_sender.write().await.check_client_send_time();
126        log_debug!("loop client cheker finish");
127    }
128}
129
130/// Handles accepting new WebSocket connections.
131///
132/// Listens for incoming TCP connections and spawns a new task for each one.
133///
134/// # Arguments
135///
136/// * `listener` - TCP listener for accepting connections
137/// * `client_senders` - Shared client senders collection
138pub async fn handle_accept(listener: TcpListener, client_senders: RwClientSenders) {
139    loop {
140        match listener.accept().await {
141            Ok((stream, _)) => {
142                let peer = stream
143                    .peer_addr()
144                    .expect("connected streams should have a peer address");
145                log_debug!("Peer address: {}", peer);
146                tokio::spawn(accept_connection(client_senders.clone(), peer, stream));
147            }
148            Err(e) => {
149                log_error!("Error accepting connection: {:?}", e);
150            }
151        }
152    }
153}
154
155/// Handles the WebSocket upgrade process for a new connection.
156///
157/// # Arguments
158///
159/// * `client_senders` - Shared client senders collection
160/// * `peer` - Socket address of the connecting client
161/// * `stream` - TCP stream for the connection
162pub async fn accept_connection(
163    client_senders: RwClientSenders,
164    peer: SocketAddr,
165    stream: TcpStream,
166) {
167    if let Err(e) = handle_connection(client_senders, peer, stream).await {
168        match e {
169            Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
170            err => log_error!("Error processing connection: {}", err),
171        }
172    }
173}
174
175/// Handles an established WebSocket connection.
176///
177/// Sets up bidirectional message handling for the client connection.
178///
179/// # Arguments
180///
181/// * `client_senders` - Shared client senders collection
182/// * `peer` - Socket address of the client
183/// * `stream` - TCP stream for the connection
184/// * `option` - Server configuration options
185///
186/// # Returns
187///
188/// A Result indicating whether the connection handling completed successfully
189pub async fn handle_connection(
190    client_senders: RwClientSenders,
191    peer: SocketAddr,
192    stream: TcpStream,
193) -> tungstenite::Result<()> {
194    match accept_async(stream).await {
195        Ok(ws_stream) => {
196            log_debug!("New WebSocket connection: {}", peer);
197            let (mut ostream, mut istream) = ws_stream.split();
198
199            let (sx, mut rx) = mpsc::channel(8);
200            tokio::spawn(async move {
201                let use_ping = client_senders.read().await.options.use_ping;
202                let id =
203                    get_id_from_first_message(&mut istream, client_senders.clone(), sx.clone())
204                        .await;
205
206                match id {
207                    Some(id) => {
208                        // Handle incoming messages
209                        while let Some(Ok(message)) = istream.next().await {
210                            let value = message.into_data();
211                            let data = match get_data_schema(&value) {
212                                Ok(data) => data,
213                                Err(e) => {
214                                    log_error!("Error getting data schema: {:?}", e);
215                                    continue;
216                                }
217                            };
218
219                            // Handle ping messages
220                            if data.category == Category::Ping as u16 && use_ping {
221                                if let Ok(data) = Ping::deserialize(&data.datas) {
222                                    client_senders
223                                        .send(data.peer.into(), make_pong_message())
224                                        .await;
225                                    continue;
226                                }
227                            }
228
229                            // Handle disconnect messages
230                            if data.category == Category::Disconnect as u16 {
231                                let _ = sx.send(make_disconnect_message(&peer.to_string())).await;
232                                break;
233                            }
234
235                            // Forward other messages to application handler
236                            client_senders.send_handle_message(data, &id).await;
237                        }
238                    }
239                    None => {
240                        let _ = sx.send(make_disconnect_message(&peer.to_string())).await;
241                    }
242                }
243            });
244
245            // Handle outgoing messages
246            while let Some(message) = rx.recv().await {
247                ostream.send(message.clone()).await?;
248                let data = message.into_data();
249                let data = match get_data_schema(&data) {
250                    Ok(data) => data,
251                    Err(e) => {
252                        log_error!("Error getting data schema: {:?}", e);
253                        rx.close();
254                        break;
255                    }
256                };
257                log_debug!("Server sending message: {:?}", data);
258                if data.category == Category::Disconnect as u16 {
259                    rx.close();
260                    break;
261                }
262            }
263            log_debug!("client: {} disconnected", peer);
264            ostream.flush().await?;
265        }
266        Err(e) => {
267            log_debug!("Error accepting WebSocket connection: {:?}", e);
268        }
269    }
270
271    Ok(())
272}
273
274/// Extracts client ID from the first message and sets up the connection.
275///
276/// WebSocket clients are expected to send a Ping message as their first
277/// communication, containing their client identifier.
278///
279/// # Arguments
280///
281/// * `istream` - Stream of incoming WebSocket messages
282/// * `client_senders` - Shared client senders collection
283/// * `sx` - Sender for outgoing messages to this client
284/// * `options` - Server configuration options
285///
286/// # Returns
287///
288/// Some(client_id) if identification was successful, None otherwise
289async fn get_id_from_first_message(
290    istream: &mut SplitStream<WebSocketStream<TcpStream>>,
291    client_senders: RwClientSenders,
292    sx: Sender<Message>,
293) -> Option<String> {
294    let mut _id: Option<String> = None;
295    if let Some(Ok(message)) = istream.next().await {
296        log_debug!("receive first message from client: {:?}", message);
297        let value = message.into_data();
298        let mut data = match get_data_schema(&value) {
299            Ok(data) => data,
300            Err(e) => {
301                log_error!("Error getting data schema: {:?}", e);
302                return None;
303            }
304        };
305        let options = client_senders.read().await.options.clone();
306
307        // Check if the first message is a ping
308        if data.category == Category::Ping as u16 {
309            log_debug!("receive ping from client: {:?}", data);
310            if let Ok(ping) = Ping::deserialize(&data.datas) {
311                _id = Some(ping.peer.into());
312                client_senders.add(&_id.as_ref().unwrap(), sx).await;
313
314                // Either respond with a pong or proxy the ping
315                if options.use_ping {
316                    client_senders
317                        .send(&_id.as_ref().unwrap(), make_pong_message())
318                        .await;
319                } else {
320                    // Optionally change the category when proxying
321                    if options.proxy_ping > 0 {
322                        data.category = options.proxy_ping as u16;
323                    }
324                    client_senders
325                        .send_handle_message(data, &_id.as_ref().unwrap())
326                        .await;
327                }
328            }
329        } else if options.proxy_ping > 0 && data.category == options.proxy_ping as u16 {
330            if let Ok(ping) = Ping::deserialize(&data.datas) {
331                _id = Some(ping.peer.into());
332                client_senders.add(&_id.as_ref().unwrap(), sx).await;
333
334                // Optionally change the category when proxying
335                data.category = options.proxy_ping as u16;
336                client_senders
337                    .send_handle_message(data, &_id.as_ref().unwrap())
338                    .await;
339            }
340        }
341    }
342    _id
343}