rpdo/
io.rs

1use crate::comm::{Command, Frame, Packet, RawDataHeader};
2use crate::context::RpdoContext;
3use crate::error::Error;
4use crate::host::SyncHost;
5use crate::Result;
6use binrw::prelude::*;
7use std::io::{Cursor, Read, Write};
8use std::mem;
9use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
10use std::time::Duration;
11
12const MAX_UDP_PACKET_SIZE: usize = 16384;
13
14const DEFAULT_ZERO_COPY_AFTER: usize = 32768;
15
16/// A helper which wraps a UDP socket into a Read/Write stream
17pub struct UdpStream {
18    socket: UdpSocket,
19    peer: Option<SocketAddr>,
20    read_buffer: Vec<u8>,
21    write_buffer: Vec<u8>,
22    mtu: usize,
23    read_timeout: Option<Duration>,
24    write_timeout: Option<Duration>,
25}
26
27impl UdpStream {
28    /// Create a new UDP stream
29    pub fn create(bind: impl ToSocketAddrs) -> Result<Self> {
30        let socket = UdpSocket::bind(bind)?;
31        Ok(Self {
32            socket,
33            peer: None,
34            read_buffer: Vec::new(),
35            write_buffer: Vec::new(),
36            mtu: MAX_UDP_PACKET_SIZE,
37            read_timeout: None,
38            write_timeout: None,
39        })
40    }
41
42    /// Set read timeout
43    pub fn with_read_timeout(mut self, timeout: Duration) -> Result<Self> {
44        self.socket.set_read_timeout(Some(timeout))?;
45        self.read_timeout = Some(timeout);
46        Ok(self)
47    }
48
49    /// Set write timeout
50    pub fn with_write_timeout(mut self, timeout: Duration) -> Result<Self> {
51        self.socket.set_write_timeout(Some(timeout))?;
52        self.write_timeout = Some(timeout);
53        Ok(self)
54    }
55
56    /// Set both timeouts
57    pub fn with_timeouts(
58        mut self,
59        read_timeout: Duration,
60        write_timeout: Duration,
61    ) -> Result<Self> {
62        self.socket.set_read_timeout(Some(read_timeout))?;
63        self.socket.set_write_timeout(Some(write_timeout))?;
64        self.read_timeout = Some(read_timeout);
65        self.write_timeout = Some(write_timeout);
66        Ok(self)
67    }
68
69    /// Set read timeout after construction
70    pub fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
71        self.socket.set_read_timeout(timeout)?;
72        self.read_timeout = timeout;
73        Ok(())
74    }
75
76    /// Set write timeout after construction
77    pub fn set_write_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
78        self.socket.set_write_timeout(timeout)?;
79        self.write_timeout = timeout;
80        Ok(())
81    }
82
83    /// Get current read timeout
84    pub fn read_timeout(&self) -> Option<Duration> {
85        self.read_timeout
86    }
87
88    /// Get current write timeout
89    pub fn write_timeout(&self) -> Option<Duration> {
90        self.write_timeout
91    }
92
93    /// Set the maximum packet size
94    pub fn try_with_mtu(mut self, max_packet_size: usize) -> Result<Self> {
95        if max_packet_size > MAX_UDP_PACKET_SIZE {
96            return Err(Error::Io(std::io::Error::new(
97                std::io::ErrorKind::InvalidInput,
98                "MTU too large",
99            )));
100        }
101        self.mtu = max_packet_size;
102        Ok(self)
103    }
104    /// Set the peer address
105    pub fn set_peer(&mut self, peer: impl ToSocketAddrs) -> Result<()> {
106        let peer = peer
107            .to_socket_addrs()?
108            .next()
109            .ok_or(Error::Io(std::io::Error::new(
110                std::io::ErrorKind::InvalidInput,
111                "Invalid peer address",
112            )))?;
113        self.peer = Some(peer);
114        Ok(())
115    }
116}
117
118impl Read for UdpStream {
119    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
120        if self.read_buffer.is_empty() {
121            // must be read in a single packet
122            let mut buf = [0; MAX_UDP_PACKET_SIZE];
123            let (size, addr) = self.socket.recv_from(&mut buf)?;
124            self.read_buffer.extend_from_slice(&buf[..size]);
125            self.peer = Some(addr);
126        }
127        let size = std::cmp::min(buf.len(), self.read_buffer.len());
128        buf[..size].copy_from_slice(&self.read_buffer[..size]);
129        self.read_buffer.drain(..size);
130        Ok(size)
131    }
132}
133
134impl Write for UdpStream {
135    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
136        self.write_buffer.extend_from_slice(buf);
137        Ok(buf.len())
138    }
139
140    fn flush(&mut self) -> std::io::Result<()> {
141        let data = mem::take(&mut self.write_buffer);
142        let Some(peer) = self.peer else {
143            return Err(std::io::Error::new(
144                std::io::ErrorKind::NotConnected,
145                "No peer address",
146            ));
147        };
148        if data.len() > self.mtu {
149            return Err(std::io::Error::new(
150                std::io::ErrorKind::InvalidInput,
151                "Data too large",
152            ));
153        }
154        self.socket.send_to(&data, peer)?;
155        Ok(())
156    }
157}
158
159/// A simple client
160pub struct SimpleClient<S>
161where
162    S: Read + Write,
163{
164    request_id: u32,
165    stream: S,
166    target_id: u32,
167    data_buf: Vec<u8>,
168    zero_copy_after: usize,
169    always_flush: bool,
170}
171
172impl<S> SimpleClient<S>
173where
174    S: Read + Write,
175{
176    /// Create a new client
177    pub fn new(stream: S, target_id: u32) -> Self {
178        Self {
179            request_id: 0,
180            stream,
181            target_id,
182            data_buf: Vec::new(),
183            zero_copy_after: DEFAULT_ZERO_COPY_AFTER,
184            always_flush: true,
185        }
186    }
187    /// If the data size is larger than this value, it will be sent in a separate write
188    pub fn with_zero_copy_after(mut self, zero_copy_after: usize) -> Self {
189        self.zero_copy_after = zero_copy_after;
190        self
191    }
192    /// Always flush after writing
193    pub fn with_always_flush(mut self, always_flush: bool) -> Self {
194        self.always_flush = always_flush;
195        self
196    }
197    /// Ping the target
198    pub fn ping(&mut self) -> Result<()> {
199        self.communicate(Command::Ping, &[], true)?;
200        Ok(())
201    }
202    /// Read a register
203    pub fn read_register(&mut self, register: u32, offset: u32, size: u32) -> Result<Vec<u8>> {
204        let raw_data_header = RawDataHeader {
205            register,
206            offset,
207            size,
208        };
209        let mut buf = Cursor::new(Vec::new());
210        raw_data_header.write(&mut buf)?;
211        let Some(v) = self.communicate(Command::ReadSharedContext, buf.get_ref(), true)? else {
212            return Err(Error::InvalidReply);
213        };
214        Ok(v)
215    }
216    /// Write a register
217    pub fn write_register(&mut self, register: u32, offset: u32, data: &[u8]) -> Result<()> {
218        let raw_data_header = RawDataHeader {
219            register,
220            offset,
221            size: u32::try_from(data.len())?,
222        };
223        let mut buf = Cursor::new(Vec::new());
224        raw_data_header.write(&mut buf)?;
225        buf.write_all(data)?;
226        self.communicate(Command::WriteSharedContext, buf.get_ref(), true)?;
227        Ok(())
228    }
229    /// Communicate with the target
230    pub fn communicate(
231        &mut self,
232        command: Command,
233        data: &[u8],
234        wait_reply: bool,
235    ) -> Result<Option<Vec<u8>>> {
236        let request_id = self.request_id;
237        self.request_id += 1;
238        let frame = Frame {
239            source: 0,
240            target: self.target_id,
241            id: request_id,
242            in_reply_to: 0,
243            command,
244        };
245        let packet = Packet::new(frame, data.len());
246        if data.len() > self.zero_copy_after {
247            packet.write_to(&mut self.stream)?;
248            self.stream.write_all(data)?;
249            self.stream.flush()?;
250        } else {
251            self.data_buf.reserve(packet.size_full());
252            self.data_buf.clear();
253            packet.write_to(&mut Cursor::new(&mut self.data_buf))?;
254            self.data_buf.extend(data);
255            self.stream.write_all(&self.data_buf)?;
256            if self.always_flush {
257                self.stream.flush()?;
258            }
259        }
260        if !wait_reply {
261            return Ok(None);
262        }
263        let packet = Packet::read_from(&mut self.stream)?;
264        let data_len = packet.data_len();
265        self.data_buf.resize(data_len, 0);
266        self.stream.read_exact(&mut self.data_buf)?;
267        let frame = packet.frame();
268        if frame.target != 0 || frame.in_reply_to != request_id {
269            return Err(Error::InvalidReply);
270        }
271        Ok(Some(self.data_buf.clone()))
272    }
273}
274
275/// A simple server processor
276pub struct SimpleServerProcessor<CTX, HOST, S>
277where
278    CTX: RpdoContext,
279    HOST: SyncHost<Context = CTX>,
280    S: Read + Write,
281{
282    host: HOST,
283    stream: S,
284    data_buf: Vec<u8>,
285    zero_copy_after: usize,
286    always_flush: bool,
287}
288
289impl<CTX, HOST, S> SimpleServerProcessor<CTX, HOST, S>
290where
291    CTX: RpdoContext,
292    HOST: SyncHost<Context = CTX>,
293    S: Read + Write,
294{
295    /// Create a new server processor
296    pub fn new(host: HOST, stream: S) -> Self
297    where
298        HOST: SyncHost,
299    {
300        Self {
301            host,
302            stream,
303            data_buf: Vec::new(),
304            zero_copy_after: DEFAULT_ZERO_COPY_AFTER,
305            always_flush: true,
306        }
307    }
308
309    /// If the data size is larger than this value, it will be sent in a separate write
310    pub fn with_zero_copy_after(mut self, zero_copy_after: usize) -> Self {
311        self.zero_copy_after = zero_copy_after;
312        self
313    }
314
315    /// Always flush after writing
316    pub fn with_always_flush(mut self, always_flush: bool) -> Self {
317        self.always_flush = always_flush;
318        self
319    }
320
321    /// Process the next packet
322    pub fn process_next(&mut self) -> Result<()> {
323        let packet = Packet::read_from(&mut self.stream)?;
324        self.data_buf.resize(packet.data_len(), 0);
325        self.stream.read_exact(&mut self.data_buf)?;
326        let frame = packet.frame();
327        if let Some((reply, data)) = self.host.process_frame(frame, &self.data_buf)? {
328            let packet = Packet::new(reply, data.len());
329            if data.len() > self.zero_copy_after {
330                packet.write_to(&mut self.stream)?;
331                self.stream.write_all(&data)?;
332                self.stream.flush()?;
333            } else {
334                self.data_buf.reserve(packet.size_full());
335                self.data_buf.clear();
336                packet.write_to(&mut Cursor::new(&mut self.data_buf))?;
337                self.data_buf.extend(data);
338                self.stream.write_all(&self.data_buf)?;
339                if self.always_flush {
340                    self.stream.flush()?;
341                }
342            }
343        }
344        Ok(())
345    }
346}