Skip to main content

network_protocol/service/
daemon.rs

1use bincode;
2use futures::{SinkExt, StreamExt};
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tokio::time;
9use tokio_util::codec::Framed;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::config::ServerConfig;
13
14use crate::utils::timeout::with_timeout_error;
15
16use crate::core::codec::PacketCodec;
17use crate::core::packet::Packet;
18use crate::protocol::message::Message;
19// Import secure handshake functions
20use crate::error::{ProtocolError, Result};
21use crate::protocol::dispatcher::Dispatcher;
22use crate::protocol::handshake::{
23    server_secure_handshake_finalize, server_secure_handshake_response,
24};
25use crate::protocol::heartbeat::{build_ping, is_pong};
26use crate::protocol::keepalive::KeepAliveManager;
27use crate::service::secure::SecureConnection;
28use crate::utils::replay_cache::ReplayCache;
29
30/// Start a secure server and listen for connections using default configuration
31#[instrument(skip(addr), fields(address = %addr))]
32pub async fn start(addr: &str) -> Result<()> {
33    // Create a never-resolving shutdown receiver for standard operation
34    let (_, shutdown_rx) = oneshot::channel::<()>();
35    start_with_shutdown(addr, shutdown_rx).await
36}
37
38/// Start a secure server with custom configuration
39#[instrument(skip(config), fields(address = %config.address))]
40pub async fn start_with_config(config: ServerConfig) -> Result<()> {
41    // Create a never-resolving shutdown receiver for standard operation
42    let (_, shutdown_rx) = oneshot::channel::<()>();
43    start_with_config_and_shutdown(config, shutdown_rx).await
44}
45
46/// Start a secure server with shutdown control for testing
47#[instrument(skip(addr, shutdown_rx), fields(address = %addr))]
48pub async fn start_with_shutdown(addr: &str, shutdown_rx: oneshot::Receiver<()>) -> Result<()> {
49    // Use default configuration with overridden address
50    let config = ServerConfig {
51        address: addr.to_string(),
52        ..Default::default()
53    };
54    start_with_config_and_shutdown(config, shutdown_rx).await
55}
56
57/// Start a secure server with custom configuration and shutdown control
58#[instrument(skip(config, shutdown_rx), fields(address = %config.address))]
59pub async fn start_with_config_and_shutdown(
60    config: ServerConfig,
61    shutdown_rx: oneshot::Receiver<()>,
62) -> Result<()> {
63    let listener = TcpListener::bind(&config.address).await?;
64    info!(address = %config.address, "Server listening");
65
66    // Shared dispatcher
67    let dispatcher = Arc::new(Dispatcher::new());
68
69    // Register default handlers
70    register_default_handlers(&dispatcher)?;
71
72    // Track active connections for graceful shutdown
73    let active_connections = Arc::new(Mutex::new(0u32));
74
75    // Create a shutdown channel for internal use
76    let (internal_shutdown_tx, mut internal_shutdown_rx) = mpsc::channel::<()>(1);
77
78    // Extract configuration values we need before moving config
79    let shutdown_timeout = config.shutdown_timeout;
80    let heartbeat_interval = config.heartbeat_interval;
81
82    // Clone a sender for the task
83    let shutdown_tx_clone = internal_shutdown_tx.clone();
84    tokio::spawn(async move {
85        match tokio::signal::ctrl_c().await {
86            Ok(()) => {
87                info!("Shutdown signal received");
88                let _ = shutdown_tx_clone.send(()).await;
89            }
90            Err(err) => {
91                error!(error = %err, "Failed to listen for shutdown signal");
92            }
93        }
94    });
95
96    // Also set up the oneshot receiver to trigger shutdown
97    let internal_shutdown_tx_clone = internal_shutdown_tx.clone();
98    tokio::spawn(async move {
99        if shutdown_rx.await.is_ok() {
100            info!("External shutdown signal received");
101            let _ = internal_shutdown_tx_clone.send(()).await;
102        }
103    });
104
105    // Server main loop with graceful shutdown
106    loop {
107        tokio::select! {
108            // Check for shutdown signal
109            _ = internal_shutdown_rx.recv() => {
110                info!("Shutting down server. Waiting for connections to close...");
111
112                // Wait for active connections to close (with configured timeout)
113                let timeout = tokio::time::sleep(shutdown_timeout);
114                tokio::pin!(timeout);
115
116                loop {
117                    tokio::select! {
118                        _ = &mut timeout => {
119                            warn!("Shutdown timeout reached, forcing exit");
120                            break;
121                        }
122                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
123                            let connections = *active_connections.lock().await;
124                            info!(connections = %connections, "Waiting for connections to close");
125                            if connections == 0 {
126                                info!("All connections closed, shutting down");
127                                break;
128                            }
129                        }
130                    }
131                }
132
133                return Ok(());
134            }
135
136            // Accept new connections
137            accept_result = listener.accept() => {
138                match accept_result {
139                    Ok((stream, peer)) => {
140                        info!(peer = %peer, "New connection established");
141                        let dispatcher = dispatcher.clone();
142                        let active_connections = active_connections.clone();
143                        // We don't need this clone if we're not using it in this scope
144                        // let _shutdown_tx = shutdown_tx.clone();
145
146                        // Increment active connections counter
147                        {
148                            let mut count = active_connections.lock().await;
149                            *count += 1;
150                        }
151
152                        // Clone the things we need to move into the task
153                        let active_connections_clone = active_connections.clone();
154                        let config_clone = config.clone();
155
156                        tokio::spawn(async move {
157                            handle_connection(stream, peer, dispatcher, active_connections_clone, config_clone, heartbeat_interval).await;
158                        });
159                    }
160                    Err(e) => {
161                        error!(error = %e, "Error accepting connection");
162                    }
163                }
164            }
165        }
166    }
167}
168
169/// Handle a client connection with proper cleanup on exit
170#[instrument(skip(stream, dispatcher, active_connections, config, heartbeat_interval), fields(peer = %peer))]
171async fn handle_connection(
172    stream: tokio::net::TcpStream,
173    peer: std::net::SocketAddr,
174    dispatcher: Arc<Dispatcher>,
175    active_connections: Arc<Mutex<u32>>,
176    config: ServerConfig,
177    heartbeat_interval: Duration,
178) {
179    // Setup the connection with cleanup
180    let result = with_timeout_error(
181        async {
182            process_connection(stream, dispatcher, peer, config.clone(), heartbeat_interval).await
183        },
184        config.connection_timeout,
185    )
186    .await;
187
188    // If there was an error, log it
189    match result {
190        Ok(_) => info!("Connection closed gracefully"),
191        Err(ProtocolError::Timeout) => warn!("Connection timed out"),
192        Err(e) => error!(error = %e, "Connection error"),
193    }
194
195    // Always decrement active connections on exit
196    {
197        let mut count = active_connections.lock().await;
198        *count -= 1;
199    }
200
201    info!("Client disconnected");
202}
203
204/// Process a client connection with handshake and secure messages
205#[instrument(skip(stream, dispatcher, peer, config, heartbeat_interval), fields(peer = %peer))]
206async fn process_connection(
207    stream: TcpStream,
208    dispatcher: Arc<Dispatcher>,
209    peer: SocketAddr,
210    config: ServerConfig,
211    heartbeat_interval: Duration,
212) -> Result<()> {
213    // Create the framed stream for packet codec
214    let mut framed = Framed::new(stream, PacketCodec);
215
216    // --- Expect Secure Handshake Init (with timeout) ---
217    let init = with_timeout_error(
218        async {
219            match framed.next().await {
220                Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
221                    .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
222                Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
223                None => Err(ProtocolError::ConnectionClosed),
224            }
225        },
226        config.connection_timeout,
227    )
228    .await?;
229
230    // Extract the client's handshake init data
231    let (client_pub_key, client_timestamp, client_nonce) = match init {
232        Message::SecureHandshakeInit {
233            pub_key,
234            timestamp,
235            nonce,
236        } => (pub_key, timestamp, nonce),
237        _ => {
238            return Err(ProtocolError::HandshakeError(
239                "Unexpected message type".to_string(),
240            ))
241        }
242    };
243
244    // --- Send Secure Handshake Response ---
245    let mut replay_cache = ReplayCache::new();
246    let (server_state, response) = server_secure_handshake_response(
247        client_pub_key,
248        client_nonce,
249        client_timestamp,
250        &peer.to_string(),
251        &mut replay_cache,
252    )?;
253
254    let response_bytes =
255        bincode::serialize(&response).map_err(|e| ProtocolError::SerializeError(e.to_string()))?;
256
257    framed
258        .send(Packet {
259            version: 1,
260            payload: response_bytes,
261        })
262        .await
263        .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
264
265    // --- Expect Handshake Confirmation (with timeout) ---
266    let confirm = with_timeout_error(
267        async {
268            match framed.next().await {
269                Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
270                    .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
271                Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
272                None => Err(ProtocolError::ConnectionClosed),
273            }
274        },
275        config.connection_timeout,
276    )
277    .await?;
278
279    let nonce_verification = match confirm {
280        Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
281        _ => {
282            return Err(ProtocolError::HandshakeError(
283                "Expected handshake confirmation".to_string(),
284            ))
285        }
286    };
287
288    // --- Finalize Handshake and Derive Session Key ---
289    let session_key = server_secure_handshake_finalize(server_state, nonce_verification)?;
290
291    // Session state is zeroized automatically on drop
292
293    // Create secure connection with derived key
294    let conn = SecureConnection::new(framed, session_key);
295
296    // Handle the secure message loop
297    handle_secure_connection(conn, dispatcher, peer, heartbeat_interval).await?;
298
299    Ok(())
300}
301
302/// Register default message handlers
303#[instrument(skip(dispatcher))]
304fn register_default_handlers(dispatcher: &Arc<Dispatcher>) -> Result<()> {
305    // Handle Ping messages
306    dispatcher.register("PING", |_| {
307        debug!("Responding to ping with pong");
308        Ok(Message::Pong)
309    })?;
310
311    // Handle Echo messages
312    dispatcher.register("ECHO", |msg| {
313        if let Message::Echo(text) = msg {
314            debug!(text = %text, "Echoing message");
315            Ok(Message::Echo(text.clone()))
316        } else {
317            Err(ProtocolError::Custom(
318                "Invalid Echo message format".to_string(),
319            ))
320        }
321    })?;
322
323    Ok(())
324}
325
326/// Message type for the internal processing channel
327#[derive(Debug)]
328enum ProcessingMessage {
329    /// Regular message to be processed
330    Message(Message),
331    /// Signal to terminate the processing task
332    Terminate,
333}
334
335/// Response from the processing task
336#[derive(Debug)]
337struct ProcessingResult {
338    /// The original message ID or correlation ID
339    original_id: usize,
340    /// The response message to send back
341    response: Option<Message>,
342}
343
344/// Handle a secure connection after handshake with backpressure
345#[instrument(skip(conn, dispatcher, heartbeat_interval), fields(peer = %peer))]
346async fn handle_secure_connection(
347    mut conn: SecureConnection,
348    dispatcher: Arc<Dispatcher>,
349    peer: std::net::SocketAddr,
350    heartbeat_interval: Duration,
351) -> Result<()> {
352    // --- Initialize Keep-Alive Manager with configured interval ---
353    let dead_timeout = heartbeat_interval.mul_f32(4.0); // 4x the heartbeat interval for dead connection detection
354    let mut keep_alive = KeepAliveManager::with_settings(heartbeat_interval, dead_timeout);
355    let mut ping_interval = time::interval(keep_alive.ping_interval());
356
357    // --- Create bounded channels for backpressure with capacity from config ---
358    // We're using an internal messaging channel, so we can use a reasonable default here
359    let (msg_tx, msg_rx) = mpsc::channel::<ProcessingMessage>(32);
360    let (resp_tx, mut resp_rx) = mpsc::channel::<ProcessingResult>(32);
361
362    // --- Spawn message processing task ---
363    let dispatcher_clone = dispatcher.clone();
364    let processor_handle =
365        tokio::spawn(async move { process_messages(msg_rx, resp_tx, dispatcher_clone).await });
366
367    // --- Set up result for final status ---
368    let mut final_result = Ok(());
369    let mut next_msg_id: usize = 0;
370
371    // --- Secure Message Loop with Backpressure ---
372    'main: loop {
373        tokio::select! {
374            // Check if we need to send a ping
375            _ = ping_interval.tick() => {
376                if keep_alive.should_ping() {
377                    debug!("Sending keep-alive ping");
378                    let ping = build_ping();
379                    if let Err(e) = conn.secure_send(ping).await {
380                        warn!(error = %e, "Failed to send ping");
381                        final_result = Err(e);
382                        break 'main;
383                    }
384                    keep_alive.update_send();
385                }
386
387                // Check if connection is dead
388                if keep_alive.is_connection_dead() {
389                    warn!(dead_seconds = ?keep_alive.time_since_last_recv().as_secs(),
390                          "Connection appears dead, closing");
391                    final_result = Err(ProtocolError::ConnectionTimeout);
392                    break 'main;
393                }
394            }
395
396            // Process any responses from the processing task
397            Some(result) = resp_rx.recv() => {
398                if let Some(response) = result.response {
399                    debug!("Sending response for message {}", result.original_id);
400                    if let Err(e) = conn.secure_send(response).await {
401                        warn!(error = %e, "Failed to send response");
402                        final_result = Err(e);
403                        break 'main;
404                    }
405                    keep_alive.update_send();
406                }
407            }
408
409            // Try to receive a message with backpressure awareness
410            recv_result = conn.secure_recv::<Message>() => {
411                match recv_result {
412                    Ok(msg) => {
413                        debug!(message = ?msg, "Received message");
414                        keep_alive.update_recv();
415
416                        // Check for disconnect message - handle directly without channel
417                        if matches!(msg, Message::Disconnect) {
418                            info!("Received disconnect request");
419                            break 'main;
420                        }
421
422                        // Special handling for pong messages - handle directly
423                        if is_pong(&msg) {
424                            debug!("Received pong response");
425                            continue;
426                        }
427
428                        // Just increment the ID counter for the next message
429                        next_msg_id = next_msg_id.wrapping_add(1);
430
431                        // Apply backpressure if needed
432                        if msg_tx.capacity() == 0 {
433                            debug!("Channel full - applying backpressure");
434
435                            // Wait until the channel has capacity before receiving more messages
436                            match msg_tx.reserve().await {
437                                Ok(permit) => {
438                                    // Channel has capacity again, send the message
439                                    permit.send(ProcessingMessage::Message(msg));
440                                },
441                                Err(_) => {
442                                    // Channel was closed, exit the loop
443                                    warn!("Processing channel closed unexpectedly");
444                                    break 'main;
445                                }
446                            }
447                        } else {
448                            // Channel has capacity, send the message
449                            if (msg_tx.send(ProcessingMessage::Message(msg)).await).is_err() {
450                                // Channel was closed, exit the loop
451                                warn!("Failed to send message to processing channel");
452                                break 'main;
453                            }
454                        }
455                    }
456                    Err(ProtocolError::Timeout) => {
457                        // Timeout is expected, just continue the loop
458                        continue;
459                    }
460                    Err(e) => {
461                        final_result = Err(e);
462                        break 'main;
463                    }
464                }
465            }
466        }
467    }
468
469    // Signal the processor to terminate
470    debug!("Signaling processor to terminate");
471    let _ = msg_tx.send(ProcessingMessage::Terminate).await;
472
473    // Wait for processor to finish
474    debug!("Waiting for processor to terminate");
475    let _ = processor_handle.await;
476
477    final_result
478}
479
480/// Process messages from the channel
481#[instrument(skip(rx, resp_tx, dispatcher), level = "debug")]
482async fn process_messages(
483    mut rx: mpsc::Receiver<ProcessingMessage>,
484    resp_tx: mpsc::Sender<ProcessingResult>,
485    dispatcher: Arc<Dispatcher>,
486) {
487    let mut msg_counter: usize = 0;
488
489    while let Some(proc_msg) = rx.recv().await {
490        match proc_msg {
491            ProcessingMessage::Message(msg) => {
492                let msg_id = msg_counter;
493                msg_counter += 1;
494
495                debug!(msg_id = msg_id, message = ?msg, "Processing message from channel");
496
497                let response = match dispatcher.dispatch(&msg) {
498                    Ok(reply) => {
499                        // Successfully dispatched, prepare response
500                        Some(reply)
501                    }
502                    Err(e) => {
503                        // Log dispatch error but continue processing messages
504                        warn!(error = %e, "Error dispatching message");
505                        None
506                    }
507                };
508
509                // Send response back through the response channel
510                let result = ProcessingResult {
511                    original_id: msg_id,
512                    response,
513                };
514
515                if (resp_tx.send(result).await).is_err() {
516                    warn!("Failed to send processing result - reader likely disconnected");
517                    break;
518                }
519            }
520            ProcessingMessage::Terminate => {
521                debug!("Processor received terminate signal");
522                break;
523            }
524        }
525    }
526
527    debug!("Message processor terminated");
528}
529
530/// A server daemon handle that can be controlled externally
531#[derive(Debug)]
532pub struct Daemon {
533    /// Address the server is listening on
534    pub address: String,
535    /// Shutdown signal sender
536    shutdown_tx: Option<oneshot::Sender<()>>,
537}
538
539impl Daemon {
540    /// Create a new daemon handle
541    pub fn new(address: String, shutdown_tx: oneshot::Sender<()>) -> Self {
542        Self {
543            address,
544            shutdown_tx: Some(shutdown_tx),
545        }
546    }
547
548    /// Run the daemon until completion or shutdown signal
549    pub async fn run(self) -> Result<()> {
550        // This function doesn't actually do anything - the server is started in the start_* functions
551        // This is just a placeholder for API compatibility
552        Ok(())
553    }
554
555    /// Shutdown the daemon gracefully
556    pub async fn shutdown(&mut self) -> Result<()> {
557        if let Some(tx) = self.shutdown_tx.take() {
558            let _ = tx.send(());
559            Ok(())
560        } else {
561            Err(ProtocolError::Custom("Shutdown already called".to_string()))
562        }
563    }
564
565    /// Shutdown the daemon with a custom timeout
566    pub async fn shutdown_with_timeout(&mut self, _timeout: Duration) -> Result<()> {
567        // The timeout is handled internally in the server loop
568        self.shutdown().await
569    }
570}
571
572/// Start a server daemon with provided configuration and return a handle to it
573#[instrument(skip(config, _dispatcher), fields(address = %config.address))]
574pub async fn start_daemon_no_signals(
575    config: ServerConfig,
576    _dispatcher: Arc<Dispatcher>,
577) -> Result<Daemon> {
578    // Create a shutdown channel
579    let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
580
581    let address = config.address.clone();
582
583    // Start the server in a background task
584    tokio::spawn(async move {
585        if let Err(e) = start_with_config_and_shutdown(config, shutdown_rx).await {
586            error!(error = ?e, "Server error");
587        }
588    });
589
590    // Return a daemon handle
591    Ok(Daemon::new(address, shutdown_tx))
592}
593
594/// Create a new server daemon with configuration and dispatcher
595pub fn new_with_config(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Daemon {
596    let (shutdown_tx, _) = oneshot::channel::<()>();
597    Daemon::new(config.address.clone(), shutdown_tx)
598}