network_protocol/service/
daemon.rs

1use bincode;
2use futures::{SinkExt, StreamExt};
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tokio::time;
9use tokio_util::codec::Framed;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::config::ServerConfig;
13
14use crate::utils::timeout::with_timeout_error;
15
16use crate::core::codec::PacketCodec;
17use crate::core::packet::Packet;
18use crate::protocol::message::Message;
19// Import secure handshake functions
20use crate::error::{ProtocolError, Result};
21use crate::protocol::dispatcher::Dispatcher;
22use crate::protocol::handshake::{
23    server_secure_handshake_finalize, server_secure_handshake_response,
24};
25use crate::protocol::heartbeat::{build_ping, is_pong};
26use crate::protocol::keepalive::KeepAliveManager;
27use crate::service::secure::SecureConnection;
28
29/// Start a secure server and listen for connections using default configuration
30#[instrument(skip(addr), fields(address = %addr))]
31pub async fn start(addr: &str) -> Result<()> {
32    // Create a never-resolving shutdown receiver for standard operation
33    let (_, shutdown_rx) = oneshot::channel::<()>();
34    start_with_shutdown(addr, shutdown_rx).await
35}
36
37/// Start a secure server with custom configuration
38#[instrument(skip(config), fields(address = %config.address))]
39pub async fn start_with_config(config: ServerConfig) -> Result<()> {
40    // Create a never-resolving shutdown receiver for standard operation
41    let (_, shutdown_rx) = oneshot::channel::<()>();
42    start_with_config_and_shutdown(config, shutdown_rx).await
43}
44
45/// Start a secure server with shutdown control for testing
46#[instrument(skip(addr, shutdown_rx), fields(address = %addr))]
47pub async fn start_with_shutdown(addr: &str, shutdown_rx: oneshot::Receiver<()>) -> Result<()> {
48    // Use default configuration with overridden address
49    let config = ServerConfig {
50        address: addr.to_string(),
51        ..Default::default()
52    };
53    start_with_config_and_shutdown(config, shutdown_rx).await
54}
55
56/// Start a secure server with custom configuration and shutdown control
57#[instrument(skip(config, shutdown_rx), fields(address = %config.address))]
58pub async fn start_with_config_and_shutdown(
59    config: ServerConfig,
60    shutdown_rx: oneshot::Receiver<()>,
61) -> Result<()> {
62    let listener = TcpListener::bind(&config.address).await?;
63    info!(address = %config.address, "Server listening");
64
65    // Shared dispatcher
66    let dispatcher = Arc::new(Dispatcher::new());
67
68    // Register default handlers
69    register_default_handlers(&dispatcher)?;
70
71    // Track active connections for graceful shutdown
72    let active_connections = Arc::new(Mutex::new(0u32));
73
74    // Create a shutdown channel for internal use
75    let (internal_shutdown_tx, mut internal_shutdown_rx) = mpsc::channel::<()>(1);
76
77    // Extract configuration values we need before moving config
78    let shutdown_timeout = config.shutdown_timeout;
79    let heartbeat_interval = config.heartbeat_interval;
80
81    // Clone a sender for the task
82    let shutdown_tx_clone = internal_shutdown_tx.clone();
83    tokio::spawn(async move {
84        match tokio::signal::ctrl_c().await {
85            Ok(()) => {
86                info!("Shutdown signal received");
87                let _ = shutdown_tx_clone.send(()).await;
88            }
89            Err(err) => {
90                error!(error = %err, "Failed to listen for shutdown signal");
91            }
92        }
93    });
94
95    // Also set up the oneshot receiver to trigger shutdown
96    let internal_shutdown_tx_clone = internal_shutdown_tx.clone();
97    tokio::spawn(async move {
98        if shutdown_rx.await.is_ok() {
99            info!("External shutdown signal received");
100            let _ = internal_shutdown_tx_clone.send(()).await;
101        }
102    });
103
104    // Server main loop with graceful shutdown
105    loop {
106        tokio::select! {
107            // Check for shutdown signal
108            _ = internal_shutdown_rx.recv() => {
109                info!("Shutting down server. Waiting for connections to close...");
110
111                // Wait for active connections to close (with configured timeout)
112                let timeout = tokio::time::sleep(shutdown_timeout);
113                tokio::pin!(timeout);
114
115                loop {
116                    tokio::select! {
117                        _ = &mut timeout => {
118                            warn!("Shutdown timeout reached, forcing exit");
119                            break;
120                        }
121                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
122                            let connections = *active_connections.lock().await;
123                            info!(connections = %connections, "Waiting for connections to close");
124                            if connections == 0 {
125                                info!("All connections closed, shutting down");
126                                break;
127                            }
128                        }
129                    }
130                }
131
132                return Ok(());
133            }
134
135            // Accept new connections
136            accept_result = listener.accept() => {
137                match accept_result {
138                    Ok((stream, peer)) => {
139                        info!(peer = %peer, "New connection established");
140                        let dispatcher = dispatcher.clone();
141                        let active_connections = active_connections.clone();
142                        // We don't need this clone if we're not using it in this scope
143                        // let _shutdown_tx = shutdown_tx.clone();
144
145                        // Increment active connections counter
146                        {
147                            let mut count = active_connections.lock().await;
148                            *count += 1;
149                        }
150
151                        // Clone the things we need to move into the task
152                        let active_connections_clone = active_connections.clone();
153                        let config_clone = config.clone();
154
155                        tokio::spawn(async move {
156                            handle_connection(stream, peer, dispatcher, active_connections_clone, config_clone, heartbeat_interval).await;
157                        });
158                    }
159                    Err(e) => {
160                        error!(error = %e, "Error accepting connection");
161                    }
162                }
163            }
164        }
165    }
166}
167
168/// Handle a client connection with proper cleanup on exit
169#[instrument(skip(stream, dispatcher, active_connections, config, heartbeat_interval), fields(peer = %peer))]
170async fn handle_connection(
171    stream: tokio::net::TcpStream,
172    peer: std::net::SocketAddr,
173    dispatcher: Arc<Dispatcher>,
174    active_connections: Arc<Mutex<u32>>,
175    config: ServerConfig,
176    heartbeat_interval: Duration,
177) {
178    // Setup the connection with cleanup
179    let result = with_timeout_error(
180        async {
181            process_connection(stream, dispatcher, peer, config.clone(), heartbeat_interval).await
182        },
183        config.connection_timeout,
184    )
185    .await;
186
187    // If there was an error, log it
188    match result {
189        Ok(_) => info!("Connection closed gracefully"),
190        Err(ProtocolError::Timeout) => warn!("Connection timed out"),
191        Err(e) => error!(error = %e, "Connection error"),
192    }
193
194    // Always decrement active connections on exit
195    {
196        let mut count = active_connections.lock().await;
197        *count -= 1;
198    }
199
200    info!("Client disconnected");
201}
202
203/// Process a client connection with handshake and secure messages
204#[instrument(skip(stream, dispatcher, peer, config, heartbeat_interval), fields(peer = %peer))]
205async fn process_connection(
206    stream: TcpStream,
207    dispatcher: Arc<Dispatcher>,
208    peer: SocketAddr,
209    config: ServerConfig,
210    heartbeat_interval: Duration,
211) -> Result<()> {
212    // Create the framed stream for packet codec
213    let mut framed = Framed::new(stream, PacketCodec);
214
215    // --- Expect Secure Handshake Init (with timeout) ---
216    let init = with_timeout_error(
217        async {
218            match framed.next().await {
219                Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
220                    .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
221                Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
222                None => Err(ProtocolError::ConnectionClosed),
223            }
224        },
225        config.connection_timeout,
226    )
227    .await?;
228
229    // Extract the client's handshake init data
230    let (client_pub_key, client_timestamp, client_nonce) = match init {
231        Message::SecureHandshakeInit {
232            pub_key,
233            timestamp,
234            nonce,
235        } => (pub_key, timestamp, nonce),
236        _ => {
237            return Err(ProtocolError::HandshakeError(
238                "Unexpected message type".to_string(),
239            ))
240        }
241    };
242
243    // --- Send Secure Handshake Response ---
244    let (server_state, response) =
245        server_secure_handshake_response(client_pub_key, client_nonce, client_timestamp)?;
246
247    let response_bytes =
248        bincode::serialize(&response).map_err(|e| ProtocolError::SerializeError(e.to_string()))?;
249
250    framed
251        .send(Packet {
252            version: 1,
253            payload: response_bytes,
254        })
255        .await
256        .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
257
258    // --- Expect Handshake Confirmation (with timeout) ---
259    let confirm = with_timeout_error(
260        async {
261            match framed.next().await {
262                Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
263                    .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
264                Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
265                None => Err(ProtocolError::ConnectionClosed),
266            }
267        },
268        config.connection_timeout,
269    )
270    .await?;
271
272    let nonce_verification = match confirm {
273        Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
274        _ => {
275            return Err(ProtocolError::HandshakeError(
276                "Expected handshake confirmation".to_string(),
277            ))
278        }
279    };
280
281    // --- Finalize Handshake and Derive Session Key ---
282    let session_key = server_secure_handshake_finalize(server_state, nonce_verification)?;
283
284    // Session state is zeroized automatically on drop
285
286    // Create secure connection with derived key
287    let conn = SecureConnection::new(framed, session_key);
288
289    // Handle the secure message loop
290    handle_secure_connection(conn, dispatcher, peer, heartbeat_interval).await?;
291
292    Ok(())
293}
294
295/// Register default message handlers
296#[instrument(skip(dispatcher))]
297fn register_default_handlers(dispatcher: &Arc<Dispatcher>) -> Result<()> {
298    // Handle Ping messages
299    dispatcher.register("PING", |_| {
300        debug!("Responding to ping with pong");
301        Ok(Message::Pong)
302    })?;
303
304    // Handle Echo messages
305    dispatcher.register("ECHO", |msg| {
306        if let Message::Echo(text) = msg {
307            debug!(text = %text, "Echoing message");
308            Ok(Message::Echo(text.clone()))
309        } else {
310            Err(ProtocolError::Custom(
311                "Invalid Echo message format".to_string(),
312            ))
313        }
314    })?;
315
316    Ok(())
317}
318
319/// Message type for the internal processing channel
320#[derive(Debug)]
321enum ProcessingMessage {
322    /// Regular message to be processed
323    Message(Message),
324    /// Signal to terminate the processing task
325    Terminate,
326}
327
328/// Response from the processing task
329#[derive(Debug)]
330struct ProcessingResult {
331    /// The original message ID or correlation ID
332    original_id: usize,
333    /// The response message to send back
334    response: Option<Message>,
335}
336
337/// Handle a secure connection after handshake with backpressure
338#[instrument(skip(conn, dispatcher, heartbeat_interval), fields(peer = %peer))]
339async fn handle_secure_connection(
340    mut conn: SecureConnection,
341    dispatcher: Arc<Dispatcher>,
342    peer: std::net::SocketAddr,
343    heartbeat_interval: Duration,
344) -> Result<()> {
345    // --- Initialize Keep-Alive Manager with configured interval ---
346    let dead_timeout = heartbeat_interval.mul_f32(4.0); // 4x the heartbeat interval for dead connection detection
347    let mut keep_alive = KeepAliveManager::with_settings(heartbeat_interval, dead_timeout);
348    let mut ping_interval = time::interval(keep_alive.ping_interval());
349
350    // --- Create bounded channels for backpressure with capacity from config ---
351    // We're using an internal messaging channel, so we can use a reasonable default here
352    let (msg_tx, msg_rx) = mpsc::channel::<ProcessingMessage>(32);
353    let (resp_tx, mut resp_rx) = mpsc::channel::<ProcessingResult>(32);
354
355    // --- Spawn message processing task ---
356    let dispatcher_clone = dispatcher.clone();
357    let processor_handle =
358        tokio::spawn(async move { process_messages(msg_rx, resp_tx, dispatcher_clone).await });
359
360    // --- Set up result for final status ---
361    let mut final_result = Ok(());
362    let mut next_msg_id: usize = 0;
363
364    // --- Secure Message Loop with Backpressure ---
365    'main: loop {
366        tokio::select! {
367            // Check if we need to send a ping
368            _ = ping_interval.tick() => {
369                if keep_alive.should_ping() {
370                    debug!("Sending keep-alive ping");
371                    let ping = build_ping();
372                    if let Err(e) = conn.secure_send(ping).await {
373                        warn!(error = %e, "Failed to send ping");
374                        final_result = Err(e);
375                        break 'main;
376                    }
377                    keep_alive.update_send();
378                }
379
380                // Check if connection is dead
381                if keep_alive.is_connection_dead() {
382                    warn!(dead_seconds = ?keep_alive.time_since_last_recv().as_secs(),
383                          "Connection appears dead, closing");
384                    final_result = Err(ProtocolError::ConnectionTimeout);
385                    break 'main;
386                }
387            }
388
389            // Process any responses from the processing task
390            Some(result) = resp_rx.recv() => {
391                if let Some(response) = result.response {
392                    debug!("Sending response for message {}", result.original_id);
393                    if let Err(e) = conn.secure_send(response).await {
394                        warn!(error = %e, "Failed to send response");
395                        final_result = Err(e);
396                        break 'main;
397                    }
398                    keep_alive.update_send();
399                }
400            }
401
402            // Try to receive a message with backpressure awareness
403            recv_result = conn.secure_recv::<Message>() => {
404                match recv_result {
405                    Ok(msg) => {
406                        debug!(message = ?msg, "Received message");
407                        keep_alive.update_recv();
408
409                        // Check for disconnect message - handle directly without channel
410                        if matches!(msg, Message::Disconnect) {
411                            info!("Received disconnect request");
412                            break 'main;
413                        }
414
415                        // Special handling for pong messages - handle directly
416                        if is_pong(&msg) {
417                            debug!("Received pong response");
418                            continue;
419                        }
420
421                        // Just increment the ID counter for the next message
422                        next_msg_id = next_msg_id.wrapping_add(1);
423
424                        // Apply backpressure if needed
425                        if msg_tx.capacity() == 0 {
426                            debug!("Channel full - applying backpressure");
427
428                            // Wait until the channel has capacity before receiving more messages
429                            match msg_tx.reserve().await {
430                                Ok(permit) => {
431                                    // Channel has capacity again, send the message
432                                    permit.send(ProcessingMessage::Message(msg));
433                                },
434                                Err(_) => {
435                                    // Channel was closed, exit the loop
436                                    warn!("Processing channel closed unexpectedly");
437                                    break 'main;
438                                }
439                            }
440                        } else {
441                            // Channel has capacity, send the message
442                            if (msg_tx.send(ProcessingMessage::Message(msg)).await).is_err() {
443                                // Channel was closed, exit the loop
444                                warn!("Failed to send message to processing channel");
445                                break 'main;
446                            }
447                        }
448                    }
449                    Err(ProtocolError::Timeout) => {
450                        // Timeout is expected, just continue the loop
451                        continue;
452                    }
453                    Err(e) => {
454                        final_result = Err(e);
455                        break 'main;
456                    }
457                }
458            }
459        }
460    }
461
462    // Signal the processor to terminate
463    debug!("Signaling processor to terminate");
464    let _ = msg_tx.send(ProcessingMessage::Terminate).await;
465
466    // Wait for processor to finish
467    debug!("Waiting for processor to terminate");
468    let _ = processor_handle.await;
469
470    final_result
471}
472
473/// Process messages from the channel
474#[instrument(skip(rx, resp_tx, dispatcher), level = "debug")]
475async fn process_messages(
476    mut rx: mpsc::Receiver<ProcessingMessage>,
477    resp_tx: mpsc::Sender<ProcessingResult>,
478    dispatcher: Arc<Dispatcher>,
479) {
480    let mut msg_counter: usize = 0;
481
482    while let Some(proc_msg) = rx.recv().await {
483        match proc_msg {
484            ProcessingMessage::Message(msg) => {
485                let msg_id = msg_counter;
486                msg_counter += 1;
487
488                debug!(msg_id = msg_id, message = ?msg, "Processing message from channel");
489
490                let response = match dispatcher.dispatch(&msg) {
491                    Ok(reply) => {
492                        // Successfully dispatched, prepare response
493                        Some(reply)
494                    }
495                    Err(e) => {
496                        // Log dispatch error but continue processing messages
497                        warn!(error = %e, "Error dispatching message");
498                        None
499                    }
500                };
501
502                // Send response back through the response channel
503                let result = ProcessingResult {
504                    original_id: msg_id,
505                    response,
506                };
507
508                if (resp_tx.send(result).await).is_err() {
509                    warn!("Failed to send processing result - reader likely disconnected");
510                    break;
511                }
512            }
513            ProcessingMessage::Terminate => {
514                debug!("Processor received terminate signal");
515                break;
516            }
517        }
518    }
519
520    debug!("Message processor terminated");
521}
522
523/// A server daemon handle that can be controlled externally
524#[derive(Debug)]
525pub struct Daemon {
526    /// Address the server is listening on
527    pub address: String,
528    /// Shutdown signal sender
529    shutdown_tx: Option<oneshot::Sender<()>>,
530}
531
532impl Daemon {
533    /// Create a new daemon handle
534    pub fn new(address: String, shutdown_tx: oneshot::Sender<()>) -> Self {
535        Self {
536            address,
537            shutdown_tx: Some(shutdown_tx),
538        }
539    }
540
541    /// Run the daemon until completion or shutdown signal
542    pub async fn run(self) -> Result<()> {
543        // This function doesn't actually do anything - the server is started in the start_* functions
544        // This is just a placeholder for API compatibility
545        Ok(())
546    }
547
548    /// Shutdown the daemon gracefully
549    pub async fn shutdown(&mut self) -> Result<()> {
550        if let Some(tx) = self.shutdown_tx.take() {
551            let _ = tx.send(());
552            Ok(())
553        } else {
554            Err(ProtocolError::Custom("Shutdown already called".to_string()))
555        }
556    }
557
558    /// Shutdown the daemon with a custom timeout
559    pub async fn shutdown_with_timeout(&mut self, _timeout: Duration) -> Result<()> {
560        // The timeout is handled internally in the server loop
561        self.shutdown().await
562    }
563}
564
565/// Start a server daemon with provided configuration and return a handle to it
566#[instrument(skip(config, _dispatcher), fields(address = %config.address))]
567pub async fn start_daemon_no_signals(
568    config: ServerConfig,
569    _dispatcher: Arc<Dispatcher>,
570) -> Result<Daemon> {
571    // Create a shutdown channel
572    let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
573
574    let address = config.address.clone();
575
576    // Start the server in a background task
577    tokio::spawn(async move {
578        if let Err(e) = start_with_config_and_shutdown(config, shutdown_rx).await {
579            error!(error = ?e, "Server error");
580        }
581    });
582
583    // Return a daemon handle
584    Ok(Daemon::new(address, shutdown_tx))
585}
586
587/// Create a new server daemon with configuration and dispatcher
588pub fn new_with_config(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Daemon {
589    let (shutdown_tx, _) = oneshot::channel::<()>();
590    Daemon::new(config.address.clone(), shutdown_tx)
591}