modbus_proxy_rs/
lib.rs

1#[macro_use]
2extern crate log;
3
4use futures::future::join_all;
5use serde::Deserialize;
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::{mpsc, oneshot};
9
10// Use Jemalloc only for musl-64 bits platforms
11#[cfg(all(target_env = "musl", target_pointer_width = "64"))]
12#[global_allocator]
13static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
14
15type Frame = Vec<u8>;
16type ReplySender = oneshot::Sender<Frame>;
17
18#[derive(Debug)]
19enum Message {
20    Connection,
21    Disconnection,
22    Packet(Frame, ReplySender),
23}
24
25type ChannelRx = mpsc::Receiver<Message>;
26type ChannelTx = mpsc::Sender<Message>;
27
28type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
29type Result<T> = std::result::Result<T, Error>;
30
31type TcpReader = BufReader<tokio::net::tcp::OwnedReadHalf>;
32type TcpWriter = BufWriter<tokio::net::tcp::OwnedWriteHalf>;
33
34fn frame_size(frame: &[u8]) -> Result<usize> {
35    Ok(u16::from_be_bytes(frame[4..6].try_into()?) as usize)
36}
37
38fn split_connection(stream: TcpStream) -> (TcpReader, TcpWriter) {
39    let (reader, writer) = stream.into_split();
40    (BufReader::new(reader), BufWriter::new(writer))
41}
42
43async fn create_connection(url: &str) -> Result<(TcpReader, TcpWriter)> {
44    let stream = TcpStream::connect(url).await?;
45    stream.set_nodelay(true)?;
46    Ok(split_connection(stream))
47}
48
49async fn read_frame(stream: &mut TcpReader) -> Result<Frame> {
50    let mut buf = vec![0u8; 6];
51    // Read header
52    stream.read_exact(&mut buf).await?;
53    // calculate payload size
54    let total_size = 6 + frame_size(&buf)?;
55    buf.resize(total_size, 0);
56    stream.read_exact(&mut buf[6..total_size]).await?;
57    Ok(buf)
58}
59
60#[derive(Debug, Deserialize)]
61struct Listen {
62    bind: String,
63}
64
65#[derive(Debug, Deserialize)]
66struct Modbus {
67    url: String,
68}
69
70struct Device {
71    url: String,
72    stream: Option<(TcpReader, TcpWriter)>,
73}
74
75impl Device {
76    pub fn new(url: &str) -> Device {
77        Device {
78            url: url.to_string(),
79            stream: None,
80        }
81    }
82
83    async fn connect(&mut self) -> Result<()> {
84        match create_connection(&self.url).await {
85            Ok(connection) => {
86                info!("modbus connection to {} sucessfull", self.url);
87                self.stream = Some(connection);
88                Ok(())
89            }
90            Err(error) => {
91                self.stream = None;
92                info!("modbus connection to {} error: {} ", self.url, error);
93                Err(error)
94            }
95        }
96    }
97
98    fn disconnect(&mut self) {
99        self.stream = None;
100    }
101
102    fn is_connected(&self) -> bool {
103        self.stream.is_some()
104    }
105
106    async fn raw_write_read(&mut self, frame: &Frame) -> Result<Frame> {
107        let (reader, writer) = self.stream.as_mut().ok_or("no modbus connection")?;
108        writer.write_all(&frame).await?;
109        writer.flush().await?;
110        read_frame(reader).await
111    }
112
113    async fn write_read(&mut self, frame: &Frame) -> Result<Frame> {
114        if self.is_connected() {
115            let result = self.raw_write_read(&frame).await;
116            match result {
117                Ok(reply) => Ok(reply),
118                Err(error) => {
119                    warn!("modbus error: {}. Retrying...", error);
120                    self.connect().await?;
121                    self.raw_write_read(&frame).await
122                }
123            }
124        } else {
125            self.connect().await?;
126            self.raw_write_read(&frame).await
127        }
128    }
129
130    async fn handle_packet(&mut self, frame: Frame, channel: ReplySender) -> Result<()> {
131        info!("modbus request {}: {} bytes", self.url, frame.len());
132        debug!("modbus request {}: {:?}", self.url, &frame[..]);
133        let reply = self.write_read(&frame).await?;
134        info!("modbus reply {}: {} bytes", self.url, reply.len());
135        debug!("modbus reply {}: {:?}", self.url, &reply[..]);
136        channel
137            .send(reply)
138            .or_else(|error| Err(format!("error sending reply to client: {:?}", error).into()))
139    }
140
141    async fn run(&mut self, channel: &mut ChannelRx) {
142        let mut nb_clients = 0;
143
144        while let Some(message) = channel.recv().await {
145            match message {
146                Message::Connection => {
147                    nb_clients += 1;
148                    info!("new client connection (active = {})", nb_clients);
149                }
150                Message::Disconnection => {
151                    nb_clients -= 1;
152                    info!("client disconnection (active = {})", nb_clients);
153                    if nb_clients == 0 {
154                        info!("disconnecting from modbus at {} (no clients)", self.url);
155                        self.disconnect();
156                    }
157                }
158                Message::Packet(frame, channel) => {
159                    if let Err(_) = self.handle_packet(frame, channel).await {
160                        self.disconnect();
161                    }
162                }
163            }
164        }
165    }
166
167    async fn launch(url: &str, channel: &mut ChannelRx) {
168        let mut modbus = Self::new(url);
169        modbus.run(channel).await;
170    }
171}
172
173#[derive(Debug, Deserialize)]
174struct Bridge {
175    listen: Listen,
176    modbus: Modbus,
177}
178
179impl Bridge {
180    pub async fn run(&mut self) {
181        let listener = TcpListener::bind(&self.listen.bind).await.unwrap();
182        let modbus_url = self.modbus.url.clone();
183        let (tx, mut rx) = mpsc::channel::<Message>(32);
184        tokio::spawn(async move {
185            Device::launch(&modbus_url, &mut rx).await;
186        });
187        info!(
188            "Ready to accept requests on {} to {}",
189            &self.listen.bind, &self.modbus.url
190        );
191        loop {
192            let (client, _) = listener.accept().await.unwrap();
193            let tx = tx.clone();
194            tokio::spawn(async move {
195                if let Err(err) = Self::handle_client(client, tx).await {
196                    error!("Client error: {:?}", err);
197                }
198            });
199        }
200    }
201
202    async fn handle_client(client: TcpStream, channel: ChannelTx) -> Result<()> {
203        client.set_nodelay(true)?;
204        channel.send(Message::Connection).await?;
205        let (mut reader, mut writer) = split_connection(client);
206        while let Ok(buf) = read_frame(&mut reader).await {
207            let (tx, rx) = oneshot::channel();
208            channel.send(Message::Packet(buf, tx)).await?;
209            writer.write_all(&rx.await?).await?;
210            writer.flush().await?;
211        }
212        channel.send(Message::Disconnection).await?;
213        Ok(())
214    }
215}
216
217#[derive(Debug, Deserialize)]
218pub struct Server {
219    devices: Vec<Bridge>,
220}
221
222impl Server {
223    pub fn new(config_file: &str) -> std::result::Result<Self, config::ConfigError> {
224        let mut cfg = config::Config::new();
225        cfg.merge(config::File::with_name(config_file))?;
226        cfg.try_into()
227    }
228
229    pub async fn run(self) {
230        let mut tasks = vec![];
231        for mut bridge in self.devices {
232            tasks.push(tokio::spawn(async move { bridge.run().await }));
233        }
234        join_all(tasks).await;
235    }
236
237    pub async fn launch(config_file: &str) -> std::result::Result<(), config::ConfigError> {
238        Ok(Self::new(config_file)?.run().await)
239    }
240}