network_protocol/service/
tls_daemon.rs

1use tokio_rustls::server::TlsStream;
2use tokio::net::{TcpListener, TcpStream};
3use tokio_util::codec::Framed;
4use futures::{StreamExt, SinkExt};
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::{mpsc, Mutex};
8use tracing::{debug, info, warn, error, instrument};
9
10use crate::core::codec::PacketCodec;
11use crate::core::packet::Packet;
12use crate::protocol::message::Message;
13use crate::protocol::dispatcher::Dispatcher;
14// Secure connection not needed since TLS handles encryption
15use crate::transport::tls::TlsServerConfig;
16use crate::error::Result;
17
18/// Start a secure TLS server and listen for connections
19#[instrument(skip(tls_config))]  
20pub async fn start(addr: &str, tls_config: TlsServerConfig) -> Result<()> {
21    // Create shutdown channel
22    let (_, shutdown_rx) = mpsc::channel::<()>(1);
23    
24    // Start with internal shutdown channel
25    start_with_shutdown(addr, tls_config, shutdown_rx).await
26}
27
28/// Start a secure TLS server with an external shutdown channel
29#[instrument(skip(tls_config, shutdown_rx))]  
30pub async fn start_with_shutdown(addr: &str, tls_config: TlsServerConfig, mut shutdown_rx: mpsc::Receiver<()>) -> Result<()> {
31    let config = tls_config.load_server_config()?;
32    let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
33    
34    let listener = TcpListener::bind(addr).await?;
35    info!(address=%addr, "TLS daemon listening");
36
37    // 🔁 Shared dispatcher
38    let dispatcher = Arc::new({
39        let d = Dispatcher::new();
40        let _ = d.register("PING", |_| Ok(Message::Pong));
41        let _ = d.register("ECHO", |msg| Ok(msg.clone()));
42        d
43    });
44    
45    // Track active connections
46    let active_connections = Arc::new(Mutex::new(0u32));
47    
48    // Spawn ctrl-c handler to forward to the provided shutdown channel
49    tokio::spawn(async move {
50        if let Ok(()) = tokio::signal::ctrl_c().await {
51            info!("Received shutdown signal, initiating graceful shutdown");
52        }
53    });
54    
55    // Server main loop with graceful shutdown
56    loop {
57        tokio::select! {
58            // Check for shutdown signal
59            _ = shutdown_rx.recv() => {     
60                info!("Shutting down server. Waiting for connections to close...");
61                
62                // Wait for active connections to close (with timeout)
63                let timeout = tokio::time::sleep(Duration::from_secs(10));
64                tokio::pin!(timeout);
65                
66                loop {
67                    tokio::select! {
68                        _ = &mut timeout => {
69                            warn!("Shutdown timeout reached, forcing exit");
70                            break;
71                        }
72                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
73                            let connections = *active_connections.lock().await;
74                            debug!(connections, "Waiting for connections to close");
75                            if connections == 0 {
76                                info!("All connections closed, shutting down");
77                                break;
78                            }
79                        }
80                    }
81                }
82                
83                return Ok(());
84            }
85            
86            // Accept new connections
87            accept_result = listener.accept() => {
88                match accept_result {
89                    Ok((stream, peer)) => {
90                        info!(%peer, "New connection accepted");
91                        let dispatcher = dispatcher.clone();
92                        let acceptor = acceptor.clone();
93                        let active_connections = active_connections.clone();
94                        
95                        // Increment active connections counter
96                        {
97                            let mut count = active_connections.lock().await;
98                            *count += 1;
99                        }
100                        
101                        tokio::spawn(async move {
102                            match acceptor.accept(stream).await {
103                                Ok(tls_stream) => {
104                                    if let Err(e) = handle_tls_connection(tls_stream, dispatcher, peer, active_connections).await {
105                                        error!(%peer, error=%e, "Connection error");
106                                    }
107                                },
108                                Err(e) => {
109                                    error!(%peer, error=%e, "TLS handshake failed");
110                                    // Decrement connections on handshake failure
111                                    let mut count = active_connections.lock().await;
112                                    *count -= 1;
113                                }
114                            }
115                        });
116                    }
117                    Err(e) => {
118                        error!(error=%e, "Error accepting connection");
119                    }
120                }
121            }
122        }
123    }
124}
125
126/// Handle a TLS connection
127#[instrument(skip(tls_stream, dispatcher, active_connections), fields(peer=%peer))]
128async fn handle_tls_connection(
129    tls_stream: TlsStream<TcpStream>,
130    dispatcher: Arc<Dispatcher>,
131    peer: std::net::SocketAddr,
132    active_connections: Arc<Mutex<u32>>
133) -> Result<()> {
134    let mut framed = Framed::new(tls_stream, PacketCodec);
135    
136    info!("TLS connection established");
137    
138    // Unlike regular daemon, we don't need a separate handshake
139    // TLS already provides the encryption layer
140    
141    // Message loop
142    loop {
143        let packet = match framed.next().await {
144            Some(Ok(pkt)) => pkt,
145            Some(Err(e)) => {
146                error!(error=%e, "Protocol error");
147                break;
148            },
149            None => break,
150        };
151        
152        // Deserialize the message
153        let msg = match bincode::deserialize::<Message>(&packet.payload) {
154            Ok(m) => m,
155            Err(e) => {
156                error!(error=%e, "Deserialization error");
157                continue;
158            }
159        };
160        
161        debug!(message=?msg, "Received message");
162        
163        // Process with dispatcher
164        match dispatcher.dispatch(&msg) {
165            Ok(reply) => {
166                let reply_bytes = match bincode::serialize(&reply) {
167                    Ok(bytes) => bytes,
168                    Err(e) => {
169                        error!(error=%e, "Serialization error");
170                        continue;
171                    }
172                };
173                
174                let reply_packet = Packet {
175                    version: packet.version,
176                    payload: reply_bytes,
177                };
178                
179                if let Err(e) = framed.send(reply_packet).await {
180                    error!(error=%e, "Send error");
181                    break;
182                }
183            },
184            Err(e) => {
185                error!(error=%e, "Dispatch error");
186                break;
187            }
188        }
189    }
190    
191    info!("Connection closed");
192    
193    // Decrement connection counter on disconnect
194    {
195        let mut count = active_connections.lock().await;
196        *count -= 1;
197    }
198    
199    Ok(())
200}