Skip to main content

mabi_modbus/
device.rs

1//! Modbus device implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tokio::sync::broadcast;
10use tracing::instrument;
11
12use mabi_core::{
13    device::{Device, DeviceInfo, DeviceState, DeviceStatistics},
14    error::{Error, Result},
15    protocol::Protocol,
16    types::{
17        AccessMode, DataPoint, DataPointDef, DataPointId, DataType, ModbusAddress,
18        ModbusRegisterType,
19    },
20    value::Value,
21};
22
23use crate::config::ModbusDeviceConfig;
24use crate::context::{DeviceContext, SharedAddressSpace};
25use crate::error::ModbusResult;
26use crate::profile::{DatastoreKind, PointProfile, UnitProfile};
27
28/// Modbus device implementation.
29pub struct ModbusDevice {
30    /// Device info.
31    info: DeviceInfo,
32
33    /// Shared low-level device context.
34    context: Arc<DeviceContext>,
35
36    /// Data point definitions.
37    point_defs: HashMap<String, DataPointDef>,
38
39    /// Point ID to address mapping.
40    point_addresses: HashMap<String, ModbusAddress>,
41
42    /// Statistics.
43    stats: RwLock<DeviceStatistics>,
44
45    /// Value change broadcaster.
46    event_tx: broadcast::Sender<DataPoint>,
47
48    /// Response delay for simulation.
49    response_delay: Duration,
50
51    /// Start time.
52    start_time: RwLock<Option<Instant>>,
53}
54
55impl ModbusDevice {
56    /// Create a new Modbus device.
57    pub fn new(config: ModbusDeviceConfig) -> Self {
58        let context = Arc::new(
59            DeviceContext::new(
60                config.unit_id,
61                config.name.clone(),
62                DatastoreKind::dense_from_counts(
63                    config.coils,
64                    config.discrete_inputs,
65                    config.holding_registers,
66                    config.input_registers,
67                )
68                .build_address_space(),
69                crate::types::WordOrder::default(),
70            )
71            .with_response_delay(Duration::from_millis(config.response_delay_ms)),
72        );
73
74        let info = DeviceInfo::new(
75            format!("modbus-{}", config.unit_id),
76            &config.name,
77            Protocol::ModbusTcp,
78        )
79        .with_metadata("unit_id", config.unit_id.to_string())
80        .with_tags(config.tags);
81
82        let (event_tx, _) = broadcast::channel(1000);
83
84        Self {
85            info,
86            context,
87            point_defs: HashMap::new(),
88            point_addresses: HashMap::new(),
89            stats: RwLock::new(DeviceStatistics::default()),
90            event_tx,
91            response_delay: Duration::from_millis(config.response_delay_ms),
92            start_time: RwLock::new(None),
93        }
94    }
95
96    /// Create a device from a profile-defined unit.
97    pub fn from_profile(profile: &UnitProfile) -> ModbusResult<Self> {
98        let context = Arc::new(
99            DeviceContext::new(
100                profile.unit_id,
101                profile.name.clone(),
102                profile.datastore.build_address_space(),
103                profile.word_order,
104            )
105            .with_broadcast(profile.broadcast_enabled)
106            .with_response_delay(Duration::from_millis(profile.response_delay_ms)),
107        );
108
109        let info = DeviceInfo::new(
110            format!("modbus-{}", profile.unit_id),
111            &profile.name,
112            Protocol::ModbusTcp,
113        )
114        .with_metadata("unit_id", profile.unit_id.to_string())
115        .with_tags(profile.tags.clone());
116
117        let (event_tx, _) = broadcast::channel(1000);
118        let mut device = Self {
119            info,
120            context,
121            point_defs: HashMap::new(),
122            point_addresses: HashMap::new(),
123            stats: RwLock::new(DeviceStatistics::default()),
124            event_tx,
125            response_delay: Duration::from_millis(profile.response_delay_ms),
126            start_time: RwLock::new(None),
127        };
128
129        for point in &profile.points {
130            device.apply_point_profile(point);
131        }
132
133        Ok(device)
134    }
135
136    /// Get the unit ID.
137    pub fn unit_id(&self) -> u8 {
138        self.context.unit_id()
139    }
140
141    /// Get the low-level device context.
142    pub fn context(&self) -> &Arc<DeviceContext> {
143        &self.context
144    }
145
146    /// Get the underlying address space for this device.
147    pub fn address_space(&self) -> SharedAddressSpace {
148        self.context.address_space()
149    }
150
151    /// Backward-compatible accessor for the unit register space.
152    pub fn registers(&self) -> SharedAddressSpace {
153        self.address_space()
154    }
155
156    /// Add a data point definition.
157    pub fn add_point(&mut self, def: DataPointDef, address: ModbusAddress) {
158        self.point_addresses.insert(def.id.clone(), address);
159        self.point_defs.insert(def.id.clone(), def);
160        self.info.point_count = self.point_defs.len();
161    }
162
163    fn apply_point_profile(&mut self, profile: &PointProfile) {
164        match profile.register_type {
165            ModbusRegisterType::HoldingRegister => self.add_holding_register(
166                profile.id.clone(),
167                profile.name.clone(),
168                profile.address,
169                profile.data_type,
170            ),
171            ModbusRegisterType::InputRegister => self.add_input_register(
172                profile.id.clone(),
173                profile.name.clone(),
174                profile.address,
175                profile.data_type,
176            ),
177            ModbusRegisterType::Coil => {
178                self.add_coil(profile.id.clone(), profile.name.clone(), profile.address)
179            }
180            ModbusRegisterType::DiscreteInput => {
181                self.add_discrete_input(profile.id.clone(), profile.name.clone(), profile.address)
182            }
183        }
184    }
185
186    /// Add a holding register point.
187    pub fn add_holding_register(
188        &mut self,
189        id: impl Into<String>,
190        name: impl Into<String>,
191        address: u16,
192        data_type: DataType,
193    ) {
194        let id = id.into();
195        let register_count = match data_type {
196            DataType::Bool
197            | DataType::Int8
198            | DataType::UInt8
199            | DataType::Int16
200            | DataType::UInt16 => 1,
201            DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
202            DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
203            _ => 1,
204        };
205
206        let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadWrite);
207
208        let addr = ModbusAddress {
209            register_type: ModbusRegisterType::HoldingRegister,
210            address,
211            count: register_count,
212        };
213
214        self.add_point(def, addr);
215    }
216
217    /// Add an input register point.
218    pub fn add_input_register(
219        &mut self,
220        id: impl Into<String>,
221        name: impl Into<String>,
222        address: u16,
223        data_type: DataType,
224    ) {
225        let id = id.into();
226        let register_count = match data_type {
227            DataType::Bool
228            | DataType::Int8
229            | DataType::UInt8
230            | DataType::Int16
231            | DataType::UInt16 => 1,
232            DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
233            DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
234            _ => 1,
235        };
236
237        let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadOnly);
238
239        let addr = ModbusAddress {
240            register_type: ModbusRegisterType::InputRegister,
241            address,
242            count: register_count,
243        };
244
245        self.add_point(def, addr);
246    }
247
248    /// Add a coil point.
249    pub fn add_coil(&mut self, id: impl Into<String>, name: impl Into<String>, address: u16) {
250        let id = id.into();
251        let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadWrite);
252
253        let addr = ModbusAddress {
254            register_type: ModbusRegisterType::Coil,
255            address,
256            count: 1,
257        };
258
259        self.add_point(def, addr);
260    }
261
262    /// Add a discrete input point.
263    pub fn add_discrete_input(
264        &mut self,
265        id: impl Into<String>,
266        name: impl Into<String>,
267        address: u16,
268    ) {
269        let id = id.into();
270        let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadOnly);
271
272        let addr = ModbusAddress {
273            register_type: ModbusRegisterType::DiscreteInput,
274            address,
275            count: 1,
276        };
277
278        self.add_point(def, addr);
279    }
280
281    /// Read value from registers.
282    fn read_value(&self, point_id: &str) -> Result<Value> {
283        let def = self
284            .point_defs
285            .get(point_id)
286            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
287
288        let addr = self
289            .point_addresses
290            .get(point_id)
291            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
292
293        match addr.register_type {
294            ModbusRegisterType::Coil => {
295                let values = self.address_space().read_coils(addr.address, 1)?;
296                Ok(Value::Bool(values[0]))
297            }
298            ModbusRegisterType::DiscreteInput => {
299                let values = self.address_space().read_discrete_inputs(addr.address, 1)?;
300                Ok(Value::Bool(values[0]))
301            }
302            ModbusRegisterType::HoldingRegister => {
303                self.read_register_value(&def.data_type, addr.address, addr.count, true)
304            }
305            ModbusRegisterType::InputRegister => {
306                self.read_register_value(&def.data_type, addr.address, addr.count, false)
307            }
308        }
309    }
310
311    fn read_register_value(
312        &self,
313        data_type: &DataType,
314        address: u16,
315        count: u16,
316        is_holding: bool,
317    ) -> Result<Value> {
318        let registers = if is_holding {
319            self.address_space()
320                .read_holding_registers(address, count)?
321        } else {
322            self.address_space().read_input_registers(address, count)?
323        };
324
325        let value = match data_type {
326            DataType::Bool => Value::Bool(registers[0] != 0),
327            DataType::Int16 => Value::I16(registers[0] as i16),
328            DataType::UInt16 => Value::U16(registers[0]),
329            DataType::Int32 if registers.len() >= 2 => {
330                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
331                Value::I32(i32::from_be_bytes([
332                    bytes[0][0],
333                    bytes[0][1],
334                    bytes[1][0],
335                    bytes[1][1],
336                ]))
337            }
338            DataType::UInt32 if registers.len() >= 2 => {
339                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
340                Value::U32(u32::from_be_bytes([
341                    bytes[0][0],
342                    bytes[0][1],
343                    bytes[1][0],
344                    bytes[1][1],
345                ]))
346            }
347            DataType::Float32 if registers.len() >= 2 => {
348                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
349                Value::F32(f32::from_be_bytes([
350                    bytes[0][0],
351                    bytes[0][1],
352                    bytes[1][0],
353                    bytes[1][1],
354                ]))
355            }
356            DataType::Float64 if registers.len() >= 4 => {
357                let bytes: Vec<u8> = registers[..4]
358                    .iter()
359                    .flat_map(|r| r.to_be_bytes())
360                    .collect();
361                Value::F64(f64::from_be_bytes([
362                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
363                ]))
364            }
365            _ => Value::U16(registers[0]),
366        };
367
368        Ok(value)
369    }
370
371    /// Write value to registers.
372    fn write_value(&self, point_id: &str, value: Value) -> Result<()> {
373        let def = self
374            .point_defs
375            .get(point_id)
376            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
377
378        if !def.access.is_writable() {
379            return Err(Error::NotSupported(format!(
380                "Point {} is read-only",
381                point_id
382            )));
383        }
384
385        let addr = self
386            .point_addresses
387            .get(point_id)
388            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
389
390        match addr.register_type {
391            ModbusRegisterType::Coil => {
392                let bool_value = value.as_bool().ok_or_else(|| Error::TypeMismatch {
393                    expected: "bool".to_string(),
394                    actual: value.type_name().to_string(),
395                })?;
396                self.address_space().write_coil(addr.address, bool_value)?;
397            }
398            ModbusRegisterType::HoldingRegister => {
399                let registers = value.to_registers();
400                if registers.is_empty() {
401                    return Err(Error::InvalidValue {
402                        point_id: point_id.to_string(),
403                        reason: "Cannot convert value to registers".to_string(),
404                    });
405                }
406                self.address_space()
407                    .write_holding_registers(addr.address, &registers)?;
408            }
409            ModbusRegisterType::DiscreteInput | ModbusRegisterType::InputRegister => {
410                return Err(Error::NotSupported(format!(
411                    "Cannot write to {} register type",
412                    match addr.register_type {
413                        ModbusRegisterType::DiscreteInput => "discrete input",
414                        ModbusRegisterType::InputRegister => "input",
415                        _ => unreachable!(),
416                    }
417                )));
418            }
419        }
420
421        Ok(())
422    }
423
424    /// Shared read surface for runtime-backed device ports.
425    pub async fn read_point(&self, point_id: &str) -> Result<DataPoint> {
426        if !self.response_delay.is_zero() {
427            tokio::time::sleep(self.response_delay).await;
428        }
429
430        let value = self.read_value(point_id)?;
431        self.stats.write().record_read();
432
433        let id = DataPointId::new(&self.info.id, point_id);
434        Ok(DataPoint::new(id, value))
435    }
436
437    /// Shared write surface for runtime-backed device ports.
438    pub async fn write_point(&self, point_id: &str, value: Value) -> Result<()> {
439        if !self.response_delay.is_zero() {
440            tokio::time::sleep(self.response_delay).await;
441        }
442
443        self.write_value(point_id, value.clone())?;
444        self.stats.write().record_write();
445
446        let id = DataPointId::new(&self.info.id, point_id);
447        let _ = self.event_tx.send(DataPoint::new(id, value));
448
449        Ok(())
450    }
451
452    /// Returns cloned point definitions for control-plane catalog queries.
453    pub fn point_definitions_owned(&self) -> Vec<DataPointDef> {
454        self.point_defs.values().cloned().collect()
455    }
456}
457
458#[async_trait]
459impl Device for ModbusDevice {
460    fn info(&self) -> &DeviceInfo {
461        &self.info
462    }
463
464    async fn initialize(&mut self) -> Result<()> {
465        self.info.state = DeviceState::Initializing;
466        // Initialize registers to default values if needed
467        self.info.state = DeviceState::Online;
468        Ok(())
469    }
470
471    async fn start(&mut self) -> Result<()> {
472        *self.start_time.write() = Some(Instant::now());
473        self.info.state = DeviceState::Online;
474        Ok(())
475    }
476
477    async fn stop(&mut self) -> Result<()> {
478        self.info.state = DeviceState::Offline;
479        if let Some(start) = *self.start_time.read() {
480            self.stats.write().uptime_secs = start.elapsed().as_secs();
481        }
482        Ok(())
483    }
484
485    async fn tick(&mut self) -> Result<()> {
486        let start = Instant::now();
487        // Simulation updates can be done here
488        self.stats
489            .write()
490            .record_tick(start.elapsed().as_micros() as u64);
491        Ok(())
492    }
493
494    fn point_definitions(&self) -> Vec<&DataPointDef> {
495        self.point_defs.values().collect()
496    }
497
498    fn point_definition(&self, point_id: &str) -> Option<&DataPointDef> {
499        self.point_defs.get(point_id)
500    }
501
502    #[instrument(skip(self))]
503    async fn read(&self, point_id: &str) -> Result<DataPoint> {
504        self.read_point(point_id).await
505    }
506
507    #[instrument(skip(self, value))]
508    async fn write(&mut self, point_id: &str, value: Value) -> Result<()> {
509        self.write_point(point_id, value).await
510    }
511
512    fn subscribe(&self) -> Option<broadcast::Receiver<DataPoint>> {
513        Some(self.event_tx.subscribe())
514    }
515
516    fn statistics(&self) -> DeviceStatistics {
517        self.stats.read().clone()
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[tokio::test]
526    async fn test_modbus_device_creation() {
527        let config = ModbusDeviceConfig::new(1, "Test Device");
528        let device = ModbusDevice::new(config);
529
530        assert_eq!(device.unit_id(), 1);
531        assert_eq!(device.info().protocol, Protocol::ModbusTcp);
532    }
533
534    #[tokio::test]
535    async fn test_modbus_device_points() {
536        let config = ModbusDeviceConfig::new(1, "Test Device");
537        let mut device = ModbusDevice::new(config);
538
539        device.add_holding_register("temp", "Temperature", 0, DataType::Float32);
540        device.add_coil("relay1", "Relay 1", 0);
541
542        assert_eq!(device.point_definitions().len(), 2);
543        assert!(device.point_definition("temp").is_some());
544    }
545
546    #[tokio::test]
547    async fn test_modbus_device_read_write() {
548        let config = ModbusDeviceConfig::new(1, "Test Device");
549        let mut device = ModbusDevice::new(config);
550
551        device.add_holding_register("value", "Value", 0, DataType::UInt16);
552        device.initialize().await.unwrap();
553
554        // Write
555        device.write("value", Value::U16(12345)).await.unwrap();
556
557        // Read
558        let dp = device.read("value").await.unwrap();
559        assert_eq!(dp.value.as_u16(), Some(12345));
560    }
561}