Skip to main content

mabi_modbus/
device.rs

1//! Modbus device implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tokio::sync::broadcast;
10use tracing::instrument;
11
12use mabi_core::{
13    device::{Device, DeviceInfo, DeviceState, DeviceStatistics},
14    error::{Error, Result},
15    protocol::Protocol,
16    types::{
17        AccessMode, DataPoint, DataPointDef, DataPointId, DataType, ModbusAddress,
18        ModbusRegisterType,
19    },
20    value::Value,
21};
22
23use crate::config::ModbusDeviceConfig;
24use crate::register::RegisterStore;
25
26/// Modbus device implementation.
27pub struct ModbusDevice {
28    /// Device info.
29    info: DeviceInfo,
30
31    /// Unit ID.
32    unit_id: u8,
33
34    /// Register storage.
35    registers: Arc<RegisterStore>,
36
37    /// Data point definitions.
38    point_defs: HashMap<String, DataPointDef>,
39
40    /// Point ID to address mapping.
41    point_addresses: HashMap<String, ModbusAddress>,
42
43    /// Statistics.
44    stats: RwLock<DeviceStatistics>,
45
46    /// Value change broadcaster.
47    event_tx: broadcast::Sender<DataPoint>,
48
49    /// Response delay for simulation.
50    response_delay: Duration,
51
52    /// Start time.
53    start_time: RwLock<Option<Instant>>,
54}
55
56impl ModbusDevice {
57    /// Create a new Modbus device.
58    pub fn new(config: ModbusDeviceConfig) -> Self {
59        let registers = Arc::new(RegisterStore::new(
60            config.coils,
61            config.discrete_inputs,
62            config.holding_registers,
63            config.input_registers,
64        ));
65
66        let info = DeviceInfo::new(
67            format!("modbus-{}", config.unit_id),
68            &config.name,
69            Protocol::ModbusTcp,
70        )
71        .with_metadata("unit_id", config.unit_id.to_string())
72        .with_tags(config.tags);
73
74        let (event_tx, _) = broadcast::channel(1000);
75
76        Self {
77            info,
78            unit_id: config.unit_id,
79            registers,
80            point_defs: HashMap::new(),
81            point_addresses: HashMap::new(),
82            stats: RwLock::new(DeviceStatistics::default()),
83            event_tx,
84            response_delay: Duration::from_millis(config.response_delay_ms),
85            start_time: RwLock::new(None),
86        }
87    }
88
89    /// Get the unit ID.
90    pub fn unit_id(&self) -> u8 {
91        self.unit_id
92    }
93
94    /// Get the register store.
95    pub fn registers(&self) -> &Arc<RegisterStore> {
96        &self.registers
97    }
98
99    /// Add a data point definition.
100    pub fn add_point(&mut self, def: DataPointDef, address: ModbusAddress) {
101        self.point_addresses.insert(def.id.clone(), address);
102        self.point_defs.insert(def.id.clone(), def);
103        self.info.point_count = self.point_defs.len();
104    }
105
106    /// Add a holding register point.
107    pub fn add_holding_register(
108        &mut self,
109        id: impl Into<String>,
110        name: impl Into<String>,
111        address: u16,
112        data_type: DataType,
113    ) {
114        let id = id.into();
115        let register_count = match data_type {
116            DataType::Bool
117            | DataType::Int8
118            | DataType::UInt8
119            | DataType::Int16
120            | DataType::UInt16 => 1,
121            DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
122            DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
123            _ => 1,
124        };
125
126        let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadWrite);
127
128        let addr = ModbusAddress {
129            register_type: ModbusRegisterType::HoldingRegister,
130            address,
131            count: register_count,
132        };
133
134        self.add_point(def, addr);
135    }
136
137    /// Add an input register point.
138    pub fn add_input_register(
139        &mut self,
140        id: impl Into<String>,
141        name: impl Into<String>,
142        address: u16,
143        data_type: DataType,
144    ) {
145        let id = id.into();
146        let register_count = match data_type {
147            DataType::Bool
148            | DataType::Int8
149            | DataType::UInt8
150            | DataType::Int16
151            | DataType::UInt16 => 1,
152            DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
153            DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
154            _ => 1,
155        };
156
157        let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadOnly);
158
159        let addr = ModbusAddress {
160            register_type: ModbusRegisterType::InputRegister,
161            address,
162            count: register_count,
163        };
164
165        self.add_point(def, addr);
166    }
167
168    /// Add a coil point.
169    pub fn add_coil(&mut self, id: impl Into<String>, name: impl Into<String>, address: u16) {
170        let id = id.into();
171        let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadWrite);
172
173        let addr = ModbusAddress {
174            register_type: ModbusRegisterType::Coil,
175            address,
176            count: 1,
177        };
178
179        self.add_point(def, addr);
180    }
181
182    /// Add a discrete input point.
183    pub fn add_discrete_input(
184        &mut self,
185        id: impl Into<String>,
186        name: impl Into<String>,
187        address: u16,
188    ) {
189        let id = id.into();
190        let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadOnly);
191
192        let addr = ModbusAddress {
193            register_type: ModbusRegisterType::DiscreteInput,
194            address,
195            count: 1,
196        };
197
198        self.add_point(def, addr);
199    }
200
201    /// Read value from registers.
202    fn read_value(&self, point_id: &str) -> Result<Value> {
203        let def = self
204            .point_defs
205            .get(point_id)
206            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
207
208        let addr = self
209            .point_addresses
210            .get(point_id)
211            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
212
213        match addr.register_type {
214            ModbusRegisterType::Coil => {
215                let values = self.registers.read_coils(addr.address, 1)?;
216                Ok(Value::Bool(values[0]))
217            }
218            ModbusRegisterType::DiscreteInput => {
219                let values = self.registers.read_discrete_inputs(addr.address, 1)?;
220                Ok(Value::Bool(values[0]))
221            }
222            ModbusRegisterType::HoldingRegister => {
223                self.read_register_value(&def.data_type, addr.address, addr.count, true)
224            }
225            ModbusRegisterType::InputRegister => {
226                self.read_register_value(&def.data_type, addr.address, addr.count, false)
227            }
228        }
229    }
230
231    fn read_register_value(
232        &self,
233        data_type: &DataType,
234        address: u16,
235        count: u16,
236        is_holding: bool,
237    ) -> Result<Value> {
238        let registers = if is_holding {
239            self.registers.read_holding_registers(address, count)?
240        } else {
241            self.registers.read_input_registers(address, count)?
242        };
243
244        let value = match data_type {
245            DataType::Bool => Value::Bool(registers[0] != 0),
246            DataType::Int16 => Value::I16(registers[0] as i16),
247            DataType::UInt16 => Value::U16(registers[0]),
248            DataType::Int32 if registers.len() >= 2 => {
249                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
250                Value::I32(i32::from_be_bytes([
251                    bytes[0][0],
252                    bytes[0][1],
253                    bytes[1][0],
254                    bytes[1][1],
255                ]))
256            }
257            DataType::UInt32 if registers.len() >= 2 => {
258                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
259                Value::U32(u32::from_be_bytes([
260                    bytes[0][0],
261                    bytes[0][1],
262                    bytes[1][0],
263                    bytes[1][1],
264                ]))
265            }
266            DataType::Float32 if registers.len() >= 2 => {
267                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
268                Value::F32(f32::from_be_bytes([
269                    bytes[0][0],
270                    bytes[0][1],
271                    bytes[1][0],
272                    bytes[1][1],
273                ]))
274            }
275            DataType::Float64 if registers.len() >= 4 => {
276                let bytes: Vec<u8> = registers[..4]
277                    .iter()
278                    .flat_map(|r| r.to_be_bytes())
279                    .collect();
280                Value::F64(f64::from_be_bytes([
281                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
282                ]))
283            }
284            _ => Value::U16(registers[0]),
285        };
286
287        Ok(value)
288    }
289
290    /// Write value to registers.
291    fn write_value(&self, point_id: &str, value: Value) -> Result<()> {
292        let def = self
293            .point_defs
294            .get(point_id)
295            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
296
297        if !def.access.is_writable() {
298            return Err(Error::NotSupported(format!(
299                "Point {} is read-only",
300                point_id
301            )));
302        }
303
304        let addr = self
305            .point_addresses
306            .get(point_id)
307            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
308
309        match addr.register_type {
310            ModbusRegisterType::Coil => {
311                let bool_value = value.as_bool().ok_or_else(|| Error::TypeMismatch {
312                    expected: "bool".to_string(),
313                    actual: value.type_name().to_string(),
314                })?;
315                self.registers.write_coil(addr.address, bool_value)?;
316            }
317            ModbusRegisterType::HoldingRegister => {
318                let registers = value.to_registers();
319                if registers.is_empty() {
320                    return Err(Error::InvalidValue {
321                        point_id: point_id.to_string(),
322                        reason: "Cannot convert value to registers".to_string(),
323                    });
324                }
325                self.registers
326                    .write_holding_registers(addr.address, &registers)?;
327            }
328            ModbusRegisterType::DiscreteInput | ModbusRegisterType::InputRegister => {
329                return Err(Error::NotSupported(format!(
330                    "Cannot write to {} register type",
331                    match addr.register_type {
332                        ModbusRegisterType::DiscreteInput => "discrete input",
333                        ModbusRegisterType::InputRegister => "input",
334                        _ => unreachable!(),
335                    }
336                )));
337            }
338        }
339
340        Ok(())
341    }
342}
343
344#[async_trait]
345impl Device for ModbusDevice {
346    fn info(&self) -> &DeviceInfo {
347        &self.info
348    }
349
350    async fn initialize(&mut self) -> Result<()> {
351        self.info.state = DeviceState::Initializing;
352        // Initialize registers to default values if needed
353        self.info.state = DeviceState::Online;
354        Ok(())
355    }
356
357    async fn start(&mut self) -> Result<()> {
358        *self.start_time.write() = Some(Instant::now());
359        self.info.state = DeviceState::Online;
360        Ok(())
361    }
362
363    async fn stop(&mut self) -> Result<()> {
364        self.info.state = DeviceState::Offline;
365        if let Some(start) = *self.start_time.read() {
366            self.stats.write().uptime_secs = start.elapsed().as_secs();
367        }
368        Ok(())
369    }
370
371    async fn tick(&mut self) -> Result<()> {
372        let start = Instant::now();
373        // Simulation updates can be done here
374        self.stats
375            .write()
376            .record_tick(start.elapsed().as_micros() as u64);
377        Ok(())
378    }
379
380    fn point_definitions(&self) -> Vec<&DataPointDef> {
381        self.point_defs.values().collect()
382    }
383
384    fn point_definition(&self, point_id: &str) -> Option<&DataPointDef> {
385        self.point_defs.get(point_id)
386    }
387
388    #[instrument(skip(self))]
389    async fn read(&self, point_id: &str) -> Result<DataPoint> {
390        // Simulate response delay
391        if !self.response_delay.is_zero() {
392            tokio::time::sleep(self.response_delay).await;
393        }
394
395        let value = self.read_value(point_id)?;
396        self.stats.write().record_read();
397
398        let id = DataPointId::new(&self.info.id, point_id);
399        Ok(DataPoint::new(id, value))
400    }
401
402    #[instrument(skip(self, value))]
403    async fn write(&mut self, point_id: &str, value: Value) -> Result<()> {
404        // Simulate response delay
405        if !self.response_delay.is_zero() {
406            tokio::time::sleep(self.response_delay).await;
407        }
408
409        self.write_value(point_id, value.clone())?;
410        self.stats.write().record_write();
411
412        // Broadcast value change
413        let id = DataPointId::new(&self.info.id, point_id);
414        let _ = self.event_tx.send(DataPoint::new(id, value));
415
416        Ok(())
417    }
418
419    fn subscribe(&self) -> Option<broadcast::Receiver<DataPoint>> {
420        Some(self.event_tx.subscribe())
421    }
422
423    fn statistics(&self) -> DeviceStatistics {
424        self.stats.read().clone()
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[tokio::test]
433    async fn test_modbus_device_creation() {
434        let config = ModbusDeviceConfig::new(1, "Test Device");
435        let device = ModbusDevice::new(config);
436
437        assert_eq!(device.unit_id(), 1);
438        assert_eq!(device.info().protocol, Protocol::ModbusTcp);
439    }
440
441    #[tokio::test]
442    async fn test_modbus_device_points() {
443        let config = ModbusDeviceConfig::new(1, "Test Device");
444        let mut device = ModbusDevice::new(config);
445
446        device.add_holding_register("temp", "Temperature", 0, DataType::Float32);
447        device.add_coil("relay1", "Relay 1", 0);
448
449        assert_eq!(device.point_definitions().len(), 2);
450        assert!(device.point_definition("temp").is_some());
451    }
452
453    #[tokio::test]
454    async fn test_modbus_device_read_write() {
455        let config = ModbusDeviceConfig::new(1, "Test Device");
456        let mut device = ModbusDevice::new(config);
457
458        device.add_holding_register("value", "Value", 0, DataType::UInt16);
459        device.initialize().await.unwrap();
460
461        // Write
462        device.write("value", Value::U16(12345)).await.unwrap();
463
464        // Read
465        let dp = device.read("value").await.unwrap();
466        assert_eq!(dp.value.as_u16(), Some(12345));
467    }
468}