network_protocol/service/
daemon.rs1use bincode;
2use futures::{SinkExt, StreamExt};
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tokio::time;
9use tokio_util::codec::Framed;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::config::ServerConfig;
13
14use crate::utils::timeout::with_timeout_error;
15
16use crate::core::codec::PacketCodec;
17use crate::core::packet::Packet;
18use crate::protocol::message::Message;
19use crate::error::{ProtocolError, Result};
21use crate::protocol::dispatcher::Dispatcher;
22use crate::protocol::handshake::{
23 server_secure_handshake_finalize, server_secure_handshake_response,
24};
25use crate::protocol::heartbeat::{build_ping, is_pong};
26use crate::protocol::keepalive::KeepAliveManager;
27use crate::service::secure::SecureConnection;
28
29#[instrument(skip(addr), fields(address = %addr))]
31pub async fn start(addr: &str) -> Result<()> {
32 let (_, shutdown_rx) = oneshot::channel::<()>();
34 start_with_shutdown(addr, shutdown_rx).await
35}
36
37#[instrument(skip(config), fields(address = %config.address))]
39pub async fn start_with_config(config: ServerConfig) -> Result<()> {
40 let (_, shutdown_rx) = oneshot::channel::<()>();
42 start_with_config_and_shutdown(config, shutdown_rx).await
43}
44
45#[instrument(skip(addr, shutdown_rx), fields(address = %addr))]
47pub async fn start_with_shutdown(addr: &str, shutdown_rx: oneshot::Receiver<()>) -> Result<()> {
48 let config = ServerConfig {
50 address: addr.to_string(),
51 ..Default::default()
52 };
53 start_with_config_and_shutdown(config, shutdown_rx).await
54}
55
56#[instrument(skip(config, shutdown_rx), fields(address = %config.address))]
58pub async fn start_with_config_and_shutdown(
59 config: ServerConfig,
60 shutdown_rx: oneshot::Receiver<()>,
61) -> Result<()> {
62 let listener = TcpListener::bind(&config.address).await?;
63 info!(address = %config.address, "Server listening");
64
65 let dispatcher = Arc::new(Dispatcher::new());
67
68 register_default_handlers(&dispatcher)?;
70
71 let active_connections = Arc::new(Mutex::new(0u32));
73
74 let (internal_shutdown_tx, mut internal_shutdown_rx) = mpsc::channel::<()>(1);
76
77 let shutdown_timeout = config.shutdown_timeout;
79 let heartbeat_interval = config.heartbeat_interval;
80
81 let shutdown_tx_clone = internal_shutdown_tx.clone();
83 tokio::spawn(async move {
84 match tokio::signal::ctrl_c().await {
85 Ok(()) => {
86 info!("Shutdown signal received");
87 let _ = shutdown_tx_clone.send(()).await;
88 }
89 Err(err) => {
90 error!(error = %err, "Failed to listen for shutdown signal");
91 }
92 }
93 });
94
95 let internal_shutdown_tx_clone = internal_shutdown_tx.clone();
97 tokio::spawn(async move {
98 if shutdown_rx.await.is_ok() {
99 info!("External shutdown signal received");
100 let _ = internal_shutdown_tx_clone.send(()).await;
101 }
102 });
103
104 loop {
106 tokio::select! {
107 _ = internal_shutdown_rx.recv() => {
109 info!("Shutting down server. Waiting for connections to close...");
110
111 let timeout = tokio::time::sleep(shutdown_timeout);
113 tokio::pin!(timeout);
114
115 loop {
116 tokio::select! {
117 _ = &mut timeout => {
118 warn!("Shutdown timeout reached, forcing exit");
119 break;
120 }
121 _ = tokio::time::sleep(Duration::from_millis(500)) => {
122 let connections = *active_connections.lock().await;
123 info!(connections = %connections, "Waiting for connections to close");
124 if connections == 0 {
125 info!("All connections closed, shutting down");
126 break;
127 }
128 }
129 }
130 }
131
132 return Ok(());
133 }
134
135 accept_result = listener.accept() => {
137 match accept_result {
138 Ok((stream, peer)) => {
139 info!(peer = %peer, "New connection established");
140 let dispatcher = dispatcher.clone();
141 let active_connections = active_connections.clone();
142 {
147 let mut count = active_connections.lock().await;
148 *count += 1;
149 }
150
151 let active_connections_clone = active_connections.clone();
153 let config_clone = config.clone();
154
155 tokio::spawn(async move {
156 handle_connection(stream, peer, dispatcher, active_connections_clone, config_clone, heartbeat_interval).await;
157 });
158 }
159 Err(e) => {
160 error!(error = %e, "Error accepting connection");
161 }
162 }
163 }
164 }
165 }
166}
167
168#[instrument(skip(stream, dispatcher, active_connections, config, heartbeat_interval), fields(peer = %peer))]
170async fn handle_connection(
171 stream: tokio::net::TcpStream,
172 peer: std::net::SocketAddr,
173 dispatcher: Arc<Dispatcher>,
174 active_connections: Arc<Mutex<u32>>,
175 config: ServerConfig,
176 heartbeat_interval: Duration,
177) {
178 let result = with_timeout_error(
180 async {
181 process_connection(stream, dispatcher, peer, config.clone(), heartbeat_interval).await
182 },
183 config.connection_timeout,
184 )
185 .await;
186
187 match result {
189 Ok(_) => info!("Connection closed gracefully"),
190 Err(ProtocolError::Timeout) => warn!("Connection timed out"),
191 Err(e) => error!(error = %e, "Connection error"),
192 }
193
194 {
196 let mut count = active_connections.lock().await;
197 *count -= 1;
198 }
199
200 info!("Client disconnected");
201}
202
203#[instrument(skip(stream, dispatcher, peer, config, heartbeat_interval), fields(peer = %peer))]
205async fn process_connection(
206 stream: TcpStream,
207 dispatcher: Arc<Dispatcher>,
208 peer: SocketAddr,
209 config: ServerConfig,
210 heartbeat_interval: Duration,
211) -> Result<()> {
212 let mut framed = Framed::new(stream, PacketCodec);
214
215 let init = with_timeout_error(
217 async {
218 match framed.next().await {
219 Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
220 .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
221 Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
222 None => Err(ProtocolError::ConnectionClosed),
223 }
224 },
225 config.connection_timeout,
226 )
227 .await?;
228
229 let (client_pub_key, client_timestamp, client_nonce) = match init {
231 Message::SecureHandshakeInit {
232 pub_key,
233 timestamp,
234 nonce,
235 } => (pub_key, timestamp, nonce),
236 _ => {
237 return Err(ProtocolError::HandshakeError(
238 "Unexpected message type".to_string(),
239 ))
240 }
241 };
242
243 let (server_state, response) =
245 server_secure_handshake_response(client_pub_key, client_nonce, client_timestamp)?;
246
247 let response_bytes =
248 bincode::serialize(&response).map_err(|e| ProtocolError::SerializeError(e.to_string()))?;
249
250 framed
251 .send(Packet {
252 version: 1,
253 payload: response_bytes,
254 })
255 .await
256 .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
257
258 let confirm = with_timeout_error(
260 async {
261 match framed.next().await {
262 Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
263 .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
264 Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
265 None => Err(ProtocolError::ConnectionClosed),
266 }
267 },
268 config.connection_timeout,
269 )
270 .await?;
271
272 let nonce_verification = match confirm {
273 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
274 _ => {
275 return Err(ProtocolError::HandshakeError(
276 "Expected handshake confirmation".to_string(),
277 ))
278 }
279 };
280
281 let session_key = server_secure_handshake_finalize(server_state, nonce_verification)?;
283
284 let conn = SecureConnection::new(framed, session_key);
288
289 handle_secure_connection(conn, dispatcher, peer, heartbeat_interval).await?;
291
292 Ok(())
293}
294
295#[instrument(skip(dispatcher))]
297fn register_default_handlers(dispatcher: &Arc<Dispatcher>) -> Result<()> {
298 dispatcher.register("PING", |_| {
300 debug!("Responding to ping with pong");
301 Ok(Message::Pong)
302 })?;
303
304 dispatcher.register("ECHO", |msg| {
306 if let Message::Echo(text) = msg {
307 debug!(text = %text, "Echoing message");
308 Ok(Message::Echo(text.clone()))
309 } else {
310 Err(ProtocolError::Custom(
311 "Invalid Echo message format".to_string(),
312 ))
313 }
314 })?;
315
316 Ok(())
317}
318
319#[derive(Debug)]
321enum ProcessingMessage {
322 Message(Message),
324 Terminate,
326}
327
328#[derive(Debug)]
330struct ProcessingResult {
331 original_id: usize,
333 response: Option<Message>,
335}
336
337#[instrument(skip(conn, dispatcher, heartbeat_interval), fields(peer = %peer))]
339async fn handle_secure_connection(
340 mut conn: SecureConnection,
341 dispatcher: Arc<Dispatcher>,
342 peer: std::net::SocketAddr,
343 heartbeat_interval: Duration,
344) -> Result<()> {
345 let dead_timeout = heartbeat_interval.mul_f32(4.0); let mut keep_alive = KeepAliveManager::with_settings(heartbeat_interval, dead_timeout);
348 let mut ping_interval = time::interval(keep_alive.ping_interval());
349
350 let (msg_tx, msg_rx) = mpsc::channel::<ProcessingMessage>(32);
353 let (resp_tx, mut resp_rx) = mpsc::channel::<ProcessingResult>(32);
354
355 let dispatcher_clone = dispatcher.clone();
357 let processor_handle =
358 tokio::spawn(async move { process_messages(msg_rx, resp_tx, dispatcher_clone).await });
359
360 let mut final_result = Ok(());
362 let mut next_msg_id: usize = 0;
363
364 'main: loop {
366 tokio::select! {
367 _ = ping_interval.tick() => {
369 if keep_alive.should_ping() {
370 debug!("Sending keep-alive ping");
371 let ping = build_ping();
372 if let Err(e) = conn.secure_send(ping).await {
373 warn!(error = %e, "Failed to send ping");
374 final_result = Err(e);
375 break 'main;
376 }
377 keep_alive.update_send();
378 }
379
380 if keep_alive.is_connection_dead() {
382 warn!(dead_seconds = ?keep_alive.time_since_last_recv().as_secs(),
383 "Connection appears dead, closing");
384 final_result = Err(ProtocolError::ConnectionTimeout);
385 break 'main;
386 }
387 }
388
389 Some(result) = resp_rx.recv() => {
391 if let Some(response) = result.response {
392 debug!("Sending response for message {}", result.original_id);
393 if let Err(e) = conn.secure_send(response).await {
394 warn!(error = %e, "Failed to send response");
395 final_result = Err(e);
396 break 'main;
397 }
398 keep_alive.update_send();
399 }
400 }
401
402 recv_result = conn.secure_recv::<Message>() => {
404 match recv_result {
405 Ok(msg) => {
406 debug!(message = ?msg, "Received message");
407 keep_alive.update_recv();
408
409 if matches!(msg, Message::Disconnect) {
411 info!("Received disconnect request");
412 break 'main;
413 }
414
415 if is_pong(&msg) {
417 debug!("Received pong response");
418 continue;
419 }
420
421 next_msg_id = next_msg_id.wrapping_add(1);
423
424 if msg_tx.capacity() == 0 {
426 debug!("Channel full - applying backpressure");
427
428 match msg_tx.reserve().await {
430 Ok(permit) => {
431 permit.send(ProcessingMessage::Message(msg));
433 },
434 Err(_) => {
435 warn!("Processing channel closed unexpectedly");
437 break 'main;
438 }
439 }
440 } else {
441 if (msg_tx.send(ProcessingMessage::Message(msg)).await).is_err() {
443 warn!("Failed to send message to processing channel");
445 break 'main;
446 }
447 }
448 }
449 Err(ProtocolError::Timeout) => {
450 continue;
452 }
453 Err(e) => {
454 final_result = Err(e);
455 break 'main;
456 }
457 }
458 }
459 }
460 }
461
462 debug!("Signaling processor to terminate");
464 let _ = msg_tx.send(ProcessingMessage::Terminate).await;
465
466 debug!("Waiting for processor to terminate");
468 let _ = processor_handle.await;
469
470 final_result
471}
472
473#[instrument(skip(rx, resp_tx, dispatcher), level = "debug")]
475async fn process_messages(
476 mut rx: mpsc::Receiver<ProcessingMessage>,
477 resp_tx: mpsc::Sender<ProcessingResult>,
478 dispatcher: Arc<Dispatcher>,
479) {
480 let mut msg_counter: usize = 0;
481
482 while let Some(proc_msg) = rx.recv().await {
483 match proc_msg {
484 ProcessingMessage::Message(msg) => {
485 let msg_id = msg_counter;
486 msg_counter += 1;
487
488 debug!(msg_id = msg_id, message = ?msg, "Processing message from channel");
489
490 let response = match dispatcher.dispatch(&msg) {
491 Ok(reply) => {
492 Some(reply)
494 }
495 Err(e) => {
496 warn!(error = %e, "Error dispatching message");
498 None
499 }
500 };
501
502 let result = ProcessingResult {
504 original_id: msg_id,
505 response,
506 };
507
508 if (resp_tx.send(result).await).is_err() {
509 warn!("Failed to send processing result - reader likely disconnected");
510 break;
511 }
512 }
513 ProcessingMessage::Terminate => {
514 debug!("Processor received terminate signal");
515 break;
516 }
517 }
518 }
519
520 debug!("Message processor terminated");
521}
522
523#[derive(Debug)]
525pub struct Daemon {
526 pub address: String,
528 shutdown_tx: Option<oneshot::Sender<()>>,
530}
531
532impl Daemon {
533 pub fn new(address: String, shutdown_tx: oneshot::Sender<()>) -> Self {
535 Self {
536 address,
537 shutdown_tx: Some(shutdown_tx),
538 }
539 }
540
541 pub async fn run(self) -> Result<()> {
543 Ok(())
546 }
547
548 pub async fn shutdown(&mut self) -> Result<()> {
550 if let Some(tx) = self.shutdown_tx.take() {
551 let _ = tx.send(());
552 Ok(())
553 } else {
554 Err(ProtocolError::Custom("Shutdown already called".to_string()))
555 }
556 }
557
558 pub async fn shutdown_with_timeout(&mut self, _timeout: Duration) -> Result<()> {
560 self.shutdown().await
562 }
563}
564
565#[instrument(skip(config, _dispatcher), fields(address = %config.address))]
567pub async fn start_daemon_no_signals(
568 config: ServerConfig,
569 _dispatcher: Arc<Dispatcher>,
570) -> Result<Daemon> {
571 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
573
574 let address = config.address.clone();
575
576 tokio::spawn(async move {
578 if let Err(e) = start_with_config_and_shutdown(config, shutdown_rx).await {
579 error!(error = ?e, "Server error");
580 }
581 });
582
583 Ok(Daemon::new(address, shutdown_tx))
585}
586
587pub fn new_with_config(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Daemon {
589 let (shutdown_tx, _) = oneshot::channel::<()>();
590 Daemon::new(config.address.clone(), shutdown_tx)
591}