network_protocol/service/
daemon.rs1use bincode;
2use futures::{SinkExt, StreamExt};
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tokio::time;
9use tokio_util::codec::Framed;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::config::ServerConfig;
13
14use crate::utils::timeout::with_timeout_error;
15
16use crate::core::codec::PacketCodec;
17use crate::core::packet::Packet;
18use crate::protocol::message::Message;
19use crate::error::{ProtocolError, Result};
21use crate::protocol::dispatcher::Dispatcher;
22use crate::protocol::handshake::{
23 server_secure_handshake_finalize, server_secure_handshake_response,
24};
25use crate::protocol::heartbeat::{build_ping, is_pong};
26use crate::protocol::keepalive::KeepAliveManager;
27use crate::service::secure::SecureConnection;
28use crate::utils::replay_cache::ReplayCache;
29
30#[instrument(skip(addr), fields(address = %addr))]
32pub async fn start(addr: &str) -> Result<()> {
33 let (_, shutdown_rx) = oneshot::channel::<()>();
35 start_with_shutdown(addr, shutdown_rx).await
36}
37
38#[instrument(skip(config), fields(address = %config.address))]
40pub async fn start_with_config(config: ServerConfig) -> Result<()> {
41 let (_, shutdown_rx) = oneshot::channel::<()>();
43 start_with_config_and_shutdown(config, shutdown_rx).await
44}
45
46#[instrument(skip(addr, shutdown_rx), fields(address = %addr))]
48pub async fn start_with_shutdown(addr: &str, shutdown_rx: oneshot::Receiver<()>) -> Result<()> {
49 let config = ServerConfig {
51 address: addr.to_string(),
52 ..Default::default()
53 };
54 start_with_config_and_shutdown(config, shutdown_rx).await
55}
56
57#[instrument(skip(config, shutdown_rx), fields(address = %config.address))]
59pub async fn start_with_config_and_shutdown(
60 config: ServerConfig,
61 shutdown_rx: oneshot::Receiver<()>,
62) -> Result<()> {
63 let listener = TcpListener::bind(&config.address).await?;
64 info!(address = %config.address, "Server listening");
65
66 let dispatcher = Arc::new(Dispatcher::new());
68
69 register_default_handlers(&dispatcher)?;
71
72 let active_connections = Arc::new(Mutex::new(0u32));
74
75 let (internal_shutdown_tx, mut internal_shutdown_rx) = mpsc::channel::<()>(1);
77
78 let shutdown_timeout = config.shutdown_timeout;
80 let heartbeat_interval = config.heartbeat_interval;
81
82 let shutdown_tx_clone = internal_shutdown_tx.clone();
84 tokio::spawn(async move {
85 match tokio::signal::ctrl_c().await {
86 Ok(()) => {
87 info!("Shutdown signal received");
88 let _ = shutdown_tx_clone.send(()).await;
89 }
90 Err(err) => {
91 error!(error = %err, "Failed to listen for shutdown signal");
92 }
93 }
94 });
95
96 let internal_shutdown_tx_clone = internal_shutdown_tx.clone();
98 tokio::spawn(async move {
99 if shutdown_rx.await.is_ok() {
100 info!("External shutdown signal received");
101 let _ = internal_shutdown_tx_clone.send(()).await;
102 }
103 });
104
105 loop {
107 tokio::select! {
108 _ = internal_shutdown_rx.recv() => {
110 info!("Shutting down server. Waiting for connections to close...");
111
112 let timeout = tokio::time::sleep(shutdown_timeout);
114 tokio::pin!(timeout);
115
116 loop {
117 tokio::select! {
118 _ = &mut timeout => {
119 warn!("Shutdown timeout reached, forcing exit");
120 break;
121 }
122 _ = tokio::time::sleep(Duration::from_millis(500)) => {
123 let connections = *active_connections.lock().await;
124 info!(connections = %connections, "Waiting for connections to close");
125 if connections == 0 {
126 info!("All connections closed, shutting down");
127 break;
128 }
129 }
130 }
131 }
132
133 return Ok(());
134 }
135
136 accept_result = listener.accept() => {
138 match accept_result {
139 Ok((stream, peer)) => {
140 info!(peer = %peer, "New connection established");
141 let dispatcher = dispatcher.clone();
142 let active_connections = active_connections.clone();
143 {
148 let mut count = active_connections.lock().await;
149 *count += 1;
150 }
151
152 let active_connections_clone = active_connections.clone();
154 let config_clone = config.clone();
155
156 tokio::spawn(async move {
157 handle_connection(stream, peer, dispatcher, active_connections_clone, config_clone, heartbeat_interval).await;
158 });
159 }
160 Err(e) => {
161 error!(error = %e, "Error accepting connection");
162 }
163 }
164 }
165 }
166 }
167}
168
169#[instrument(skip(stream, dispatcher, active_connections, config, heartbeat_interval), fields(peer = %peer))]
171async fn handle_connection(
172 stream: tokio::net::TcpStream,
173 peer: std::net::SocketAddr,
174 dispatcher: Arc<Dispatcher>,
175 active_connections: Arc<Mutex<u32>>,
176 config: ServerConfig,
177 heartbeat_interval: Duration,
178) {
179 let result = with_timeout_error(
181 async {
182 process_connection(stream, dispatcher, peer, config.clone(), heartbeat_interval).await
183 },
184 config.connection_timeout,
185 )
186 .await;
187
188 match result {
190 Ok(_) => info!("Connection closed gracefully"),
191 Err(ProtocolError::Timeout) => warn!("Connection timed out"),
192 Err(e) => error!(error = %e, "Connection error"),
193 }
194
195 {
197 let mut count = active_connections.lock().await;
198 *count -= 1;
199 }
200
201 info!("Client disconnected");
202}
203
204#[instrument(skip(stream, dispatcher, peer, config, heartbeat_interval), fields(peer = %peer))]
206async fn process_connection(
207 stream: TcpStream,
208 dispatcher: Arc<Dispatcher>,
209 peer: SocketAddr,
210 config: ServerConfig,
211 heartbeat_interval: Duration,
212) -> Result<()> {
213 let mut framed = Framed::new(stream, PacketCodec);
215
216 let init = with_timeout_error(
218 async {
219 match framed.next().await {
220 Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
221 .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
222 Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
223 None => Err(ProtocolError::ConnectionClosed),
224 }
225 },
226 config.connection_timeout,
227 )
228 .await?;
229
230 let (client_pub_key, client_timestamp, client_nonce) = match init {
232 Message::SecureHandshakeInit {
233 pub_key,
234 timestamp,
235 nonce,
236 } => (pub_key, timestamp, nonce),
237 _ => {
238 return Err(ProtocolError::HandshakeError(
239 "Unexpected message type".to_string(),
240 ))
241 }
242 };
243
244 let mut replay_cache = ReplayCache::new();
246 let (server_state, response) = server_secure_handshake_response(
247 client_pub_key,
248 client_nonce,
249 client_timestamp,
250 &peer.to_string(),
251 &mut replay_cache,
252 )?;
253
254 let response_bytes =
255 bincode::serialize(&response).map_err(|e| ProtocolError::SerializeError(e.to_string()))?;
256
257 framed
258 .send(Packet {
259 version: 1,
260 payload: response_bytes,
261 })
262 .await
263 .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
264
265 let confirm = with_timeout_error(
267 async {
268 match framed.next().await {
269 Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
270 .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
271 Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
272 None => Err(ProtocolError::ConnectionClosed),
273 }
274 },
275 config.connection_timeout,
276 )
277 .await?;
278
279 let nonce_verification = match confirm {
280 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
281 _ => {
282 return Err(ProtocolError::HandshakeError(
283 "Expected handshake confirmation".to_string(),
284 ))
285 }
286 };
287
288 let session_key = server_secure_handshake_finalize(server_state, nonce_verification)?;
290
291 let conn = SecureConnection::new(framed, session_key);
295
296 handle_secure_connection(conn, dispatcher, peer, heartbeat_interval).await?;
298
299 Ok(())
300}
301
302#[instrument(skip(dispatcher))]
304fn register_default_handlers(dispatcher: &Arc<Dispatcher>) -> Result<()> {
305 dispatcher.register("PING", |_| {
307 debug!("Responding to ping with pong");
308 Ok(Message::Pong)
309 })?;
310
311 dispatcher.register("ECHO", |msg| {
313 if let Message::Echo(text) = msg {
314 debug!(text = %text, "Echoing message");
315 Ok(Message::Echo(text.clone()))
316 } else {
317 Err(ProtocolError::Custom(
318 "Invalid Echo message format".to_string(),
319 ))
320 }
321 })?;
322
323 Ok(())
324}
325
326#[derive(Debug)]
328enum ProcessingMessage {
329 Message(Message),
331 Terminate,
333}
334
335#[derive(Debug)]
337struct ProcessingResult {
338 original_id: usize,
340 response: Option<Message>,
342}
343
344#[instrument(skip(conn, dispatcher, heartbeat_interval), fields(peer = %peer))]
346async fn handle_secure_connection(
347 mut conn: SecureConnection,
348 dispatcher: Arc<Dispatcher>,
349 peer: std::net::SocketAddr,
350 heartbeat_interval: Duration,
351) -> Result<()> {
352 let dead_timeout = heartbeat_interval.mul_f32(4.0); let mut keep_alive = KeepAliveManager::with_settings(heartbeat_interval, dead_timeout);
355 let mut ping_interval = time::interval(keep_alive.ping_interval());
356
357 let (msg_tx, msg_rx) = mpsc::channel::<ProcessingMessage>(32);
360 let (resp_tx, mut resp_rx) = mpsc::channel::<ProcessingResult>(32);
361
362 let dispatcher_clone = dispatcher.clone();
364 let processor_handle =
365 tokio::spawn(async move { process_messages(msg_rx, resp_tx, dispatcher_clone).await });
366
367 let mut final_result = Ok(());
369 let mut next_msg_id: usize = 0;
370
371 'main: loop {
373 tokio::select! {
374 _ = ping_interval.tick() => {
376 if keep_alive.should_ping() {
377 debug!("Sending keep-alive ping");
378 let ping = build_ping();
379 if let Err(e) = conn.secure_send(ping).await {
380 warn!(error = %e, "Failed to send ping");
381 final_result = Err(e);
382 break 'main;
383 }
384 keep_alive.update_send();
385 }
386
387 if keep_alive.is_connection_dead() {
389 warn!(dead_seconds = ?keep_alive.time_since_last_recv().as_secs(),
390 "Connection appears dead, closing");
391 final_result = Err(ProtocolError::ConnectionTimeout);
392 break 'main;
393 }
394 }
395
396 Some(result) = resp_rx.recv() => {
398 if let Some(response) = result.response {
399 debug!("Sending response for message {}", result.original_id);
400 if let Err(e) = conn.secure_send(response).await {
401 warn!(error = %e, "Failed to send response");
402 final_result = Err(e);
403 break 'main;
404 }
405 keep_alive.update_send();
406 }
407 }
408
409 recv_result = conn.secure_recv::<Message>() => {
411 match recv_result {
412 Ok(msg) => {
413 debug!(message = ?msg, "Received message");
414 keep_alive.update_recv();
415
416 if matches!(msg, Message::Disconnect) {
418 info!("Received disconnect request");
419 break 'main;
420 }
421
422 if is_pong(&msg) {
424 debug!("Received pong response");
425 continue;
426 }
427
428 next_msg_id = next_msg_id.wrapping_add(1);
430
431 if msg_tx.capacity() == 0 {
433 debug!("Channel full - applying backpressure");
434
435 match msg_tx.reserve().await {
437 Ok(permit) => {
438 permit.send(ProcessingMessage::Message(msg));
440 },
441 Err(_) => {
442 warn!("Processing channel closed unexpectedly");
444 break 'main;
445 }
446 }
447 } else {
448 if (msg_tx.send(ProcessingMessage::Message(msg)).await).is_err() {
450 warn!("Failed to send message to processing channel");
452 break 'main;
453 }
454 }
455 }
456 Err(ProtocolError::Timeout) => {
457 continue;
459 }
460 Err(e) => {
461 final_result = Err(e);
462 break 'main;
463 }
464 }
465 }
466 }
467 }
468
469 debug!("Signaling processor to terminate");
471 let _ = msg_tx.send(ProcessingMessage::Terminate).await;
472
473 debug!("Waiting for processor to terminate");
475 let _ = processor_handle.await;
476
477 final_result
478}
479
480#[instrument(skip(rx, resp_tx, dispatcher), level = "debug")]
482async fn process_messages(
483 mut rx: mpsc::Receiver<ProcessingMessage>,
484 resp_tx: mpsc::Sender<ProcessingResult>,
485 dispatcher: Arc<Dispatcher>,
486) {
487 let mut msg_counter: usize = 0;
488
489 while let Some(proc_msg) = rx.recv().await {
490 match proc_msg {
491 ProcessingMessage::Message(msg) => {
492 let msg_id = msg_counter;
493 msg_counter += 1;
494
495 debug!(msg_id = msg_id, message = ?msg, "Processing message from channel");
496
497 let response = match dispatcher.dispatch(&msg) {
498 Ok(reply) => {
499 Some(reply)
501 }
502 Err(e) => {
503 warn!(error = %e, "Error dispatching message");
505 None
506 }
507 };
508
509 let result = ProcessingResult {
511 original_id: msg_id,
512 response,
513 };
514
515 if (resp_tx.send(result).await).is_err() {
516 warn!("Failed to send processing result - reader likely disconnected");
517 break;
518 }
519 }
520 ProcessingMessage::Terminate => {
521 debug!("Processor received terminate signal");
522 break;
523 }
524 }
525 }
526
527 debug!("Message processor terminated");
528}
529
530#[derive(Debug)]
532pub struct Daemon {
533 pub address: String,
535 shutdown_tx: Option<oneshot::Sender<()>>,
537}
538
539impl Daemon {
540 pub fn new(address: String, shutdown_tx: oneshot::Sender<()>) -> Self {
542 Self {
543 address,
544 shutdown_tx: Some(shutdown_tx),
545 }
546 }
547
548 pub async fn run(self) -> Result<()> {
550 Ok(())
553 }
554
555 pub async fn shutdown(&mut self) -> Result<()> {
557 if let Some(tx) = self.shutdown_tx.take() {
558 let _ = tx.send(());
559 Ok(())
560 } else {
561 Err(ProtocolError::Custom("Shutdown already called".to_string()))
562 }
563 }
564
565 pub async fn shutdown_with_timeout(&mut self, _timeout: Duration) -> Result<()> {
567 self.shutdown().await
569 }
570}
571
572#[instrument(skip(config, _dispatcher), fields(address = %config.address))]
574pub async fn start_daemon_no_signals(
575 config: ServerConfig,
576 _dispatcher: Arc<Dispatcher>,
577) -> Result<Daemon> {
578 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
580
581 let address = config.address.clone();
582
583 tokio::spawn(async move {
585 if let Err(e) = start_with_config_and_shutdown(config, shutdown_rx).await {
586 error!(error = ?e, "Server error");
587 }
588 });
589
590 Ok(Daemon::new(address, shutdown_tx))
592}
593
594pub fn new_with_config(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Daemon {
596 let (shutdown_tx, _) = oneshot::channel::<()>();
597 Daemon::new(config.address.clone(), shutdown_tx)
598}