use crate::error::{ModbusError, ModbusResult};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
const DEFAULT_COILS_SIZE: usize = 10000;
const DEFAULT_DISCRETE_INPUTS_SIZE: usize = 10000;
const DEFAULT_HOLDING_REGISTERS_SIZE: usize = 10000;
const DEFAULT_INPUT_REGISTERS_SIZE: usize = 10000;
#[derive(Debug, Clone)]
pub struct ModbusRegisterBank {
coils: Arc<RwLock<HashMap<u16, bool>>>,
discrete_inputs: Arc<RwLock<HashMap<u16, bool>>>,
holding_registers: Arc<RwLock<HashMap<u16, u16>>>,
input_registers: Arc<RwLock<HashMap<u16, u16>>>,
}
impl ModbusRegisterBank {
pub fn new() -> Self {
Self {
coils: Arc::new(RwLock::new(HashMap::new())),
discrete_inputs: Arc::new(RwLock::new(HashMap::new())),
holding_registers: Arc::new(RwLock::new(HashMap::new())),
input_registers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_capacity(
coils_cap: usize,
discrete_inputs_cap: usize,
holding_registers_cap: usize,
input_registers_cap: usize,
) -> Self {
Self {
coils: Arc::new(RwLock::new(HashMap::with_capacity(coils_cap))),
discrete_inputs: Arc::new(RwLock::new(HashMap::with_capacity(discrete_inputs_cap))),
holding_registers: Arc::new(RwLock::new(HashMap::with_capacity(holding_registers_cap))),
input_registers: Arc::new(RwLock::new(HashMap::with_capacity(input_registers_cap))),
}
}
pub fn with_default_capacity() -> Self {
Self::with_capacity(
DEFAULT_COILS_SIZE,
DEFAULT_DISCRETE_INPUTS_SIZE,
DEFAULT_HOLDING_REGISTERS_SIZE,
DEFAULT_INPUT_REGISTERS_SIZE,
)
}
pub fn read_coils(&self, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
validate_address_range(address, quantity)?;
let coils = self
.coils
.read()
.map_err(|_| ModbusError::internal("Failed to lock coils"))?;
let mut result = Vec::with_capacity(quantity as usize);
for i in 0..quantity {
let addr = checked_address(address, i)?;
result.push(coils.get(&addr).copied().unwrap_or(false));
}
Ok(result)
}
pub fn read_01(&self, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
self.read_coils(address, quantity)
}
pub fn write_05(&self, address: u16, value: bool) -> ModbusResult<()> {
let mut coils = self
.coils
.write()
.map_err(|_| ModbusError::internal("Failed to lock coils"))?;
coils.insert(address, value);
Ok(())
}
pub fn write_0f(&self, address: u16, values: &[bool]) -> ModbusResult<()> {
let quantity = u16::try_from(values.len())
.map_err(|_| ModbusError::invalid_address(address, u16::MAX))?;
validate_address_range(address, quantity)?;
let mut coils = self
.coils
.write()
.map_err(|_| ModbusError::internal("Failed to lock coils"))?;
for (i, &value) in values.iter().enumerate() {
let addr = checked_address(address, i)?;
coils.insert(addr, value);
}
Ok(())
}
pub fn read_discrete_inputs(&self, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
validate_address_range(address, quantity)?;
let inputs = self
.discrete_inputs
.read()
.map_err(|_| ModbusError::internal("Failed to lock discrete inputs"))?;
let mut result = Vec::with_capacity(quantity as usize);
for i in 0..quantity {
let addr = checked_address(address, i)?;
result.push(inputs.get(&addr).copied().unwrap_or(false));
}
Ok(result)
}
pub fn read_02(&self, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
self.read_discrete_inputs(address, quantity)
}
pub fn read_holding_registers(&self, address: u16, quantity: u16) -> ModbusResult<Vec<u16>> {
validate_address_range(address, quantity)?;
let registers = self
.holding_registers
.read()
.map_err(|_| ModbusError::internal("Failed to lock holding registers"))?;
let mut result = Vec::with_capacity(quantity as usize);
for i in 0..quantity {
let addr = checked_address(address, i)?;
result.push(registers.get(&addr).copied().unwrap_or(0));
}
Ok(result)
}
pub fn read_03(&self, address: u16, quantity: u16) -> ModbusResult<Vec<u16>> {
self.read_holding_registers(address, quantity)
}
pub fn write_06(&self, address: u16, value: u16) -> ModbusResult<()> {
let mut registers = self
.holding_registers
.write()
.map_err(|_| ModbusError::internal("Failed to lock holding registers"))?;
registers.insert(address, value);
Ok(())
}
pub fn write_10(&self, address: u16, values: &[u16]) -> ModbusResult<()> {
let quantity = u16::try_from(values.len())
.map_err(|_| ModbusError::invalid_address(address, u16::MAX))?;
validate_address_range(address, quantity)?;
let mut registers = self
.holding_registers
.write()
.map_err(|_| ModbusError::internal("Failed to lock holding registers"))?;
for (i, &value) in values.iter().enumerate() {
let addr = checked_address(address, i)?;
registers.insert(addr, value);
}
Ok(())
}
pub fn read_input_registers(&self, address: u16, quantity: u16) -> ModbusResult<Vec<u16>> {
validate_address_range(address, quantity)?;
let registers = self
.input_registers
.read()
.map_err(|_| ModbusError::internal("Failed to lock input registers"))?;
let mut result = Vec::with_capacity(quantity as usize);
for i in 0..quantity {
let addr = checked_address(address, i)?;
result.push(registers.get(&addr).copied().unwrap_or(0));
}
Ok(result)
}
pub fn read_04(&self, address: u16, quantity: u16) -> ModbusResult<Vec<u16>> {
self.read_input_registers(address, quantity)
}
pub fn set_input_register(&self, address: u16, value: u16) -> ModbusResult<()> {
let mut registers = self
.input_registers
.write()
.map_err(|_| ModbusError::internal("Failed to lock input registers"))?;
registers.insert(address, value);
Ok(())
}
pub fn set_discrete_input(&self, address: u16, value: bool) -> ModbusResult<()> {
let mut inputs = self
.discrete_inputs
.write()
.map_err(|_| ModbusError::internal("Failed to lock discrete inputs"))?;
inputs.insert(address, value);
Ok(())
}
pub fn get_stats(&self) -> RegisterBankStats {
RegisterBankStats {
coils_count: self.coils.read().map(|coils| coils.len()).unwrap_or(0),
discrete_inputs_count: self
.discrete_inputs
.read()
.map(|inputs| inputs.len())
.unwrap_or(0),
holding_registers_count: self
.holding_registers
.read()
.map(|registers| registers.len())
.unwrap_or(0),
input_registers_count: self
.input_registers
.read()
.map(|registers| registers.len())
.unwrap_or(0),
}
}
}
fn validate_address_range(address: u16, quantity: u16) -> ModbusResult<()> {
if quantity == 0 {
return Err(ModbusError::invalid_address(address, quantity));
}
if address.checked_add(quantity - 1).is_none() {
return Err(ModbusError::invalid_address(address, quantity));
}
Ok(())
}
fn checked_address(address: u16, offset: impl TryInto<u16>) -> ModbusResult<u16> {
let offset = offset
.try_into()
.map_err(|_| ModbusError::invalid_address(address, u16::MAX))?;
address
.checked_add(offset)
.ok_or_else(|| ModbusError::invalid_address(address, offset.saturating_add(1)))
}
impl Default for ModbusRegisterBank {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RegisterBankStats {
pub coils_count: usize,
pub discrete_inputs_count: usize,
pub holding_registers_count: usize,
pub input_registers_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coil_operations() {
let bank = ModbusRegisterBank::new();
bank.write_05(10, true).unwrap();
let coils = bank.read_01(10, 1).unwrap();
assert!(coils[0]);
bank.write_0f(20, &[true, false, true]).unwrap();
let coils = bank.read_01(20, 3).unwrap();
assert_eq!(coils, vec![true, false, true]);
}
#[test]
fn test_register_operations() {
let bank = ModbusRegisterBank::new();
bank.write_06(5, 42).unwrap();
let registers = bank.read_03(5, 1).unwrap();
assert_eq!(registers[0], 42);
bank.write_10(100, &[100, 200, 300]).unwrap();
let registers = bank.read_03(100, 3).unwrap();
assert_eq!(registers, vec![100, 200, 300]);
}
#[test]
fn test_range_overflow_is_rejected() {
let bank = ModbusRegisterBank::new();
assert!(bank.read_03(u16::MAX, 2).is_err());
assert!(bank.write_10(u16::MAX, &[1, 2]).is_err());
assert!(bank.read_01(10, 0).is_err());
}
}