ws_rs/
server.rs

1use std::fs::File;
2use std::io::BufReader;
3use std::net::SocketAddr;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{FutureExt, SinkExt, StreamExt};
9use log::{debug, error, info, warn};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11use tokio::net::{TcpListener, TcpStream};
12use tokio::sync::mpsc;
13use tokio_rustls::TlsAcceptor;
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::accept_async;
16
17/// Message type for communication with WebSocket clients
18#[derive(Debug, Clone)]
19pub enum WsMessage {
20    /// Text message
21    Text(String),
22    /// Binary message
23    Binary(Vec<u8>),
24    /// Close connection request
25    Close,
26}
27
28impl From<Message> for WsMessage {
29    fn from(msg: Message) -> Self {
30        match msg {
31            Message::Text(text) => WsMessage::Text(text),
32            Message::Binary(data) => WsMessage::Binary(data),
33            Message::Close(_) => WsMessage::Close,
34            Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
35                // These are handled internally by the WebSocket implementation
36                WsMessage::Text("".to_string())
37            }
38        }
39    }
40}
41
42impl From<WsMessage> for Message {
43    fn from(msg: WsMessage) -> Self {
44        match msg {
45            WsMessage::Text(text) => Message::Text(text),
46            WsMessage::Binary(data) => Message::Binary(data),
47            WsMessage::Close => Message::Close(None),
48        }
49    }
50}
51
52/// Client identifier for WebSocket connections
53#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct ClientId(pub String);
55
56/// Configuration for WebSocket server
57#[derive(Debug, Clone)]
58pub struct WsServerConfig {
59    /// Address to bind the server to
60    pub addr: String,
61    /// Path to server certificate file
62    pub cert_path: PathBuf,
63    /// Path to server private key file
64    pub key_path: PathBuf,
65    /// Path to CA certificate file
66    pub ca_cert_path: PathBuf,
67    /// Maximum number of concurrent connections
68    pub max_connections: usize,
69    /// Connection timeout in seconds
70    pub connection_timeout: u64,
71    /// Enable client certificate verification
72    pub client_cert_required: bool,
73}
74
75impl Default for WsServerConfig {
76    fn default() -> Self {
77        Self {
78            addr: "127.0.0.1:9000".to_string(),
79            cert_path: PathBuf::from("./crate_cert/a_cert.pem"),
80            key_path: PathBuf::from("./crate_cert/a_key.pem"),
81            ca_cert_path: PathBuf::from("./crate_cert/ca_cert.pem"),
82            max_connections: 1000,
83            connection_timeout: 30,
84            client_cert_required: true,
85        }
86    }
87}
88
89/// WebSocket Server handler for processing incoming messages
90///
91/// Implement this trait to handle incoming WebSocket messages.
92pub trait ServerHandler: Send + Sync + 'static {
93    /// Called when a client connects
94    fn on_connect(&self, client_id: ClientId, addr: SocketAddr);
95    
96    /// Called when a client disconnects
97    fn on_disconnect(&self, client_id: ClientId);
98    
99    /// Called when a message is received from a client
100    ///
101    /// Return optional response message
102    fn on_message(&self, client_id: ClientId, message: WsMessage) -> Option<WsMessage>;
103    
104    /// Called when an error occurs
105    fn on_error(&self, client_id: Option<ClientId>, error: String);
106}
107
108/// Client connection manager
109struct ClientConnection {
110    client_id: ClientId,
111    tx: mpsc::Sender<WsMessage>,
112}
113
114/// WebSocket Server
115pub struct WsServer {
116    config: WsServerConfig,
117    handler: Arc<dyn ServerHandler>,
118    tls_acceptor: TlsAcceptor,
119    clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
120}
121
122impl WsServer {
123    /// Create a new WebSocket server with the provided configuration and handler
124    ///
125    /// # Arguments
126    ///
127    /// * `config` - Server configuration
128    /// * `handler` - Server handler for processing messages
129    ///
130    /// # Returns
131    ///
132    /// * `Result<Self, String>` - New server instance or error
133    pub fn new(config: WsServerConfig, handler: impl ServerHandler) -> Result<Self, String> {
134        let tls_acceptor = Self::create_tls_acceptor(&config)
135            .map_err(|e| format!("Failed to create TLS acceptor: {}", e))?;
136            
137        Ok(Self {
138            config,
139            handler: Arc::new(handler),
140            tls_acceptor,
141            clients: Arc::new(tokio::sync::Mutex::new(Vec::new())),
142        })
143    }
144    
145    /// Start the WebSocket server
146    ///
147    /// # Returns
148    ///
149    /// * `Result<(), String>` - Success or error
150    pub async fn start(&self) -> Result<(), String> {
151        let listener = TcpListener::bind(&self.config.addr)
152            .await
153            .map_err(|e| format!("Failed to bind to address {}: {}", self.config.addr, e))?;
154            
155        info!("WebSocket server started on {}", self.config.addr);
156        
157        loop {
158            match listener.accept().await {
159                Ok((stream, addr)) => {
160                    debug!("New TCP connection from: {}", addr);
161                    
162                    // Clone necessary references for the task
163                    let acceptor = self.tls_acceptor.clone();
164                    let handler = self.handler.clone();
165                    let clients = self.clients.clone();
166                    let connection_timeout = Duration::from_secs(self.config.connection_timeout);
167                    
168                    // Generate a unique client ID
169                    let client_id = ClientId(format!("client-{}", uuid_simple()));
170                    let client_id_clone = client_id.clone();
171                    
172                    tokio::spawn(async move {
173                        if let Err(e) = Self::handle_connection(
174                            stream, 
175                            addr, 
176                            acceptor, 
177                            handler, 
178                            clients, 
179                            client_id_clone, 
180                            connection_timeout
181                        ).await {
182                            error!("Connection error for {}: {}", addr, e);
183                        }
184                    });
185                }
186                Err(e) => {
187                    error!("Failed to accept connection: {}", e);
188                }
189            }
190            
191            // Check if we've reached the maximum number of connections
192            let client_count = self.clients.lock().await.len();
193            if client_count >= self.config.max_connections {
194                warn!("Maximum connections reached: {}", client_count);
195                
196                // Small delay to avoid CPU spinning
197                tokio::time::sleep(Duration::from_millis(100)).await;
198            }
199        }
200    }
201    
202    /// Broadcast a message to all connected clients
203    ///
204    /// # Arguments
205    ///
206    /// * `message` - Message to broadcast
207    ///
208    /// # Returns
209    ///
210    /// * `Result<usize, String>` - Number of clients that received the message or error
211    pub async fn broadcast(&self, message: WsMessage) -> Result<usize, String> {
212        let clients = self.clients.lock().await;
213        let mut sent_count = 0;
214        
215        for client in clients.iter() {
216            if client.tx.send(message.clone()).await.is_ok() {
217                sent_count += 1;
218            }
219        }
220        
221        Ok(sent_count)
222    }
223    
224    /// Send a message to a specific client
225    ///
226    /// # Arguments
227    ///
228    /// * `client_id` - ID of the client to send the message to
229    /// * `message` - Message to send
230    ///
231    /// # Returns
232    ///
233    /// * `Result<(), String>` - Success or error
234    pub async fn send_to_client(&self, client_id: &ClientId, message: WsMessage) -> Result<(), String> {
235        let clients = self.clients.lock().await;
236        
237        for client in clients.iter() {
238            if client.client_id == *client_id {
239                return client.tx.send(message)
240                    .await
241                    .map_err(|_| format!("Failed to send message to client {}", client_id.0));
242            }
243        }
244        
245        Err(format!("Client not found: {}", client_id.0))
246    }
247    
248    /// Get the number of connected clients
249    ///
250    /// # Returns
251    ///
252    /// * `usize` - Number of connected clients
253    pub async fn client_count(&self) -> usize {
254        self.clients.lock().await.len()
255    }
256    
257    /// Get a list of connected client IDs
258    ///
259    /// # Returns
260    ///
261    /// * `Vec<ClientId>` - List of connected client IDs
262    pub async fn client_list(&self) -> Vec<ClientId> {
263        let clients = self.clients.lock().await;
264        clients.iter().map(|c| c.client_id.clone()).collect()
265    }
266    
267    /// Handle a new WebSocket connection
268    async fn handle_connection(
269        stream: TcpStream,
270        addr: SocketAddr,
271        acceptor: TlsAcceptor,
272        handler: Arc<dyn ServerHandler>,
273        clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
274        client_id: ClientId,
275        connection_timeout: Duration,
276    ) -> Result<(), String> {
277        // Apply timeout to the TLS handshake
278        let tls_handshake = tokio::time::timeout(
279            connection_timeout,
280            acceptor.accept(stream),
281        ).await
282            .map_err(|_| format!("TLS handshake timed out after {} seconds", connection_timeout.as_secs()))?
283            .map_err(|e| format!("TLS handshake failed: {}", e))?;
284            
285        debug!("TLS handshake successful for {}", addr);
286        
287        // Apply timeout to the WebSocket handshake
288        let ws_stream = tokio::time::timeout(
289            connection_timeout,
290            accept_async(tls_handshake),
291        ).await
292            .map_err(|_| format!("WebSocket handshake timed out after {} seconds", connection_timeout.as_secs()))?
293            .map_err(|e| format!("WebSocket handshake failed: {}", e))?;
294            
295        debug!("WebSocket handshake successful for {}", addr);
296        
297        // Create message channels
298        let (tx, mut rx) = mpsc::channel::<WsMessage>(100);
299        
300        // Register the client
301        {
302            let mut clients_lock = clients.lock().await;
303            clients_lock.push(ClientConnection {
304                client_id: client_id.clone(),
305                tx: tx.clone(),
306            });
307            
308            info!("Client connected: {} from {}", client_id.0, addr);
309        }
310        
311        // Notify the handler about the new connection
312        handler.on_connect(client_id.clone(), addr);
313        
314        // Split WebSocket stream
315        let (ws_sender, ws_receiver) = ws_stream.split();
316        
317        // Forward outgoing messages to the WebSocket
318        let mut send_task = {
319            let mut ws_sender = ws_sender;
320            let client_id_for_send = client_id.clone();
321            let handler_for_send = handler.clone();
322            
323            async move {
324                while let Some(msg) = rx.recv().await {
325                    match ws_sender.send(msg.into()).await {
326                        Ok(_) => {
327                            debug!("Message sent to client {}", client_id_for_send.0);
328                        }
329                        Err(e) => {
330                            let error_msg = format!("Failed to send message: {}", e);
331                            handler_for_send.on_error(Some(client_id_for_send.clone()), error_msg);
332                            break;
333                        }
334                    }
335                }
336                
337                // Try to close the connection gracefully
338                let _ = ws_sender.close().await;
339                
340                debug!("Send task completed for client {}", client_id_for_send.0);
341            }.boxed()
342        };
343        
344        // Process incoming messages from the WebSocket
345        let mut receive_task = {
346            let mut ws_receiver = ws_receiver;
347            let handler_for_recv = handler.clone();
348            let client_id_for_recv = client_id.clone();
349            let tx_for_recv = tx.clone();
350            
351            async move {
352                while let Some(result) = ws_receiver.next().await {
353                    match result {
354                        Ok(msg) => {
355                            if msg.is_close() {
356                                debug!("Client {} requested close", client_id_for_recv.0);
357                                break;
358                            }
359                            
360                            // Convert to our message type
361                            let ws_msg = WsMessage::from(msg);
362                            
363                            // Let the handler process the message
364                            if let Some(response) = handler_for_recv.on_message(client_id_for_recv.clone(), ws_msg) {
365                                // Send the response if provided
366                                if tx_for_recv.send(response).await.is_err() {
367                                    break;
368                                }
369                            }
370                        }
371                        Err(e) => {
372                            let error_msg = format!("Error receiving message: {}", e);
373                            handler_for_recv.on_error(Some(client_id_for_recv.clone()), error_msg);
374                            break;
375                        }
376                    }
377                }
378                
379                debug!("Receive task completed for client {}", client_id_for_recv.0);
380            }.boxed()
381        };
382        
383        // Wait for either task to complete
384        tokio::select! {
385            _ = &mut send_task => {},
386            _ = &mut receive_task => {},
387        }
388        
389        // Clean up
390        Self::remove_client(clients, client_id.clone()).await;
391        handler.on_disconnect(client_id.clone());
392        
393        info!("Client disconnected: {} from {}", client_id.0, addr);
394        Ok(())
395    }
396    
397    /// Remove a client from the clients list
398    async fn remove_client(
399        clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
400        client_id: ClientId,
401    ) {
402        let mut clients_lock = clients.lock().await;
403        if let Some(pos) = clients_lock.iter().position(|c| c.client_id == client_id) {
404            clients_lock.remove(pos);
405        }
406    }
407    
408    /// Create a TLS acceptor from the server configuration
409    fn create_tls_acceptor(config: &WsServerConfig) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
410        // Load certificates and keys
411        info!("Loading certificates and keys...");
412        let certs = load_certs(&config.cert_path)?;
413        let key = load_private_key(&config.key_path)?;
414        let ca_certs = load_certs(&config.ca_cert_path)?;
415        
416        // Create root certificate store
417        let mut root_cert_store = rustls::RootCertStore::empty();
418        for cert in ca_certs {
419            root_cert_store.add(cert)?;
420        }
421        
422        // Create a server config builder
423        let server_config = if config.client_cert_required {
424            // Create client certificate verifier
425            let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
426                .build()?;
427                
428            // Configure TLS server with client authentication
429            rustls::ServerConfig::builder()
430                .with_client_cert_verifier(client_verifier)
431                .with_single_cert(certs, key)?
432        } else {
433            // Configure TLS server without client authentication
434            rustls::ServerConfig::builder()
435                .with_no_client_auth()
436                .with_single_cert(certs, key)?
437        };
438            
439        // Create TLS acceptor
440        Ok(TlsAcceptor::from(Arc::new(server_config)))
441    }
442}
443
444/// Load certificates from a file
445///
446/// # Arguments
447///
448/// * `path` - Path to the certificate file
449///
450/// # Returns
451///
452/// * `Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>>` - Loaded certificates or error
453fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>> {
454    let file = File::open(path)?;
455    let mut reader = BufReader::new(file);
456    let mut certs = Vec::new();
457    
458    for cert_result in rustls_pemfile::certs(&mut reader) {
459        let cert = cert_result?;
460        certs.push(cert);
461    }
462    
463    if certs.is_empty() {
464        return Err(format!("No certificates found in {}", path.display()).into());
465    }
466    
467    Ok(certs)
468}
469
470/// Load private key from a file
471///
472/// # Arguments
473///
474/// * `path` - Path to the private key file
475///
476/// # Returns
477///
478/// * `Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>>` - Loaded private key or error
479fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>> {
480    let file = File::open(path)?;
481    let mut reader = BufReader::new(file);
482    
483    // Try PKCS8 format first
484    let mut pkcs8_keys = Vec::new();
485    for key_result in rustls_pemfile::pkcs8_private_keys(&mut reader) {
486        pkcs8_keys.push(key_result?);
487    }
488    
489    if !pkcs8_keys.is_empty() {
490        return Ok(PrivateKeyDer::Pkcs8(pkcs8_keys.remove(0)));
491    }
492    
493    // Reset reader position
494    reader = BufReader::new(File::open(path)?);
495    
496    // Try RSA format
497    let mut rsa_keys = Vec::new();
498    for key_result in rustls_pemfile::rsa_private_keys(&mut reader) {
499        rsa_keys.push(key_result?);
500    }
501    
502    if !rsa_keys.is_empty() {
503        return Ok(PrivateKeyDer::Pkcs1(rsa_keys.remove(0)));
504    }
505    
506    // Reset reader position
507    reader = BufReader::new(File::open(path)?);
508    
509    // Try EC format
510    let mut ec_keys = Vec::new();
511    for key_result in rustls_pemfile::ec_private_keys(&mut reader) {
512        ec_keys.push(key_result?);
513    }
514    
515    if !ec_keys.is_empty() {
516        return Ok(PrivateKeyDer::Sec1(ec_keys.remove(0)));
517    }
518    
519    Err(format!("No private keys found in {}", path.display()).into())
520}
521
522/// Generate a simple UUID-like string for client IDs
523fn uuid_simple() -> String {
524    use std::time::{SystemTime, UNIX_EPOCH};
525    let now = SystemTime::now()
526        .duration_since(UNIX_EPOCH)
527        .unwrap_or_default();
528        
529    format!(
530        "{:x}{:x}",
531        now.as_secs(),
532        now.subsec_nanos()
533    )
534}