modbus_mqtt/modbus/
connection.rs

1use super::Word;
2use crate::modbus::{self, register};
3use crate::mqtt::Scopable;
4use crate::Error;
5use rust_decimal::prelude::Zero;
6use serde::Deserialize;
7use tokio::select;
8use tokio::sync::{mpsc, oneshot, watch};
9use tokio_modbus::client::{rtu, tcp, Context as ModbusClient};
10use tracing::{debug, error, warn};
11
12use crate::{mqtt, shutdown::Shutdown};
13
14use super::register::RegisterType;
15
16pub(crate) async fn run(
17    config: Config,
18    mqtt: mqtt::Handle,
19    shutdown: Shutdown,
20) -> crate::Result<Handle> {
21    let (connection_is_ready, mut is_connection_ready) = watch::channel(());
22    let (mut tx, mut rx) = mpsc::channel(32);
23    let handle = Handle { tx: tx.clone() };
24
25    tokio::spawn(async move {
26        // Can unwrap because if MQTT handler is bad, we have nothing to do here.
27        mqtt.publish("connecting").await.unwrap();
28
29        let address_offset = config.address_offset;
30
31        const MAX_WAIT: usize = 35;
32        let mut current_wait = 1;
33        let mut next_wait = 1;
34
35        loop {
36            match config.settings.connect(config.unit).await {
37                Ok(client) => {
38                    // Can unwrap because if MQTT handler is bad, we have nothing to do here.
39                    mqtt.publish("connected").await.unwrap();
40
41                    let mut conn = Connection {
42                        address_offset,
43                        client,
44                        mqtt: mqtt.clone(),
45                        shutdown: shutdown.clone(), // Important, so that we can publish "disconnected" below
46                        rx,
47                        tx,
48                    };
49
50                    let _ = connection_is_ready.send(());
51
52                    let result = conn.run().await;
53
54                    if let Err(error) = result {
55                        error!(?error, "Modbus connection failed");
56                        mqtt.publish("error").await.unwrap();
57                        mqtt.publish_under("last_error", format!("{error:?}"))
58                            .await
59                            .unwrap();
60
61                        // TODO, reset current_wait to 0 if it's been a while since it crashed.
62
63                        tokio::time::sleep(std::time::Duration::from_secs(current_wait as u64))
64                            .await;
65
66                        let Connection { rx: r, tx: t, .. } = conn;
67                        rx = r;
68                        tx = t;
69                        (current_wait, next_wait) =
70                            (next_wait, (current_wait + next_wait).clamp(0, MAX_WAIT));
71                    } else {
72                        // we are shutting down here, so don't care if this fails
73                        let send = mqtt.publish("disconnected").await;
74                        debug!(?config, ?send, "shutting down modbus connection");
75                        break;
76                    }
77                }
78                Err(error) => error!(?error),
79            }
80        }
81    });
82
83    is_connection_ready
84        .changed()
85        .await
86        .map_err(|_| Error::RecvError)?;
87    Ok(handle)
88}
89
90struct Connection {
91    client: ModbusClient,
92    address_offset: i8,
93    mqtt: mqtt::Handle,
94    shutdown: Shutdown,
95    rx: mpsc::Receiver<Command>,
96    tx: mpsc::Sender<Command>,
97}
98
99#[derive(Debug)]
100pub struct Handle {
101    tx: mpsc::Sender<Command>,
102}
103
104impl Handle {
105    pub async fn write_register(&self, address: u16, data: Vec<Word>) -> crate::Result<Vec<Word>> {
106        let (tx, rx) = oneshot::channel();
107        self.tx
108            .send(Command::Write(address, data, tx))
109            .await
110            .map_err(|_| Error::SendError)?;
111        rx.await.map_err(|_| Error::RecvError)?
112    }
113    pub async fn read_input_register(
114        &self,
115        address: u16,
116        quantity: u8,
117    ) -> crate::Result<Vec<Word>> {
118        self.read_register(RegisterType::Input, address, quantity)
119            .await
120    }
121    pub async fn read_holding_register(
122        &self,
123        address: u16,
124        quantity: u8,
125    ) -> crate::Result<Vec<Word>> {
126        self.read_register(RegisterType::Holding, address, quantity)
127            .await
128    }
129
130    async fn read_register(
131        &self,
132        reg_type: RegisterType,
133        address: u16,
134        quantity: u8,
135    ) -> crate::Result<Vec<Word>> {
136        let (tx, rx) = oneshot::channel();
137        self.tx
138            .send(Command::Read(reg_type, address, quantity, tx))
139            .await
140            .map_err(|_| Error::SendError)?;
141        rx.await.map_err(|_| Error::RecvError)?
142    }
143}
144
145type Response = oneshot::Sender<crate::Result<Vec<Word>>>;
146
147#[derive(Debug)]
148enum Command {
149    Read(RegisterType, u16, u8, Response),
150    Write(u16, Vec<Word>, Response),
151}
152
153impl Connection {
154    pub async fn run(&mut self) -> crate::Result<()> {
155        let mut registers_rx = register::subscribe(&self.mqtt).await?;
156
157        loop {
158            select! {
159                Some(cmd) = self.rx.recv() => { self.process_command(cmd).await?; },
160
161                Some(register) = registers_rx.recv() => {
162                    debug!(?register);
163                    let mqtt = self.mqtt.scoped("registers");
164                    let modbus = self.handle();
165                    register::Monitor::new(
166                        register,
167                        mqtt,
168                        modbus,
169                    )
170                    .run()
171                    .await;
172                },
173
174                _ = self.shutdown.recv() => {
175                    return Ok(());
176                }
177            }
178        }
179    }
180
181    fn handle(&self) -> Handle {
182        Handle {
183            tx: self.tx.clone(),
184        }
185    }
186
187    // TODO: if we get a new register definition for an existing register, how do we avoid redundant (and possibly
188    // conflicting) tasks? Should MQTT component only allow one subscriber per topic filter, replacing the old one
189    // when it gets a new subscribe request?
190    // IDEA: Allow providing a subscription ID which _replaces_ any existing subscription with the same ID
191
192    /// Apply address offset to address.
193    ///
194    /// Panics if offset would overflow or underflow the address.
195    fn adjust_address(&self, address: u16) -> u16 {
196        if self.address_offset.is_zero() {
197            return address;
198        }
199
200        // TODO: use `checked_add_signed()` once stabilised: https://doc.rust-lang.org/std/primitive.u16.html#method.checked_add_signed
201        let adjusted_address = if self.address_offset >= 0 {
202            address.checked_add(self.address_offset as u16)
203        } else {
204            address.checked_sub(self.address_offset.unsigned_abs() as u16)
205        };
206
207        if let Some(address) = adjusted_address {
208            address
209        } else {
210            error!(address, offset = self.address_offset,);
211            address
212            // panic!("Address offset would underflow/overflow")
213        }
214    }
215
216    async fn process_command(&mut self, cmd: Command) -> crate::Result<()> {
217        use tokio_modbus::prelude::Reader;
218
219        let (tx, response) = match cmd {
220            Command::Read(RegisterType::Input, address, count, tx) => {
221                let address = self.adjust_address(address);
222                (
223                    tx,
224                    self.client
225                        .read_input_registers(address, count as u16)
226                        .await,
227                )
228            }
229            Command::Read(RegisterType::Holding, address, count, tx) => {
230                let address = self.adjust_address(address);
231                (
232                    tx,
233                    self.client
234                        .read_holding_registers(address, count as u16)
235                        .await,
236                )
237            }
238            Command::Write(address, data, tx) => {
239                let address = self.adjust_address(address);
240                (
241                    tx,
242                    self.client
243                        .read_write_multiple_registers(
244                            address,
245                            data.len() as u16,
246                            address,
247                            &data[..],
248                        )
249                        .await,
250                )
251            }
252        };
253
254        // This might be transient, so don't kill connection. We may be able to discriminate on the error to determine
255        // which errors are transient and which are conclusive.
256        //
257        // Some errors that we have observed:
258        //
259        //     Error { kind: UnexpectedEof, message: "failed to fill whole buffer" }'
260        //     Custom { kind: InvalidData, error: "Invalid data length: 0" }'
261        //     Os { code: 36, kind: Uncategorized, message: "Operation now in progress" }'
262        //     Os { code: 35, kind: WouldBlock, message: "Resource temporarily unavailable" }
263        //
264        if let Err(ref error) = response {
265            use std::io::ErrorKind;
266            match error.kind() {
267                ErrorKind::UnexpectedEof | ErrorKind::InvalidData => {
268                    // THIS happening feels like a bug either in how I am using tokio_modbus or in tokio_modbus. It seems
269                    // like the underlying buffers get all messed up and restarting doesn't always fix it unless I wait a
270                    // few seconds. I might need to get help from someone to figure it out.
271                    error!(?error, "Connection error, may not be recoverable");
272                    return Err(response.unwrap_err().into());
273                }
274                _ => error!(?error),
275            }
276        }
277
278        // This probably just means that the register task died or is no longer monitoring the response.
279        if let Err(response) = tx.send(response.map_err(Into::into)) {
280            warn!(?response, "error sending response");
281        }
282
283        Ok(())
284    }
285}
286
287#[derive(Debug, Deserialize)]
288pub(crate) struct Config {
289    #[serde(flatten)]
290    pub settings: ModbusProto,
291
292    #[serde(
293        alias = "slave",
294        default = "tokio_modbus::slave::Slave::broadcast",
295        with = "Unit"
296    )]
297    pub unit: modbus::Unit,
298
299    #[serde(default)]
300    pub address_offset: i8,
301}
302
303#[derive(Deserialize)]
304#[serde(remote = "tokio_modbus::slave::Slave")]
305pub(crate) struct Unit(crate::modbus::UnitId);
306
307#[derive(Clone, Debug, Deserialize)]
308#[serde(tag = "proto", rename_all = "lowercase")]
309pub(crate) enum ModbusProto {
310    #[cfg(feature = "tcp")]
311    Tcp {
312        host: String,
313
314        #[serde(default = "default_modbus_port")]
315        port: u16,
316    },
317    #[cfg(feature = "rtu")]
318    #[serde(rename_all = "lowercase")]
319    Rtu {
320        // tty: std::path::PathBuf,
321        tty: String,
322        baud_rate: u32,
323
324        #[serde(default = "default_modbus_data_bits")]
325        data_bits: tokio_serial::DataBits, // TODO: allow this to be represented as a number instead of string
326
327        #[serde(default = "default_modbus_stop_bits")]
328        stop_bits: tokio_serial::StopBits, // TODO: allow this to be represented as a number instead of string
329
330        #[serde(default = "default_modbus_flow_control")]
331        flow_control: tokio_serial::FlowControl,
332
333        #[serde(default = "default_modbus_parity")]
334        parity: tokio_serial::Parity,
335    },
336    #[cfg(feature = "winet-s")]
337    #[serde(rename = "winet-s")]
338    SungrowWiNetS { host: String },
339
340    // Predominantly for if the binary is compiled with no default features for some reason.
341    #[serde(other)]
342    Unknown,
343}
344
345impl ModbusProto {
346    // Can we use the "slave context" thing in Modbus to pass the unit later?
347    pub async fn connect(&self, unit: modbus::Unit) -> crate::Result<ModbusClient> {
348        let client = match *self {
349            #[cfg(feature = "winet-s")]
350            ModbusProto::SungrowWiNetS { ref host } => {
351                tokio_modbus_winets::connect_slave(host, unit).await?
352            }
353
354            #[cfg(feature = "tcp")]
355            ModbusProto::Tcp { ref host, port } => {
356                let socket_addr = format!("{}:{}", host, port).parse()?;
357                tcp::connect_slave(socket_addr, unit).await?
358            }
359
360            #[cfg(feature = "rtu")]
361            ModbusProto::Rtu {
362                ref tty,
363                baud_rate,
364                data_bits,
365                stop_bits,
366                flow_control,
367                parity,
368            } => {
369                let builder = tokio_serial::new(tty, baud_rate)
370                    .data_bits(data_bits)
371                    .flow_control(flow_control)
372                    .parity(parity)
373                    .stop_bits(stop_bits);
374                let port = tokio_serial::SerialStream::open(&builder)?;
375                rtu::connect_slave(port, unit).await?
376            }
377
378            ModbusProto::Unknown => {
379                error!("Unrecognised protocol");
380                Err(Error::UnrecognisedModbusProtocol)?
381            }
382        };
383        Ok(client)
384    }
385}
386
387pub(crate) fn default_modbus_port() -> u16 {
388    502
389}
390
391#[cfg(feature = "rtu")]
392pub(crate) fn default_modbus_data_bits() -> tokio_serial::DataBits {
393    tokio_serial::DataBits::Eight
394}
395
396#[cfg(feature = "rtu")]
397pub(crate) fn default_modbus_stop_bits() -> tokio_serial::StopBits {
398    tokio_serial::StopBits::One
399}
400
401#[cfg(feature = "rtu")]
402pub(crate) fn default_modbus_flow_control() -> tokio_serial::FlowControl {
403    tokio_serial::FlowControl::None
404}
405
406#[cfg(feature = "rtu")]
407pub(crate) fn default_modbus_parity() -> tokio_serial::Parity {
408    tokio_serial::Parity::None
409}
410
411#[test]
412fn parse_minimal_tcp_connect_config() {
413    use serde_json::json;
414    let result = serde_json::from_value::<Config>(json!({
415        "proto": "tcp",
416        "host": "1.1.1.1"
417    }));
418
419    let connect = result.unwrap();
420    assert!(matches!(
421        connect.settings,
422        ModbusProto::Tcp {
423            ref host,
424            port: 502
425        } if host == "1.1.1.1"
426    ))
427}
428
429#[test]
430fn parse_full_tcp_connect_config() {
431    use serde_json::json;
432    let _ = serde_json::from_value::<Config>(json!({
433        "proto": "tcp",
434        "host": "10.10.10.219",
435        "unit": 1,
436        "address_offset": -1,
437        "input": [
438            {
439                "address": 5017,
440                "type": "u32",
441                "name": "dc_power",
442                "swap_words": false,
443                "period": "3s"
444            },
445            {
446                "address": 5008,
447                "type": "s16",
448                "name": "internal_temperature",
449                "period": "1m"
450            },
451            {
452                "address": 13008,
453                "type": "s32",
454                "name": "load_power",
455                "swap_words": false,
456                "period": "3s"
457            },
458            {
459                "address": 13010,
460                "type": "s32",
461                "name": "export_power",
462                "swap_words": false,
463                "period": "3s"
464            },
465            {
466                "address": 13022,
467                "name": "battery_power",
468                "period": "3s"
469            },
470            {
471                "address": 13023,
472                "name": "battery_level",
473                "period": "1m"
474            },
475            {
476                "address": 13024,
477                "name": "battery_health",
478                "period": "10m"
479            }
480        ],
481        "hold": [
482            {
483                "address": 13058,
484                "name": "max_soc",
485                "period": "90s"
486            },
487            {
488                "address": 13059,
489                "name": "min_soc",
490                "period": "90s"
491            }
492        ]
493    }))
494    .unwrap();
495}
496
497#[test]
498fn parse_minimal_rtu_connect_config() {
499    use serde_json::json;
500    let result = serde_json::from_value::<Config>(json!({
501        "proto": "rtu",
502        "tty": "/dev/ttyUSB0",
503        "baud_rate": 9600,
504    }));
505
506    let connect = result.unwrap();
507    use tokio_serial::*;
508    assert!(matches!(
509        connect.settings,
510        ModbusProto::Rtu {
511            ref tty,
512            baud_rate: 9600,
513            data_bits: DataBits::Eight,
514            stop_bits: StopBits::One,
515            flow_control: FlowControl::None,
516            parity: Parity::None,
517            ..
518        } if tty == "/dev/ttyUSB0"
519    ))
520}
521
522#[test]
523fn parse_complete_rtu_connect_config() {
524    use serde_json::json;
525    let result = serde_json::from_value::<Config>(json!({
526        "proto": "rtu",
527        "tty": "/dev/ttyUSB0",
528        "baud_rate": 12800,
529
530        // TODO: make lowercase words work
531        "data_bits": "Seven", // TODO: make 7 work
532        "stop_bits": "Two", // TODO: make 2 work
533        "flow_control": "Software",
534        "parity": "Even",
535    }));
536
537    let connect = result.unwrap();
538    use tokio_serial::*;
539    assert!(matches!(
540        connect.settings,
541        ModbusProto::Rtu {
542            ref tty,
543            baud_rate: 12800,
544            data_bits: DataBits::Seven,
545            stop_bits: StopBits::Two,
546            flow_control: FlowControl::Software,
547            parity: Parity::Even,
548            ..
549        } if tty == "/dev/ttyUSB0"
550    ),);
551}