network_protocol/service/
tls_daemon.rs1use tokio_rustls::server::TlsStream;
2use tokio::net::{TcpListener, TcpStream};
3use tokio_util::codec::Framed;
4use futures::{StreamExt, SinkExt};
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::{mpsc, Mutex};
8use tracing::{debug, info, warn, error, instrument};
9
10use crate::core::codec::PacketCodec;
11use crate::core::packet::Packet;
12use crate::protocol::message::Message;
13use crate::protocol::dispatcher::Dispatcher;
14use crate::transport::tls::TlsServerConfig;
16use crate::error::Result;
17
18#[instrument(skip(tls_config))]
20pub async fn start(addr: &str, tls_config: TlsServerConfig) -> Result<()> {
21 let (_, shutdown_rx) = mpsc::channel::<()>(1);
23
24 start_with_shutdown(addr, tls_config, shutdown_rx).await
26}
27
28#[instrument(skip(tls_config, shutdown_rx))]
30pub async fn start_with_shutdown(addr: &str, tls_config: TlsServerConfig, mut shutdown_rx: mpsc::Receiver<()>) -> Result<()> {
31 let config = tls_config.load_server_config()?;
32 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
33
34 let listener = TcpListener::bind(addr).await?;
35 info!(address=%addr, "TLS daemon listening");
36
37 let dispatcher = Arc::new({
39 let d = Dispatcher::new();
40 let _ = d.register("PING", |_| Ok(Message::Pong));
41 let _ = d.register("ECHO", |msg| Ok(msg.clone()));
42 d
43 });
44
45 let active_connections = Arc::new(Mutex::new(0u32));
47
48 tokio::spawn(async move {
50 if let Ok(()) = tokio::signal::ctrl_c().await {
51 info!("Received shutdown signal, initiating graceful shutdown");
52 }
53 });
54
55 loop {
57 tokio::select! {
58 _ = shutdown_rx.recv() => {
60 info!("Shutting down server. Waiting for connections to close...");
61
62 let timeout = tokio::time::sleep(Duration::from_secs(10));
64 tokio::pin!(timeout);
65
66 loop {
67 tokio::select! {
68 _ = &mut timeout => {
69 warn!("Shutdown timeout reached, forcing exit");
70 break;
71 }
72 _ = tokio::time::sleep(Duration::from_millis(500)) => {
73 let connections = *active_connections.lock().await;
74 debug!(connections, "Waiting for connections to close");
75 if connections == 0 {
76 info!("All connections closed, shutting down");
77 break;
78 }
79 }
80 }
81 }
82
83 return Ok(());
84 }
85
86 accept_result = listener.accept() => {
88 match accept_result {
89 Ok((stream, peer)) => {
90 info!(%peer, "New connection accepted");
91 let dispatcher = dispatcher.clone();
92 let acceptor = acceptor.clone();
93 let active_connections = active_connections.clone();
94
95 {
97 let mut count = active_connections.lock().await;
98 *count += 1;
99 }
100
101 tokio::spawn(async move {
102 match acceptor.accept(stream).await {
103 Ok(tls_stream) => {
104 if let Err(e) = handle_tls_connection(tls_stream, dispatcher, peer, active_connections).await {
105 error!(%peer, error=%e, "Connection error");
106 }
107 },
108 Err(e) => {
109 error!(%peer, error=%e, "TLS handshake failed");
110 let mut count = active_connections.lock().await;
112 *count -= 1;
113 }
114 }
115 });
116 }
117 Err(e) => {
118 error!(error=%e, "Error accepting connection");
119 }
120 }
121 }
122 }
123 }
124}
125
126#[instrument(skip(tls_stream, dispatcher, active_connections), fields(peer=%peer))]
128async fn handle_tls_connection(
129 tls_stream: TlsStream<TcpStream>,
130 dispatcher: Arc<Dispatcher>,
131 peer: std::net::SocketAddr,
132 active_connections: Arc<Mutex<u32>>
133) -> Result<()> {
134 let mut framed = Framed::new(tls_stream, PacketCodec);
135
136 info!("TLS connection established");
137
138 loop {
143 let packet = match framed.next().await {
144 Some(Ok(pkt)) => pkt,
145 Some(Err(e)) => {
146 error!(error=%e, "Protocol error");
147 break;
148 },
149 None => break,
150 };
151
152 let msg = match bincode::deserialize::<Message>(&packet.payload) {
154 Ok(m) => m,
155 Err(e) => {
156 error!(error=%e, "Deserialization error");
157 continue;
158 }
159 };
160
161 debug!(message=?msg, "Received message");
162
163 match dispatcher.dispatch(&msg) {
165 Ok(reply) => {
166 let reply_bytes = match bincode::serialize(&reply) {
167 Ok(bytes) => bytes,
168 Err(e) => {
169 error!(error=%e, "Serialization error");
170 continue;
171 }
172 };
173
174 let reply_packet = Packet {
175 version: packet.version,
176 payload: reply_bytes,
177 };
178
179 if let Err(e) = framed.send(reply_packet).await {
180 error!(error=%e, "Send error");
181 break;
182 }
183 },
184 Err(e) => {
185 error!(error=%e, "Dispatch error");
186 break;
187 }
188 }
189 }
190
191 info!("Connection closed");
192
193 {
195 let mut count = active_connections.lock().await;
196 *count -= 1;
197 }
198
199 Ok(())
200}