use std::collections::HashMap;
use parking_lot::RwLock;
use rusty_modbus_types::{DiagnosticSubFunction, ExceptionCode};
use crate::file_record::{self, MAX_RECORD_NUMBER, MIN_FILE_NUMBER, RECORD_COUNT};
use super::{
DataStore, MAX_DIAGNOSTIC_RESPONSE_DATA_LEN, MAX_FILE_RECORD_REGISTERS, MAX_SERVER_ID_BYTES,
bits::BitTable, pack_registers_be, validate_packed_coils, validate_register_values_be,
};
pub const MAX_TABLE_SIZE: usize = 65_536;
#[derive(Debug, Clone)]
pub struct StoreConfig {
pub coil_count: usize,
pub discrete_input_count: usize,
pub holding_register_count: usize,
pub input_register_count: usize,
}
impl Default for StoreConfig {
fn default() -> Self {
Self {
coil_count: 65536,
discrete_input_count: 65536,
holding_register_count: 65536,
input_register_count: 65536,
}
}
}
impl StoreConfig {
pub fn validate(&self) -> Result<(), StoreError> {
validate_table_size("coils", self.coil_count)?;
validate_table_size("discrete_inputs", self.discrete_input_count)?;
validate_table_size("holding_registers", self.holding_register_count)?;
validate_table_size("input_registers", self.input_register_count)?;
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum StoreError {
#[error("{table} table size {count} exceeds Modbus address space ({max})")]
TableTooLarge {
table: &'static str,
count: usize,
max: usize,
},
#[error("{table} address {address} is outside configured table size {len}")]
AddressOutOfRange {
table: &'static str,
address: u16,
len: usize,
},
#[error("file number {file_number} is outside Modbus file range ({minimum}..=65535)")]
FileNumberOutOfRange {
file_number: u16,
minimum: u16,
},
#[error("file record {record_number} is outside Modbus file record range (0..={maximum})")]
FileRecordOutOfRange {
record_number: u16,
maximum: u16,
},
}
pub struct InMemoryStore {
coils: RwLock<BitTable>,
discrete_inputs: RwLock<BitTable>,
holding_registers: RwLock<Vec<u16>>,
input_registers: RwLock<Vec<u16>>,
files: RwLock<HashMap<u16, Vec<u16>>>,
fifo_queues: RwLock<HashMap<u16, Vec<u16>>>,
exception_status: RwLock<u8>,
server_id: RwLock<Vec<u8>>,
}
impl InMemoryStore {
#[must_use]
pub fn new(config: StoreConfig) -> Self {
Self::try_new(config).expect("StoreConfig should fit the Modbus address space")
}
pub fn try_new(config: StoreConfig) -> Result<Self, StoreError> {
config.validate()?;
Ok(Self {
coils: RwLock::new(BitTable::new(config.coil_count)),
discrete_inputs: RwLock::new(BitTable::new(config.discrete_input_count)),
holding_registers: RwLock::new(vec![0u16; config.holding_register_count]),
input_registers: RwLock::new(vec![0u16; config.input_register_count]),
files: RwLock::new(HashMap::new()),
fifo_queues: RwLock::new(HashMap::new()),
exception_status: RwLock::new(0),
server_id: RwLock::new(b"rusty-modbus\xFF".to_vec()),
})
}
pub fn set_input_register(&self, address: u16, value: u16) -> Result<(), StoreError> {
let mut regs = self.input_registers.write();
let index = check_setup_address("input_registers", address, regs.len())?;
regs[index] = value;
Ok(())
}
pub fn set_discrete_input(&self, address: u16, value: bool) -> Result<(), StoreError> {
let mut inputs = self.discrete_inputs.write();
let index = check_setup_address("discrete_inputs", address, inputs.len())?;
inputs.set(index, value);
Ok(())
}
pub fn set_holding_register(&self, address: u16, value: u16) -> Result<(), StoreError> {
let mut regs = self.holding_registers.write();
let index = check_setup_address("holding_registers", address, regs.len())?;
regs[index] = value;
Ok(())
}
pub fn set_coil(&self, address: u16, value: bool) -> Result<(), StoreError> {
let mut coils = self.coils.write();
let index = check_setup_address("coils", address, coils.len())?;
coils.set(index, value);
Ok(())
}
pub fn set_file_record(
&self,
file_number: u16,
record_number: u16,
value: u16,
) -> Result<(), StoreError> {
check_setup_file_record(file_number, record_number)?;
let mut files = self.files.write();
let file = files.entry(file_number).or_default();
let idx = usize::from(record_number);
if idx >= file.len() {
file.resize(idx + 1, 0);
}
file[idx] = value;
Ok(())
}
pub fn set_fifo_queue(&self, address: u16, values: Vec<u16>) {
self.fifo_queues.write().insert(address, values);
}
pub fn set_exception_status(&self, status: u8) {
*self.exception_status.write() = status;
}
pub fn set_server_id(&self, data: Vec<u8>) {
*self.server_id.write() = data;
}
}
impl std::fmt::Debug for InMemoryStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InMemoryStore")
.field("coils", &self.coils.read().len())
.field("holding_registers", &self.holding_registers.read().len())
.finish_non_exhaustive()
}
}
fn check_range(address: u16, quantity: usize, max: usize) -> Result<(), ExceptionCode> {
let end = usize::from(address)
.checked_add(quantity)
.ok_or(ExceptionCode::IllegalDataAddress)?;
if end > max {
return Err(ExceptionCode::IllegalDataAddress);
}
Ok(())
}
fn validate_table_size(table: &'static str, count: usize) -> Result<(), StoreError> {
if count > MAX_TABLE_SIZE {
return Err(StoreError::TableTooLarge {
table,
count,
max: MAX_TABLE_SIZE,
});
}
Ok(())
}
fn check_setup_address(table: &'static str, address: u16, len: usize) -> Result<usize, StoreError> {
let index = usize::from(address);
if index >= len {
return Err(StoreError::AddressOutOfRange {
table,
address,
len,
});
}
Ok(index)
}
fn check_setup_file_record(file_number: u16, record_number: u16) -> Result<(), StoreError> {
if file_number < MIN_FILE_NUMBER {
return Err(StoreError::FileNumberOutOfRange {
file_number,
minimum: MIN_FILE_NUMBER,
});
}
if usize::from(record_number) >= RECORD_COUNT {
return Err(StoreError::FileRecordOutOfRange {
record_number,
maximum: MAX_RECORD_NUMBER,
});
}
Ok(())
}
impl DataStore for InMemoryStore {
async fn read_coils(
&self,
address: u16,
quantity: u16,
buf: &mut [bool],
) -> Result<usize, ExceptionCode> {
let coils = self.coils.read();
coils.read_bits(address, quantity, buf)
}
async fn read_coils_packed(
&self,
address: u16,
quantity: u16,
out: &mut [u8],
) -> Result<usize, ExceptionCode> {
let coils = self.coils.read();
coils.read_packed(address, quantity, out)
}
async fn write_coil(&self, address: u16, value: bool) -> Result<(), ExceptionCode> {
let mut coils = self.coils.write();
check_range(address, 1, coils.len())?;
coils.set(usize::from(address), value);
Ok(())
}
async fn write_coils(&self, address: u16, values: &[bool]) -> Result<(), ExceptionCode> {
let mut coils = self.coils.write();
coils.write_bits(address, values)
}
async fn write_coils_packed(
&self,
address: u16,
quantity: u16,
packed_values: &[u8],
) -> Result<(), ExceptionCode> {
let quantity = validate_packed_coils(quantity, packed_values)?;
let mut coils = self.coils.write();
coils.write_packed(address, quantity, packed_values)
}
async fn read_discrete_inputs(
&self,
address: u16,
quantity: u16,
buf: &mut [bool],
) -> Result<usize, ExceptionCode> {
let inputs = self.discrete_inputs.read();
inputs.read_bits(address, quantity, buf)
}
async fn read_discrete_inputs_packed(
&self,
address: u16,
quantity: u16,
out: &mut [u8],
) -> Result<usize, ExceptionCode> {
let inputs = self.discrete_inputs.read();
inputs.read_packed(address, quantity, out)
}
async fn read_holding_registers(
&self,
address: u16,
quantity: u16,
buf: &mut [u16],
) -> Result<usize, ExceptionCode> {
let regs = self.holding_registers.read();
check_range(address, usize::from(quantity), regs.len())?;
let start = address as usize;
let qty = quantity as usize;
buf[..qty].copy_from_slice(®s[start..start + qty]);
Ok(qty)
}
async fn read_holding_registers_be(
&self,
address: u16,
quantity: u16,
out: &mut [u8],
) -> Result<usize, ExceptionCode> {
let regs = self.holding_registers.read();
check_range(address, usize::from(quantity), regs.len())?;
let start = address as usize;
let qty = quantity as usize;
pack_registers_be(®s[start..start + qty], out)?;
Ok(qty)
}
async fn write_register(&self, address: u16, value: u16) -> Result<(), ExceptionCode> {
let mut regs = self.holding_registers.write();
check_range(address, 1, regs.len())?;
regs[address as usize] = value;
Ok(())
}
async fn write_registers(&self, address: u16, values: &[u16]) -> Result<(), ExceptionCode> {
let mut regs = self.holding_registers.write();
check_range(address, values.len(), regs.len())?;
let start = address as usize;
regs[start..start + values.len()].copy_from_slice(values);
Ok(())
}
async fn write_registers_be(
&self,
address: u16,
quantity: u16,
value_bytes: &[u8],
) -> Result<(), ExceptionCode> {
let quantity = validate_register_values_be(quantity, value_bytes)?;
let mut regs = self.holding_registers.write();
check_range(address, quantity, regs.len())?;
let start = address as usize;
for (slot, chunk) in regs[start..start + quantity]
.iter_mut()
.zip(value_bytes.chunks_exact(2))
{
*slot = u16::from_be_bytes([chunk[0], chunk[1]]);
}
Ok(())
}
async fn read_input_registers(
&self,
address: u16,
quantity: u16,
buf: &mut [u16],
) -> Result<usize, ExceptionCode> {
let regs = self.input_registers.read();
check_range(address, usize::from(quantity), regs.len())?;
let start = address as usize;
let qty = quantity as usize;
buf[..qty].copy_from_slice(®s[start..start + qty]);
Ok(qty)
}
async fn read_input_registers_be(
&self,
address: u16,
quantity: u16,
out: &mut [u8],
) -> Result<usize, ExceptionCode> {
let regs = self.input_registers.read();
check_range(address, usize::from(quantity), regs.len())?;
let start = address as usize;
let qty = quantity as usize;
pack_registers_be(®s[start..start + qty], out)?;
Ok(qty)
}
async fn read_file_record(
&self,
file_number: u16,
record_number: u16,
record_length: u16,
buf: &mut [u16],
) -> Result<usize, ExceptionCode> {
file_record::validate_range(file_number, record_number, usize::from(record_length))?;
let files = self.files.read();
let file = files
.get(&file_number)
.ok_or(ExceptionCode::IllegalDataAddress)?;
let start = usize::from(record_number);
let len = usize::from(record_length);
let end = start
.checked_add(len)
.ok_or(ExceptionCode::IllegalDataAddress)?;
if end > file.len() || len > buf.len() {
return Err(ExceptionCode::IllegalDataAddress);
}
buf[..len].copy_from_slice(&file[start..end]);
Ok(len)
}
async fn read_file_record_be(
&self,
file_number: u16,
record_number: u16,
record_length: u16,
out: &mut [u8],
) -> Result<usize, ExceptionCode> {
let len = usize::from(record_length);
file_record::validate_range(file_number, record_number, len)?;
if len > MAX_FILE_RECORD_REGISTERS {
return Err(ExceptionCode::IllegalDataAddress);
}
let files = self.files.read();
let file = files
.get(&file_number)
.ok_or(ExceptionCode::IllegalDataAddress)?;
let start = usize::from(record_number);
let end = start
.checked_add(len)
.ok_or(ExceptionCode::IllegalDataAddress)?;
if end > file.len() {
return Err(ExceptionCode::IllegalDataAddress);
}
pack_registers_be(&file[start..end], out)?;
Ok(len)
}
async fn write_file_record(
&self,
file_number: u16,
record_number: u16,
values: &[u16],
) -> Result<(), ExceptionCode> {
file_record::validate_range(file_number, record_number, values.len())?;
let mut files = self.files.write();
let file = files.entry(file_number).or_default();
let start = usize::from(record_number);
let end = start
.checked_add(values.len())
.ok_or(ExceptionCode::IllegalDataAddress)?;
if end > file.len() {
file.resize(end, 0);
}
file[start..end].copy_from_slice(values);
Ok(())
}
async fn write_file_record_be(
&self,
file_number: u16,
record_number: u16,
record_length: u16,
value_bytes: &[u8],
) -> Result<(), ExceptionCode> {
let len = usize::from(record_length);
if value_bytes.len() != len * 2 {
return Err(ExceptionCode::IllegalDataValue);
}
file_record::validate_range(file_number, record_number, len)?;
let mut files = self.files.write();
let file = files.entry(file_number).or_default();
let start = usize::from(record_number);
let end = start
.checked_add(len)
.ok_or(ExceptionCode::IllegalDataAddress)?;
if end > file.len() {
file.resize(end, 0);
}
for (slot, chunk) in file[start..end].iter_mut().zip(value_bytes.chunks_exact(2)) {
*slot = u16::from_be_bytes([chunk[0], chunk[1]]);
}
Ok(())
}
async fn read_fifo_queue(&self, address: u16) -> Result<Vec<u16>, ExceptionCode> {
self.fifo_queues
.read()
.get(&address)
.cloned()
.ok_or(ExceptionCode::IllegalDataAddress)
}
async fn read_fifo_queue_be(
&self,
address: u16,
out: &mut [u8],
) -> Result<usize, ExceptionCode> {
let queues = self.fifo_queues.read();
let values = queues
.get(&address)
.ok_or(ExceptionCode::IllegalDataAddress)?;
if values.len() > usize::from(rusty_modbus_types::MAX_FIFO_VALUES) {
return Err(ExceptionCode::IllegalDataValue);
}
pack_registers_be(values, out)?;
Ok(values.len())
}
async fn read_exception_status(&self) -> Result<u8, ExceptionCode> {
Ok(*self.exception_status.read())
}
async fn report_server_id(&self) -> Result<Vec<u8>, ExceptionCode> {
Ok(self.server_id.read().clone())
}
async fn append_server_id(&self, out: &mut Vec<u8>) -> Result<usize, ExceptionCode> {
let server_id = self.server_id.read();
if server_id.len() > MAX_SERVER_ID_BYTES {
return Err(ExceptionCode::ServerDeviceFailure);
}
out.extend_from_slice(&server_id);
Ok(server_id.len())
}
async fn append_diagnostic_response(
&self,
sub_function: DiagnosticSubFunction,
data: &[u8],
out: &mut Vec<u8>,
) -> Result<Option<usize>, ExceptionCode> {
match sub_function {
DiagnosticSubFunction::ReturnQueryData => {
if data.len() > MAX_DIAGNOSTIC_RESPONSE_DATA_LEN {
return Err(ExceptionCode::ServerDeviceFailure);
}
out.extend_from_slice(data);
Ok(Some(data.len()))
}
_ => Err(ExceptionCode::IllegalFunction),
}
}
}