Skip to main content

mabi_modbus/
server.rs

1//! Modbus TCP server implementation.
2
3use std::net::SocketAddr;
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Instant;
7
8use bytes::BytesMut;
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::Semaphore;
14use tracing::{debug, error, info, instrument, warn};
15
16use crate::config::ModbusServerConfig;
17use crate::device::ModbusDevice;
18use crate::error::{ModbusError, ModbusResult};
19use crate::register::RegisterStore;
20
21/// Modbus function codes.
22#[repr(u8)]
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum FunctionCode {
25    ReadCoils = 0x01,
26    ReadDiscreteInputs = 0x02,
27    ReadHoldingRegisters = 0x03,
28    ReadInputRegisters = 0x04,
29    WriteSingleCoil = 0x05,
30    WriteSingleRegister = 0x06,
31    WriteMultipleCoils = 0x0F,
32    WriteMultipleRegisters = 0x10,
33    ReadWriteMultipleRegisters = 0x17,
34}
35
36impl TryFrom<u8> for FunctionCode {
37    type Error = ModbusError;
38
39    fn try_from(value: u8) -> Result<Self, Self::Error> {
40        match value {
41            0x01 => Ok(Self::ReadCoils),
42            0x02 => Ok(Self::ReadDiscreteInputs),
43            0x03 => Ok(Self::ReadHoldingRegisters),
44            0x04 => Ok(Self::ReadInputRegisters),
45            0x05 => Ok(Self::WriteSingleCoil),
46            0x06 => Ok(Self::WriteSingleRegister),
47            0x0F => Ok(Self::WriteMultipleCoils),
48            0x10 => Ok(Self::WriteMultipleRegisters),
49            0x17 => Ok(Self::ReadWriteMultipleRegisters),
50            _ => Err(ModbusError::InvalidFunction(value)),
51        }
52    }
53}
54
55/// Modbus TCP server.
56pub struct ModbusTcpServer {
57    /// Configuration.
58    config: ModbusServerConfig,
59
60    /// Devices by unit ID.
61    devices: DashMap<u8, Arc<ModbusDevice>>,
62
63    /// Shared register store (for unit ID 0 / broadcast).
64    shared_registers: Arc<RegisterStore>,
65
66    /// Connection semaphore.
67    connection_semaphore: Arc<Semaphore>,
68
69    /// Shutdown flag.
70    shutdown: Arc<AtomicBool>,
71
72    /// Active connections count.
73    active_connections: AtomicU64,
74
75    /// Total requests processed.
76    total_requests: AtomicU64,
77
78    /// Start time.
79    start_time: RwLock<Option<Instant>>,
80}
81
82impl ModbusTcpServer {
83    /// Create a new Modbus TCP server.
84    pub fn new(config: ModbusServerConfig) -> Self {
85        Self {
86            connection_semaphore: Arc::new(Semaphore::new(config.max_connections)),
87            config,
88            devices: DashMap::new(),
89            shared_registers: Arc::new(RegisterStore::with_defaults()),
90            shutdown: Arc::new(AtomicBool::new(false)),
91            active_connections: AtomicU64::new(0),
92            total_requests: AtomicU64::new(0),
93            start_time: RwLock::new(None),
94        }
95    }
96
97    /// Add a device to the server.
98    pub fn add_device(&self, device: ModbusDevice) {
99        let unit_id = device.unit_id();
100        self.devices.insert(unit_id, Arc::new(device));
101    }
102
103    /// Get device by unit ID.
104    pub fn device(&self, unit_id: u8) -> Option<Arc<ModbusDevice>> {
105        self.devices.get(&unit_id).map(|d| d.clone())
106    }
107
108    /// Get shared register store.
109    pub fn shared_registers(&self) -> &Arc<RegisterStore> {
110        &self.shared_registers
111    }
112
113    /// Get active connections count.
114    pub fn active_connections(&self) -> u64 {
115        self.active_connections.load(Ordering::Relaxed)
116    }
117
118    /// Get total requests processed.
119    pub fn total_requests(&self) -> u64 {
120        self.total_requests.load(Ordering::Relaxed)
121    }
122
123    /// Check if shutdown is requested.
124    pub fn is_shutdown(&self) -> bool {
125        self.shutdown.load(Ordering::SeqCst)
126    }
127
128    /// Request shutdown.
129    pub fn shutdown(&self) {
130        self.shutdown.store(true, Ordering::SeqCst);
131    }
132
133    /// Run the server.
134    #[instrument(skip(self))]
135    pub async fn run(&self) -> ModbusResult<()> {
136        let listener = TcpListener::bind(self.config.bind_address).await?;
137        info!(address = %self.config.bind_address, "Modbus TCP server started");
138
139        *self.start_time.write() = Some(Instant::now());
140
141        while !self.is_shutdown() {
142            tokio::select! {
143                result = listener.accept() => {
144                    match result {
145                        Ok((stream, addr)) => {
146                            if let Ok(permit) = self.connection_semaphore.clone().try_acquire_owned() {
147                                self.active_connections.fetch_add(1, Ordering::Relaxed);
148
149                                let devices = self.devices.clone();
150                                let shared_registers = self.shared_registers.clone();
151                                let shutdown = self.shutdown.clone();
152                                let timeout = self.config.timeout();
153
154                                tokio::spawn(async move {
155                                    if let Err(e) = handle_connection(
156                                        stream,
157                                        addr,
158                                        devices,
159                                        shared_registers,
160                                        shutdown,
161                                        timeout,
162                                    ).await {
163                                        debug!(error = %e, "Connection error");
164                                    }
165                                    drop(permit);
166                                });
167                            } else {
168                                warn!("Max connections reached, rejecting connection from {}", addr);
169                            }
170                        }
171                        Err(e) => {
172                            error!(error = %e, "Accept error");
173                        }
174                    }
175                }
176                _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
177                    // Check shutdown flag periodically
178                }
179            }
180        }
181
182        info!("Modbus TCP server stopped");
183        Ok(())
184    }
185}
186
187/// Handle a single connection.
188async fn handle_connection(
189    mut stream: TcpStream,
190    addr: SocketAddr,
191    devices: DashMap<u8, Arc<ModbusDevice>>,
192    shared_registers: Arc<RegisterStore>,
193    shutdown: Arc<AtomicBool>,
194    timeout: std::time::Duration,
195) -> ModbusResult<()> {
196    debug!(addr = %addr, "New connection");
197
198    let mut buffer = BytesMut::with_capacity(256);
199
200    loop {
201        if shutdown.load(Ordering::SeqCst) {
202            break;
203        }
204
205        // Read MBAP header (7 bytes) + PDU
206        buffer.clear();
207        buffer.resize(256, 0);
208
209        let read_result = tokio::time::timeout(timeout, stream.read(&mut buffer)).await;
210
211        match read_result {
212            Ok(Ok(0)) => {
213                debug!(addr = %addr, "Connection closed");
214                break;
215            }
216            Ok(Ok(n)) => {
217                buffer.truncate(n);
218
219                if n < 8 {
220                    warn!(addr = %addr, "Packet too short: {} bytes", n);
221                    continue;
222                }
223
224                // Parse MBAP header
225                let transaction_id = u16::from_be_bytes([buffer[0], buffer[1]]);
226                let protocol_id = u16::from_be_bytes([buffer[2], buffer[3]]);
227                let _length = u16::from_be_bytes([buffer[4], buffer[5]]);
228                let unit_id = buffer[6];
229
230                if protocol_id != 0 {
231                    warn!(addr = %addr, "Invalid protocol ID: {}", protocol_id);
232                    continue;
233                }
234
235                // Get registers for this unit ID
236                let registers = if let Some(device) = devices.get(&unit_id) {
237                    device.registers().clone()
238                } else if unit_id == 0 {
239                    // Broadcast / shared
240                    shared_registers.clone()
241                } else {
242                    // Send exception response
243                    let response = build_exception_response(
244                        transaction_id,
245                        unit_id,
246                        buffer[7],
247                        0x0B, // Gateway Target Device Failed to Respond
248                    );
249                    stream.write_all(&response).await?;
250                    continue;
251                };
252
253                // Process request
254                let pdu = &buffer[7..];
255                let response = process_request(transaction_id, unit_id, pdu, &registers)?;
256
257                stream.write_all(&response).await?;
258            }
259            Ok(Err(e)) => {
260                debug!(addr = %addr, error = %e, "Read error");
261                break;
262            }
263            Err(_) => {
264                debug!(addr = %addr, "Connection timeout");
265                break;
266            }
267        }
268    }
269
270    Ok(())
271}
272
273/// Process a Modbus request.
274fn process_request(
275    transaction_id: u16,
276    unit_id: u8,
277    pdu: &[u8],
278    registers: &RegisterStore,
279) -> ModbusResult<Vec<u8>> {
280    if pdu.is_empty() {
281        return Err(ModbusError::InvalidData("Empty PDU".into()));
282    }
283
284    let function_code = FunctionCode::try_from(pdu[0])?;
285
286    let response_pdu = match function_code {
287        FunctionCode::ReadCoils => {
288            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
289            let quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
290            read_coils(registers, address, quantity)?
291        }
292        FunctionCode::ReadDiscreteInputs => {
293            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
294            let quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
295            read_discrete_inputs(registers, address, quantity)?
296        }
297        FunctionCode::ReadHoldingRegisters => {
298            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
299            let quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
300            read_holding_registers(registers, address, quantity)?
301        }
302        FunctionCode::ReadInputRegisters => {
303            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
304            let quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
305            read_input_registers(registers, address, quantity)?
306        }
307        FunctionCode::WriteSingleCoil => {
308            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
309            let value = u16::from_be_bytes([pdu[3], pdu[4]]);
310            write_single_coil(registers, address, value)?
311        }
312        FunctionCode::WriteSingleRegister => {
313            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
314            let value = u16::from_be_bytes([pdu[3], pdu[4]]);
315            write_single_register(registers, address, value)?
316        }
317        FunctionCode::WriteMultipleCoils => {
318            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
319            let quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
320            let byte_count = pdu[5] as usize;
321            let data = &pdu[6..6 + byte_count];
322            write_multiple_coils(registers, address, quantity, data)?
323        }
324        FunctionCode::WriteMultipleRegisters => {
325            let address = u16::from_be_bytes([pdu[1], pdu[2]]);
326            let quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
327            let byte_count = pdu[5] as usize;
328            let data = &pdu[6..6 + byte_count];
329            write_multiple_registers(registers, address, quantity, data)?
330        }
331        FunctionCode::ReadWriteMultipleRegisters => {
332            let read_address = u16::from_be_bytes([pdu[1], pdu[2]]);
333            let read_quantity = u16::from_be_bytes([pdu[3], pdu[4]]);
334            let write_address = u16::from_be_bytes([pdu[5], pdu[6]]);
335            let write_quantity = u16::from_be_bytes([pdu[7], pdu[8]]);
336            let byte_count = pdu[9] as usize;
337            let data = &pdu[10..10 + byte_count];
338            read_write_multiple_registers(
339                registers,
340                read_address,
341                read_quantity,
342                write_address,
343                write_quantity,
344                data,
345            )?
346        }
347    };
348
349    // Build MBAP + PDU response
350    Ok(build_response(transaction_id, unit_id, &response_pdu))
351}
352
353/// Build MBAP + PDU response.
354fn build_response(transaction_id: u16, unit_id: u8, pdu: &[u8]) -> Vec<u8> {
355    let length = (pdu.len() + 1) as u16; // PDU + unit_id
356
357    let mut response = Vec::with_capacity(7 + pdu.len());
358    response.extend_from_slice(&transaction_id.to_be_bytes());
359    response.extend_from_slice(&0u16.to_be_bytes()); // Protocol ID
360    response.extend_from_slice(&length.to_be_bytes());
361    response.push(unit_id);
362    response.extend_from_slice(pdu);
363
364    response
365}
366
367/// Build exception response.
368fn build_exception_response(
369    transaction_id: u16,
370    unit_id: u8,
371    function_code: u8,
372    exception_code: u8,
373) -> Vec<u8> {
374    let pdu = vec![function_code | 0x80, exception_code];
375    build_response(transaction_id, unit_id, &pdu)
376}
377
378// Function handlers
379
380fn read_coils(registers: &RegisterStore, address: u16, quantity: u16) -> ModbusResult<Vec<u8>> {
381    let coils = registers.read_coils(address, quantity)?;
382    let byte_count = quantity.div_ceil(8);
383
384    let mut response = vec![0x01, byte_count as u8];
385    let mut bytes = vec![0u8; byte_count as usize];
386
387    for (i, &coil) in coils.iter().enumerate() {
388        if coil {
389            bytes[i / 8] |= 1 << (i % 8);
390        }
391    }
392
393    response.extend_from_slice(&bytes);
394    Ok(response)
395}
396
397fn read_discrete_inputs(
398    registers: &RegisterStore,
399    address: u16,
400    quantity: u16,
401) -> ModbusResult<Vec<u8>> {
402    let inputs = registers.read_discrete_inputs(address, quantity)?;
403    let byte_count = quantity.div_ceil(8);
404
405    let mut response = vec![0x02, byte_count as u8];
406    let mut bytes = vec![0u8; byte_count as usize];
407
408    for (i, &input) in inputs.iter().enumerate() {
409        if input {
410            bytes[i / 8] |= 1 << (i % 8);
411        }
412    }
413
414    response.extend_from_slice(&bytes);
415    Ok(response)
416}
417
418fn read_holding_registers(
419    registers: &RegisterStore,
420    address: u16,
421    quantity: u16,
422) -> ModbusResult<Vec<u8>> {
423    let values = registers.read_holding_registers(address, quantity)?;
424    let byte_count = (quantity * 2) as u8;
425
426    let mut response = vec![0x03, byte_count];
427    for value in values {
428        response.extend_from_slice(&value.to_be_bytes());
429    }
430
431    Ok(response)
432}
433
434fn read_input_registers(
435    registers: &RegisterStore,
436    address: u16,
437    quantity: u16,
438) -> ModbusResult<Vec<u8>> {
439    let values = registers.read_input_registers(address, quantity)?;
440    let byte_count = (quantity * 2) as u8;
441
442    let mut response = vec![0x04, byte_count];
443    for value in values {
444        response.extend_from_slice(&value.to_be_bytes());
445    }
446
447    Ok(response)
448}
449
450fn write_single_coil(registers: &RegisterStore, address: u16, value: u16) -> ModbusResult<Vec<u8>> {
451    let coil_value = value == 0xFF00;
452    registers.write_coil(address, coil_value)?;
453
454    let mut response = vec![0x05];
455    response.extend_from_slice(&address.to_be_bytes());
456    response.extend_from_slice(&value.to_be_bytes());
457
458    Ok(response)
459}
460
461fn write_single_register(
462    registers: &RegisterStore,
463    address: u16,
464    value: u16,
465) -> ModbusResult<Vec<u8>> {
466    registers.write_holding_register(address, value)?;
467
468    let mut response = vec![0x06];
469    response.extend_from_slice(&address.to_be_bytes());
470    response.extend_from_slice(&value.to_be_bytes());
471
472    Ok(response)
473}
474
475fn write_multiple_coils(
476    registers: &RegisterStore,
477    address: u16,
478    quantity: u16,
479    data: &[u8],
480) -> ModbusResult<Vec<u8>> {
481    let mut coils = Vec::with_capacity(quantity as usize);
482    for i in 0..quantity as usize {
483        coils.push((data[i / 8] & (1 << (i % 8))) != 0);
484    }
485
486    registers.write_coils(address, &coils)?;
487
488    let mut response = vec![0x0F];
489    response.extend_from_slice(&address.to_be_bytes());
490    response.extend_from_slice(&quantity.to_be_bytes());
491
492    Ok(response)
493}
494
495fn write_multiple_registers(
496    registers: &RegisterStore,
497    address: u16,
498    quantity: u16,
499    data: &[u8],
500) -> ModbusResult<Vec<u8>> {
501    let mut values = Vec::with_capacity(quantity as usize);
502    for i in 0..quantity as usize {
503        values.push(u16::from_be_bytes([data[i * 2], data[i * 2 + 1]]));
504    }
505
506    registers.write_holding_registers(address, &values)?;
507
508    let mut response = vec![0x10];
509    response.extend_from_slice(&address.to_be_bytes());
510    response.extend_from_slice(&quantity.to_be_bytes());
511
512    Ok(response)
513}
514
515fn read_write_multiple_registers(
516    registers: &RegisterStore,
517    read_address: u16,
518    read_quantity: u16,
519    write_address: u16,
520    write_quantity: u16,
521    data: &[u8],
522) -> ModbusResult<Vec<u8>> {
523    // Write first
524    let mut write_values = Vec::with_capacity(write_quantity as usize);
525    for i in 0..write_quantity as usize {
526        write_values.push(u16::from_be_bytes([data[i * 2], data[i * 2 + 1]]));
527    }
528    registers.write_holding_registers(write_address, &write_values)?;
529
530    // Then read
531    let read_values = registers.read_holding_registers(read_address, read_quantity)?;
532    let byte_count = (read_quantity * 2) as u8;
533
534    let mut response = vec![0x17, byte_count];
535    for value in read_values {
536        response.extend_from_slice(&value.to_be_bytes());
537    }
538
539    Ok(response)
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_function_code_conversion() {
548        assert_eq!(
549            FunctionCode::try_from(0x01).unwrap(),
550            FunctionCode::ReadCoils
551        );
552        assert_eq!(
553            FunctionCode::try_from(0x03).unwrap(),
554            FunctionCode::ReadHoldingRegisters
555        );
556        assert!(FunctionCode::try_from(0xFF).is_err());
557    }
558
559    #[test]
560    fn test_process_read_holding_registers() {
561        let registers = RegisterStore::with_defaults();
562        registers
563            .write_holding_registers(0, &[100, 200, 300])
564            .unwrap();
565
566        let pdu = [0x03, 0x00, 0x00, 0x00, 0x03]; // Read 3 registers from address 0
567        let response = process_request(1, 1, &pdu, &registers).unwrap();
568
569        // MBAP (7) + Function (1) + Byte count (1) + Data (6) = 15 bytes
570        assert_eq!(response.len(), 15);
571        assert_eq!(response[7], 0x03); // Function code
572        assert_eq!(response[8], 6); // Byte count
573    }
574
575    #[test]
576    fn test_process_write_single_register() {
577        let registers = RegisterStore::with_defaults();
578
579        let pdu = [0x06, 0x00, 0x0A, 0x12, 0x34]; // Write 0x1234 to address 10
580        let response = process_request(1, 1, &pdu, &registers).unwrap();
581
582        // Verify response echoes the request
583        assert_eq!(response[7], 0x06);
584
585        // Verify the register was written
586        let values = registers.read_holding_registers(10, 1).unwrap();
587        assert_eq!(values[0], 0x1234);
588    }
589}