network_protocol/service/
daemon.rs

1use tokio::net::{TcpListener, TcpStream};
2use tokio_util::codec::Framed;
3use futures::{StreamExt, SinkExt};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::{mpsc, Mutex, oneshot};
7use std::net::SocketAddr;
8use tokio::time;
9use bincode;
10use tracing::{info, debug, warn, error, instrument};
11
12use crate::config::ServerConfig;
13
14use crate::utils::timeout::with_timeout_error;
15
16use crate::core::codec::PacketCodec;
17use crate::core::packet::Packet;
18use crate::protocol::message::Message;
19// Import secure handshake functions
20use crate::protocol::handshake::{server_secure_handshake_response, server_secure_handshake_finalize, clear_handshake_data};
21use crate::protocol::dispatcher::Dispatcher;
22use crate::protocol::keepalive::KeepAliveManager;
23use crate::protocol::heartbeat::{build_ping, is_pong};
24use crate::service::secure::SecureConnection;
25use crate::error::{Result, ProtocolError};
26
27/// Start a secure server and listen for connections using default configuration
28#[instrument(skip(addr), fields(address = %addr))]
29pub async fn start(addr: &str) -> Result<()> {
30    // Create a never-resolving shutdown receiver for standard operation
31    let (_, shutdown_rx) = oneshot::channel::<()>();
32    start_with_shutdown(addr, shutdown_rx).await
33}
34
35/// Start a secure server with custom configuration
36#[instrument(skip(config), fields(address = %config.address))]
37pub async fn start_with_config(config: ServerConfig) -> Result<()> {
38    // Create a never-resolving shutdown receiver for standard operation
39    let (_, shutdown_rx) = oneshot::channel::<()>();
40    start_with_config_and_shutdown(config, shutdown_rx).await
41}
42
43/// Start a secure server with shutdown control for testing
44#[instrument(skip(addr, shutdown_rx), fields(address = %addr))]
45pub async fn start_with_shutdown(
46    addr: &str,
47    shutdown_rx: oneshot::Receiver<()>
48) -> Result<()> {
49    // Use default configuration with overridden address
50    let config = ServerConfig {
51        address: addr.to_string(),
52        ..Default::default()
53    };
54    start_with_config_and_shutdown(config, shutdown_rx).await
55}
56
57/// Start a secure server with custom configuration and shutdown control
58#[instrument(skip(config, shutdown_rx), fields(address = %config.address))]
59pub async fn start_with_config_and_shutdown(
60    config: ServerConfig,
61    shutdown_rx: oneshot::Receiver<()>
62) -> Result<()> {
63    let listener = TcpListener::bind(&config.address).await?;
64    info!(address = %config.address, "Server listening");
65
66    // Shared dispatcher
67    let dispatcher = Arc::new(Dispatcher::new());
68    
69    // Register default handlers
70    register_default_handlers(&dispatcher)?;
71    
72    // Track active connections for graceful shutdown
73    let active_connections = Arc::new(Mutex::new(0u32));
74    
75    // Create a shutdown channel for internal use
76    let (internal_shutdown_tx, mut internal_shutdown_rx) = mpsc::channel::<()>(1);
77    
78    // Extract configuration values we need before moving config
79    let shutdown_timeout = config.shutdown_timeout;
80    let heartbeat_interval = config.heartbeat_interval;
81    
82    // Clone a sender for the task
83    let shutdown_tx_clone = internal_shutdown_tx.clone();
84    tokio::spawn(async move {
85        match tokio::signal::ctrl_c().await {
86            Ok(()) => {
87                info!("Shutdown signal received");
88                let _ = shutdown_tx_clone.send(()).await;
89            },
90            Err(err) => {
91                error!(error = %err, "Failed to listen for shutdown signal");
92            },
93        }
94    });
95    
96    // Also set up the oneshot receiver to trigger shutdown
97    let internal_shutdown_tx_clone = internal_shutdown_tx.clone();
98    tokio::spawn(async move {
99        if shutdown_rx.await.is_ok() {
100            info!("External shutdown signal received");
101            let _ = internal_shutdown_tx_clone.send(()).await;
102        }
103    });
104    
105    // Server main loop with graceful shutdown
106    loop {
107        tokio::select! {
108            // Check for shutdown signal
109            _ = internal_shutdown_rx.recv() => {     
110                info!("Shutting down server. Waiting for connections to close...");
111                
112                // Wait for active connections to close (with configured timeout)
113                let timeout = tokio::time::sleep(shutdown_timeout);
114                tokio::pin!(timeout);
115                
116                loop {
117                    tokio::select! {
118                        _ = &mut timeout => {
119                            warn!("Shutdown timeout reached, forcing exit");
120                            break;
121                        }
122                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
123                            let connections = *active_connections.lock().await;
124                            info!(connections = %connections, "Waiting for connections to close");
125                            if connections == 0 {
126                                info!("All connections closed, shutting down");
127                                break;
128                            }
129                        }
130                    }
131                }
132                
133                return Ok(());
134            }
135            
136            // Accept new connections
137            accept_result = listener.accept() => {
138                match accept_result {
139                    Ok((stream, peer)) => {
140                        info!(peer = %peer, "New connection established");
141                        let dispatcher = dispatcher.clone();
142                        let active_connections = active_connections.clone();
143                        // We don't need this clone if we're not using it in this scope
144                        // let _shutdown_tx = shutdown_tx.clone();
145                        
146                        // Increment active connections counter
147                        {
148                            let mut count = active_connections.lock().await;
149                            *count += 1;
150                        }
151                        
152                        // Clone the things we need to move into the task
153                        let active_connections_clone = active_connections.clone();
154                        let config_clone = config.clone();
155                        
156                        tokio::spawn(async move {
157                            handle_connection(stream, peer, dispatcher, active_connections_clone, config_clone, heartbeat_interval).await;
158                        });
159                    }
160                    Err(e) => {
161                        error!(error = %e, "Error accepting connection");
162                    }
163                }
164            }
165        }
166    }
167}
168
169/// Handle a client connection with proper cleanup on exit
170#[instrument(skip(stream, dispatcher, active_connections, config, heartbeat_interval), fields(peer = %peer))]
171async fn handle_connection(
172    stream: tokio::net::TcpStream,
173    peer: std::net::SocketAddr,
174    dispatcher: Arc<Dispatcher>,
175    active_connections: Arc<Mutex<u32>>,
176    config: ServerConfig,
177    heartbeat_interval: Duration,
178) {
179    // Setup the connection with cleanup
180    let result = with_timeout_error(
181        async {
182            process_connection(stream, dispatcher, peer, config.clone(), heartbeat_interval).await
183        },
184        config.connection_timeout
185    ).await;
186    
187    // If there was an error, log it
188    match result {
189        Ok(_) => info!("Connection closed gracefully"),
190        Err(ProtocolError::Timeout) => warn!("Connection timed out"),
191        Err(e) => error!(error = %e, "Connection error"),
192    }
193    
194    // Always decrement active connections on exit
195    {
196        let mut count = active_connections.lock().await;
197        *count -= 1;
198    }
199    
200    info!("Client disconnected");
201}
202
203/// Process a client connection with handshake and secure messages
204#[instrument(skip(stream, dispatcher, peer, config, heartbeat_interval), fields(peer = %peer))]  
205async fn process_connection(
206    stream: TcpStream,
207    dispatcher: Arc<Dispatcher>,
208    peer: SocketAddr,
209    config: ServerConfig,
210    heartbeat_interval: Duration,
211) -> Result<()> {
212    // Create the framed stream for packet codec
213    let mut framed = Framed::new(stream, PacketCodec);
214
215    // --- Expect Secure Handshake Init (with timeout) ---
216    let init = with_timeout_error(
217        async {
218            match framed.next().await {
219                Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
220                    .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
221                Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
222                None => Err(ProtocolError::ConnectionClosed),
223            }
224        },
225        config.connection_timeout
226    ).await?;
227
228    // Extract the client's handshake init data
229    let (client_pub_key, client_timestamp, client_nonce) = match init {
230        Message::SecureHandshakeInit { pub_key, timestamp, nonce } => {
231            (pub_key, timestamp, nonce)
232        },
233        _ => return Err(ProtocolError::HandshakeError("Unexpected message type".to_string())),
234    };
235
236    // --- Send Secure Handshake Response ---
237    let response = server_secure_handshake_response(client_pub_key, client_nonce, client_timestamp)?;
238    
239    let response_bytes = bincode::serialize(&response)
240        .map_err(|e| ProtocolError::SerializeError(e.to_string()))?;
241        
242    framed.send(Packet { version: 1, payload: response_bytes }).await
243        .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
244    
245    // --- Expect Handshake Confirmation (with timeout) ---
246    let confirm = with_timeout_error(
247        async {
248            match framed.next().await {
249                Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
250                    .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
251                Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
252                None => Err(ProtocolError::ConnectionClosed),
253            }
254        },
255        config.connection_timeout
256    ).await?;
257    
258    let nonce_verification = match confirm {
259        Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
260        _ => return Err(ProtocolError::HandshakeError("Expected handshake confirmation".to_string())),
261    };
262    
263    // --- Finalize Handshake and Derive Session Key ---
264    let session_key = server_secure_handshake_finalize(nonce_verification)?;
265    
266    // Clear sensitive handshake data from memory
267    let _ = clear_handshake_data();
268    
269    // Create secure connection with derived key
270    let conn = SecureConnection::new(framed, session_key);
271    
272    // Handle the secure message loop
273    handle_secure_connection(conn, dispatcher, peer, heartbeat_interval).await?;
274    
275    Ok(())
276}
277
278/// Register default message handlers
279#[instrument(skip(dispatcher))]
280fn register_default_handlers(dispatcher: &Arc<Dispatcher>) -> Result<()> {
281    // Handle Ping messages
282    dispatcher.register("PING", |_| {
283        debug!("Responding to ping with pong");
284        Ok(Message::Pong)
285    })?;
286    
287    // Handle Echo messages
288    dispatcher.register("ECHO", |msg| {
289        if let Message::Echo(text) = msg {
290            debug!(text = %text, "Echoing message");
291            Ok(Message::Echo(text.clone()))
292        } else {
293            Err(ProtocolError::Custom("Invalid Echo message format".to_string()))
294        }
295    })?;
296    
297    Ok(())
298}
299
300/// Message type for the internal processing channel
301#[derive(Debug)]
302enum ProcessingMessage {
303    /// Regular message to be processed
304    Message(Message),
305    /// Signal to terminate the processing task
306    Terminate,
307}
308
309/// Response from the processing task
310#[derive(Debug)]
311struct ProcessingResult {
312    /// The original message ID or correlation ID
313    original_id: usize,
314    /// The response message to send back
315    response: Option<Message>,
316}
317
318/// Handle a secure connection after handshake with backpressure
319#[instrument(skip(conn, dispatcher, heartbeat_interval), fields(peer = %peer))]
320async fn handle_secure_connection(
321    mut conn: SecureConnection,
322    dispatcher: Arc<Dispatcher>,
323    peer: std::net::SocketAddr,
324    heartbeat_interval: Duration,
325) -> Result<()> {
326    // --- Initialize Keep-Alive Manager with configured interval ---
327    let dead_timeout = heartbeat_interval.mul_f32(4.0); // 4x the heartbeat interval for dead connection detection
328    let mut keep_alive = KeepAliveManager::with_settings(heartbeat_interval, dead_timeout);
329    let mut ping_interval = time::interval(keep_alive.ping_interval());
330    
331    // --- Create bounded channels for backpressure with capacity from config ---
332    // We're using an internal messaging channel, so we can use a reasonable default here
333    let (msg_tx, msg_rx) = mpsc::channel::<ProcessingMessage>(32);
334    let (resp_tx, mut resp_rx) = mpsc::channel::<ProcessingResult>(32);
335    
336    // --- Spawn message processing task ---
337    let dispatcher_clone = dispatcher.clone();
338    let processor_handle = tokio::spawn(async move {
339        process_messages(msg_rx, resp_tx, dispatcher_clone).await
340    });
341    
342    // --- Set up result for final status ---
343    let mut final_result = Ok(());
344    let mut next_msg_id: usize = 0;
345    
346    // --- Secure Message Loop with Backpressure ---
347    'main: loop {
348        tokio::select! {
349            // Check if we need to send a ping
350            _ = ping_interval.tick() => {
351                if keep_alive.should_ping() {
352                    debug!("Sending keep-alive ping");
353                    let ping = build_ping();
354                    if let Err(e) = conn.secure_send(ping).await {
355                        warn!(error = %e, "Failed to send ping");
356                        final_result = Err(e);
357                        break 'main;
358                    }
359                    keep_alive.update_send();
360                }
361                
362                // Check if connection is dead
363                if keep_alive.is_connection_dead() {
364                    warn!(dead_seconds = ?keep_alive.time_since_last_recv().as_secs(), 
365                          "Connection appears dead, closing");
366                    final_result = Err(ProtocolError::ConnectionTimeout);
367                    break 'main;
368                }
369            }
370            
371            // Process any responses from the processing task
372            Some(result) = resp_rx.recv() => {
373                if let Some(response) = result.response {
374                    debug!("Sending response for message {}", result.original_id);
375                    if let Err(e) = conn.secure_send(response).await {
376                        warn!(error = %e, "Failed to send response");
377                        final_result = Err(e);
378                        break 'main;
379                    }
380                    keep_alive.update_send();
381                }
382            }
383            
384            // Try to receive a message with backpressure awareness
385            recv_result = conn.secure_recv::<Message>() => {
386                match recv_result {
387                    Ok(msg) => {
388                        debug!(message = ?msg, "Received message");
389                        keep_alive.update_recv();
390                        
391                        // Check for disconnect message - handle directly without channel
392                        if matches!(msg, Message::Disconnect) {
393                            info!("Received disconnect request");
394                            break 'main;
395                        }
396                        
397                        // Special handling for pong messages - handle directly
398                        if is_pong(&msg) {
399                            debug!("Received pong response");
400                            continue;
401                        }
402                        
403                        // Just increment the ID counter for the next message
404                        next_msg_id = next_msg_id.wrapping_add(1);
405                        
406                        // Apply backpressure if needed
407                        if msg_tx.capacity() == 0 {
408                            debug!("Channel full - applying backpressure");
409                            
410                            // Wait until the channel has capacity before receiving more messages
411                            match msg_tx.reserve().await {
412                                Ok(permit) => {
413                                    // Channel has capacity again, send the message
414                                    permit.send(ProcessingMessage::Message(msg));
415                                },
416                                Err(_) => {
417                                    // Channel was closed, exit the loop
418                                    warn!("Processing channel closed unexpectedly");
419                                    break 'main;
420                                }
421                            }
422                        } else {
423                            // Channel has capacity, send the message
424                            if (msg_tx.send(ProcessingMessage::Message(msg)).await).is_err() {
425                                // Channel was closed, exit the loop
426                                warn!("Failed to send message to processing channel");
427                                break 'main;
428                            }
429                        }
430                    }
431                    Err(ProtocolError::Timeout) => {
432                        // Timeout is expected, just continue the loop
433                        continue;
434                    }
435                    Err(e) => {
436                        final_result = Err(e);
437                        break 'main;
438                    }
439                }
440            }
441        }
442    }
443    
444    // Signal the processor to terminate
445    debug!("Signaling processor to terminate");
446    let _ = msg_tx.send(ProcessingMessage::Terminate).await;
447    
448    // Wait for processor to finish
449    debug!("Waiting for processor to terminate");
450    let _ = processor_handle.await;
451    
452    final_result
453}
454
455/// Process messages from the channel
456#[instrument(skip(rx, resp_tx, dispatcher), level = "debug")]
457async fn process_messages(
458    mut rx: mpsc::Receiver<ProcessingMessage>,
459    resp_tx: mpsc::Sender<ProcessingResult>,
460    dispatcher: Arc<Dispatcher>,
461) {
462    let mut msg_counter: usize = 0;
463    
464    while let Some(proc_msg) = rx.recv().await {
465        match proc_msg {
466            ProcessingMessage::Message(msg) => {
467                let msg_id = msg_counter;
468                msg_counter += 1;
469                
470                debug!(msg_id = msg_id, message = ?msg, "Processing message from channel");
471                
472                let response = match dispatcher.dispatch(&msg) {
473                    Ok(reply) => {
474                        // Successfully dispatched, prepare response
475                        Some(reply)
476                    },
477                    Err(e) => {
478                        // Log dispatch error but continue processing messages
479                        warn!(error = %e, "Error dispatching message");
480                        None
481                    }
482                };
483                
484                // Send response back through the response channel
485                let result = ProcessingResult {
486                    original_id: msg_id,
487                    response,
488                };
489                
490                if (resp_tx.send(result).await).is_err() {
491                    warn!("Failed to send processing result - reader likely disconnected");
492                    break;
493                }
494            },
495            ProcessingMessage::Terminate => {
496                debug!("Processor received terminate signal");
497                break;
498            }
499        }
500    }
501    
502    debug!("Message processor terminated");
503}
504
505/// A server daemon handle that can be controlled externally
506#[derive(Debug)]
507pub struct Daemon {
508    /// Address the server is listening on
509    pub address: String,
510    /// Shutdown signal sender
511    shutdown_tx: Option<oneshot::Sender<()>>,
512}
513
514impl Daemon {
515    /// Create a new daemon handle
516    pub fn new(address: String, shutdown_tx: oneshot::Sender<()>) -> Self {
517        Self {
518            address,
519            shutdown_tx: Some(shutdown_tx),
520        }
521    }
522    
523    /// Run the daemon until completion or shutdown signal
524    pub async fn run(self) -> Result<()> {
525        // This function doesn't actually do anything - the server is started in the start_* functions
526        // This is just a placeholder for API compatibility
527        Ok(())
528    }
529    
530    /// Shutdown the daemon gracefully
531    pub async fn shutdown(&mut self) -> Result<()> {
532        if let Some(tx) = self.shutdown_tx.take() {
533            let _ = tx.send(());
534            Ok(())
535        } else {
536            Err(ProtocolError::Custom("Shutdown already called".to_string()))
537        }
538    }
539    
540    /// Shutdown the daemon with a custom timeout
541    pub async fn shutdown_with_timeout(&mut self, _timeout: Duration) -> Result<()> {
542        // The timeout is handled internally in the server loop
543        self.shutdown().await
544    }
545}
546
547/// Start a server daemon with provided configuration and return a handle to it
548#[instrument(skip(config, _dispatcher), fields(address = %config.address))]
549pub async fn start_daemon_no_signals(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Result<Daemon> {
550    // Create a shutdown channel
551    let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
552    
553    let address = config.address.clone();
554    
555    // Start the server in a background task
556    tokio::spawn(async move {
557        if let Err(e) = start_with_config_and_shutdown(config, shutdown_rx).await {
558            error!(error = ?e, "Server error");
559        }
560    });
561    
562    // Return a daemon handle
563    Ok(Daemon::new(address, shutdown_tx))
564}
565
566/// Create a new server daemon with configuration and dispatcher
567pub fn new_with_config(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Daemon {
568    let (shutdown_tx, _) = oneshot::channel::<()>();
569    Daemon::new(config.address.clone(), shutdown_tx)
570}