use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tracing::{debug, instrument, trace, warn};
use crate::error::{ModbusError, ModbusResult};
use crate::registers::SparseRegisterStore;
use crate::types::{RegisterConverter, WordOrder};
use super::config::{BroadcastMode, UnitConfig, UnitManagerConfig};
#[derive(Debug)]
pub struct UnitInfo {
unit_id: u8,
config: UnitConfig,
registers: Arc<SparseRegisterStore>,
converter: RegisterConverter,
created_at: Instant,
read_count: AtomicU64,
write_count: AtomicU64,
error_count: AtomicU64,
}
impl UnitInfo {
fn new(
unit_id: u8,
config: UnitConfig,
default_word_order: WordOrder,
default_register_config: &crate::registers::RegisterStoreConfig,
) -> Self {
let word_order = config.effective_word_order(default_word_order);
let register_config = config
.register_config
.clone()
.unwrap_or_else(|| default_register_config.clone());
Self {
unit_id,
config,
registers: Arc::new(SparseRegisterStore::new(register_config)),
converter: RegisterConverter::new(word_order),
created_at: Instant::now(),
read_count: AtomicU64::new(0),
write_count: AtomicU64::new(0),
error_count: AtomicU64::new(0),
}
}
#[inline]
pub fn unit_id(&self) -> u8 {
self.unit_id
}
#[inline]
pub fn config(&self) -> &UnitConfig {
&self.config
}
#[inline]
pub fn name(&self) -> &str {
&self.config.name
}
#[inline]
pub fn registers(&self) -> &Arc<SparseRegisterStore> {
&self.registers
}
#[inline]
pub fn converter(&self) -> &RegisterConverter {
&self.converter
}
#[inline]
pub fn word_order(&self) -> WordOrder {
self.converter.word_order()
}
#[inline]
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
#[inline]
pub fn broadcast_enabled(&self) -> bool {
self.config.broadcast_enabled
}
#[inline]
pub fn response_delay_us(&self) -> u64 {
self.config.response_delay_us
}
#[inline]
pub fn created_at(&self) -> Instant {
self.created_at
}
#[inline]
pub fn uptime(&self) -> std::time::Duration {
self.created_at.elapsed()
}
#[inline]
pub fn read_count(&self) -> u64 {
self.read_count.load(Ordering::Relaxed)
}
#[inline]
pub fn write_count(&self) -> u64 {
self.write_count.load(Ordering::Relaxed)
}
#[inline]
pub fn error_count(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}
pub(crate) fn record_read(&self) {
self.read_count.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn record_write(&self) {
self.write_count.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn record_error(&self) {
self.error_count.fetch_add(1, Ordering::Relaxed);
}
}
pub struct MultiUnitManager {
config: UnitManagerConfig,
units: DashMap<u8, UnitInfo>,
total_requests: AtomicU64,
broadcast_count: AtomicU64,
}
impl MultiUnitManager {
pub fn new(config: UnitManagerConfig) -> Self {
Self {
config,
units: DashMap::new(),
total_requests: AtomicU64::new(0),
broadcast_count: AtomicU64::new(0),
}
}
pub fn with_defaults() -> Self {
Self::new(UnitManagerConfig::default())
}
pub fn config(&self) -> &UnitManagerConfig {
&self.config
}
pub fn unit_count(&self) -> usize {
self.units.len()
}
pub fn unit_ids(&self) -> Vec<u8> {
self.units.iter().map(|entry| *entry.key()).collect()
}
pub fn has_unit(&self, unit_id: u8) -> bool {
self.units.contains_key(&unit_id)
}
#[instrument(skip(self, config), fields(unit_id = unit_id, name = %config.name))]
pub fn add_unit(&self, unit_id: u8, config: UnitConfig) -> ModbusResult<()> {
if unit_id == 0 {
return Err(ModbusError::InvalidUnitId {
unit_id: 0,
reason: "Unit ID 0 is reserved for broadcast".to_string(),
});
}
if self.units.len() >= self.config.max_units {
return Err(ModbusError::UnitLimitReached {
max: self.config.max_units,
});
}
if self.units.contains_key(&unit_id) {
return Err(ModbusError::UnitAlreadyExists { unit_id });
}
let unit_info = UnitInfo::new(
unit_id,
config,
self.config.default_word_order,
&self.config.default_register_config,
);
self.units.insert(unit_id, unit_info);
debug!(unit_id, "Unit added");
Ok(())
}
#[instrument(skip(self))]
pub fn remove_unit(&self, unit_id: u8) -> Option<UnitInfo> {
let removed = self.units.remove(&unit_id).map(|(_, info)| info);
if removed.is_some() {
debug!(unit_id, "Unit removed");
}
removed
}
pub fn get_unit(&self, unit_id: u8) -> Option<dashmap::mapref::one::Ref<'_, u8, UnitInfo>> {
if unit_id == 0 {
return None;
}
if let Some(unit) = self.units.get(&unit_id) {
return Some(unit);
}
if self.config.auto_create_units {
let config = UnitConfig::new(format!("Auto-created Unit {}", unit_id));
if self.add_unit(unit_id, config).is_ok() {
return self.units.get(&unit_id);
}
}
None
}
pub fn get_unit_mut(
&self,
unit_id: u8,
) -> Option<dashmap::mapref::one::RefMut<'_, u8, UnitInfo>> {
if unit_id == 0 {
return None;
}
self.units.get_mut(&unit_id)
}
#[instrument(skip(self, update_fn))]
pub fn update_unit<F>(&self, unit_id: u8, update_fn: F) -> ModbusResult<()>
where
F: FnOnce(&mut UnitConfig),
{
if let Some(mut unit) = self.units.get_mut(&unit_id) {
update_fn(&mut unit.config);
debug!(unit_id, "Unit configuration updated");
Ok(())
} else {
Err(ModbusError::UnitNotFound { unit_id })
}
}
#[instrument(skip(self), level = "trace")]
pub fn read_holding_registers(
&self,
unit_id: u8,
address: u16,
quantity: u16,
) -> ModbusResult<Vec<u16>> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_read();
unit.registers().read_holding_registers(address, quantity)
}
#[instrument(skip(self), level = "trace")]
pub fn read_input_registers(
&self,
unit_id: u8,
address: u16,
quantity: u16,
) -> ModbusResult<Vec<u16>> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_read();
unit.registers().read_input_registers(address, quantity)
}
#[instrument(skip(self), level = "trace")]
pub fn read_coils(&self, unit_id: u8, address: u16, quantity: u16) -> ModbusResult<Vec<bool>> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_read();
unit.registers().read_coils(address, quantity)
}
#[instrument(skip(self), level = "trace")]
pub fn read_discrete_inputs(
&self,
unit_id: u8,
address: u16,
quantity: u16,
) -> ModbusResult<Vec<bool>> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_read();
unit.registers().read_discrete_inputs(address, quantity)
}
#[instrument(skip(self), level = "trace")]
pub fn write_holding_register(
&self,
unit_id: u8,
address: u16,
value: u16,
) -> ModbusResult<()> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if unit_id == 0 {
return self.broadcast_write_holding_register(address, value);
}
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_write();
unit.registers().write_holding_register(address, value)
}
#[instrument(skip(self, values), level = "trace")]
pub fn write_holding_registers(
&self,
unit_id: u8,
address: u16,
values: &[u16],
) -> ModbusResult<()> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if unit_id == 0 {
return self.broadcast_write_holding_registers(address, values);
}
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_write();
unit.registers().write_holding_registers(address, values)
}
#[instrument(skip(self), level = "trace")]
pub fn write_coil(&self, unit_id: u8, address: u16, value: bool) -> ModbusResult<()> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if unit_id == 0 {
return self.broadcast_write_coil(address, value);
}
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_write();
unit.registers().write_coil(address, value)
}
#[instrument(skip(self, values), level = "trace")]
pub fn write_coils(&self, unit_id: u8, address: u16, values: &[bool]) -> ModbusResult<()> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if unit_id == 0 {
return self.broadcast_write_coils(address, values);
}
let unit = self
.get_unit(unit_id)
.ok_or(ModbusError::UnitNotFound { unit_id })?;
if !unit.is_enabled() {
unit.record_error();
return Err(ModbusError::UnitDisabled { unit_id });
}
unit.record_write();
unit.registers().write_coils(address, values)
}
#[instrument(skip(self), level = "debug")]
pub fn broadcast_write_holding_register(&self, address: u16, value: u16) -> ModbusResult<()> {
self.broadcast_count.fetch_add(1, Ordering::Relaxed);
let units = self.get_broadcast_targets();
trace!(
address,
value,
unit_count = units.len(),
"Broadcasting write holding register"
);
for unit_id in units {
if let Some(unit) = self.units.get(&unit_id) {
if unit.is_enabled() && unit.broadcast_enabled() {
let _ = unit.registers().write_holding_register(address, value);
unit.record_write();
}
}
}
Ok(())
}
#[instrument(skip(self, values), level = "debug")]
pub fn broadcast_write_holding_registers(
&self,
address: u16,
values: &[u16],
) -> ModbusResult<()> {
self.broadcast_count.fetch_add(1, Ordering::Relaxed);
let units = self.get_broadcast_targets();
trace!(
address,
count = values.len(),
unit_count = units.len(),
"Broadcasting write multiple holding registers"
);
for unit_id in units {
if let Some(unit) = self.units.get(&unit_id) {
if unit.is_enabled() && unit.broadcast_enabled() {
let _ = unit.registers().write_holding_registers(address, values);
unit.record_write();
}
}
}
Ok(())
}
#[instrument(skip(self), level = "debug")]
pub fn broadcast_write_coil(&self, address: u16, value: bool) -> ModbusResult<()> {
self.broadcast_count.fetch_add(1, Ordering::Relaxed);
let units = self.get_broadcast_targets();
trace!(
address,
value,
unit_count = units.len(),
"Broadcasting write coil"
);
for unit_id in units {
if let Some(unit) = self.units.get(&unit_id) {
if unit.is_enabled() && unit.broadcast_enabled() {
let _ = unit.registers().write_coil(address, value);
unit.record_write();
}
}
}
Ok(())
}
#[instrument(skip(self, values), level = "debug")]
pub fn broadcast_write_coils(&self, address: u16, values: &[bool]) -> ModbusResult<()> {
self.broadcast_count.fetch_add(1, Ordering::Relaxed);
let units = self.get_broadcast_targets();
trace!(
address,
count = values.len(),
unit_count = units.len(),
"Broadcasting write multiple coils"
);
for unit_id in units {
if let Some(unit) = self.units.get(&unit_id) {
if unit.is_enabled() && unit.broadcast_enabled() {
let _ = unit.registers().write_coils(address, values);
unit.record_write();
}
}
}
Ok(())
}
fn get_broadcast_targets(&self) -> Vec<u8> {
match &self.config.broadcast_mode {
BroadcastMode::WriteAll => self.unit_ids(),
BroadcastMode::Disabled => vec![],
BroadcastMode::SelectiveList(units) => units.clone(),
BroadcastMode::EchoToUnit(id) => vec![*id],
}
}
pub fn total_requests(&self) -> u64 {
self.total_requests.load(Ordering::Relaxed)
}
pub fn broadcast_count(&self) -> u64 {
self.broadcast_count.load(Ordering::Relaxed)
}
pub fn unit_statistics(&self) -> Vec<UnitStatistics> {
self.units
.iter()
.map(|entry| UnitStatistics {
unit_id: *entry.key(),
name: entry.value().config.name.clone(),
enabled: entry.value().is_enabled(),
read_count: entry.value().read_count(),
write_count: entry.value().write_count(),
error_count: entry.value().error_count(),
register_count: entry.value().registers().entry_count(),
})
.collect()
}
pub fn reset_statistics(&self) {
self.total_requests.store(0, Ordering::Relaxed);
self.broadcast_count.store(0, Ordering::Relaxed);
for entry in self.units.iter() {
entry.value().read_count.store(0, Ordering::Relaxed);
entry.value().write_count.store(0, Ordering::Relaxed);
entry.value().error_count.store(0, Ordering::Relaxed);
}
}
}
impl Default for MultiUnitManager {
fn default() -> Self {
Self::with_defaults()
}
}
impl std::fmt::Debug for MultiUnitManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiUnitManager")
.field("unit_count", &self.unit_count())
.field("total_requests", &self.total_requests())
.field("broadcast_count", &self.broadcast_count())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct UnitStatistics {
pub unit_id: u8,
pub name: String,
pub enabled: bool,
pub read_count: u64,
pub write_count: u64,
pub error_count: u64,
pub register_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_manager() {
let manager = MultiUnitManager::with_defaults();
assert_eq!(manager.unit_count(), 0);
}
#[test]
fn test_add_and_get_unit() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Test Unit")).unwrap();
assert!(manager.has_unit(1));
assert!(!manager.has_unit(2));
let unit = manager.get_unit(1).unwrap();
assert_eq!(unit.name(), "Test Unit");
}
#[test]
fn test_cannot_add_unit_zero() {
let manager = MultiUnitManager::with_defaults();
let result = manager.add_unit(0, UnitConfig::new("Broadcast"));
assert!(result.is_err());
}
#[test]
fn test_cannot_add_duplicate_unit() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
let result = manager.add_unit(1, UnitConfig::new("Unit 1 Again"));
assert!(result.is_err());
}
#[test]
fn test_remove_unit() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
assert!(manager.has_unit(1));
let removed = manager.remove_unit(1);
assert!(removed.is_some());
assert!(!manager.has_unit(1));
}
#[test]
fn test_read_write_operations() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
manager.write_holding_register(1, 0, 12345).unwrap();
let values = manager.read_holding_registers(1, 0, 1).unwrap();
assert_eq!(values[0], 12345);
}
#[test]
fn test_broadcast_write() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
manager.add_unit(2, UnitConfig::new("Unit 2")).unwrap();
manager.add_unit(3, UnitConfig::new("Unit 3")).unwrap();
manager.write_holding_register(0, 100, 999).unwrap();
let v1 = manager.read_holding_registers(1, 100, 1).unwrap();
let v2 = manager.read_holding_registers(2, 100, 1).unwrap();
let v3 = manager.read_holding_registers(3, 100, 1).unwrap();
assert_eq!(v1[0], 999);
assert_eq!(v2[0], 999);
assert_eq!(v3[0], 999);
}
#[test]
fn test_broadcast_with_disabled_unit() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
manager
.add_unit(2, UnitConfig::new("Unit 2").with_broadcast(false))
.unwrap();
manager.broadcast_write_holding_register(100, 888).unwrap();
let v1 = manager.read_holding_registers(1, 100, 1).unwrap();
assert_eq!(v1[0], 888);
let v2 = manager.read_holding_registers(2, 100, 1).unwrap();
assert_ne!(v2[0], 888);
}
#[test]
fn test_auto_create_units() {
let config = UnitManagerConfig::default().with_auto_create(true);
let manager = MultiUnitManager::new(config);
assert!(!manager.has_unit(5));
let unit = manager.get_unit(5);
assert!(unit.is_some());
assert!(manager.has_unit(5));
}
#[test]
fn test_statistics() {
let manager = MultiUnitManager::with_defaults();
manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
manager.write_holding_register(1, 0, 100).unwrap();
manager.read_holding_registers(1, 0, 1).unwrap();
manager.read_holding_registers(1, 0, 1).unwrap();
let stats = manager.unit_statistics();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].read_count, 2);
assert_eq!(stats[0].write_count, 1);
}
#[test]
fn test_different_word_orders() {
let manager = MultiUnitManager::with_defaults();
manager
.add_unit(
1,
UnitConfig::with_word_order("BE Unit", WordOrder::BigEndian),
)
.unwrap();
manager
.add_unit(
2,
UnitConfig::with_word_order("LE Unit", WordOrder::LittleEndian),
)
.unwrap();
let unit1 = manager.get_unit(1).unwrap();
let unit2 = manager.get_unit(2).unwrap();
assert_eq!(unit1.word_order(), WordOrder::BigEndian);
assert_eq!(unit2.word_order(), WordOrder::LittleEndian);
}
}