network_protocol/service/
tls_daemon.rs

1use futures::{SinkExt, StreamExt};
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::net::{TcpListener, TcpStream};
5use tokio::sync::{mpsc, Mutex};
6use tokio_rustls::server::TlsStream;
7use tokio_util::codec::Framed;
8use tracing::{debug, error, info, instrument, warn};
9
10use crate::core::codec::PacketCodec;
11use crate::core::packet::Packet;
12use crate::protocol::dispatcher::Dispatcher;
13use crate::protocol::message::Message;
14// Secure connection not needed since TLS handles encryption
15use crate::error::Result;
16use crate::transport::tls::TlsServerConfig;
17
18/// Start a secure TLS server and listen for connections
19#[instrument(skip(tls_config))]
20pub async fn start(addr: &str, tls_config: TlsServerConfig) -> Result<()> {
21    // Create shutdown channel
22    let (_, shutdown_rx) = mpsc::channel::<()>(1);
23
24    // Start with internal shutdown channel
25    start_with_shutdown(addr, tls_config, shutdown_rx).await
26}
27
28/// Start a secure TLS server with an external shutdown channel
29#[instrument(skip(tls_config, shutdown_rx))]
30pub async fn start_with_shutdown(
31    addr: &str,
32    tls_config: TlsServerConfig,
33    mut shutdown_rx: mpsc::Receiver<()>,
34) -> Result<()> {
35    let config = tls_config.load_server_config()?;
36    let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
37
38    let listener = TcpListener::bind(addr).await?;
39    info!(address=%addr, "TLS daemon listening");
40
41    // 🔁 Shared dispatcher
42    let dispatcher = Arc::new({
43        let d = Dispatcher::new();
44        let _ = d.register("PING", |_| Ok(Message::Pong));
45        let _ = d.register("ECHO", |msg| Ok(msg.clone()));
46        d
47    });
48
49    // Track active connections
50    let active_connections = Arc::new(Mutex::new(0u32));
51
52    // Spawn ctrl-c handler to forward to the provided shutdown channel
53    tokio::spawn(async move {
54        if let Ok(()) = tokio::signal::ctrl_c().await {
55            info!("Received shutdown signal, initiating graceful shutdown");
56        }
57    });
58
59    // Server main loop with graceful shutdown
60    loop {
61        tokio::select! {
62            // Check for shutdown signal
63            _ = shutdown_rx.recv() => {
64                info!("Shutting down server. Waiting for connections to close...");
65
66                // Wait for active connections to close (with timeout)
67                let timeout = tokio::time::sleep(Duration::from_secs(10));
68                tokio::pin!(timeout);
69
70                loop {
71                    tokio::select! {
72                        _ = &mut timeout => {
73                            warn!("Shutdown timeout reached, forcing exit");
74                            break;
75                        }
76                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
77                            let connections = *active_connections.lock().await;
78                            debug!(connections, "Waiting for connections to close");
79                            if connections == 0 {
80                                info!("All connections closed, shutting down");
81                                break;
82                            }
83                        }
84                    }
85                }
86
87                return Ok(());
88            }
89
90            // Accept new connections
91            accept_result = listener.accept() => {
92                match accept_result {
93                    Ok((stream, peer)) => {
94                        info!(%peer, "New connection accepted");
95                        let dispatcher = dispatcher.clone();
96                        let acceptor = acceptor.clone();
97                        let active_connections = active_connections.clone();
98
99                        // Increment active connections counter
100                        {
101                            let mut count = active_connections.lock().await;
102                            *count += 1;
103                        }
104
105                        tokio::spawn(async move {
106                            match acceptor.accept(stream).await {
107                                Ok(tls_stream) => {
108                                    if let Err(e) = handle_tls_connection(tls_stream, dispatcher, peer, active_connections).await {
109                                        error!(%peer, error=%e, "Connection error");
110                                    }
111                                },
112                                Err(e) => {
113                                    error!(%peer, error=%e, "TLS handshake failed");
114                                    // Decrement connections on handshake failure
115                                    let mut count = active_connections.lock().await;
116                                    *count -= 1;
117                                }
118                            }
119                        });
120                    }
121                    Err(e) => {
122                        error!(error=%e, "Error accepting connection");
123                    }
124                }
125            }
126        }
127    }
128}
129
130/// Handle a TLS connection
131#[instrument(skip(tls_stream, dispatcher, active_connections), fields(peer=%peer))]
132async fn handle_tls_connection(
133    tls_stream: TlsStream<TcpStream>,
134    dispatcher: Arc<Dispatcher>,
135    peer: std::net::SocketAddr,
136    active_connections: Arc<Mutex<u32>>,
137) -> Result<()> {
138    let mut framed = Framed::new(tls_stream, PacketCodec);
139
140    info!("TLS connection established");
141
142    // Unlike regular daemon, we don't need a separate handshake
143    // TLS already provides the encryption layer
144
145    // Message loop
146    loop {
147        let packet = match framed.next().await {
148            Some(Ok(pkt)) => pkt,
149            Some(Err(e)) => {
150                error!(error=%e, "Protocol error");
151                break;
152            }
153            None => break,
154        };
155
156        // Deserialize the message
157        let msg = match bincode::deserialize::<Message>(&packet.payload) {
158            Ok(m) => m,
159            Err(e) => {
160                error!(error=%e, "Deserialization error");
161                continue;
162            }
163        };
164
165        debug!(message=?msg, "Received message");
166
167        // Process with dispatcher
168        match dispatcher.dispatch(&msg) {
169            Ok(reply) => {
170                let reply_bytes = match bincode::serialize(&reply) {
171                    Ok(bytes) => bytes,
172                    Err(e) => {
173                        error!(error=%e, "Serialization error");
174                        continue;
175                    }
176                };
177
178                let reply_packet = Packet {
179                    version: packet.version,
180                    payload: reply_bytes,
181                };
182
183                if let Err(e) = framed.send(reply_packet).await {
184                    error!(error=%e, "Send error");
185                    break;
186                }
187            }
188            Err(e) => {
189                error!(error=%e, "Dispatch error");
190                break;
191            }
192        }
193    }
194
195    info!("Connection closed");
196
197    // Decrement connection counter on disconnect
198    {
199        let mut count = active_connections.lock().await;
200        *count -= 1;
201    }
202
203    Ok(())
204}