1use crate::error::{ModbusError, ModbusResult};
7use crate::protocol::crc::{append_crc, verify_crc};
8use crate::protocol::frame::FunctionCode;
9use bytes::{BufMut, BytesMut};
10use std::time::Duration;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio_serial::{DataBits, FlowControl, Parity, SerialPortBuilderExt, SerialStream, StopBits};
13
14pub const DEFAULT_BAUD_RATE: u32 = 9600;
16
17pub const DEFAULT_RTU_TIMEOUT: Duration = Duration::from_millis(1000);
19
20const INTER_FRAME_DELAY_MS: u64 = 4;
22
23pub struct ModbusRtuClient {
41 stream: SerialStream,
43
44 unit_id: u8,
46
47 timeout: Duration,
49}
50
51impl ModbusRtuClient {
52 pub fn open(port: &str, baud_rate: u32, unit_id: u8) -> ModbusResult<Self> {
64 let stream = tokio_serial::new(port, baud_rate)
65 .data_bits(DataBits::Eight)
66 .parity(Parity::None)
67 .stop_bits(StopBits::One)
68 .flow_control(FlowControl::None)
69 .open_native_async()
70 .map_err(|e| ModbusError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
71
72 Ok(Self {
73 stream,
74 unit_id,
75 timeout: DEFAULT_RTU_TIMEOUT,
76 })
77 }
78
79 pub fn open_with_settings(
90 port: &str,
91 baud_rate: u32,
92 unit_id: u8,
93 data_bits: DataBits,
94 parity: Parity,
95 stop_bits: StopBits,
96 ) -> ModbusResult<Self> {
97 let stream = tokio_serial::new(port, baud_rate)
98 .data_bits(data_bits)
99 .parity(parity)
100 .stop_bits(stop_bits)
101 .flow_control(FlowControl::None)
102 .open_native_async()
103 .map_err(|e| ModbusError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
104
105 Ok(Self {
106 stream,
107 unit_id,
108 timeout: DEFAULT_RTU_TIMEOUT,
109 })
110 }
111
112 pub fn set_timeout(&mut self, timeout: Duration) {
114 self.timeout = timeout;
115 }
116
117 pub async fn read_holding_registers(
128 &mut self,
129 start_address: u16,
130 count: u16,
131 ) -> ModbusResult<Vec<u16>> {
132 if count > 125 {
133 return Err(ModbusError::Io(std::io::Error::new(
134 std::io::ErrorKind::InvalidInput,
135 "Cannot read more than 125 registers",
136 )));
137 }
138
139 let request = self.build_request(FunctionCode::ReadHoldingRegisters, start_address, count);
140
141 let response = self.send_request(&request).await?;
142 self.extract_registers(&response)
143 }
144
145 pub async fn read_input_registers(
156 &mut self,
157 start_address: u16,
158 count: u16,
159 ) -> ModbusResult<Vec<u16>> {
160 if count > 125 {
161 return Err(ModbusError::Io(std::io::Error::new(
162 std::io::ErrorKind::InvalidInput,
163 "Cannot read more than 125 registers",
164 )));
165 }
166
167 let request = self.build_request(FunctionCode::ReadInputRegisters, start_address, count);
168
169 let response = self.send_request(&request).await?;
170 self.extract_registers(&response)
171 }
172
173 pub async fn write_single_register(&mut self, address: u16, value: u16) -> ModbusResult<()> {
180 let request = self.build_request(FunctionCode::WriteSingleRegister, address, value);
181
182 let _response = self.send_request(&request).await?;
183 Ok(())
184 }
185
186 pub async fn read_coils(&mut self, start_address: u16, count: u16) -> ModbusResult<Vec<bool>> {
197 if count > 2000 {
198 return Err(ModbusError::Io(std::io::Error::new(
199 std::io::ErrorKind::InvalidInput,
200 "Cannot read more than 2000 coils",
201 )));
202 }
203
204 let request = self.build_request(FunctionCode::ReadCoils, start_address, count);
205 let response = self.send_request(&request).await?;
206 self.extract_coils(&response, count as usize)
207 }
208
209 pub async fn read_discrete_inputs(
220 &mut self,
221 start_address: u16,
222 count: u16,
223 ) -> ModbusResult<Vec<bool>> {
224 if count > 2000 {
225 return Err(ModbusError::Io(std::io::Error::new(
226 std::io::ErrorKind::InvalidInput,
227 "Cannot read more than 2000 inputs",
228 )));
229 }
230
231 let request = self.build_request(FunctionCode::ReadDiscreteInputs, start_address, count);
232 let response = self.send_request(&request).await?;
233 self.extract_coils(&response, count as usize)
234 }
235
236 fn build_request(&self, function_code: FunctionCode, param1: u16, param2: u16) -> Vec<u8> {
238 let mut bytes = BytesMut::with_capacity(8);
239
240 bytes.put_u8(self.unit_id);
242
243 bytes.put_u8(function_code.as_u8());
245
246 bytes.put_u16(param1);
248 bytes.put_u16(param2);
249
250 let mut request = bytes.to_vec();
252 append_crc(&mut request);
253
254 request
255 }
256
257 async fn send_request(&mut self, request: &[u8]) -> ModbusResult<Vec<u8>> {
259 tokio::time::sleep(Duration::from_millis(INTER_FRAME_DELAY_MS)).await;
261
262 self.stream.write_all(request).await?;
264 self.stream.flush().await?;
265
266 let response = tokio::time::timeout(self.timeout, self.read_response())
268 .await
269 .map_err(|_| ModbusError::Timeout(self.timeout))??;
270
271 if !verify_crc(&response) {
273 return Err(ModbusError::CrcError {
274 expected: 0, actual: 0,
276 });
277 }
278
279 if response.len() >= 3 && (response[1] & 0x80) != 0 {
281 let function = response[1] & 0x7F;
282 let exception_code = response[2];
283 return Err(ModbusError::ModbusException {
284 code: exception_code,
285 function,
286 });
287 }
288
289 Ok(response)
290 }
291
292 async fn read_response(&mut self) -> ModbusResult<Vec<u8>> {
294 let mut buffer = Vec::with_capacity(256);
295 let mut temp = [0u8; 256];
296
297 let n = self.stream.read(&mut temp).await?;
299 if n < 2 {
300 return Err(ModbusError::Io(std::io::Error::new(
301 std::io::ErrorKind::UnexpectedEof,
302 "Response too short",
303 )));
304 }
305 buffer.extend_from_slice(&temp[..n]);
306
307 let function_code = buffer[1] & 0x7F;
309 let expected_len = if (buffer[1] & 0x80) != 0 {
310 5
312 } else {
313 match function_code {
314 0x01..=0x04 => {
315 if buffer.len() < 3 {
317 let n = self.stream.read(&mut temp).await?;
319 buffer.extend_from_slice(&temp[..n]);
320 }
321 if buffer.len() < 3 {
322 return Err(ModbusError::Io(std::io::Error::new(
323 std::io::ErrorKind::UnexpectedEof,
324 "Missing byte count",
325 )));
326 }
327 3 + buffer[2] as usize + 2
329 }
330 0x05 | 0x06 => {
331 8
333 }
334 0x0F | 0x10 => {
335 8
337 }
338 _ => {
339 buffer.len() + 2
341 }
342 }
343 };
344
345 while buffer.len() < expected_len {
347 let n = self.stream.read(&mut temp).await?;
348 if n == 0 {
349 break;
350 }
351 buffer.extend_from_slice(&temp[..n]);
352 }
353
354 Ok(buffer)
355 }
356
357 fn extract_registers(&self, response: &[u8]) -> ModbusResult<Vec<u16>> {
359 if response.len() < 5 {
361 return Err(ModbusError::Io(std::io::Error::new(
362 std::io::ErrorKind::UnexpectedEof,
363 "Response too short",
364 )));
365 }
366
367 let byte_count = response[2] as usize;
368 if response.len() < 3 + byte_count + 2 {
369 return Err(ModbusError::Io(std::io::Error::new(
370 std::io::ErrorKind::UnexpectedEof,
371 "Incomplete register data",
372 )));
373 }
374
375 let mut registers = Vec::with_capacity(byte_count / 2);
376 let data = &response[3..3 + byte_count];
377
378 for chunk in data.chunks_exact(2) {
379 let value = u16::from_be_bytes([chunk[0], chunk[1]]);
380 registers.push(value);
381 }
382
383 Ok(registers)
384 }
385
386 fn extract_coils(&self, response: &[u8], count: usize) -> ModbusResult<Vec<bool>> {
388 if response.len() < 5 {
390 return Err(ModbusError::Io(std::io::Error::new(
391 std::io::ErrorKind::UnexpectedEof,
392 "Response too short",
393 )));
394 }
395
396 let byte_count = response[2] as usize;
397 if response.len() < 3 + byte_count + 2 {
398 return Err(ModbusError::Io(std::io::Error::new(
399 std::io::ErrorKind::UnexpectedEof,
400 "Incomplete coil data",
401 )));
402 }
403
404 let mut coils = Vec::with_capacity(count);
405 let data = &response[3..3 + byte_count];
406
407 for (byte_idx, &byte) in data.iter().enumerate() {
408 for bit_idx in 0..8 {
409 let coil_idx = byte_idx * 8 + bit_idx;
410 if coil_idx >= count {
411 break;
412 }
413 coils.push((byte >> bit_idx) & 1 == 1);
414 }
415 }
416
417 Ok(coils)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_build_request_read_registers() {
427 let request_data = vec![
430 0x01, 0x03, 0x00, 0x00, 0x00, 0x0A, ];
435
436 let mut request = request_data.clone();
437 append_crc(&mut request);
438
439 assert_eq!(request.len(), 8); assert!(verify_crc(&request));
441 }
442
443 #[test]
444 fn test_build_request_write_register() {
445 let request_data = vec![
446 0x01, 0x06, 0x00, 0x01, 0x00, 0x64, ];
451
452 let mut request = request_data;
453 append_crc(&mut request);
454
455 assert_eq!(request.len(), 8);
456 assert!(verify_crc(&request));
457 }
458}