network_protocol/service/
daemon.rs1use tokio::net::{TcpListener, TcpStream};
2use tokio_util::codec::Framed;
3use futures::{StreamExt, SinkExt};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::{mpsc, Mutex, oneshot};
7use std::net::SocketAddr;
8use tokio::time;
9use bincode;
10use tracing::{info, debug, warn, error, instrument};
11
12use crate::config::ServerConfig;
13
14use crate::utils::timeout::with_timeout_error;
15
16use crate::core::codec::PacketCodec;
17use crate::core::packet::Packet;
18use crate::protocol::message::Message;
19use crate::protocol::handshake::{server_secure_handshake_response, server_secure_handshake_finalize, clear_handshake_data};
21use crate::protocol::dispatcher::Dispatcher;
22use crate::protocol::keepalive::KeepAliveManager;
23use crate::protocol::heartbeat::{build_ping, is_pong};
24use crate::service::secure::SecureConnection;
25use crate::error::{Result, ProtocolError};
26
27#[instrument(skip(addr), fields(address = %addr))]
29pub async fn start(addr: &str) -> Result<()> {
30 let (_, shutdown_rx) = oneshot::channel::<()>();
32 start_with_shutdown(addr, shutdown_rx).await
33}
34
35#[instrument(skip(config), fields(address = %config.address))]
37pub async fn start_with_config(config: ServerConfig) -> Result<()> {
38 let (_, shutdown_rx) = oneshot::channel::<()>();
40 start_with_config_and_shutdown(config, shutdown_rx).await
41}
42
43#[instrument(skip(addr, shutdown_rx), fields(address = %addr))]
45pub async fn start_with_shutdown(
46 addr: &str,
47 shutdown_rx: oneshot::Receiver<()>
48) -> Result<()> {
49 let config = ServerConfig {
51 address: addr.to_string(),
52 ..Default::default()
53 };
54 start_with_config_and_shutdown(config, shutdown_rx).await
55}
56
57#[instrument(skip(config, shutdown_rx), fields(address = %config.address))]
59pub async fn start_with_config_and_shutdown(
60 config: ServerConfig,
61 shutdown_rx: oneshot::Receiver<()>
62) -> Result<()> {
63 let listener = TcpListener::bind(&config.address).await?;
64 info!(address = %config.address, "Server listening");
65
66 let dispatcher = Arc::new(Dispatcher::new());
68
69 register_default_handlers(&dispatcher)?;
71
72 let active_connections = Arc::new(Mutex::new(0u32));
74
75 let (internal_shutdown_tx, mut internal_shutdown_rx) = mpsc::channel::<()>(1);
77
78 let shutdown_timeout = config.shutdown_timeout;
80 let heartbeat_interval = config.heartbeat_interval;
81
82 let shutdown_tx_clone = internal_shutdown_tx.clone();
84 tokio::spawn(async move {
85 match tokio::signal::ctrl_c().await {
86 Ok(()) => {
87 info!("Shutdown signal received");
88 let _ = shutdown_tx_clone.send(()).await;
89 },
90 Err(err) => {
91 error!(error = %err, "Failed to listen for shutdown signal");
92 },
93 }
94 });
95
96 let internal_shutdown_tx_clone = internal_shutdown_tx.clone();
98 tokio::spawn(async move {
99 if shutdown_rx.await.is_ok() {
100 info!("External shutdown signal received");
101 let _ = internal_shutdown_tx_clone.send(()).await;
102 }
103 });
104
105 loop {
107 tokio::select! {
108 _ = internal_shutdown_rx.recv() => {
110 info!("Shutting down server. Waiting for connections to close...");
111
112 let timeout = tokio::time::sleep(shutdown_timeout);
114 tokio::pin!(timeout);
115
116 loop {
117 tokio::select! {
118 _ = &mut timeout => {
119 warn!("Shutdown timeout reached, forcing exit");
120 break;
121 }
122 _ = tokio::time::sleep(Duration::from_millis(500)) => {
123 let connections = *active_connections.lock().await;
124 info!(connections = %connections, "Waiting for connections to close");
125 if connections == 0 {
126 info!("All connections closed, shutting down");
127 break;
128 }
129 }
130 }
131 }
132
133 return Ok(());
134 }
135
136 accept_result = listener.accept() => {
138 match accept_result {
139 Ok((stream, peer)) => {
140 info!(peer = %peer, "New connection established");
141 let dispatcher = dispatcher.clone();
142 let active_connections = active_connections.clone();
143 {
148 let mut count = active_connections.lock().await;
149 *count += 1;
150 }
151
152 let active_connections_clone = active_connections.clone();
154 let config_clone = config.clone();
155
156 tokio::spawn(async move {
157 handle_connection(stream, peer, dispatcher, active_connections_clone, config_clone, heartbeat_interval).await;
158 });
159 }
160 Err(e) => {
161 error!(error = %e, "Error accepting connection");
162 }
163 }
164 }
165 }
166 }
167}
168
169#[instrument(skip(stream, dispatcher, active_connections, config, heartbeat_interval), fields(peer = %peer))]
171async fn handle_connection(
172 stream: tokio::net::TcpStream,
173 peer: std::net::SocketAddr,
174 dispatcher: Arc<Dispatcher>,
175 active_connections: Arc<Mutex<u32>>,
176 config: ServerConfig,
177 heartbeat_interval: Duration,
178) {
179 let result = with_timeout_error(
181 async {
182 process_connection(stream, dispatcher, peer, config.clone(), heartbeat_interval).await
183 },
184 config.connection_timeout
185 ).await;
186
187 match result {
189 Ok(_) => info!("Connection closed gracefully"),
190 Err(ProtocolError::Timeout) => warn!("Connection timed out"),
191 Err(e) => error!(error = %e, "Connection error"),
192 }
193
194 {
196 let mut count = active_connections.lock().await;
197 *count -= 1;
198 }
199
200 info!("Client disconnected");
201}
202
203#[instrument(skip(stream, dispatcher, peer, config, heartbeat_interval), fields(peer = %peer))]
205async fn process_connection(
206 stream: TcpStream,
207 dispatcher: Arc<Dispatcher>,
208 peer: SocketAddr,
209 config: ServerConfig,
210 heartbeat_interval: Duration,
211) -> Result<()> {
212 let mut framed = Framed::new(stream, PacketCodec);
214
215 let init = with_timeout_error(
217 async {
218 match framed.next().await {
219 Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
220 .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
221 Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
222 None => Err(ProtocolError::ConnectionClosed),
223 }
224 },
225 config.connection_timeout
226 ).await?;
227
228 let (client_pub_key, client_timestamp, client_nonce) = match init {
230 Message::SecureHandshakeInit { pub_key, timestamp, nonce } => {
231 (pub_key, timestamp, nonce)
232 },
233 _ => return Err(ProtocolError::HandshakeError("Unexpected message type".to_string())),
234 };
235
236 let response = server_secure_handshake_response(client_pub_key, client_nonce, client_timestamp)?;
238
239 let response_bytes = bincode::serialize(&response)
240 .map_err(|e| ProtocolError::SerializeError(e.to_string()))?;
241
242 framed.send(Packet { version: 1, payload: response_bytes }).await
243 .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
244
245 let confirm = with_timeout_error(
247 async {
248 match framed.next().await {
249 Some(Ok(pkt)) => bincode::deserialize::<Message>(&pkt.payload)
250 .map_err(|e| ProtocolError::DeserializeError(e.to_string())),
251 Some(Err(e)) => Err(ProtocolError::TransportError(e.to_string())),
252 None => Err(ProtocolError::ConnectionClosed),
253 }
254 },
255 config.connection_timeout
256 ).await?;
257
258 let nonce_verification = match confirm {
259 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
260 _ => return Err(ProtocolError::HandshakeError("Expected handshake confirmation".to_string())),
261 };
262
263 let session_key = server_secure_handshake_finalize(nonce_verification)?;
265
266 let _ = clear_handshake_data();
268
269 let conn = SecureConnection::new(framed, session_key);
271
272 handle_secure_connection(conn, dispatcher, peer, heartbeat_interval).await?;
274
275 Ok(())
276}
277
278#[instrument(skip(dispatcher))]
280fn register_default_handlers(dispatcher: &Arc<Dispatcher>) -> Result<()> {
281 dispatcher.register("PING", |_| {
283 debug!("Responding to ping with pong");
284 Ok(Message::Pong)
285 })?;
286
287 dispatcher.register("ECHO", |msg| {
289 if let Message::Echo(text) = msg {
290 debug!(text = %text, "Echoing message");
291 Ok(Message::Echo(text.clone()))
292 } else {
293 Err(ProtocolError::Custom("Invalid Echo message format".to_string()))
294 }
295 })?;
296
297 Ok(())
298}
299
300#[derive(Debug)]
302enum ProcessingMessage {
303 Message(Message),
305 Terminate,
307}
308
309#[derive(Debug)]
311struct ProcessingResult {
312 original_id: usize,
314 response: Option<Message>,
316}
317
318#[instrument(skip(conn, dispatcher, heartbeat_interval), fields(peer = %peer))]
320async fn handle_secure_connection(
321 mut conn: SecureConnection,
322 dispatcher: Arc<Dispatcher>,
323 peer: std::net::SocketAddr,
324 heartbeat_interval: Duration,
325) -> Result<()> {
326 let dead_timeout = heartbeat_interval.mul_f32(4.0); let mut keep_alive = KeepAliveManager::with_settings(heartbeat_interval, dead_timeout);
329 let mut ping_interval = time::interval(keep_alive.ping_interval());
330
331 let (msg_tx, msg_rx) = mpsc::channel::<ProcessingMessage>(32);
334 let (resp_tx, mut resp_rx) = mpsc::channel::<ProcessingResult>(32);
335
336 let dispatcher_clone = dispatcher.clone();
338 let processor_handle = tokio::spawn(async move {
339 process_messages(msg_rx, resp_tx, dispatcher_clone).await
340 });
341
342 let mut final_result = Ok(());
344 let mut next_msg_id: usize = 0;
345
346 'main: loop {
348 tokio::select! {
349 _ = ping_interval.tick() => {
351 if keep_alive.should_ping() {
352 debug!("Sending keep-alive ping");
353 let ping = build_ping();
354 if let Err(e) = conn.secure_send(ping).await {
355 warn!(error = %e, "Failed to send ping");
356 final_result = Err(e);
357 break 'main;
358 }
359 keep_alive.update_send();
360 }
361
362 if keep_alive.is_connection_dead() {
364 warn!(dead_seconds = ?keep_alive.time_since_last_recv().as_secs(),
365 "Connection appears dead, closing");
366 final_result = Err(ProtocolError::ConnectionTimeout);
367 break 'main;
368 }
369 }
370
371 Some(result) = resp_rx.recv() => {
373 if let Some(response) = result.response {
374 debug!("Sending response for message {}", result.original_id);
375 if let Err(e) = conn.secure_send(response).await {
376 warn!(error = %e, "Failed to send response");
377 final_result = Err(e);
378 break 'main;
379 }
380 keep_alive.update_send();
381 }
382 }
383
384 recv_result = conn.secure_recv::<Message>() => {
386 match recv_result {
387 Ok(msg) => {
388 debug!(message = ?msg, "Received message");
389 keep_alive.update_recv();
390
391 if matches!(msg, Message::Disconnect) {
393 info!("Received disconnect request");
394 break 'main;
395 }
396
397 if is_pong(&msg) {
399 debug!("Received pong response");
400 continue;
401 }
402
403 next_msg_id = next_msg_id.wrapping_add(1);
405
406 if msg_tx.capacity() == 0 {
408 debug!("Channel full - applying backpressure");
409
410 match msg_tx.reserve().await {
412 Ok(permit) => {
413 permit.send(ProcessingMessage::Message(msg));
415 },
416 Err(_) => {
417 warn!("Processing channel closed unexpectedly");
419 break 'main;
420 }
421 }
422 } else {
423 if (msg_tx.send(ProcessingMessage::Message(msg)).await).is_err() {
425 warn!("Failed to send message to processing channel");
427 break 'main;
428 }
429 }
430 }
431 Err(ProtocolError::Timeout) => {
432 continue;
434 }
435 Err(e) => {
436 final_result = Err(e);
437 break 'main;
438 }
439 }
440 }
441 }
442 }
443
444 debug!("Signaling processor to terminate");
446 let _ = msg_tx.send(ProcessingMessage::Terminate).await;
447
448 debug!("Waiting for processor to terminate");
450 let _ = processor_handle.await;
451
452 final_result
453}
454
455#[instrument(skip(rx, resp_tx, dispatcher), level = "debug")]
457async fn process_messages(
458 mut rx: mpsc::Receiver<ProcessingMessage>,
459 resp_tx: mpsc::Sender<ProcessingResult>,
460 dispatcher: Arc<Dispatcher>,
461) {
462 let mut msg_counter: usize = 0;
463
464 while let Some(proc_msg) = rx.recv().await {
465 match proc_msg {
466 ProcessingMessage::Message(msg) => {
467 let msg_id = msg_counter;
468 msg_counter += 1;
469
470 debug!(msg_id = msg_id, message = ?msg, "Processing message from channel");
471
472 let response = match dispatcher.dispatch(&msg) {
473 Ok(reply) => {
474 Some(reply)
476 },
477 Err(e) => {
478 warn!(error = %e, "Error dispatching message");
480 None
481 }
482 };
483
484 let result = ProcessingResult {
486 original_id: msg_id,
487 response,
488 };
489
490 if (resp_tx.send(result).await).is_err() {
491 warn!("Failed to send processing result - reader likely disconnected");
492 break;
493 }
494 },
495 ProcessingMessage::Terminate => {
496 debug!("Processor received terminate signal");
497 break;
498 }
499 }
500 }
501
502 debug!("Message processor terminated");
503}
504
505#[derive(Debug)]
507pub struct Daemon {
508 pub address: String,
510 shutdown_tx: Option<oneshot::Sender<()>>,
512}
513
514impl Daemon {
515 pub fn new(address: String, shutdown_tx: oneshot::Sender<()>) -> Self {
517 Self {
518 address,
519 shutdown_tx: Some(shutdown_tx),
520 }
521 }
522
523 pub async fn run(self) -> Result<()> {
525 Ok(())
528 }
529
530 pub async fn shutdown(&mut self) -> Result<()> {
532 if let Some(tx) = self.shutdown_tx.take() {
533 let _ = tx.send(());
534 Ok(())
535 } else {
536 Err(ProtocolError::Custom("Shutdown already called".to_string()))
537 }
538 }
539
540 pub async fn shutdown_with_timeout(&mut self, _timeout: Duration) -> Result<()> {
542 self.shutdown().await
544 }
545}
546
547#[instrument(skip(config, _dispatcher), fields(address = %config.address))]
549pub async fn start_daemon_no_signals(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Result<Daemon> {
550 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
552
553 let address = config.address.clone();
554
555 tokio::spawn(async move {
557 if let Err(e) = start_with_config_and_shutdown(config, shutdown_rx).await {
558 error!(error = ?e, "Server error");
559 }
560 });
561
562 Ok(Daemon::new(address, shutdown_tx))
564}
565
566pub fn new_with_config(config: ServerConfig, _dispatcher: Arc<Dispatcher>) -> Daemon {
568 let (shutdown_tx, _) = oneshot::channel::<()>();
569 Daemon::new(config.address.clone(), shutdown_tx)
570}