1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tokio::sync::broadcast;
10use tracing::instrument;
11
12use mabi_core::{
13 device::{Device, DeviceInfo, DeviceState, DeviceStatistics},
14 error::{Error, Result},
15 protocol::Protocol,
16 types::{
17 AccessMode, DataPoint, DataPointDef, DataPointId, DataType, ModbusAddress,
18 ModbusRegisterType,
19 },
20 value::Value,
21};
22
23use crate::config::ModbusDeviceConfig;
24use crate::register::RegisterStore;
25
26pub struct ModbusDevice {
28 info: DeviceInfo,
30
31 unit_id: u8,
33
34 registers: Arc<RegisterStore>,
36
37 point_defs: HashMap<String, DataPointDef>,
39
40 point_addresses: HashMap<String, ModbusAddress>,
42
43 stats: RwLock<DeviceStatistics>,
45
46 event_tx: broadcast::Sender<DataPoint>,
48
49 response_delay: Duration,
51
52 start_time: RwLock<Option<Instant>>,
54}
55
56impl ModbusDevice {
57 pub fn new(config: ModbusDeviceConfig) -> Self {
59 let registers = Arc::new(RegisterStore::new(
60 config.coils,
61 config.discrete_inputs,
62 config.holding_registers,
63 config.input_registers,
64 ));
65
66 let info = DeviceInfo::new(
67 format!("modbus-{}", config.unit_id),
68 &config.name,
69 Protocol::ModbusTcp,
70 )
71 .with_metadata("unit_id", config.unit_id.to_string());
72
73 let (event_tx, _) = broadcast::channel(1000);
74
75 Self {
76 info,
77 unit_id: config.unit_id,
78 registers,
79 point_defs: HashMap::new(),
80 point_addresses: HashMap::new(),
81 stats: RwLock::new(DeviceStatistics::default()),
82 event_tx,
83 response_delay: Duration::from_millis(config.response_delay_ms),
84 start_time: RwLock::new(None),
85 }
86 }
87
88 pub fn unit_id(&self) -> u8 {
90 self.unit_id
91 }
92
93 pub fn registers(&self) -> &Arc<RegisterStore> {
95 &self.registers
96 }
97
98 pub fn add_point(&mut self, def: DataPointDef, address: ModbusAddress) {
100 self.point_addresses.insert(def.id.clone(), address);
101 self.point_defs.insert(def.id.clone(), def);
102 self.info.point_count = self.point_defs.len();
103 }
104
105 pub fn add_holding_register(
107 &mut self,
108 id: impl Into<String>,
109 name: impl Into<String>,
110 address: u16,
111 data_type: DataType,
112 ) {
113 let id = id.into();
114 let register_count = match data_type {
115 DataType::Bool
116 | DataType::Int8
117 | DataType::UInt8
118 | DataType::Int16
119 | DataType::UInt16 => 1,
120 DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
121 DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
122 _ => 1,
123 };
124
125 let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadWrite);
126
127 let addr = ModbusAddress {
128 register_type: ModbusRegisterType::HoldingRegister,
129 address,
130 count: register_count,
131 };
132
133 self.add_point(def, addr);
134 }
135
136 pub fn add_input_register(
138 &mut self,
139 id: impl Into<String>,
140 name: impl Into<String>,
141 address: u16,
142 data_type: DataType,
143 ) {
144 let id = id.into();
145 let register_count = match data_type {
146 DataType::Bool
147 | DataType::Int8
148 | DataType::UInt8
149 | DataType::Int16
150 | DataType::UInt16 => 1,
151 DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
152 DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
153 _ => 1,
154 };
155
156 let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadOnly);
157
158 let addr = ModbusAddress {
159 register_type: ModbusRegisterType::InputRegister,
160 address,
161 count: register_count,
162 };
163
164 self.add_point(def, addr);
165 }
166
167 pub fn add_coil(&mut self, id: impl Into<String>, name: impl Into<String>, address: u16) {
169 let id = id.into();
170 let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadWrite);
171
172 let addr = ModbusAddress {
173 register_type: ModbusRegisterType::Coil,
174 address,
175 count: 1,
176 };
177
178 self.add_point(def, addr);
179 }
180
181 pub fn add_discrete_input(
183 &mut self,
184 id: impl Into<String>,
185 name: impl Into<String>,
186 address: u16,
187 ) {
188 let id = id.into();
189 let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadOnly);
190
191 let addr = ModbusAddress {
192 register_type: ModbusRegisterType::DiscreteInput,
193 address,
194 count: 1,
195 };
196
197 self.add_point(def, addr);
198 }
199
200 fn read_value(&self, point_id: &str) -> Result<Value> {
202 let def = self
203 .point_defs
204 .get(point_id)
205 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
206
207 let addr = self
208 .point_addresses
209 .get(point_id)
210 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
211
212 match addr.register_type {
213 ModbusRegisterType::Coil => {
214 let values = self.registers.read_coils(addr.address, 1)?;
215 Ok(Value::Bool(values[0]))
216 }
217 ModbusRegisterType::DiscreteInput => {
218 let values = self.registers.read_discrete_inputs(addr.address, 1)?;
219 Ok(Value::Bool(values[0]))
220 }
221 ModbusRegisterType::HoldingRegister => {
222 self.read_register_value(&def.data_type, addr.address, addr.count, true)
223 }
224 ModbusRegisterType::InputRegister => {
225 self.read_register_value(&def.data_type, addr.address, addr.count, false)
226 }
227 }
228 }
229
230 fn read_register_value(
231 &self,
232 data_type: &DataType,
233 address: u16,
234 count: u16,
235 is_holding: bool,
236 ) -> Result<Value> {
237 let registers = if is_holding {
238 self.registers.read_holding_registers(address, count)?
239 } else {
240 self.registers.read_input_registers(address, count)?
241 };
242
243 let value = match data_type {
244 DataType::Bool => Value::Bool(registers[0] != 0),
245 DataType::Int16 => Value::I16(registers[0] as i16),
246 DataType::UInt16 => Value::U16(registers[0]),
247 DataType::Int32 if registers.len() >= 2 => {
248 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
249 Value::I32(i32::from_be_bytes([
250 bytes[0][0],
251 bytes[0][1],
252 bytes[1][0],
253 bytes[1][1],
254 ]))
255 }
256 DataType::UInt32 if registers.len() >= 2 => {
257 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
258 Value::U32(u32::from_be_bytes([
259 bytes[0][0],
260 bytes[0][1],
261 bytes[1][0],
262 bytes[1][1],
263 ]))
264 }
265 DataType::Float32 if registers.len() >= 2 => {
266 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
267 Value::F32(f32::from_be_bytes([
268 bytes[0][0],
269 bytes[0][1],
270 bytes[1][0],
271 bytes[1][1],
272 ]))
273 }
274 DataType::Float64 if registers.len() >= 4 => {
275 let bytes: Vec<u8> = registers[..4]
276 .iter()
277 .flat_map(|r| r.to_be_bytes())
278 .collect();
279 Value::F64(f64::from_be_bytes([
280 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
281 ]))
282 }
283 _ => Value::U16(registers[0]),
284 };
285
286 Ok(value)
287 }
288
289 fn write_value(&self, point_id: &str, value: Value) -> Result<()> {
291 let def = self
292 .point_defs
293 .get(point_id)
294 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
295
296 if !def.access.is_writable() {
297 return Err(Error::NotSupported(format!(
298 "Point {} is read-only",
299 point_id
300 )));
301 }
302
303 let addr = self
304 .point_addresses
305 .get(point_id)
306 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
307
308 match addr.register_type {
309 ModbusRegisterType::Coil => {
310 let bool_value = value.as_bool().ok_or_else(|| Error::TypeMismatch {
311 expected: "bool".to_string(),
312 actual: value.type_name().to_string(),
313 })?;
314 self.registers.write_coil(addr.address, bool_value)?;
315 }
316 ModbusRegisterType::HoldingRegister => {
317 let registers = value.to_registers();
318 if registers.is_empty() {
319 return Err(Error::InvalidValue {
320 point_id: point_id.to_string(),
321 reason: "Cannot convert value to registers".to_string(),
322 });
323 }
324 self.registers
325 .write_holding_registers(addr.address, ®isters)?;
326 }
327 ModbusRegisterType::DiscreteInput | ModbusRegisterType::InputRegister => {
328 return Err(Error::NotSupported(format!(
329 "Cannot write to {} register type",
330 match addr.register_type {
331 ModbusRegisterType::DiscreteInput => "discrete input",
332 ModbusRegisterType::InputRegister => "input",
333 _ => unreachable!(),
334 }
335 )));
336 }
337 }
338
339 Ok(())
340 }
341}
342
343#[async_trait]
344impl Device for ModbusDevice {
345 fn info(&self) -> &DeviceInfo {
346 &self.info
347 }
348
349 async fn initialize(&mut self) -> Result<()> {
350 self.info.state = DeviceState::Initializing;
351 self.info.state = DeviceState::Online;
353 Ok(())
354 }
355
356 async fn start(&mut self) -> Result<()> {
357 *self.start_time.write() = Some(Instant::now());
358 self.info.state = DeviceState::Online;
359 Ok(())
360 }
361
362 async fn stop(&mut self) -> Result<()> {
363 self.info.state = DeviceState::Offline;
364 if let Some(start) = *self.start_time.read() {
365 self.stats.write().uptime_secs = start.elapsed().as_secs();
366 }
367 Ok(())
368 }
369
370 async fn tick(&mut self) -> Result<()> {
371 let start = Instant::now();
372 self.stats
374 .write()
375 .record_tick(start.elapsed().as_micros() as u64);
376 Ok(())
377 }
378
379 fn point_definitions(&self) -> Vec<&DataPointDef> {
380 self.point_defs.values().collect()
381 }
382
383 fn point_definition(&self, point_id: &str) -> Option<&DataPointDef> {
384 self.point_defs.get(point_id)
385 }
386
387 #[instrument(skip(self))]
388 async fn read(&self, point_id: &str) -> Result<DataPoint> {
389 if !self.response_delay.is_zero() {
391 tokio::time::sleep(self.response_delay).await;
392 }
393
394 let value = self.read_value(point_id)?;
395 self.stats.write().record_read();
396
397 let id = DataPointId::new(&self.info.id, point_id);
398 Ok(DataPoint::new(id, value))
399 }
400
401 #[instrument(skip(self, value))]
402 async fn write(&mut self, point_id: &str, value: Value) -> Result<()> {
403 if !self.response_delay.is_zero() {
405 tokio::time::sleep(self.response_delay).await;
406 }
407
408 self.write_value(point_id, value.clone())?;
409 self.stats.write().record_write();
410
411 let id = DataPointId::new(&self.info.id, point_id);
413 let _ = self.event_tx.send(DataPoint::new(id, value));
414
415 Ok(())
416 }
417
418 fn subscribe(&self) -> Option<broadcast::Receiver<DataPoint>> {
419 Some(self.event_tx.subscribe())
420 }
421
422 fn statistics(&self) -> DeviceStatistics {
423 self.stats.read().clone()
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[tokio::test]
432 async fn test_modbus_device_creation() {
433 let config = ModbusDeviceConfig::new(1, "Test Device");
434 let device = ModbusDevice::new(config);
435
436 assert_eq!(device.unit_id(), 1);
437 assert_eq!(device.info().protocol, Protocol::ModbusTcp);
438 }
439
440 #[tokio::test]
441 async fn test_modbus_device_points() {
442 let config = ModbusDeviceConfig::new(1, "Test Device");
443 let mut device = ModbusDevice::new(config);
444
445 device.add_holding_register("temp", "Temperature", 0, DataType::Float32);
446 device.add_coil("relay1", "Relay 1", 0);
447
448 assert_eq!(device.point_definitions().len(), 2);
449 assert!(device.point_definition("temp").is_some());
450 }
451
452 #[tokio::test]
453 async fn test_modbus_device_read_write() {
454 let config = ModbusDeviceConfig::new(1, "Test Device");
455 let mut device = ModbusDevice::new(config);
456
457 device.add_holding_register("value", "Value", 0, DataType::UInt16);
458 device.initialize().await.unwrap();
459
460 device.write("value", Value::U16(12345)).await.unwrap();
462
463 let dp = device.read("value").await.unwrap();
465 assert_eq!(dp.value.as_u16(), Some(12345));
466 }
467}