1use crate::comm::{Command, Frame, Packet, RawDataHeader};
2use crate::context::RpdoContext;
3use crate::error::Error;
4use crate::host::SyncHost;
5use crate::Result;
6use binrw::prelude::*;
7use std::io::{Cursor, Read, Write};
8use std::mem;
9use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
10use std::time::Duration;
11
12const MAX_UDP_PACKET_SIZE: usize = 16384;
13
14const DEFAULT_ZERO_COPY_AFTER: usize = 32768;
15
16pub struct UdpStream {
18 socket: UdpSocket,
19 peer: Option<SocketAddr>,
20 read_buffer: Vec<u8>,
21 write_buffer: Vec<u8>,
22 mtu: usize,
23 read_timeout: Option<Duration>,
24 write_timeout: Option<Duration>,
25}
26
27impl UdpStream {
28 pub fn create(bind: impl ToSocketAddrs) -> Result<Self> {
30 let socket = UdpSocket::bind(bind)?;
31 Ok(Self {
32 socket,
33 peer: None,
34 read_buffer: Vec::new(),
35 write_buffer: Vec::new(),
36 mtu: MAX_UDP_PACKET_SIZE,
37 read_timeout: None,
38 write_timeout: None,
39 })
40 }
41
42 pub fn with_read_timeout(mut self, timeout: Duration) -> Result<Self> {
44 self.socket.set_read_timeout(Some(timeout))?;
45 self.read_timeout = Some(timeout);
46 Ok(self)
47 }
48
49 pub fn with_write_timeout(mut self, timeout: Duration) -> Result<Self> {
51 self.socket.set_write_timeout(Some(timeout))?;
52 self.write_timeout = Some(timeout);
53 Ok(self)
54 }
55
56 pub fn with_timeouts(
58 mut self,
59 read_timeout: Duration,
60 write_timeout: Duration,
61 ) -> Result<Self> {
62 self.socket.set_read_timeout(Some(read_timeout))?;
63 self.socket.set_write_timeout(Some(write_timeout))?;
64 self.read_timeout = Some(read_timeout);
65 self.write_timeout = Some(write_timeout);
66 Ok(self)
67 }
68
69 pub fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
71 self.socket.set_read_timeout(timeout)?;
72 self.read_timeout = timeout;
73 Ok(())
74 }
75
76 pub fn set_write_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
78 self.socket.set_write_timeout(timeout)?;
79 self.write_timeout = timeout;
80 Ok(())
81 }
82
83 pub fn read_timeout(&self) -> Option<Duration> {
85 self.read_timeout
86 }
87
88 pub fn write_timeout(&self) -> Option<Duration> {
90 self.write_timeout
91 }
92
93 pub fn try_with_mtu(mut self, max_packet_size: usize) -> Result<Self> {
95 if max_packet_size > MAX_UDP_PACKET_SIZE {
96 return Err(Error::Io(std::io::Error::new(
97 std::io::ErrorKind::InvalidInput,
98 "MTU too large",
99 )));
100 }
101 self.mtu = max_packet_size;
102 Ok(self)
103 }
104 pub fn set_peer(&mut self, peer: impl ToSocketAddrs) -> Result<()> {
106 let peer = peer
107 .to_socket_addrs()?
108 .next()
109 .ok_or(Error::Io(std::io::Error::new(
110 std::io::ErrorKind::InvalidInput,
111 "Invalid peer address",
112 )))?;
113 self.peer = Some(peer);
114 Ok(())
115 }
116}
117
118impl Read for UdpStream {
119 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
120 if self.read_buffer.is_empty() {
121 let mut buf = [0; MAX_UDP_PACKET_SIZE];
123 let (size, addr) = self.socket.recv_from(&mut buf)?;
124 self.read_buffer.extend_from_slice(&buf[..size]);
125 self.peer = Some(addr);
126 }
127 let size = std::cmp::min(buf.len(), self.read_buffer.len());
128 buf[..size].copy_from_slice(&self.read_buffer[..size]);
129 self.read_buffer.drain(..size);
130 Ok(size)
131 }
132}
133
134impl Write for UdpStream {
135 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
136 self.write_buffer.extend_from_slice(buf);
137 Ok(buf.len())
138 }
139
140 fn flush(&mut self) -> std::io::Result<()> {
141 let data = mem::take(&mut self.write_buffer);
142 let Some(peer) = self.peer else {
143 return Err(std::io::Error::new(
144 std::io::ErrorKind::NotConnected,
145 "No peer address",
146 ));
147 };
148 if data.len() > self.mtu {
149 return Err(std::io::Error::new(
150 std::io::ErrorKind::InvalidInput,
151 "Data too large",
152 ));
153 }
154 self.socket.send_to(&data, peer)?;
155 Ok(())
156 }
157}
158
159pub struct SimpleClient<S>
161where
162 S: Read + Write,
163{
164 request_id: u32,
165 stream: S,
166 target_id: u32,
167 data_buf: Vec<u8>,
168 zero_copy_after: usize,
169 always_flush: bool,
170}
171
172impl<S> SimpleClient<S>
173where
174 S: Read + Write,
175{
176 pub fn new(stream: S, target_id: u32) -> Self {
178 Self {
179 request_id: 0,
180 stream,
181 target_id,
182 data_buf: Vec::new(),
183 zero_copy_after: DEFAULT_ZERO_COPY_AFTER,
184 always_flush: true,
185 }
186 }
187 pub fn with_zero_copy_after(mut self, zero_copy_after: usize) -> Self {
189 self.zero_copy_after = zero_copy_after;
190 self
191 }
192 pub fn with_always_flush(mut self, always_flush: bool) -> Self {
194 self.always_flush = always_flush;
195 self
196 }
197 pub fn ping(&mut self) -> Result<()> {
199 self.communicate(Command::Ping, &[], true)?;
200 Ok(())
201 }
202 pub fn read_register(&mut self, register: u32, offset: u32, size: u32) -> Result<Vec<u8>> {
204 let raw_data_header = RawDataHeader {
205 register,
206 offset,
207 size,
208 };
209 let mut buf = Cursor::new(Vec::new());
210 raw_data_header.write(&mut buf)?;
211 let Some(v) = self.communicate(Command::ReadSharedContext, buf.get_ref(), true)? else {
212 return Err(Error::InvalidReply);
213 };
214 Ok(v)
215 }
216 pub fn write_register(&mut self, register: u32, offset: u32, data: &[u8]) -> Result<()> {
218 let raw_data_header = RawDataHeader {
219 register,
220 offset,
221 size: u32::try_from(data.len())?,
222 };
223 let mut buf = Cursor::new(Vec::new());
224 raw_data_header.write(&mut buf)?;
225 buf.write_all(data)?;
226 self.communicate(Command::WriteSharedContext, buf.get_ref(), true)?;
227 Ok(())
228 }
229 pub fn communicate(
231 &mut self,
232 command: Command,
233 data: &[u8],
234 wait_reply: bool,
235 ) -> Result<Option<Vec<u8>>> {
236 let request_id = self.request_id;
237 self.request_id += 1;
238 let frame = Frame {
239 source: 0,
240 target: self.target_id,
241 id: request_id,
242 in_reply_to: 0,
243 command,
244 };
245 let packet = Packet::new(frame, data.len());
246 if data.len() > self.zero_copy_after {
247 packet.write_to(&mut self.stream)?;
248 self.stream.write_all(data)?;
249 self.stream.flush()?;
250 } else {
251 self.data_buf.reserve(packet.size_full());
252 self.data_buf.clear();
253 packet.write_to(&mut Cursor::new(&mut self.data_buf))?;
254 self.data_buf.extend(data);
255 self.stream.write_all(&self.data_buf)?;
256 if self.always_flush {
257 self.stream.flush()?;
258 }
259 }
260 if !wait_reply {
261 return Ok(None);
262 }
263 let packet = Packet::read_from(&mut self.stream)?;
264 let data_len = packet.data_len();
265 self.data_buf.resize(data_len, 0);
266 self.stream.read_exact(&mut self.data_buf)?;
267 let frame = packet.frame();
268 if frame.target != 0 || frame.in_reply_to != request_id {
269 return Err(Error::InvalidReply);
270 }
271 Ok(Some(self.data_buf.clone()))
272 }
273}
274
275pub struct SimpleServerProcessor<CTX, HOST, S>
277where
278 CTX: RpdoContext,
279 HOST: SyncHost<Context = CTX>,
280 S: Read + Write,
281{
282 host: HOST,
283 stream: S,
284 data_buf: Vec<u8>,
285 zero_copy_after: usize,
286 always_flush: bool,
287}
288
289impl<CTX, HOST, S> SimpleServerProcessor<CTX, HOST, S>
290where
291 CTX: RpdoContext,
292 HOST: SyncHost<Context = CTX>,
293 S: Read + Write,
294{
295 pub fn new(host: HOST, stream: S) -> Self
297 where
298 HOST: SyncHost,
299 {
300 Self {
301 host,
302 stream,
303 data_buf: Vec::new(),
304 zero_copy_after: DEFAULT_ZERO_COPY_AFTER,
305 always_flush: true,
306 }
307 }
308
309 pub fn with_zero_copy_after(mut self, zero_copy_after: usize) -> Self {
311 self.zero_copy_after = zero_copy_after;
312 self
313 }
314
315 pub fn with_always_flush(mut self, always_flush: bool) -> Self {
317 self.always_flush = always_flush;
318 self
319 }
320
321 pub fn process_next(&mut self) -> Result<()> {
323 let packet = Packet::read_from(&mut self.stream)?;
324 self.data_buf.resize(packet.data_len(), 0);
325 self.stream.read_exact(&mut self.data_buf)?;
326 let frame = packet.frame();
327 if let Some((reply, data)) = self.host.process_frame(frame, &self.data_buf)? {
328 let packet = Packet::new(reply, data.len());
329 if data.len() > self.zero_copy_after {
330 packet.write_to(&mut self.stream)?;
331 self.stream.write_all(&data)?;
332 self.stream.flush()?;
333 } else {
334 self.data_buf.reserve(packet.size_full());
335 self.data_buf.clear();
336 packet.write_to(&mut Cursor::new(&mut self.data_buf))?;
337 self.data_buf.extend(data);
338 self.stream.write_all(&self.data_buf)?;
339 if self.always_flush {
340 self.stream.flush()?;
341 }
342 }
343 }
344 Ok(())
345 }
346}