1use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::Instant;
6
7use dashmap::DashMap;
8use tracing::{debug, instrument, trace, warn};
9
10use crate::error::{ModbusError, ModbusResult};
11use crate::registers::SparseRegisterStore;
12use crate::types::{RegisterConverter, WordOrder};
13
14use super::config::{BroadcastMode, UnitConfig, UnitManagerConfig};
15
16#[derive(Debug)]
18pub struct UnitInfo {
19 unit_id: u8,
21
22 config: UnitConfig,
24
25 registers: Arc<SparseRegisterStore>,
27
28 converter: RegisterConverter,
30
31 created_at: Instant,
33
34 read_count: AtomicU64,
36
37 write_count: AtomicU64,
39
40 error_count: AtomicU64,
42}
43
44impl UnitInfo {
45 fn new(
47 unit_id: u8,
48 config: UnitConfig,
49 default_word_order: WordOrder,
50 default_register_config: &crate::registers::RegisterStoreConfig,
51 ) -> Self {
52 let word_order = config.effective_word_order(default_word_order);
53 let register_config = config
54 .register_config
55 .clone()
56 .unwrap_or_else(|| default_register_config.clone());
57
58 Self {
59 unit_id,
60 config,
61 registers: Arc::new(SparseRegisterStore::new(register_config)),
62 converter: RegisterConverter::new(word_order),
63 created_at: Instant::now(),
64 read_count: AtomicU64::new(0),
65 write_count: AtomicU64::new(0),
66 error_count: AtomicU64::new(0),
67 }
68 }
69
70 #[inline]
72 pub fn unit_id(&self) -> u8 {
73 self.unit_id
74 }
75
76 #[inline]
78 pub fn config(&self) -> &UnitConfig {
79 &self.config
80 }
81
82 #[inline]
84 pub fn name(&self) -> &str {
85 &self.config.name
86 }
87
88 #[inline]
90 pub fn registers(&self) -> &Arc<SparseRegisterStore> {
91 &self.registers
92 }
93
94 #[inline]
96 pub fn converter(&self) -> &RegisterConverter {
97 &self.converter
98 }
99
100 #[inline]
102 pub fn word_order(&self) -> WordOrder {
103 self.converter.word_order()
104 }
105
106 #[inline]
108 pub fn is_enabled(&self) -> bool {
109 self.config.enabled
110 }
111
112 #[inline]
114 pub fn broadcast_enabled(&self) -> bool {
115 self.config.broadcast_enabled
116 }
117
118 #[inline]
120 pub fn response_delay_us(&self) -> u64 {
121 self.config.response_delay_us
122 }
123
124 #[inline]
126 pub fn created_at(&self) -> Instant {
127 self.created_at
128 }
129
130 #[inline]
132 pub fn uptime(&self) -> std::time::Duration {
133 self.created_at.elapsed()
134 }
135
136 #[inline]
138 pub fn read_count(&self) -> u64 {
139 self.read_count.load(Ordering::Relaxed)
140 }
141
142 #[inline]
144 pub fn write_count(&self) -> u64 {
145 self.write_count.load(Ordering::Relaxed)
146 }
147
148 #[inline]
150 pub fn error_count(&self) -> u64 {
151 self.error_count.load(Ordering::Relaxed)
152 }
153
154 pub(crate) fn record_read(&self) {
156 self.read_count.fetch_add(1, Ordering::Relaxed);
157 }
158
159 pub(crate) fn record_write(&self) {
161 self.write_count.fetch_add(1, Ordering::Relaxed);
162 }
163
164 pub(crate) fn record_error(&self) {
166 self.error_count.fetch_add(1, Ordering::Relaxed);
167 }
168}
169
170pub struct MultiUnitManager {
198 config: UnitManagerConfig,
200
201 units: DashMap<u8, UnitInfo>,
203
204 total_requests: AtomicU64,
206
207 broadcast_count: AtomicU64,
209}
210
211impl MultiUnitManager {
212 pub fn new(config: UnitManagerConfig) -> Self {
214 Self {
215 config,
216 units: DashMap::new(),
217 total_requests: AtomicU64::new(0),
218 broadcast_count: AtomicU64::new(0),
219 }
220 }
221
222 pub fn with_defaults() -> Self {
224 Self::new(UnitManagerConfig::default())
225 }
226
227 pub fn config(&self) -> &UnitManagerConfig {
229 &self.config
230 }
231
232 pub fn unit_count(&self) -> usize {
234 self.units.len()
235 }
236
237 pub fn unit_ids(&self) -> Vec<u8> {
239 self.units.iter().map(|entry| *entry.key()).collect()
240 }
241
242 pub fn has_unit(&self, unit_id: u8) -> bool {
244 self.units.contains_key(&unit_id)
245 }
246
247 #[instrument(skip(self, config), fields(unit_id = unit_id, name = %config.name))]
265 pub fn add_unit(&self, unit_id: u8, config: UnitConfig) -> ModbusResult<()> {
266 if unit_id == 0 {
268 return Err(ModbusError::InvalidUnitId {
269 unit_id: 0,
270 reason: "Unit ID 0 is reserved for broadcast".to_string(),
271 });
272 }
273
274 if self.units.len() >= self.config.max_units {
276 return Err(ModbusError::UnitLimitReached {
277 max: self.config.max_units,
278 });
279 }
280
281 if self.units.contains_key(&unit_id) {
283 return Err(ModbusError::UnitAlreadyExists { unit_id });
284 }
285
286 let unit_info = UnitInfo::new(
287 unit_id,
288 config,
289 self.config.default_word_order,
290 &self.config.default_register_config,
291 );
292
293 self.units.insert(unit_id, unit_info);
294 debug!(unit_id, "Unit added");
295
296 Ok(())
297 }
298
299 #[instrument(skip(self))]
305 pub fn remove_unit(&self, unit_id: u8) -> Option<UnitInfo> {
306 let removed = self.units.remove(&unit_id).map(|(_, info)| info);
307 if removed.is_some() {
308 debug!(unit_id, "Unit removed");
309 }
310 removed
311 }
312
313 pub fn get_unit(&self, unit_id: u8) -> Option<dashmap::mapref::one::Ref<'_, u8, UnitInfo>> {
318 if unit_id == 0 {
320 return None;
321 }
322
323 if let Some(unit) = self.units.get(&unit_id) {
325 return Some(unit);
326 }
327
328 if self.config.auto_create_units {
330 let config = UnitConfig::new(format!("Auto-created Unit {}", unit_id));
331 if self.add_unit(unit_id, config).is_ok() {
332 return self.units.get(&unit_id);
333 }
334 }
335
336 None
337 }
338
339 pub fn get_unit_mut(
341 &self,
342 unit_id: u8,
343 ) -> Option<dashmap::mapref::one::RefMut<'_, u8, UnitInfo>> {
344 if unit_id == 0 {
345 return None;
346 }
347
348 self.units.get_mut(&unit_id)
349 }
350
351 #[instrument(skip(self, update_fn))]
353 pub fn update_unit<F>(&self, unit_id: u8, update_fn: F) -> ModbusResult<()>
354 where
355 F: FnOnce(&mut UnitConfig),
356 {
357 if let Some(mut unit) = self.units.get_mut(&unit_id) {
358 update_fn(&mut unit.config);
359 debug!(unit_id, "Unit configuration updated");
360 Ok(())
361 } else {
362 Err(ModbusError::UnitNotFound { unit_id })
363 }
364 }
365
366 #[instrument(skip(self), level = "trace")]
372 pub fn read_holding_registers(
373 &self,
374 unit_id: u8,
375 address: u16,
376 quantity: u16,
377 ) -> ModbusResult<Vec<u16>> {
378 self.total_requests.fetch_add(1, Ordering::Relaxed);
379
380 let unit = self
381 .get_unit(unit_id)
382 .ok_or(ModbusError::UnitNotFound { unit_id })?;
383
384 if !unit.is_enabled() {
385 unit.record_error();
386 return Err(ModbusError::UnitDisabled { unit_id });
387 }
388
389 unit.record_read();
390 unit.registers().read_holding_registers(address, quantity)
391 }
392
393 #[instrument(skip(self), level = "trace")]
395 pub fn read_input_registers(
396 &self,
397 unit_id: u8,
398 address: u16,
399 quantity: u16,
400 ) -> ModbusResult<Vec<u16>> {
401 self.total_requests.fetch_add(1, Ordering::Relaxed);
402
403 let unit = self
404 .get_unit(unit_id)
405 .ok_or(ModbusError::UnitNotFound { unit_id })?;
406
407 if !unit.is_enabled() {
408 unit.record_error();
409 return Err(ModbusError::UnitDisabled { unit_id });
410 }
411
412 unit.record_read();
413 unit.registers().read_input_registers(address, quantity)
414 }
415
416 #[instrument(skip(self), level = "trace")]
418 pub fn read_coils(
419 &self,
420 unit_id: u8,
421 address: u16,
422 quantity: u16,
423 ) -> ModbusResult<Vec<bool>> {
424 self.total_requests.fetch_add(1, Ordering::Relaxed);
425
426 let unit = self
427 .get_unit(unit_id)
428 .ok_or(ModbusError::UnitNotFound { unit_id })?;
429
430 if !unit.is_enabled() {
431 unit.record_error();
432 return Err(ModbusError::UnitDisabled { unit_id });
433 }
434
435 unit.record_read();
436 unit.registers().read_coils(address, quantity)
437 }
438
439 #[instrument(skip(self), level = "trace")]
441 pub fn read_discrete_inputs(
442 &self,
443 unit_id: u8,
444 address: u16,
445 quantity: u16,
446 ) -> ModbusResult<Vec<bool>> {
447 self.total_requests.fetch_add(1, Ordering::Relaxed);
448
449 let unit = self
450 .get_unit(unit_id)
451 .ok_or(ModbusError::UnitNotFound { unit_id })?;
452
453 if !unit.is_enabled() {
454 unit.record_error();
455 return Err(ModbusError::UnitDisabled { unit_id });
456 }
457
458 unit.record_read();
459 unit.registers().read_discrete_inputs(address, quantity)
460 }
461
462 #[instrument(skip(self), level = "trace")]
464 pub fn write_holding_register(
465 &self,
466 unit_id: u8,
467 address: u16,
468 value: u16,
469 ) -> ModbusResult<()> {
470 self.total_requests.fetch_add(1, Ordering::Relaxed);
471
472 if unit_id == 0 {
474 return self.broadcast_write_holding_register(address, value);
475 }
476
477 let unit = self
478 .get_unit(unit_id)
479 .ok_or(ModbusError::UnitNotFound { unit_id })?;
480
481 if !unit.is_enabled() {
482 unit.record_error();
483 return Err(ModbusError::UnitDisabled { unit_id });
484 }
485
486 unit.record_write();
487 unit.registers().write_holding_register(address, value)
488 }
489
490 #[instrument(skip(self, values), level = "trace")]
492 pub fn write_holding_registers(
493 &self,
494 unit_id: u8,
495 address: u16,
496 values: &[u16],
497 ) -> ModbusResult<()> {
498 self.total_requests.fetch_add(1, Ordering::Relaxed);
499
500 if unit_id == 0 {
502 return self.broadcast_write_holding_registers(address, values);
503 }
504
505 let unit = self
506 .get_unit(unit_id)
507 .ok_or(ModbusError::UnitNotFound { unit_id })?;
508
509 if !unit.is_enabled() {
510 unit.record_error();
511 return Err(ModbusError::UnitDisabled { unit_id });
512 }
513
514 unit.record_write();
515 unit.registers().write_holding_registers(address, values)
516 }
517
518 #[instrument(skip(self), level = "trace")]
520 pub fn write_coil(&self, unit_id: u8, address: u16, value: bool) -> ModbusResult<()> {
521 self.total_requests.fetch_add(1, Ordering::Relaxed);
522
523 if unit_id == 0 {
525 return self.broadcast_write_coil(address, value);
526 }
527
528 let unit = self
529 .get_unit(unit_id)
530 .ok_or(ModbusError::UnitNotFound { unit_id })?;
531
532 if !unit.is_enabled() {
533 unit.record_error();
534 return Err(ModbusError::UnitDisabled { unit_id });
535 }
536
537 unit.record_write();
538 unit.registers().write_coil(address, value)
539 }
540
541 #[instrument(skip(self, values), level = "trace")]
543 pub fn write_coils(
544 &self,
545 unit_id: u8,
546 address: u16,
547 values: &[bool],
548 ) -> ModbusResult<()> {
549 self.total_requests.fetch_add(1, Ordering::Relaxed);
550
551 if unit_id == 0 {
553 return self.broadcast_write_coils(address, values);
554 }
555
556 let unit = self
557 .get_unit(unit_id)
558 .ok_or(ModbusError::UnitNotFound { unit_id })?;
559
560 if !unit.is_enabled() {
561 unit.record_error();
562 return Err(ModbusError::UnitDisabled { unit_id });
563 }
564
565 unit.record_write();
566 unit.registers().write_coils(address, values)
567 }
568
569 #[instrument(skip(self), level = "debug")]
575 pub fn broadcast_write_holding_register(&self, address: u16, value: u16) -> ModbusResult<()> {
576 self.broadcast_count.fetch_add(1, Ordering::Relaxed);
577 let units = self.get_broadcast_targets();
578
579 trace!(address, value, unit_count = units.len(), "Broadcasting write holding register");
580
581 for unit_id in units {
582 if let Some(unit) = self.units.get(&unit_id) {
583 if unit.is_enabled() && unit.broadcast_enabled() {
584 let _ = unit.registers().write_holding_register(address, value);
585 unit.record_write();
586 }
587 }
588 }
589
590 Ok(())
591 }
592
593 #[instrument(skip(self, values), level = "debug")]
595 pub fn broadcast_write_holding_registers(
596 &self,
597 address: u16,
598 values: &[u16],
599 ) -> ModbusResult<()> {
600 self.broadcast_count.fetch_add(1, Ordering::Relaxed);
601 let units = self.get_broadcast_targets();
602
603 trace!(
604 address,
605 count = values.len(),
606 unit_count = units.len(),
607 "Broadcasting write multiple holding registers"
608 );
609
610 for unit_id in units {
611 if let Some(unit) = self.units.get(&unit_id) {
612 if unit.is_enabled() && unit.broadcast_enabled() {
613 let _ = unit.registers().write_holding_registers(address, values);
614 unit.record_write();
615 }
616 }
617 }
618
619 Ok(())
620 }
621
622 #[instrument(skip(self), level = "debug")]
624 pub fn broadcast_write_coil(&self, address: u16, value: bool) -> ModbusResult<()> {
625 self.broadcast_count.fetch_add(1, Ordering::Relaxed);
626 let units = self.get_broadcast_targets();
627
628 trace!(address, value, unit_count = units.len(), "Broadcasting write coil");
629
630 for unit_id in units {
631 if let Some(unit) = self.units.get(&unit_id) {
632 if unit.is_enabled() && unit.broadcast_enabled() {
633 let _ = unit.registers().write_coil(address, value);
634 unit.record_write();
635 }
636 }
637 }
638
639 Ok(())
640 }
641
642 #[instrument(skip(self, values), level = "debug")]
644 pub fn broadcast_write_coils(&self, address: u16, values: &[bool]) -> ModbusResult<()> {
645 self.broadcast_count.fetch_add(1, Ordering::Relaxed);
646 let units = self.get_broadcast_targets();
647
648 trace!(
649 address,
650 count = values.len(),
651 unit_count = units.len(),
652 "Broadcasting write multiple coils"
653 );
654
655 for unit_id in units {
656 if let Some(unit) = self.units.get(&unit_id) {
657 if unit.is_enabled() && unit.broadcast_enabled() {
658 let _ = unit.registers().write_coils(address, values);
659 unit.record_write();
660 }
661 }
662 }
663
664 Ok(())
665 }
666
667 fn get_broadcast_targets(&self) -> Vec<u8> {
669 match &self.config.broadcast_mode {
670 BroadcastMode::WriteAll => self.unit_ids(),
671 BroadcastMode::Disabled => vec![],
672 BroadcastMode::SelectiveList(units) => units.clone(),
673 BroadcastMode::EchoToUnit(id) => vec![*id],
674 }
675 }
676
677 pub fn total_requests(&self) -> u64 {
683 self.total_requests.load(Ordering::Relaxed)
684 }
685
686 pub fn broadcast_count(&self) -> u64 {
688 self.broadcast_count.load(Ordering::Relaxed)
689 }
690
691 pub fn unit_statistics(&self) -> Vec<UnitStatistics> {
693 self.units
694 .iter()
695 .map(|entry| UnitStatistics {
696 unit_id: *entry.key(),
697 name: entry.value().config.name.clone(),
698 enabled: entry.value().is_enabled(),
699 read_count: entry.value().read_count(),
700 write_count: entry.value().write_count(),
701 error_count: entry.value().error_count(),
702 register_count: entry.value().registers().entry_count(),
703 })
704 .collect()
705 }
706
707 pub fn reset_statistics(&self) {
709 self.total_requests.store(0, Ordering::Relaxed);
710 self.broadcast_count.store(0, Ordering::Relaxed);
711
712 for entry in self.units.iter() {
713 entry.value().read_count.store(0, Ordering::Relaxed);
714 entry.value().write_count.store(0, Ordering::Relaxed);
715 entry.value().error_count.store(0, Ordering::Relaxed);
716 }
717 }
718}
719
720impl Default for MultiUnitManager {
721 fn default() -> Self {
722 Self::with_defaults()
723 }
724}
725
726impl std::fmt::Debug for MultiUnitManager {
727 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728 f.debug_struct("MultiUnitManager")
729 .field("unit_count", &self.unit_count())
730 .field("total_requests", &self.total_requests())
731 .field("broadcast_count", &self.broadcast_count())
732 .finish()
733 }
734}
735
736#[derive(Debug, Clone)]
738pub struct UnitStatistics {
739 pub unit_id: u8,
740 pub name: String,
741 pub enabled: bool,
742 pub read_count: u64,
743 pub write_count: u64,
744 pub error_count: u64,
745 pub register_count: usize,
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751
752 #[test]
753 fn test_create_manager() {
754 let manager = MultiUnitManager::with_defaults();
755 assert_eq!(manager.unit_count(), 0);
756 }
757
758 #[test]
759 fn test_add_and_get_unit() {
760 let manager = MultiUnitManager::with_defaults();
761
762 manager
763 .add_unit(1, UnitConfig::new("Test Unit"))
764 .unwrap();
765
766 assert!(manager.has_unit(1));
767 assert!(!manager.has_unit(2));
768
769 let unit = manager.get_unit(1).unwrap();
770 assert_eq!(unit.name(), "Test Unit");
771 }
772
773 #[test]
774 fn test_cannot_add_unit_zero() {
775 let manager = MultiUnitManager::with_defaults();
776
777 let result = manager.add_unit(0, UnitConfig::new("Broadcast"));
778 assert!(result.is_err());
779 }
780
781 #[test]
782 fn test_cannot_add_duplicate_unit() {
783 let manager = MultiUnitManager::with_defaults();
784
785 manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
786 let result = manager.add_unit(1, UnitConfig::new("Unit 1 Again"));
787 assert!(result.is_err());
788 }
789
790 #[test]
791 fn test_remove_unit() {
792 let manager = MultiUnitManager::with_defaults();
793
794 manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
795 assert!(manager.has_unit(1));
796
797 let removed = manager.remove_unit(1);
798 assert!(removed.is_some());
799 assert!(!manager.has_unit(1));
800 }
801
802 #[test]
803 fn test_read_write_operations() {
804 let manager = MultiUnitManager::with_defaults();
805 manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
806
807 manager.write_holding_register(1, 0, 12345).unwrap();
809
810 let values = manager.read_holding_registers(1, 0, 1).unwrap();
812 assert_eq!(values[0], 12345);
813 }
814
815 #[test]
816 fn test_broadcast_write() {
817 let manager = MultiUnitManager::with_defaults();
818 manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
819 manager.add_unit(2, UnitConfig::new("Unit 2")).unwrap();
820 manager.add_unit(3, UnitConfig::new("Unit 3")).unwrap();
821
822 manager.write_holding_register(0, 100, 999).unwrap();
824
825 let v1 = manager.read_holding_registers(1, 100, 1).unwrap();
827 let v2 = manager.read_holding_registers(2, 100, 1).unwrap();
828 let v3 = manager.read_holding_registers(3, 100, 1).unwrap();
829
830 assert_eq!(v1[0], 999);
831 assert_eq!(v2[0], 999);
832 assert_eq!(v3[0], 999);
833 }
834
835 #[test]
836 fn test_broadcast_with_disabled_unit() {
837 let manager = MultiUnitManager::with_defaults();
838 manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
839 manager
840 .add_unit(2, UnitConfig::new("Unit 2").with_broadcast(false))
841 .unwrap();
842
843 manager.broadcast_write_holding_register(100, 888).unwrap();
845
846 let v1 = manager.read_holding_registers(1, 100, 1).unwrap();
848 assert_eq!(v1[0], 888);
849
850 let v2 = manager.read_holding_registers(2, 100, 1).unwrap();
852 assert_ne!(v2[0], 888);
853 }
854
855 #[test]
856 fn test_auto_create_units() {
857 let config = UnitManagerConfig::default().with_auto_create(true);
858 let manager = MultiUnitManager::new(config);
859
860 assert!(!manager.has_unit(5));
862
863 let unit = manager.get_unit(5);
865 assert!(unit.is_some());
866 assert!(manager.has_unit(5));
867 }
868
869 #[test]
870 fn test_statistics() {
871 let manager = MultiUnitManager::with_defaults();
872 manager.add_unit(1, UnitConfig::new("Unit 1")).unwrap();
873
874 manager.write_holding_register(1, 0, 100).unwrap();
876 manager.read_holding_registers(1, 0, 1).unwrap();
877 manager.read_holding_registers(1, 0, 1).unwrap();
878
879 let stats = manager.unit_statistics();
880 assert_eq!(stats.len(), 1);
881 assert_eq!(stats[0].read_count, 2);
882 assert_eq!(stats[0].write_count, 1);
883 }
884
885 #[test]
886 fn test_different_word_orders() {
887 let manager = MultiUnitManager::with_defaults();
888
889 manager
890 .add_unit(1, UnitConfig::with_word_order("BE Unit", WordOrder::BigEndian))
891 .unwrap();
892 manager
893 .add_unit(
894 2,
895 UnitConfig::with_word_order("LE Unit", WordOrder::LittleEndian),
896 )
897 .unwrap();
898
899 let unit1 = manager.get_unit(1).unwrap();
900 let unit2 = manager.get_unit(2).unwrap();
901
902 assert_eq!(unit1.word_order(), WordOrder::BigEndian);
903 assert_eq!(unit2.word_order(), WordOrder::LittleEndian);
904 }
905}