use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::error::{ModbusError, ModbusResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RegisterType {
Coil,
DiscreteInput,
HoldingRegister,
InputRegister,
}
impl RegisterType {
pub fn read_function(&self) -> u8 {
match self {
Self::Coil => 0x01,
Self::DiscreteInput => 0x02,
Self::HoldingRegister => 0x03,
Self::InputRegister => 0x04,
}
}
pub fn write_function(&self) -> Option<u8> {
match self {
Self::Coil => Some(0x05), Self::HoldingRegister => Some(0x06), Self::DiscreteInput | Self::InputRegister => None,
}
}
pub fn write_multiple_function(&self) -> Option<u8> {
match self {
Self::Coil => Some(0x0F), Self::HoldingRegister => Some(0x10), Self::DiscreteInput | Self::InputRegister => None,
}
}
pub fn is_writable(&self) -> bool {
matches!(self, Self::Coil | Self::HoldingRegister)
}
pub fn is_bit_type(&self) -> bool {
matches!(self, Self::Coil | Self::DiscreteInput)
}
pub fn max_read_quantity(&self) -> u16 {
if self.is_bit_type() {
2000 } else {
125 }
}
pub fn max_write_quantity(&self) -> u16 {
if self.is_bit_type() {
1968 } else {
123 }
}
}
pub struct RegisterStore {
coils: RwLock<Vec<bool>>,
discrete_inputs: RwLock<Vec<bool>>,
holding_registers: RwLock<Vec<u16>>,
input_registers: RwLock<Vec<u16>>,
coil_count: u16,
discrete_input_count: u16,
holding_register_count: u16,
input_register_count: u16,
}
impl RegisterStore {
pub fn new(
coils: u16,
discrete_inputs: u16,
holding_registers: u16,
input_registers: u16,
) -> Self {
Self {
coils: RwLock::new(vec![false; coils as usize]),
discrete_inputs: RwLock::new(vec![false; discrete_inputs as usize]),
holding_registers: RwLock::new(vec![0u16; holding_registers as usize]),
input_registers: RwLock::new(vec![0u16; input_registers as usize]),
coil_count: coils,
discrete_input_count: discrete_inputs,
holding_register_count: holding_registers,
input_register_count: input_registers,
}
}
pub fn with_defaults() -> Self {
Self::new(10000, 10000, 10000, 10000)
}
pub fn count(&self, reg_type: RegisterType) -> u16 {
match reg_type {
RegisterType::Coil => self.coil_count,
RegisterType::DiscreteInput => self.discrete_input_count,
RegisterType::HoldingRegister => self.holding_register_count,
RegisterType::InputRegister => self.input_register_count,
}
}
fn validate(&self, reg_type: RegisterType, address: u16, quantity: u16) -> ModbusResult<()> {
let max = self.count(reg_type);
if address >= max {
return Err(ModbusError::invalid_address(address, max - 1));
}
if quantity == 0 || address + quantity > max {
return Err(ModbusError::invalid_quantity(quantity, max - address));
}
Ok(())
}
pub fn read_coils(&self, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
self.validate(RegisterType::Coil, address, quantity)?;
let coils = self.coils.read();
let start = address as usize;
let end = start + quantity as usize;
Ok(coils[start..end].to_vec())
}
pub fn write_coil(&self, address: u16, value: bool) -> ModbusResult<()> {
self.validate(RegisterType::Coil, address, 1)?;
let mut coils = self.coils.write();
coils[address as usize] = value;
Ok(())
}
pub fn write_coils(&self, address: u16, values: &[bool]) -> ModbusResult<()> {
self.validate(RegisterType::Coil, address, values.len() as u16)?;
let mut coils = self.coils.write();
for (i, &value) in values.iter().enumerate() {
coils[address as usize + i] = value;
}
Ok(())
}
pub fn read_discrete_inputs(&self, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
self.validate(RegisterType::DiscreteInput, address, quantity)?;
let inputs = self.discrete_inputs.read();
let start = address as usize;
let end = start + quantity as usize;
Ok(inputs[start..end].to_vec())
}
pub fn set_discrete_input(&self, address: u16, value: bool) -> ModbusResult<()> {
self.validate(RegisterType::DiscreteInput, address, 1)?;
let mut inputs = self.discrete_inputs.write();
inputs[address as usize] = value;
Ok(())
}
pub fn read_holding_registers(&self, address: u16, quantity: u16) -> ModbusResult<Vec<u16>> {
self.validate(RegisterType::HoldingRegister, address, quantity)?;
let registers = self.holding_registers.read();
let start = address as usize;
let end = start + quantity as usize;
Ok(registers[start..end].to_vec())
}
pub fn write_holding_register(&self, address: u16, value: u16) -> ModbusResult<()> {
self.validate(RegisterType::HoldingRegister, address, 1)?;
let mut registers = self.holding_registers.write();
registers[address as usize] = value;
Ok(())
}
pub fn write_holding_registers(&self, address: u16, values: &[u16]) -> ModbusResult<()> {
self.validate(RegisterType::HoldingRegister, address, values.len() as u16)?;
let mut registers = self.holding_registers.write();
for (i, &value) in values.iter().enumerate() {
registers[address as usize + i] = value;
}
Ok(())
}
pub fn read_input_registers(&self, address: u16, quantity: u16) -> ModbusResult<Vec<u16>> {
self.validate(RegisterType::InputRegister, address, quantity)?;
let registers = self.input_registers.read();
let start = address as usize;
let end = start + quantity as usize;
Ok(registers[start..end].to_vec())
}
pub fn set_input_register(&self, address: u16, value: u16) -> ModbusResult<()> {
self.validate(RegisterType::InputRegister, address, 1)?;
let mut registers = self.input_registers.write();
registers[address as usize] = value;
Ok(())
}
pub fn set_input_registers(&self, address: u16, values: &[u16]) -> ModbusResult<()> {
self.validate(RegisterType::InputRegister, address, values.len() as u16)?;
let mut registers = self.input_registers.write();
for (i, &value) in values.iter().enumerate() {
registers[address as usize + i] = value;
}
Ok(())
}
pub fn mask_write_holding_register(
&self,
address: u16,
and_mask: u16,
or_mask: u16,
) -> ModbusResult<u16> {
self.validate(RegisterType::HoldingRegister, address, 1)?;
let mut registers = self.holding_registers.write();
let current = registers[address as usize];
let result = (current & and_mask) | (or_mask & !and_mask);
registers[address as usize] = result;
Ok(result)
}
pub fn read_bytes(
&self,
reg_type: RegisterType,
address: u16,
byte_count: usize,
) -> ModbusResult<Vec<u8>> {
let register_count = byte_count.div_ceil(2);
let registers = match reg_type {
RegisterType::HoldingRegister => {
self.read_holding_registers(address, register_count as u16)?
}
RegisterType::InputRegister => {
self.read_input_registers(address, register_count as u16)?
}
_ => return Err(ModbusError::InvalidFunction(reg_type.read_function())),
};
let mut bytes = Vec::with_capacity(byte_count);
for reg in registers {
bytes.extend_from_slice(®.to_be_bytes());
}
bytes.truncate(byte_count);
Ok(bytes)
}
pub fn write_bytes(
&self,
reg_type: RegisterType,
address: u16,
bytes: &[u8],
) -> ModbusResult<()> {
if !reg_type.is_writable() || reg_type.is_bit_type() {
return Err(ModbusError::InvalidFunction(0));
}
let mut registers = Vec::new();
for chunk in bytes.chunks(2) {
let value = if chunk.len() == 2 {
u16::from_be_bytes([chunk[0], chunk[1]])
} else {
u16::from_be_bytes([chunk[0], 0])
};
registers.push(value);
}
self.write_holding_registers(address, ®isters)
}
pub fn reset(&self) {
self.coils.write().fill(false);
self.discrete_inputs.write().fill(false);
self.holding_registers.write().fill(0);
self.input_registers.write().fill(0);
}
}
impl Default for RegisterStore {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_store_creation() {
let store = RegisterStore::new(100, 100, 100, 100);
assert_eq!(store.count(RegisterType::Coil), 100);
assert_eq!(store.count(RegisterType::HoldingRegister), 100);
}
#[test]
fn test_coil_operations() {
let store = RegisterStore::with_defaults();
store.write_coil(0, true).unwrap();
let values = store.read_coils(0, 1).unwrap();
assert_eq!(values, vec![true]);
store.write_coils(10, &[true, false, true]).unwrap();
let values = store.read_coils(10, 3).unwrap();
assert_eq!(values, vec![true, false, true]);
}
#[test]
fn test_holding_register_operations() {
let store = RegisterStore::with_defaults();
store.write_holding_register(0, 12345).unwrap();
let values = store.read_holding_registers(0, 1).unwrap();
assert_eq!(values, vec![12345]);
store.write_holding_registers(10, &[100, 200, 300]).unwrap();
let values = store.read_holding_registers(10, 3).unwrap();
assert_eq!(values, vec![100, 200, 300]);
}
#[test]
fn test_invalid_address() {
let store = RegisterStore::new(100, 100, 100, 100);
let result = store.read_coils(100, 1);
assert!(result.is_err());
let result = store.read_coils(99, 2);
assert!(result.is_err());
}
#[test]
fn test_float_operations() {
let store = RegisterStore::with_defaults();
let f32_value: f32 = 3.14159;
let bytes = f32_value.to_be_bytes();
store
.write_bytes(RegisterType::HoldingRegister, 0, &bytes)
.unwrap();
let read_bytes = store
.read_bytes(RegisterType::HoldingRegister, 0, 4)
.unwrap();
let read_value =
f32::from_be_bytes([read_bytes[0], read_bytes[1], read_bytes[2], read_bytes[3]]);
assert!((read_value - f32_value).abs() < 0.0001);
}
}