1use std::net::SocketAddr;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use std::time::Duration;
15
16use dashmap::DashMap;
17use futures::{SinkExt, StreamExt};
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::{broadcast, Semaphore};
20use tokio_util::codec::Framed;
21use tracing::{debug, error, info, instrument, warn};
22
23use crate::config::ModbusServerConfig;
24use crate::device::ModbusDevice;
25use crate::error::ModbusResult;
26use crate::fault_injection::connection_disruption::{
27 ConnectionDisruptionConfig, ConnectionDisruptionState, DisruptionAction,
28};
29use crate::fault_injection::{FaultAction, FaultPipeline, ModbusFaultContext};
30use crate::handler::{build_exception_pdu, ExceptionCode, HandlerContext, HandlerRegistry};
31use crate::register::RegisterStore;
32
33use super::codec::{MbapCodec, MbapFrame};
34use super::connection::ConnectionPool;
35use super::metrics::{LatencyTimer, ServerMetrics};
36
37#[derive(Debug, Clone)]
39pub struct ServerConfigV2 {
40 pub bind_address: SocketAddr,
42
43 pub max_connections: usize,
45
46 pub connection_timeout: Duration,
48
49 pub request_timeout: Duration,
51
52 pub tcp_keepalive: Option<Duration>,
54
55 pub tcp_nodelay: bool,
57
58 pub shutdown_timeout: Duration,
60}
61
62impl Default for ServerConfigV2 {
63 fn default() -> Self {
64 Self {
65 bind_address: "0.0.0.0:502".parse().unwrap(),
66 max_connections: 10_000,
67 connection_timeout: Duration::from_secs(300),
68 request_timeout: Duration::from_secs(5),
69 tcp_keepalive: Some(Duration::from_secs(60)),
70 tcp_nodelay: true,
71 shutdown_timeout: Duration::from_secs(30),
72 }
73 }
74}
75
76impl From<ModbusServerConfig> for ServerConfigV2 {
77 fn from(config: ModbusServerConfig) -> Self {
78 Self {
79 bind_address: config.bind_address,
80 max_connections: config.max_connections,
81 connection_timeout: Duration::from_secs(config.timeout_secs),
82 request_timeout: Duration::from_secs(5),
83 tcp_keepalive: if config.keep_alive {
84 Some(Duration::from_secs(60))
85 } else {
86 None
87 },
88 tcp_nodelay: config.tcp_nodelay,
89 shutdown_timeout: Duration::from_secs(30),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub enum ServerEvent {
97 Started { address: SocketAddr },
99 Stopped,
101 Error { message: String },
103}
104
105pub struct ModbusTcpServerV2 {
115 config: ServerConfigV2,
117
118 handlers: Arc<HandlerRegistry>,
120
121 devices: DashMap<u8, Arc<ModbusDevice>>,
123
124 default_registers: Arc<RegisterStore>,
126
127 connections: Arc<ConnectionPool>,
129
130 metrics: Arc<ServerMetrics>,
132
133 connection_semaphore: Arc<Semaphore>,
135
136 shutdown: Arc<AtomicBool>,
138
139 shutdown_tx: broadcast::Sender<()>,
141
142 event_tx: broadcast::Sender<ServerEvent>,
144
145 fault_pipeline: Option<Arc<FaultPipeline>>,
147
148 connection_disruption: Option<Arc<ConnectionDisruptionConfig>>,
150}
151
152impl ModbusTcpServerV2 {
153 pub fn new(config: ServerConfigV2) -> Self {
155 let (shutdown_tx, _) = broadcast::channel(1);
156 let (event_tx, _) = broadcast::channel(64);
157
158 Self {
159 connection_semaphore: Arc::new(Semaphore::new(config.max_connections)),
160 connections: Arc::new(ConnectionPool::new(config.max_connections)),
161 config,
162 handlers: Arc::new(HandlerRegistry::with_defaults()),
163 devices: DashMap::new(),
164 default_registers: Arc::new(RegisterStore::with_defaults()),
165 metrics: Arc::new(ServerMetrics::new()),
166 shutdown: Arc::new(AtomicBool::new(false)),
167 shutdown_tx,
168 event_tx,
169 fault_pipeline: None,
170 connection_disruption: None,
171 }
172 }
173
174 pub fn from_config(config: ModbusServerConfig) -> Self {
176 Self::new(config.into())
177 }
178
179 pub fn with_handlers(mut self, handlers: HandlerRegistry) -> Self {
181 self.handlers = Arc::new(handlers);
182 self
183 }
184
185 pub fn with_fault_pipeline(mut self, pipeline: FaultPipeline) -> Self {
187 self.fault_pipeline = Some(Arc::new(pipeline));
188 self
189 }
190
191 pub fn with_connection_disruption(mut self, config: ConnectionDisruptionConfig) -> Self {
193 self.connection_disruption = Some(Arc::new(config));
194 self
195 }
196
197 pub fn with_default_registers(mut self, registers: RegisterStore) -> Self {
199 self.default_registers = Arc::new(registers);
200 self
201 }
202
203 pub fn add_device(&self, device: ModbusDevice) {
205 let unit_id = device.unit_id();
206 self.devices.insert(unit_id, Arc::new(device));
207 }
208
209 pub fn remove_device(&self, unit_id: u8) -> Option<Arc<ModbusDevice>> {
211 self.devices.remove(&unit_id).map(|(_, d)| d)
212 }
213
214 pub fn device(&self, unit_id: u8) -> Option<Arc<ModbusDevice>> {
216 self.devices.get(&unit_id).map(|d| d.clone())
217 }
218
219 pub fn default_registers(&self) -> &Arc<RegisterStore> {
221 &self.default_registers
222 }
223
224 pub fn metrics(&self) -> &Arc<ServerMetrics> {
226 &self.metrics
227 }
228
229 pub fn connections(&self) -> &Arc<ConnectionPool> {
231 &self.connections
232 }
233
234 pub fn subscribe(&self) -> broadcast::Receiver<ServerEvent> {
236 self.event_tx.subscribe()
237 }
238
239 pub fn is_shutdown(&self) -> bool {
241 self.shutdown.load(Ordering::SeqCst)
242 }
243
244 pub fn shutdown(&self) {
246 if !self.shutdown.swap(true, Ordering::SeqCst) {
247 info!("Shutdown requested");
248 let _ = self.shutdown_tx.send(());
249 }
250 }
251
252 #[instrument(skip(self))]
254 pub async fn run(&self) -> ModbusResult<()> {
255 let listener = TcpListener::bind(self.config.bind_address).await?;
256 info!(address = %self.config.bind_address, "Modbus TCP server started");
257
258 let _ = self.event_tx.send(ServerEvent::Started {
259 address: self.config.bind_address,
260 });
261
262 let mut shutdown_rx = self.shutdown_tx.subscribe();
263
264 loop {
265 tokio::select! {
266 result = listener.accept() => {
268 match result {
269 Ok((stream, peer_addr)) => {
270 self.handle_new_connection(stream, peer_addr).await;
271 }
272 Err(e) => {
273 error!(error = %e, "Failed to accept connection");
274 self.metrics.record_error();
275 }
276 }
277 }
278
279 _ = shutdown_rx.recv() => {
281 info!("Shutdown signal received");
282 break;
283 }
284 }
285 }
286
287 self.graceful_shutdown().await;
289
290 let _ = self.event_tx.send(ServerEvent::Stopped);
291 info!("Modbus TCP server stopped");
292
293 Ok(())
294 }
295
296 async fn handle_new_connection(&self, stream: TcpStream, peer_addr: SocketAddr) {
298 let permit = match self.connection_semaphore.clone().try_acquire_owned() {
300 Ok(permit) => permit,
301 Err(_) => {
302 warn!(peer = %peer_addr, "Max connections reached, rejecting");
303 self.metrics.record_connection_rejected();
304 return;
305 }
306 };
307
308 let connection_id = match self.connections.try_register(peer_addr) {
310 Some(id) => id,
311 None => {
312 warn!(peer = %peer_addr, "Connection pool full, rejecting");
313 self.metrics.record_connection_rejected();
314 return;
315 }
316 };
317
318 self.metrics.record_connection();
319
320 let handlers = self.handlers.clone();
322 let devices = self.devices.clone();
323 let default_registers = self.default_registers.clone();
324 let connections = self.connections.clone();
325 let metrics = self.metrics.clone();
326 let shutdown = self.shutdown.clone();
327 let config = self.config.clone();
328 let fault_pipeline = self.fault_pipeline.clone();
329 let connection_disruption = self.connection_disruption.clone();
330
331 tokio::spawn(async move {
332 let result = handle_connection(
333 stream,
334 peer_addr,
335 connection_id,
336 handlers,
337 devices,
338 default_registers,
339 connections.clone(),
340 metrics.clone(),
341 shutdown,
342 config,
343 fault_pipeline,
344 connection_disruption,
345 )
346 .await;
347
348 if let Err(e) = result {
349 debug!(peer = %peer_addr, error = %e, "Connection handler error");
350 }
351
352 connections.unregister(connection_id);
353 metrics.record_disconnection();
354 drop(permit);
355 });
356 }
357
358 async fn graceful_shutdown(&self) {
360 info!("Starting graceful shutdown");
361
362 let start = std::time::Instant::now();
364 loop {
365 let active = self.connections.active_count();
366 if active == 0 {
367 info!("All connections closed");
368 break;
369 }
370
371 if start.elapsed() > self.config.shutdown_timeout {
372 warn!(
373 active_connections = active,
374 "Shutdown timeout reached, forcing close"
375 );
376 break;
377 }
378
379 debug!(active_connections = active, "Waiting for connections to close");
380 tokio::time::sleep(Duration::from_millis(100)).await;
381 }
382 }
383}
384
385async fn handle_connection(
387 stream: TcpStream,
388 peer_addr: SocketAddr,
389 connection_id: u64,
390 handlers: Arc<HandlerRegistry>,
391 devices: DashMap<u8, Arc<ModbusDevice>>,
392 default_registers: Arc<RegisterStore>,
393 connections: Arc<ConnectionPool>,
394 metrics: Arc<ServerMetrics>,
395 shutdown: Arc<AtomicBool>,
396 config: ServerConfigV2,
397 fault_pipeline: Option<Arc<FaultPipeline>>,
398 connection_disruption: Option<Arc<ConnectionDisruptionConfig>>,
399) -> ModbusResult<()> {
400 debug!(peer = %peer_addr, connection_id, "Connection established");
401
402 if config.tcp_nodelay {
404 stream.set_nodelay(true)?;
405 }
406
407 let mut framed = Framed::new(stream, MbapCodec::new());
409 let mut request_number: u64 = 0;
410
411 let disruption_state = connection_disruption
413 .as_ref()
414 .map(|cfg| ConnectionDisruptionState::new((**cfg).clone()));
415
416 loop {
417 if shutdown.load(Ordering::SeqCst) {
419 debug!(peer = %peer_addr, "Shutdown requested, closing connection");
420 break;
421 }
422
423 let read_result = tokio::time::timeout(config.connection_timeout, framed.next()).await;
425
426 let frame = match read_result {
427 Ok(Some(Ok(frame))) => frame,
428 Ok(Some(Err(e))) => {
429 debug!(peer = %peer_addr, error = %e, "Frame decode error");
430 metrics.record_frame_error();
431 continue;
432 }
433 Ok(None) => {
434 debug!(peer = %peer_addr, "Connection closed by client");
435 break;
436 }
437 Err(_) => {
438 debug!(peer = %peer_addr, "Connection timeout");
439 metrics.record_timeout();
440 break;
441 }
442 };
443
444 let timer = LatencyTimer::start();
446 let request_bytes = frame.frame_size() as u64;
447 let unit_id = frame.header.unit_id;
448 let function_code = frame.function_code().unwrap_or(0);
449
450 metrics.record_request(function_code);
451
452 let registers = if let Some(device) = devices.get(&unit_id) {
454 device.registers().clone()
455 } else if unit_id == 0 {
456 default_registers.clone()
457 } else {
458 let exception_pdu = build_exception_pdu(
460 function_code,
461 ExceptionCode::GatewayTargetDeviceFailedToRespond,
462 );
463
464 let response = MbapFrame::response(&frame, exception_pdu);
465 let response_bytes = response.frame_size() as u64;
466
467 if let Err(e) = framed.send(response).await {
468 warn!(peer = %peer_addr, error = %e, "Failed to send exception response");
469 break;
470 }
471
472 let latency = timer.elapsed_us();
473 metrics.record_exception(latency, request_bytes, response_bytes);
474 connections.record_request(
475 connection_id,
476 unit_id,
477 function_code,
478 false,
479 latency,
480 request_bytes,
481 response_bytes,
482 );
483
484 continue;
485 };
486
487 let ctx = HandlerContext::new(unit_id, registers, frame.header.transaction_id);
489
490 let process_result = tokio::time::timeout(
492 config.request_timeout,
493 async {
494 handlers.dispatch(&frame.pdu, &ctx)
495 }
496 ).await;
497
498 let response_pdu = match process_result {
499 Ok(Ok(pdu)) => pdu,
500 Ok(Err(exception_code)) => {
501 build_exception_pdu(function_code, exception_code)
502 }
503 Err(_) => {
504 warn!(peer = %peer_addr, "Request processing timeout");
505 metrics.record_timeout();
506 build_exception_pdu(function_code, ExceptionCode::SlaveDeviceBusy)
507 }
508 };
509
510 request_number += 1;
512 let fault_action = if let Some(ref pipeline) = fault_pipeline {
513 let fault_ctx = ModbusFaultContext::tcp(
514 unit_id,
515 function_code,
516 &frame.pdu,
517 &response_pdu,
518 frame.header.transaction_id,
519 request_number,
520 );
521 pipeline.apply(&fault_ctx)
522 } else {
523 None
524 };
525
526 match fault_action {
527 Some(FaultAction::DropResponse) => {
528 debug!(peer = %peer_addr, unit_id, fc = function_code, "Fault: dropping response");
530 let latency = timer.elapsed_us();
531 metrics.record_success(latency, request_bytes, 0);
532 continue;
533 }
534 Some(FaultAction::DelayThenSend { delay, response: fault_pdu }) => {
535 tokio::time::sleep(delay).await;
536 let is_exception = fault_pdu.first().map(|&fc| fc & 0x80 != 0).unwrap_or(false);
537 let response = MbapFrame::response(&frame, fault_pdu);
538 let response_bytes = response.frame_size() as u64;
539 if let Err(e) = framed.send(response).await {
540 warn!(peer = %peer_addr, error = %e, "Failed to send delayed response");
541 break;
542 }
543 let latency = timer.elapsed_us();
544 if is_exception {
545 metrics.record_exception(latency, request_bytes, response_bytes);
546 } else {
547 metrics.record_success(latency, request_bytes, response_bytes);
548 }
549 connections.record_request(connection_id, unit_id, function_code, !is_exception, latency, request_bytes, response_bytes);
550 }
551 Some(FaultAction::OverrideTransactionId { transaction_id, response: fault_pdu }) => {
552 let is_exception = fault_pdu.first().map(|&fc| fc & 0x80 != 0).unwrap_or(false);
553 let mut response = MbapFrame::response(&frame, fault_pdu);
554 response.header.transaction_id = transaction_id;
555 let response_bytes = response.frame_size() as u64;
556 if let Err(e) = framed.send(response).await {
557 warn!(peer = %peer_addr, error = %e, "Failed to send response with overridden TID");
558 break;
559 }
560 let latency = timer.elapsed_us();
561 if is_exception {
562 metrics.record_exception(latency, request_bytes, response_bytes);
563 } else {
564 metrics.record_success(latency, request_bytes, response_bytes);
565 }
566 connections.record_request(connection_id, unit_id, function_code, !is_exception, latency, request_bytes, response_bytes);
567 }
568 Some(FaultAction::SendRawBytes(raw_bytes)) => {
569 use tokio::io::AsyncWriteExt;
571 let inner = framed.get_mut();
572 let response_bytes = raw_bytes.len() as u64;
573 if let Err(e) = inner.write_all(&raw_bytes).await {
574 warn!(peer = %peer_addr, error = %e, "Failed to send raw bytes");
575 break;
576 }
577 let _ = inner.flush().await;
578 let latency = timer.elapsed_us();
579 metrics.record_success(latency, request_bytes, response_bytes);
580 connections.record_request(connection_id, unit_id, function_code, true, latency, request_bytes, response_bytes);
581 }
582 Some(FaultAction::SendResponse(fault_pdu)) => {
583 let is_exception = fault_pdu.first().map(|&fc| fc & 0x80 != 0).unwrap_or(false);
584 let response = MbapFrame::response(&frame, fault_pdu);
585 let response_bytes = response.frame_size() as u64;
586 if let Err(e) = framed.send(response).await {
587 warn!(peer = %peer_addr, error = %e, "Failed to send faulted response");
588 break;
589 }
590 let latency = timer.elapsed_us();
591 if is_exception {
592 metrics.record_exception(latency, request_bytes, response_bytes);
593 } else {
594 metrics.record_success(latency, request_bytes, response_bytes);
595 }
596 connections.record_request(connection_id, unit_id, function_code, !is_exception, latency, request_bytes, response_bytes);
597 }
598 Some(FaultAction::SendPartial { bytes }) => {
599 use tokio::io::AsyncWriteExt;
601 let inner = framed.get_mut();
602 let response_bytes = bytes.len() as u64;
603 if let Err(e) = inner.write_all(&bytes).await {
604 warn!(peer = %peer_addr, error = %e, "Failed to send partial bytes");
605 break;
606 }
607 let _ = inner.flush().await;
608 let latency = timer.elapsed_us();
609 metrics.record_success(latency, request_bytes, response_bytes);
610 connections.record_request(connection_id, unit_id, function_code, true, latency, request_bytes, response_bytes);
611 }
612 None => {
613 let is_exception = response_pdu.first().map(|&fc| fc & 0x80 != 0).unwrap_or(false);
615 let response = MbapFrame::response(&frame, response_pdu);
616 let response_bytes = response.frame_size() as u64;
617
618 if let Err(e) = framed.send(response).await {
619 warn!(peer = %peer_addr, error = %e, "Failed to send response");
620 break;
621 }
622
623 let latency = timer.elapsed_us();
625 if is_exception {
626 metrics.record_exception(latency, request_bytes, response_bytes);
627 } else {
628 metrics.record_success(latency, request_bytes, response_bytes);
629 }
630
631 connections.record_request(
632 connection_id,
633 unit_id,
634 function_code,
635 !is_exception,
636 latency,
637 request_bytes,
638 response_bytes,
639 );
640 }
641 }
642
643 if let Some(ref state) = disruption_state {
645 match state.record_request() {
646 DisruptionAction::None => {}
647 DisruptionAction::Disconnect { close_delay, use_rst: _ } => {
648 debug!(peer = %peer_addr, "Connection disruption: disconnect");
649 if let Some(delay) = close_delay {
650 tokio::time::sleep(delay).await;
651 }
652 break;
657 }
658 DisruptionAction::DropMidFrame { close_delay, use_rst: _ } => {
659 debug!(peer = %peer_addr, "Connection disruption: drop mid-frame");
660 if let Some(delay) = close_delay {
661 tokio::time::sleep(delay).await;
662 }
663 break;
664 }
665 DisruptionAction::RstAfterPartial { byte_count, close_delay, use_rst: _ } => {
666 debug!(peer = %peer_addr, byte_count, "Connection disruption: RST after partial");
667 use tokio::io::AsyncWriteExt;
669 let garbage: Vec<u8> = (0..byte_count).map(|i| i as u8).collect();
670 let inner = framed.get_mut();
671 let _ = inner.write_all(&garbage).await;
672 let _ = inner.flush().await;
673 if let Some(delay) = close_delay {
674 tokio::time::sleep(delay).await;
675 }
676 break;
677 }
678 DisruptionAction::HoldOpen { duration } => {
679 debug!(peer = %peer_addr, ?duration, "Connection disruption: hold open");
680 state.set_holding_open(true);
681 tokio::time::sleep(duration).await;
682 state.set_holding_open(false);
683 break;
684 }
685 }
686 }
687 }
688
689 Ok(())
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use crate::config::ModbusDeviceConfig;
696
697 #[tokio::test]
698 async fn test_server_creation() {
699 let config = ServerConfigV2::default();
700 let server = ModbusTcpServerV2::new(config);
701
702 assert!(!server.is_shutdown());
703 assert_eq!(server.connections().active_count(), 0);
704 }
705
706 #[tokio::test]
707 async fn test_device_management() {
708 let server = ModbusTcpServerV2::new(ServerConfigV2::default());
709
710 let device = ModbusDevice::new(ModbusDeviceConfig::new(5, "Test"));
712 server.add_device(device);
713
714 assert!(server.device(5).is_some());
715 assert!(server.device(10).is_none());
716
717 let removed = server.remove_device(5);
719 assert!(removed.is_some());
720 assert!(server.device(5).is_none());
721 }
722
723 #[tokio::test]
724 async fn test_shutdown_flag() {
725 let server = ModbusTcpServerV2::new(ServerConfigV2::default());
726
727 assert!(!server.is_shutdown());
728 server.shutdown();
729 assert!(server.is_shutdown());
730
731 server.shutdown();
733 assert!(server.is_shutdown());
734 }
735
736 }