1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tokio::sync::broadcast;
10use tracing::instrument;
11
12use mabi_core::{
13 device::{Device, DeviceInfo, DeviceState, DeviceStatistics},
14 error::{Error, Result},
15 protocol::Protocol,
16 types::{
17 AccessMode, DataPoint, DataPointDef, DataPointId, DataType, ModbusAddress,
18 ModbusRegisterType,
19 },
20 value::Value,
21};
22
23use crate::config::ModbusDeviceConfig;
24use crate::register::RegisterStore;
25
26pub struct ModbusDevice {
28 info: DeviceInfo,
30
31 unit_id: u8,
33
34 registers: Arc<RegisterStore>,
36
37 point_defs: HashMap<String, DataPointDef>,
39
40 point_addresses: HashMap<String, ModbusAddress>,
42
43 stats: RwLock<DeviceStatistics>,
45
46 event_tx: broadcast::Sender<DataPoint>,
48
49 response_delay: Duration,
51
52 start_time: RwLock<Option<Instant>>,
54}
55
56impl ModbusDevice {
57 pub fn new(config: ModbusDeviceConfig) -> Self {
59 let registers = Arc::new(RegisterStore::new(
60 config.coils,
61 config.discrete_inputs,
62 config.holding_registers,
63 config.input_registers,
64 ));
65
66 let info = DeviceInfo::new(
67 format!("modbus-{}", config.unit_id),
68 &config.name,
69 Protocol::ModbusTcp,
70 )
71 .with_metadata("unit_id", config.unit_id.to_string())
72 .with_tags(config.tags);
73
74 let (event_tx, _) = broadcast::channel(1000);
75
76 Self {
77 info,
78 unit_id: config.unit_id,
79 registers,
80 point_defs: HashMap::new(),
81 point_addresses: HashMap::new(),
82 stats: RwLock::new(DeviceStatistics::default()),
83 event_tx,
84 response_delay: Duration::from_millis(config.response_delay_ms),
85 start_time: RwLock::new(None),
86 }
87 }
88
89 pub fn unit_id(&self) -> u8 {
91 self.unit_id
92 }
93
94 pub fn registers(&self) -> &Arc<RegisterStore> {
96 &self.registers
97 }
98
99 pub fn add_point(&mut self, def: DataPointDef, address: ModbusAddress) {
101 self.point_addresses.insert(def.id.clone(), address);
102 self.point_defs.insert(def.id.clone(), def);
103 self.info.point_count = self.point_defs.len();
104 }
105
106 pub fn add_holding_register(
108 &mut self,
109 id: impl Into<String>,
110 name: impl Into<String>,
111 address: u16,
112 data_type: DataType,
113 ) {
114 let id = id.into();
115 let register_count = match data_type {
116 DataType::Bool
117 | DataType::Int8
118 | DataType::UInt8
119 | DataType::Int16
120 | DataType::UInt16 => 1,
121 DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
122 DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
123 _ => 1,
124 };
125
126 let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadWrite);
127
128 let addr = ModbusAddress {
129 register_type: ModbusRegisterType::HoldingRegister,
130 address,
131 count: register_count,
132 };
133
134 self.add_point(def, addr);
135 }
136
137 pub fn add_input_register(
139 &mut self,
140 id: impl Into<String>,
141 name: impl Into<String>,
142 address: u16,
143 data_type: DataType,
144 ) {
145 let id = id.into();
146 let register_count = match data_type {
147 DataType::Bool
148 | DataType::Int8
149 | DataType::UInt8
150 | DataType::Int16
151 | DataType::UInt16 => 1,
152 DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
153 DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
154 _ => 1,
155 };
156
157 let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadOnly);
158
159 let addr = ModbusAddress {
160 register_type: ModbusRegisterType::InputRegister,
161 address,
162 count: register_count,
163 };
164
165 self.add_point(def, addr);
166 }
167
168 pub fn add_coil(&mut self, id: impl Into<String>, name: impl Into<String>, address: u16) {
170 let id = id.into();
171 let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadWrite);
172
173 let addr = ModbusAddress {
174 register_type: ModbusRegisterType::Coil,
175 address,
176 count: 1,
177 };
178
179 self.add_point(def, addr);
180 }
181
182 pub fn add_discrete_input(
184 &mut self,
185 id: impl Into<String>,
186 name: impl Into<String>,
187 address: u16,
188 ) {
189 let id = id.into();
190 let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadOnly);
191
192 let addr = ModbusAddress {
193 register_type: ModbusRegisterType::DiscreteInput,
194 address,
195 count: 1,
196 };
197
198 self.add_point(def, addr);
199 }
200
201 fn read_value(&self, point_id: &str) -> Result<Value> {
203 let def = self
204 .point_defs
205 .get(point_id)
206 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
207
208 let addr = self
209 .point_addresses
210 .get(point_id)
211 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
212
213 match addr.register_type {
214 ModbusRegisterType::Coil => {
215 let values = self.registers.read_coils(addr.address, 1)?;
216 Ok(Value::Bool(values[0]))
217 }
218 ModbusRegisterType::DiscreteInput => {
219 let values = self.registers.read_discrete_inputs(addr.address, 1)?;
220 Ok(Value::Bool(values[0]))
221 }
222 ModbusRegisterType::HoldingRegister => {
223 self.read_register_value(&def.data_type, addr.address, addr.count, true)
224 }
225 ModbusRegisterType::InputRegister => {
226 self.read_register_value(&def.data_type, addr.address, addr.count, false)
227 }
228 }
229 }
230
231 fn read_register_value(
232 &self,
233 data_type: &DataType,
234 address: u16,
235 count: u16,
236 is_holding: bool,
237 ) -> Result<Value> {
238 let registers = if is_holding {
239 self.registers.read_holding_registers(address, count)?
240 } else {
241 self.registers.read_input_registers(address, count)?
242 };
243
244 let value = match data_type {
245 DataType::Bool => Value::Bool(registers[0] != 0),
246 DataType::Int16 => Value::I16(registers[0] as i16),
247 DataType::UInt16 => Value::U16(registers[0]),
248 DataType::Int32 if registers.len() >= 2 => {
249 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
250 Value::I32(i32::from_be_bytes([
251 bytes[0][0],
252 bytes[0][1],
253 bytes[1][0],
254 bytes[1][1],
255 ]))
256 }
257 DataType::UInt32 if registers.len() >= 2 => {
258 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
259 Value::U32(u32::from_be_bytes([
260 bytes[0][0],
261 bytes[0][1],
262 bytes[1][0],
263 bytes[1][1],
264 ]))
265 }
266 DataType::Float32 if registers.len() >= 2 => {
267 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
268 Value::F32(f32::from_be_bytes([
269 bytes[0][0],
270 bytes[0][1],
271 bytes[1][0],
272 bytes[1][1],
273 ]))
274 }
275 DataType::Float64 if registers.len() >= 4 => {
276 let bytes: Vec<u8> = registers[..4]
277 .iter()
278 .flat_map(|r| r.to_be_bytes())
279 .collect();
280 Value::F64(f64::from_be_bytes([
281 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
282 ]))
283 }
284 _ => Value::U16(registers[0]),
285 };
286
287 Ok(value)
288 }
289
290 fn write_value(&self, point_id: &str, value: Value) -> Result<()> {
292 let def = self
293 .point_defs
294 .get(point_id)
295 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
296
297 if !def.access.is_writable() {
298 return Err(Error::NotSupported(format!(
299 "Point {} is read-only",
300 point_id
301 )));
302 }
303
304 let addr = self
305 .point_addresses
306 .get(point_id)
307 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
308
309 match addr.register_type {
310 ModbusRegisterType::Coil => {
311 let bool_value = value.as_bool().ok_or_else(|| Error::TypeMismatch {
312 expected: "bool".to_string(),
313 actual: value.type_name().to_string(),
314 })?;
315 self.registers.write_coil(addr.address, bool_value)?;
316 }
317 ModbusRegisterType::HoldingRegister => {
318 let registers = value.to_registers();
319 if registers.is_empty() {
320 return Err(Error::InvalidValue {
321 point_id: point_id.to_string(),
322 reason: "Cannot convert value to registers".to_string(),
323 });
324 }
325 self.registers
326 .write_holding_registers(addr.address, ®isters)?;
327 }
328 ModbusRegisterType::DiscreteInput | ModbusRegisterType::InputRegister => {
329 return Err(Error::NotSupported(format!(
330 "Cannot write to {} register type",
331 match addr.register_type {
332 ModbusRegisterType::DiscreteInput => "discrete input",
333 ModbusRegisterType::InputRegister => "input",
334 _ => unreachable!(),
335 }
336 )));
337 }
338 }
339
340 Ok(())
341 }
342}
343
344#[async_trait]
345impl Device for ModbusDevice {
346 fn info(&self) -> &DeviceInfo {
347 &self.info
348 }
349
350 async fn initialize(&mut self) -> Result<()> {
351 self.info.state = DeviceState::Initializing;
352 self.info.state = DeviceState::Online;
354 Ok(())
355 }
356
357 async fn start(&mut self) -> Result<()> {
358 *self.start_time.write() = Some(Instant::now());
359 self.info.state = DeviceState::Online;
360 Ok(())
361 }
362
363 async fn stop(&mut self) -> Result<()> {
364 self.info.state = DeviceState::Offline;
365 if let Some(start) = *self.start_time.read() {
366 self.stats.write().uptime_secs = start.elapsed().as_secs();
367 }
368 Ok(())
369 }
370
371 async fn tick(&mut self) -> Result<()> {
372 let start = Instant::now();
373 self.stats
375 .write()
376 .record_tick(start.elapsed().as_micros() as u64);
377 Ok(())
378 }
379
380 fn point_definitions(&self) -> Vec<&DataPointDef> {
381 self.point_defs.values().collect()
382 }
383
384 fn point_definition(&self, point_id: &str) -> Option<&DataPointDef> {
385 self.point_defs.get(point_id)
386 }
387
388 #[instrument(skip(self))]
389 async fn read(&self, point_id: &str) -> Result<DataPoint> {
390 if !self.response_delay.is_zero() {
392 tokio::time::sleep(self.response_delay).await;
393 }
394
395 let value = self.read_value(point_id)?;
396 self.stats.write().record_read();
397
398 let id = DataPointId::new(&self.info.id, point_id);
399 Ok(DataPoint::new(id, value))
400 }
401
402 #[instrument(skip(self, value))]
403 async fn write(&mut self, point_id: &str, value: Value) -> Result<()> {
404 if !self.response_delay.is_zero() {
406 tokio::time::sleep(self.response_delay).await;
407 }
408
409 self.write_value(point_id, value.clone())?;
410 self.stats.write().record_write();
411
412 let id = DataPointId::new(&self.info.id, point_id);
414 let _ = self.event_tx.send(DataPoint::new(id, value));
415
416 Ok(())
417 }
418
419 fn subscribe(&self) -> Option<broadcast::Receiver<DataPoint>> {
420 Some(self.event_tx.subscribe())
421 }
422
423 fn statistics(&self) -> DeviceStatistics {
424 self.stats.read().clone()
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[tokio::test]
433 async fn test_modbus_device_creation() {
434 let config = ModbusDeviceConfig::new(1, "Test Device");
435 let device = ModbusDevice::new(config);
436
437 assert_eq!(device.unit_id(), 1);
438 assert_eq!(device.info().protocol, Protocol::ModbusTcp);
439 }
440
441 #[tokio::test]
442 async fn test_modbus_device_points() {
443 let config = ModbusDeviceConfig::new(1, "Test Device");
444 let mut device = ModbusDevice::new(config);
445
446 device.add_holding_register("temp", "Temperature", 0, DataType::Float32);
447 device.add_coil("relay1", "Relay 1", 0);
448
449 assert_eq!(device.point_definitions().len(), 2);
450 assert!(device.point_definition("temp").is_some());
451 }
452
453 #[tokio::test]
454 async fn test_modbus_device_read_write() {
455 let config = ModbusDeviceConfig::new(1, "Test Device");
456 let mut device = ModbusDevice::new(config);
457
458 device.add_holding_register("value", "Value", 0, DataType::UInt16);
459 device.initialize().await.unwrap();
460
461 device.write("value", Value::U16(12345)).await.unwrap();
463
464 let dp = device.read("value").await.unwrap();
466 assert_eq!(dp.value.as_u16(), Some(12345));
467 }
468}