1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tokio::sync::broadcast;
10use tracing::instrument;
11
12use mabi_core::{
13 device::{Device, DeviceInfo, DeviceState, DeviceStatistics},
14 error::{Error, Result},
15 protocol::Protocol,
16 types::{
17 AccessMode, DataPoint, DataPointDef, DataPointId, DataType, ModbusAddress,
18 ModbusRegisterType,
19 },
20 value::Value,
21};
22
23use crate::config::ModbusDeviceConfig;
24use crate::context::{DeviceContext, SharedAddressSpace};
25use crate::error::ModbusResult;
26use crate::profile::{DatastoreKind, PointProfile, UnitProfile};
27
28pub struct ModbusDevice {
30 info: DeviceInfo,
32
33 context: Arc<DeviceContext>,
35
36 point_defs: HashMap<String, DataPointDef>,
38
39 point_addresses: HashMap<String, ModbusAddress>,
41
42 stats: RwLock<DeviceStatistics>,
44
45 event_tx: broadcast::Sender<DataPoint>,
47
48 response_delay: Duration,
50
51 start_time: RwLock<Option<Instant>>,
53}
54
55impl ModbusDevice {
56 pub fn new(config: ModbusDeviceConfig) -> Self {
58 let context = Arc::new(
59 DeviceContext::new(
60 config.unit_id,
61 config.name.clone(),
62 DatastoreKind::dense_from_counts(
63 config.coils,
64 config.discrete_inputs,
65 config.holding_registers,
66 config.input_registers,
67 )
68 .build_address_space(),
69 crate::types::WordOrder::default(),
70 )
71 .with_response_delay(Duration::from_millis(config.response_delay_ms)),
72 );
73
74 let info = DeviceInfo::new(
75 format!("modbus-{}", config.unit_id),
76 &config.name,
77 Protocol::ModbusTcp,
78 )
79 .with_metadata("unit_id", config.unit_id.to_string())
80 .with_tags(config.tags);
81
82 let (event_tx, _) = broadcast::channel(1000);
83
84 Self {
85 info,
86 context,
87 point_defs: HashMap::new(),
88 point_addresses: HashMap::new(),
89 stats: RwLock::new(DeviceStatistics::default()),
90 event_tx,
91 response_delay: Duration::from_millis(config.response_delay_ms),
92 start_time: RwLock::new(None),
93 }
94 }
95
96 pub fn from_profile(profile: &UnitProfile) -> ModbusResult<Self> {
98 let context = Arc::new(
99 DeviceContext::new(
100 profile.unit_id,
101 profile.name.clone(),
102 profile.datastore.build_address_space(),
103 profile.word_order,
104 )
105 .with_broadcast(profile.broadcast_enabled)
106 .with_response_delay(Duration::from_millis(profile.response_delay_ms)),
107 );
108
109 let info = DeviceInfo::new(
110 format!("modbus-{}", profile.unit_id),
111 &profile.name,
112 Protocol::ModbusTcp,
113 )
114 .with_metadata("unit_id", profile.unit_id.to_string())
115 .with_tags(profile.tags.clone());
116
117 let (event_tx, _) = broadcast::channel(1000);
118 let mut device = Self {
119 info,
120 context,
121 point_defs: HashMap::new(),
122 point_addresses: HashMap::new(),
123 stats: RwLock::new(DeviceStatistics::default()),
124 event_tx,
125 response_delay: Duration::from_millis(profile.response_delay_ms),
126 start_time: RwLock::new(None),
127 };
128
129 for point in &profile.points {
130 device.apply_point_profile(point);
131 }
132
133 Ok(device)
134 }
135
136 pub fn unit_id(&self) -> u8 {
138 self.context.unit_id()
139 }
140
141 pub fn context(&self) -> &Arc<DeviceContext> {
143 &self.context
144 }
145
146 pub fn address_space(&self) -> SharedAddressSpace {
148 self.context.address_space()
149 }
150
151 pub fn registers(&self) -> SharedAddressSpace {
153 self.address_space()
154 }
155
156 pub fn add_point(&mut self, def: DataPointDef, address: ModbusAddress) {
158 self.point_addresses.insert(def.id.clone(), address);
159 self.point_defs.insert(def.id.clone(), def);
160 self.info.point_count = self.point_defs.len();
161 }
162
163 fn apply_point_profile(&mut self, profile: &PointProfile) {
164 match profile.register_type {
165 ModbusRegisterType::HoldingRegister => self.add_holding_register(
166 profile.id.clone(),
167 profile.name.clone(),
168 profile.address,
169 profile.data_type,
170 ),
171 ModbusRegisterType::InputRegister => self.add_input_register(
172 profile.id.clone(),
173 profile.name.clone(),
174 profile.address,
175 profile.data_type,
176 ),
177 ModbusRegisterType::Coil => {
178 self.add_coil(profile.id.clone(), profile.name.clone(), profile.address)
179 }
180 ModbusRegisterType::DiscreteInput => {
181 self.add_discrete_input(profile.id.clone(), profile.name.clone(), profile.address)
182 }
183 }
184 }
185
186 pub fn add_holding_register(
188 &mut self,
189 id: impl Into<String>,
190 name: impl Into<String>,
191 address: u16,
192 data_type: DataType,
193 ) {
194 let id = id.into();
195 let register_count = match data_type {
196 DataType::Bool
197 | DataType::Int8
198 | DataType::UInt8
199 | DataType::Int16
200 | DataType::UInt16 => 1,
201 DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
202 DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
203 _ => 1,
204 };
205
206 let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadWrite);
207
208 let addr = ModbusAddress {
209 register_type: ModbusRegisterType::HoldingRegister,
210 address,
211 count: register_count,
212 };
213
214 self.add_point(def, addr);
215 }
216
217 pub fn add_input_register(
219 &mut self,
220 id: impl Into<String>,
221 name: impl Into<String>,
222 address: u16,
223 data_type: DataType,
224 ) {
225 let id = id.into();
226 let register_count = match data_type {
227 DataType::Bool
228 | DataType::Int8
229 | DataType::UInt8
230 | DataType::Int16
231 | DataType::UInt16 => 1,
232 DataType::Int32 | DataType::UInt32 | DataType::Float32 => 2,
233 DataType::Int64 | DataType::UInt64 | DataType::Float64 | DataType::DateTime => 4,
234 _ => 1,
235 };
236
237 let def = DataPointDef::new(&id, name, data_type).with_access(AccessMode::ReadOnly);
238
239 let addr = ModbusAddress {
240 register_type: ModbusRegisterType::InputRegister,
241 address,
242 count: register_count,
243 };
244
245 self.add_point(def, addr);
246 }
247
248 pub fn add_coil(&mut self, id: impl Into<String>, name: impl Into<String>, address: u16) {
250 let id = id.into();
251 let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadWrite);
252
253 let addr = ModbusAddress {
254 register_type: ModbusRegisterType::Coil,
255 address,
256 count: 1,
257 };
258
259 self.add_point(def, addr);
260 }
261
262 pub fn add_discrete_input(
264 &mut self,
265 id: impl Into<String>,
266 name: impl Into<String>,
267 address: u16,
268 ) {
269 let id = id.into();
270 let def = DataPointDef::new(&id, name, DataType::Bool).with_access(AccessMode::ReadOnly);
271
272 let addr = ModbusAddress {
273 register_type: ModbusRegisterType::DiscreteInput,
274 address,
275 count: 1,
276 };
277
278 self.add_point(def, addr);
279 }
280
281 fn read_value(&self, point_id: &str) -> Result<Value> {
283 let def = self
284 .point_defs
285 .get(point_id)
286 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
287
288 let addr = self
289 .point_addresses
290 .get(point_id)
291 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
292
293 match addr.register_type {
294 ModbusRegisterType::Coil => {
295 let values = self.address_space().read_coils(addr.address, 1)?;
296 Ok(Value::Bool(values[0]))
297 }
298 ModbusRegisterType::DiscreteInput => {
299 let values = self.address_space().read_discrete_inputs(addr.address, 1)?;
300 Ok(Value::Bool(values[0]))
301 }
302 ModbusRegisterType::HoldingRegister => {
303 self.read_register_value(&def.data_type, addr.address, addr.count, true)
304 }
305 ModbusRegisterType::InputRegister => {
306 self.read_register_value(&def.data_type, addr.address, addr.count, false)
307 }
308 }
309 }
310
311 fn read_register_value(
312 &self,
313 data_type: &DataType,
314 address: u16,
315 count: u16,
316 is_holding: bool,
317 ) -> Result<Value> {
318 let registers = if is_holding {
319 self.address_space()
320 .read_holding_registers(address, count)?
321 } else {
322 self.address_space().read_input_registers(address, count)?
323 };
324
325 let value = match data_type {
326 DataType::Bool => Value::Bool(registers[0] != 0),
327 DataType::Int16 => Value::I16(registers[0] as i16),
328 DataType::UInt16 => Value::U16(registers[0]),
329 DataType::Int32 if registers.len() >= 2 => {
330 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
331 Value::I32(i32::from_be_bytes([
332 bytes[0][0],
333 bytes[0][1],
334 bytes[1][0],
335 bytes[1][1],
336 ]))
337 }
338 DataType::UInt32 if registers.len() >= 2 => {
339 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
340 Value::U32(u32::from_be_bytes([
341 bytes[0][0],
342 bytes[0][1],
343 bytes[1][0],
344 bytes[1][1],
345 ]))
346 }
347 DataType::Float32 if registers.len() >= 2 => {
348 let bytes = [registers[0].to_be_bytes(), registers[1].to_be_bytes()];
349 Value::F32(f32::from_be_bytes([
350 bytes[0][0],
351 bytes[0][1],
352 bytes[1][0],
353 bytes[1][1],
354 ]))
355 }
356 DataType::Float64 if registers.len() >= 4 => {
357 let bytes: Vec<u8> = registers[..4]
358 .iter()
359 .flat_map(|r| r.to_be_bytes())
360 .collect();
361 Value::F64(f64::from_be_bytes([
362 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
363 ]))
364 }
365 _ => Value::U16(registers[0]),
366 };
367
368 Ok(value)
369 }
370
371 fn write_value(&self, point_id: &str, value: Value) -> Result<()> {
373 let def = self
374 .point_defs
375 .get(point_id)
376 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
377
378 if !def.access.is_writable() {
379 return Err(Error::NotSupported(format!(
380 "Point {} is read-only",
381 point_id
382 )));
383 }
384
385 let addr = self
386 .point_addresses
387 .get(point_id)
388 .ok_or_else(|| Error::point_not_found(&self.info.id, point_id))?;
389
390 match addr.register_type {
391 ModbusRegisterType::Coil => {
392 let bool_value = value.as_bool().ok_or_else(|| Error::TypeMismatch {
393 expected: "bool".to_string(),
394 actual: value.type_name().to_string(),
395 })?;
396 self.address_space().write_coil(addr.address, bool_value)?;
397 }
398 ModbusRegisterType::HoldingRegister => {
399 let registers = value.to_registers();
400 if registers.is_empty() {
401 return Err(Error::InvalidValue {
402 point_id: point_id.to_string(),
403 reason: "Cannot convert value to registers".to_string(),
404 });
405 }
406 self.address_space()
407 .write_holding_registers(addr.address, ®isters)?;
408 }
409 ModbusRegisterType::DiscreteInput | ModbusRegisterType::InputRegister => {
410 return Err(Error::NotSupported(format!(
411 "Cannot write to {} register type",
412 match addr.register_type {
413 ModbusRegisterType::DiscreteInput => "discrete input",
414 ModbusRegisterType::InputRegister => "input",
415 _ => unreachable!(),
416 }
417 )));
418 }
419 }
420
421 Ok(())
422 }
423
424 pub async fn read_point(&self, point_id: &str) -> Result<DataPoint> {
426 if !self.response_delay.is_zero() {
427 tokio::time::sleep(self.response_delay).await;
428 }
429
430 let value = self.read_value(point_id)?;
431 self.stats.write().record_read();
432
433 let id = DataPointId::new(&self.info.id, point_id);
434 Ok(DataPoint::new(id, value))
435 }
436
437 pub async fn write_point(&self, point_id: &str, value: Value) -> Result<()> {
439 if !self.response_delay.is_zero() {
440 tokio::time::sleep(self.response_delay).await;
441 }
442
443 self.write_value(point_id, value.clone())?;
444 self.stats.write().record_write();
445
446 let id = DataPointId::new(&self.info.id, point_id);
447 let _ = self.event_tx.send(DataPoint::new(id, value));
448
449 Ok(())
450 }
451
452 pub fn point_definitions_owned(&self) -> Vec<DataPointDef> {
454 self.point_defs.values().cloned().collect()
455 }
456}
457
458#[async_trait]
459impl Device for ModbusDevice {
460 fn info(&self) -> &DeviceInfo {
461 &self.info
462 }
463
464 async fn initialize(&mut self) -> Result<()> {
465 self.info.state = DeviceState::Initializing;
466 self.info.state = DeviceState::Online;
468 Ok(())
469 }
470
471 async fn start(&mut self) -> Result<()> {
472 *self.start_time.write() = Some(Instant::now());
473 self.info.state = DeviceState::Online;
474 Ok(())
475 }
476
477 async fn stop(&mut self) -> Result<()> {
478 self.info.state = DeviceState::Offline;
479 if let Some(start) = *self.start_time.read() {
480 self.stats.write().uptime_secs = start.elapsed().as_secs();
481 }
482 Ok(())
483 }
484
485 async fn tick(&mut self) -> Result<()> {
486 let start = Instant::now();
487 self.stats
489 .write()
490 .record_tick(start.elapsed().as_micros() as u64);
491 Ok(())
492 }
493
494 fn point_definitions(&self) -> Vec<&DataPointDef> {
495 self.point_defs.values().collect()
496 }
497
498 fn point_definition(&self, point_id: &str) -> Option<&DataPointDef> {
499 self.point_defs.get(point_id)
500 }
501
502 #[instrument(skip(self))]
503 async fn read(&self, point_id: &str) -> Result<DataPoint> {
504 self.read_point(point_id).await
505 }
506
507 #[instrument(skip(self, value))]
508 async fn write(&mut self, point_id: &str, value: Value) -> Result<()> {
509 self.write_point(point_id, value).await
510 }
511
512 fn subscribe(&self) -> Option<broadcast::Receiver<DataPoint>> {
513 Some(self.event_tx.subscribe())
514 }
515
516 fn statistics(&self) -> DeviceStatistics {
517 self.stats.read().clone()
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[tokio::test]
526 async fn test_modbus_device_creation() {
527 let config = ModbusDeviceConfig::new(1, "Test Device");
528 let device = ModbusDevice::new(config);
529
530 assert_eq!(device.unit_id(), 1);
531 assert_eq!(device.info().protocol, Protocol::ModbusTcp);
532 }
533
534 #[tokio::test]
535 async fn test_modbus_device_points() {
536 let config = ModbusDeviceConfig::new(1, "Test Device");
537 let mut device = ModbusDevice::new(config);
538
539 device.add_holding_register("temp", "Temperature", 0, DataType::Float32);
540 device.add_coil("relay1", "Relay 1", 0);
541
542 assert_eq!(device.point_definitions().len(), 2);
543 assert!(device.point_definition("temp").is_some());
544 }
545
546 #[tokio::test]
547 async fn test_modbus_device_read_write() {
548 let config = ModbusDeviceConfig::new(1, "Test Device");
549 let mut device = ModbusDevice::new(config);
550
551 device.add_holding_register("value", "Value", 0, DataType::UInt16);
552 device.initialize().await.unwrap();
553
554 device.write("value", Value::U16(12345)).await.unwrap();
556
557 let dp = device.read("value").await.unwrap();
559 assert_eq!(dp.value.as_u16(), Some(12345));
560 }
561}