network_protocol/service/
tls_daemon.rs1use futures::{SinkExt, StreamExt};
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::net::{TcpListener, TcpStream};
5use tokio::sync::{mpsc, Mutex};
6use tokio_rustls::server::TlsStream;
7use tokio_util::codec::Framed;
8use tracing::{debug, error, info, instrument, warn};
9
10use crate::core::codec::PacketCodec;
11use crate::core::packet::Packet;
12use crate::protocol::dispatcher::Dispatcher;
13use crate::protocol::message::Message;
14use crate::error::Result;
16use crate::transport::tls::TlsServerConfig;
17
18#[instrument(skip(tls_config))]
20pub async fn start(addr: &str, tls_config: TlsServerConfig) -> Result<()> {
21 let (_, shutdown_rx) = mpsc::channel::<()>(1);
23
24 start_with_shutdown(addr, tls_config, shutdown_rx).await
26}
27
28#[instrument(skip(tls_config, shutdown_rx))]
30pub async fn start_with_shutdown(
31 addr: &str,
32 tls_config: TlsServerConfig,
33 mut shutdown_rx: mpsc::Receiver<()>,
34) -> Result<()> {
35 let config = tls_config.load_server_config()?;
36 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
37
38 let listener = TcpListener::bind(addr).await?;
39 info!(address=%addr, "TLS daemon listening");
40
41 let dispatcher = Arc::new({
43 let d = Dispatcher::new();
44 let _ = d.register("PING", |_| Ok(Message::Pong));
45 let _ = d.register("ECHO", |msg| Ok(msg.clone()));
46 d
47 });
48
49 let active_connections = Arc::new(Mutex::new(0u32));
51
52 tokio::spawn(async move {
54 if let Ok(()) = tokio::signal::ctrl_c().await {
55 info!("Received shutdown signal, initiating graceful shutdown");
56 }
57 });
58
59 loop {
61 tokio::select! {
62 _ = shutdown_rx.recv() => {
64 info!("Shutting down server. Waiting for connections to close...");
65
66 let timeout = tokio::time::sleep(Duration::from_secs(10));
68 tokio::pin!(timeout);
69
70 loop {
71 tokio::select! {
72 _ = &mut timeout => {
73 warn!("Shutdown timeout reached, forcing exit");
74 break;
75 }
76 _ = tokio::time::sleep(Duration::from_millis(500)) => {
77 let connections = *active_connections.lock().await;
78 debug!(connections, "Waiting for connections to close");
79 if connections == 0 {
80 info!("All connections closed, shutting down");
81 break;
82 }
83 }
84 }
85 }
86
87 return Ok(());
88 }
89
90 accept_result = listener.accept() => {
92 match accept_result {
93 Ok((stream, peer)) => {
94 info!(%peer, "New connection accepted");
95 let dispatcher = dispatcher.clone();
96 let acceptor = acceptor.clone();
97 let active_connections = active_connections.clone();
98
99 {
101 let mut count = active_connections.lock().await;
102 *count += 1;
103 }
104
105 tokio::spawn(async move {
106 match acceptor.accept(stream).await {
107 Ok(tls_stream) => {
108 if let Err(e) = handle_tls_connection(tls_stream, dispatcher, peer, active_connections).await {
109 error!(%peer, error=%e, "Connection error");
110 }
111 },
112 Err(e) => {
113 error!(%peer, error=%e, "TLS handshake failed");
114 let mut count = active_connections.lock().await;
116 *count -= 1;
117 }
118 }
119 });
120 }
121 Err(e) => {
122 error!(error=%e, "Error accepting connection");
123 }
124 }
125 }
126 }
127 }
128}
129
130#[instrument(skip(tls_stream, dispatcher, active_connections), fields(peer=%peer))]
132async fn handle_tls_connection(
133 tls_stream: TlsStream<TcpStream>,
134 dispatcher: Arc<Dispatcher>,
135 peer: std::net::SocketAddr,
136 active_connections: Arc<Mutex<u32>>,
137) -> Result<()> {
138 let mut framed = Framed::new(tls_stream, PacketCodec);
139
140 info!("TLS connection established");
141
142 loop {
147 let packet = match framed.next().await {
148 Some(Ok(pkt)) => pkt,
149 Some(Err(e)) => {
150 error!(error=%e, "Protocol error");
151 break;
152 }
153 None => break,
154 };
155
156 let msg = match bincode::deserialize::<Message>(&packet.payload) {
158 Ok(m) => m,
159 Err(e) => {
160 error!(error=%e, "Deserialization error");
161 continue;
162 }
163 };
164
165 debug!(message=?msg, "Received message");
166
167 match dispatcher.dispatch(&msg) {
169 Ok(reply) => {
170 let reply_bytes = match bincode::serialize(&reply) {
171 Ok(bytes) => bytes,
172 Err(e) => {
173 error!(error=%e, "Serialization error");
174 continue;
175 }
176 };
177
178 let reply_packet = Packet {
179 version: packet.version,
180 payload: reply_bytes,
181 };
182
183 if let Err(e) = framed.send(reply_packet).await {
184 error!(error=%e, "Send error");
185 break;
186 }
187 }
188 Err(e) => {
189 error!(error=%e, "Dispatch error");
190 break;
191 }
192 }
193 }
194
195 info!("Connection closed");
196
197 {
199 let mut count = active_connections.lock().await;
200 *count -= 1;
201 }
202
203 Ok(())
204}