Skip to main content

mabi_modbus/
device.rs

1//! Modbus device implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tokio::sync::broadcast;
10use tracing::instrument;
11
12use mabi_core::{
13    device::{Device, DeviceInfo, DeviceState, DeviceStatistics},
14    error::{Error, Result},
15    protocol::Protocol,
16    types::{
17        AccessMode, DataPoint, DataPointDef, DataPointId, DataType, ModbusAddress,
18        ModbusRegisterType,
19    },
20    value::Value,
21};
22
23use crate::config::ModbusDeviceConfig;
24use crate::register::RegisterStore;
25
26/// Modbus device implementation.
27pub struct ModbusDevice {
28    /// Device info.
29    info: DeviceInfo,
30
31    /// Unit ID.
32    unit_id: u8,
33
34    /// Register storage.
35    registers: Arc<RegisterStore>,
36
37    /// Data point definitions.
38    point_defs: HashMap<String, DataPointDef>,
39
40    /// Point ID to address mapping.
41    point_addresses: HashMap<String, ModbusAddress>,
42
43    /// Statistics.
44    stats: RwLock<DeviceStatistics>,
45
46    /// Value change broadcaster.
47    event_tx: broadcast::Sender<DataPoint>,
48
49    /// Response delay for simulation.
50    response_delay: Duration,
51
52    /// Start time.
53    start_time: RwLock<Option<Instant>>,
54}
55
56impl ModbusDevice {
57    /// Create a new Modbus device.
58    pub fn new(config: ModbusDeviceConfig) -> Self {
59        let registers = Arc::new(RegisterStore::new(
60            config.coils,
61            config.discrete_inputs,
62            config.holding_registers,
63            config.input_registers,
64        ));
65
66        let info = DeviceInfo::new(
67            format!("modbus-{}", config.unit_id),
68            &config.name,
69            Protocol::ModbusTcp,
70        )
71        .with_metadata("unit_id", config.unit_id.to_string());
72
73        let (event_tx, _) = broadcast::channel(1000);
74
75        Self {
76            info,
77            unit_id: config.unit_id,
78            registers,
79            point_defs: HashMap::new(),
80            point_addresses: HashMap::new(),
81            stats: RwLock::new(DeviceStatistics::default()),
82            event_tx,
83            response_delay: Duration::from_millis(config.response_delay_ms),
84            start_time: RwLock::new(None),
85        }
86    }
87
88    /// Get the unit ID.
89    pub fn unit_id(&self) -> u8 {
90        self.unit_id
91    }
92
93    /// Get the register store.
94    pub fn registers(&self) -> &Arc<RegisterStore> {
95        &self.registers
96    }
97
98    /// Add a data point definition.
99    pub fn add_point(&mut self, def: DataPointDef, address: ModbusAddress) {
100        self.point_addresses.insert(def.id.clone(), address);
101        self.point_defs.insert(def.id.clone(), def);
102        self.info.point_count = self.point_defs.len();
103    }
104
105    /// Add a holding register point.
106    pub fn add_holding_register(
107        &mut self,
108        id: impl Into<String>,
109        name: impl Into<String>,
110        address: u16,
111        data_type: DataType,
112    ) {
113        let id = id.into();
114        let register_count = match data_type {
115            DataType::Bool
116            | DataType::Int8
117            | DataType::UInt8
118            | DataType::Int16
119            | DataType::UInt16 => 1,
120            DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
121            DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
122            _ => 1,
123        };
124
125        let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadWrite);
126
127        let addr = ModbusAddress {
128            register_type: ModbusRegisterType::HoldingRegister,
129            address,
130            count: register_count,
131        };
132
133        self.add_point(def, addr);
134    }
135
136    /// Add an input register point.
137    pub fn add_input_register(
138        &mut self,
139        id: impl Into<String>,
140        name: impl Into<String>,
141        address: u16,
142        data_type: DataType,
143    ) {
144        let id = id.into();
145        let register_count = match data_type {
146            DataType::Bool
147            | DataType::Int8
148            | DataType::UInt8
149            | DataType::Int16
150            | DataType::UInt16 => 1,
151            DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
152            DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
153            _ => 1,
154        };
155
156        let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadOnly);
157
158        let addr = ModbusAddress {
159            register_type: ModbusRegisterType::InputRegister,
160            address,
161            count: register_count,
162        };
163
164        self.add_point(def, addr);
165    }
166
167    /// Add a coil point.
168    pub fn add_coil(&mut self, id: impl Into<String>, name: impl Into<String>, address: u16) {
169        let id = id.into();
170        let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadWrite);
171
172        let addr = ModbusAddress {
173            register_type: ModbusRegisterType::Coil,
174            address,
175            count: 1,
176        };
177
178        self.add_point(def, addr);
179    }
180
181    /// Add a discrete input point.
182    pub fn add_discrete_input(
183        &mut self,
184        id: impl Into<String>,
185        name: impl Into<String>,
186        address: u16,
187    ) {
188        let id = id.into();
189        let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadOnly);
190
191        let addr = ModbusAddress {
192            register_type: ModbusRegisterType::DiscreteInput,
193            address,
194            count: 1,
195        };
196
197        self.add_point(def, addr);
198    }
199
200    /// Read value from registers.
201    fn read_value(&self, point_id: &str) -> Result<Value> {
202        let def = self
203            .point_defs
204            .get(point_id)
205            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
206
207        let addr = self
208            .point_addresses
209            .get(point_id)
210            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
211
212        match addr.register_type {
213            ModbusRegisterType::Coil => {
214                let values = self.registers.read_coils(addr.address, 1)?;
215                Ok(Value::Bool(values[0]))
216            }
217            ModbusRegisterType::DiscreteInput => {
218                let values = self.registers.read_discrete_inputs(addr.address, 1)?;
219                Ok(Value::Bool(values[0]))
220            }
221            ModbusRegisterType::HoldingRegister => {
222                self.read_register_value(&def.data_type, addr.address, addr.count, true)
223            }
224            ModbusRegisterType::InputRegister => {
225                self.read_register_value(&def.data_type, addr.address, addr.count, false)
226            }
227        }
228    }
229
230    fn read_register_value(
231        &self,
232        data_type: &DataType,
233        address: u16,
234        count: u16,
235        is_holding: bool,
236    ) -> Result<Value> {
237        let registers = if is_holding {
238            self.registers.read_holding_registers(address, count)?
239        } else {
240            self.registers.read_input_registers(address, count)?
241        };
242
243        let value = match data_type {
244            DataType::Bool => Value::Bool(registers[0] != 0),
245            DataType::Int16 => Value::I16(registers[0] as i16),
246            DataType::UInt16 => Value::U16(registers[0]),
247            DataType::Int32 if registers.len() >= 2 => {
248                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
249                Value::I32(i32::from_be_bytes([
250                    bytes[0][0],
251                    bytes[0][1],
252                    bytes[1][0],
253                    bytes[1][1],
254                ]))
255            }
256            DataType::UInt32 if registers.len() >= 2 => {
257                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
258                Value::U32(u32::from_be_bytes([
259                    bytes[0][0],
260                    bytes[0][1],
261                    bytes[1][0],
262                    bytes[1][1],
263                ]))
264            }
265            DataType::Float32 if registers.len() >= 2 => {
266                let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
267                Value::F32(f32::from_be_bytes([
268                    bytes[0][0],
269                    bytes[0][1],
270                    bytes[1][0],
271                    bytes[1][1],
272                ]))
273            }
274            DataType::Float64 if registers.len() >= 4 => {
275                let bytes: Vec<u8> = registers[..4]
276                    .iter()
277                    .flat_map(|r| r.to_be_bytes())
278                    .collect();
279                Value::F64(f64::from_be_bytes([
280                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
281                ]))
282            }
283            _ => Value::U16(registers[0]),
284        };
285
286        Ok(value)
287    }
288
289    /// Write value to registers.
290    fn write_value(&self, point_id: &str, value: Value) -> Result<()> {
291        let def = self
292            .point_defs
293            .get(point_id)
294            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
295
296        if !def.access.is_writable() {
297            return Err(Error::NotSupported(format!(
298                "Point {} is read-only",
299                point_id
300            )));
301        }
302
303        let addr = self
304            .point_addresses
305            .get(point_id)
306            .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
307
308        match addr.register_type {
309            ModbusRegisterType::Coil => {
310                let bool_value = value.as_bool().ok_or_else(|| Error::TypeMismatch {
311                    expected: "bool".to_string(),
312                    actual: value.type_name().to_string(),
313                })?;
314                self.registers.write_coil(addr.address, bool_value)?;
315            }
316            ModbusRegisterType::HoldingRegister => {
317                let registers = value.to_registers();
318                if registers.is_empty() {
319                    return Err(Error::InvalidValue {
320                        point_id: point_id.to_string(),
321                        reason: "Cannot convert value to registers".to_string(),
322                    });
323                }
324                self.registers
325                    .write_holding_registers(addr.address, &registers)?;
326            }
327            ModbusRegisterType::DiscreteInput | ModbusRegisterType::InputRegister => {
328                return Err(Error::NotSupported(format!(
329                    "Cannot write to {} register type",
330                    match addr.register_type {
331                        ModbusRegisterType::DiscreteInput => "discrete input",
332                        ModbusRegisterType::InputRegister => "input",
333                        _ => unreachable!(),
334                    }
335                )));
336            }
337        }
338
339        Ok(())
340    }
341}
342
343#[async_trait]
344impl Device for ModbusDevice {
345    fn info(&self) -> &DeviceInfo {
346        &self.info
347    }
348
349    async fn initialize(&mut self) -> Result<()> {
350        self.info.state = DeviceState::Initializing;
351        // Initialize registers to default values if needed
352        self.info.state = DeviceState::Online;
353        Ok(())
354    }
355
356    async fn start(&mut self) -> Result<()> {
357        *self.start_time.write() = Some(Instant::now());
358        self.info.state = DeviceState::Online;
359        Ok(())
360    }
361
362    async fn stop(&mut self) -> Result<()> {
363        self.info.state = DeviceState::Offline;
364        if let Some(start) = *self.start_time.read() {
365            self.stats.write().uptime_secs = start.elapsed().as_secs();
366        }
367        Ok(())
368    }
369
370    async fn tick(&mut self) -> Result<()> {
371        let start = Instant::now();
372        // Simulation updates can be done here
373        self.stats
374            .write()
375            .record_tick(start.elapsed().as_micros() as u64);
376        Ok(())
377    }
378
379    fn point_definitions(&self) -> Vec<&DataPointDef> {
380        self.point_defs.values().collect()
381    }
382
383    fn point_definition(&self, point_id: &str) -> Option<&DataPointDef> {
384        self.point_defs.get(point_id)
385    }
386
387    #[instrument(skip(self))]
388    async fn read(&self, point_id: &str) -> Result<DataPoint> {
389        // Simulate response delay
390        if !self.response_delay.is_zero() {
391            tokio::time::sleep(self.response_delay).await;
392        }
393
394        let value = self.read_value(point_id)?;
395        self.stats.write().record_read();
396
397        let id = DataPointId::new(&self.info.id, point_id);
398        Ok(DataPoint::new(id, value))
399    }
400
401    #[instrument(skip(self, value))]
402    async fn write(&mut self, point_id: &str, value: Value) -> Result<()> {
403        // Simulate response delay
404        if !self.response_delay.is_zero() {
405            tokio::time::sleep(self.response_delay).await;
406        }
407
408        self.write_value(point_id, value.clone())?;
409        self.stats.write().record_write();
410
411        // Broadcast value change
412        let id = DataPointId::new(&self.info.id, point_id);
413        let _ = self.event_tx.send(DataPoint::new(id, value));
414
415        Ok(())
416    }
417
418    fn subscribe(&self) -> Option<broadcast::Receiver<DataPoint>> {
419        Some(self.event_tx.subscribe())
420    }
421
422    fn statistics(&self) -> DeviceStatistics {
423        self.stats.read().clone()
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[tokio::test]
432    async fn test_modbus_device_creation() {
433        let config = ModbusDeviceConfig::new(1, "Test Device");
434        let device = ModbusDevice::new(config);
435
436        assert_eq!(device.unit_id(), 1);
437        assert_eq!(device.info().protocol, Protocol::ModbusTcp);
438    }
439
440    #[tokio::test]
441    async fn test_modbus_device_points() {
442        let config = ModbusDeviceConfig::new(1, "Test Device");
443        let mut device = ModbusDevice::new(config);
444
445        device.add_holding_register("temp", "Temperature", 0, DataType::Float32);
446        device.add_coil("relay1", "Relay 1", 0);
447
448        assert_eq!(device.point_definitions().len(), 2);
449        assert!(device.point_definition("temp").is_some());
450    }
451
452    #[tokio::test]
453    async fn test_modbus_device_read_write() {
454        let config = ModbusDeviceConfig::new(1, "Test Device");
455        let mut device = ModbusDevice::new(config);
456
457        device.add_holding_register("value", "Value", 0, DataType::UInt16);
458        device.initialize().await.unwrap();
459
460        // Write
461        device.write("value", Value::U16(12345)).await.unwrap();
462
463        // Read
464        let dp = device.read("value").await.unwrap();
465        assert_eq!(dp.value.as_u16(), Some(12345));
466    }
467}