1use std::io::{self};
7use std::net::{SocketAddr, UdpSocket};
8use std::thread;
9
10use rns_core::transport::types::InterfaceId;
11
12use crate::event::{Event, EventSender};
13use crate::interface::Writer;
14
15#[derive(Debug, Clone)]
17pub struct UdpConfig {
18 pub name: String,
19 pub listen_ip: Option<String>,
20 pub listen_port: Option<u16>,
21 pub forward_ip: Option<String>,
22 pub forward_port: Option<u16>,
23 pub interface_id: InterfaceId,
24}
25
26impl Default for UdpConfig {
27 fn default() -> Self {
28 UdpConfig {
29 name: String::new(),
30 listen_ip: None,
31 listen_port: None,
32 forward_ip: None,
33 forward_port: None,
34 interface_id: InterfaceId(0),
35 }
36 }
37}
38
39struct UdpWriter {
41 socket: UdpSocket,
42 target: SocketAddr,
43}
44
45impl Writer for UdpWriter {
46 fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
47 self.socket.send_to(data, self.target)?;
48 Ok(())
49 }
50}
51
52pub fn start(config: UdpConfig, tx: EventSender) -> io::Result<Option<Box<dyn Writer>>> {
55 let id = config.interface_id;
56 let mut writer: Option<Box<dyn Writer>> = None;
57
58 if let (Some(ref fwd_ip), Some(fwd_port)) = (&config.forward_ip, config.forward_port) {
60 let target: SocketAddr = format!("{}:{}", fwd_ip, fwd_port)
61 .parse()
62 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
63
64 let send_socket = UdpSocket::bind("0.0.0.0:0")?;
65 send_socket.set_broadcast(true)?;
66
67 writer = Some(Box::new(UdpWriter {
68 socket: send_socket,
69 target,
70 }));
71 }
72
73 if let (Some(ref bind_ip), Some(bind_port)) = (&config.listen_ip, config.listen_port) {
75 let bind_addr = format!("{}:{}", bind_ip, bind_port);
76 let recv_socket = UdpSocket::bind(&bind_addr)?;
77
78 log::info!("[{}] UDP listening on {}", config.name, bind_addr);
79
80 let _ = tx.send(Event::InterfaceUp(id, None, None));
82
83 let name = config.name.clone();
84 thread::Builder::new()
85 .name(format!("udp-reader-{}", id.0))
86 .spawn(move || {
87 udp_reader_loop(recv_socket, id, name, tx);
88 })?;
89 }
90
91 Ok(writer)
92}
93
94fn udp_reader_loop(socket: UdpSocket, id: InterfaceId, name: String, tx: EventSender) {
96 let mut buf = [0u8; 2048];
97
98 loop {
99 match socket.recv_from(&mut buf) {
100 Ok((n, _src)) => {
101 if tx
102 .send(Event::Frame {
103 interface_id: id,
104 data: buf[..n].to_vec(),
105 })
106 .is_err()
107 {
108 return;
110 }
111 }
112 Err(e) => {
113 log::warn!("[{}] recv error: {}", name, e);
114 let _ = tx.send(Event::InterfaceDown(id));
115 return;
116 }
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use std::net::UdpSocket;
125 use std::sync::mpsc;
126 use std::time::Duration;
127
128 fn find_free_port() -> u16 {
129 std::net::TcpListener::bind("127.0.0.1:0")
130 .unwrap()
131 .local_addr()
132 .unwrap()
133 .port()
134 }
135
136 #[test]
137 fn bind_and_receive() {
138 let port = find_free_port();
139 let (tx, rx) = mpsc::channel();
140
141 let config = UdpConfig {
142 name: "test-udp".into(),
143 listen_ip: Some("127.0.0.1".into()),
144 listen_port: Some(port),
145 forward_ip: None,
146 forward_port: None,
147 interface_id: InterfaceId(10),
148 };
149
150 let _writer = start(config, tx).unwrap();
151
152 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
154
155 let sender = UdpSocket::bind("127.0.0.1:0").unwrap();
157 let payload = b"hello udp";
158 sender
159 .send_to(payload, format!("127.0.0.1:{}", port))
160 .unwrap();
161
162 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
164 match event {
165 Event::Frame { interface_id, data } => {
166 assert_eq!(interface_id, InterfaceId(10));
167 assert_eq!(data, payload);
168 }
169 other => panic!("expected Frame, got {:?}", other),
170 }
171 }
172
173 #[test]
174 fn send_broadcast() {
175 let recv_port = find_free_port();
176 let (tx, _rx) = mpsc::channel();
177
178 let config = UdpConfig {
179 name: "test-udp-send".into(),
180 listen_ip: None,
181 listen_port: None,
182 forward_ip: Some("127.0.0.1".into()),
183 forward_port: Some(recv_port),
184 interface_id: InterfaceId(11),
185 };
186
187 let writer = start(config, tx).unwrap();
188 let mut writer = writer.unwrap();
189
190 let receiver = UdpSocket::bind(format!("127.0.0.1:{}", recv_port)).unwrap();
192 receiver
193 .set_read_timeout(Some(Duration::from_secs(2)))
194 .unwrap();
195
196 let payload = b"broadcast data";
198 writer.send_frame(payload).unwrap();
199
200 let mut buf = [0u8; 256];
202 let (n, _) = receiver.recv_from(&mut buf).unwrap();
203 assert_eq!(&buf[..n], payload);
204 }
205
206 #[test]
207 fn round_trip() {
208 let listen_port = find_free_port();
209 let forward_port = find_free_port();
210 let (tx, rx) = mpsc::channel();
211
212 let config = UdpConfig {
213 name: "test-udp-rt".into(),
214 listen_ip: Some("127.0.0.1".into()),
215 listen_port: Some(listen_port),
216 forward_ip: Some("127.0.0.1".into()),
217 forward_port: Some(forward_port),
218 interface_id: InterfaceId(12),
219 };
220
221 let writer = start(config, tx).unwrap();
222 assert!(writer.is_some());
223
224 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
226
227 let sender = UdpSocket::bind("127.0.0.1:0").unwrap();
229 sender
230 .send_to(b"ping", format!("127.0.0.1:{}", listen_port))
231 .unwrap();
232
233 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
234 match event {
235 Event::Frame { data, .. } => assert_eq!(data, b"ping"),
236 other => panic!("expected Frame, got {:?}", other),
237 }
238 }
239
240 #[test]
241 fn multiple_datagrams() {
242 let port = find_free_port();
243 let (tx, rx) = mpsc::channel();
244
245 let config = UdpConfig {
246 name: "test-udp-multi".into(),
247 listen_ip: Some("127.0.0.1".into()),
248 listen_port: Some(port),
249 forward_ip: None,
250 forward_port: None,
251 interface_id: InterfaceId(13),
252 };
253
254 let _writer = start(config, tx).unwrap();
255
256 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
258
259 let sender = UdpSocket::bind("127.0.0.1:0").unwrap();
260 for i in 0..5u8 {
261 sender
262 .send_to(&[i], format!("127.0.0.1:{}", port))
263 .unwrap();
264 }
265
266 for i in 0..5u8 {
267 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
268 match event {
269 Event::Frame { data, .. } => assert_eq!(data, vec![i]),
270 other => panic!("expected Frame, got {:?}", other),
271 }
272 }
273 }
274
275 #[test]
276 fn writer_send_to() {
277 let recv_port = find_free_port();
278
279 let receiver = UdpSocket::bind(format!("127.0.0.1:{}", recv_port)).unwrap();
281 receiver
282 .set_read_timeout(Some(Duration::from_secs(2)))
283 .unwrap();
284
285 let send_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
287 send_socket.set_broadcast(true).unwrap();
288 let target: SocketAddr = format!("127.0.0.1:{}", recv_port).parse().unwrap();
289 let mut writer = UdpWriter {
290 socket: send_socket,
291 target,
292 };
293
294 let payload = vec![0xAA, 0xBB, 0xCC];
295 writer.send_frame(&payload).unwrap();
296
297 let mut buf = [0u8; 256];
298 let (n, _) = receiver.recv_from(&mut buf).unwrap();
299 assert_eq!(&buf[..n], &payload);
300 }
301}