1use crate::udp::UdpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::IpAddr;
5use std::net::{SocketAddr, UdpSocket as StdUdpSocket};
6
7#[derive(Debug)]
9pub struct UdpSocket {
10 socket: Socket,
11}
12
13#[derive(Clone, Debug, Eq, PartialEq)]
15pub struct UdpRecvMeta {
16 pub bytes_read: usize,
18 pub source_addr: SocketAddr,
20 pub destination_addr: Option<IpAddr>,
22 pub interface_index: Option<u32>,
24}
25
26#[derive(Clone, Debug, Default, Eq, PartialEq)]
28pub struct UdpSendMeta {
29 pub source_addr: Option<IpAddr>,
31 pub interface_index: Option<u32>,
33}
34
35impl UdpSocket {
36 pub fn from_config(config: &UdpConfig) -> io::Result<Self> {
38 config.validate()?;
39
40 let socket = Socket::new(
41 config.socket_family.to_domain(),
42 config.socket_type.to_sock_type(),
43 Some(Protocol::UDP),
44 )?;
45
46 socket.set_nonblocking(false)?;
47
48 if let Some(flag) = config.reuseaddr {
50 socket.set_reuse_address(flag)?;
51 }
52 #[cfg(any(
53 target_os = "android",
54 target_os = "dragonfly",
55 target_os = "freebsd",
56 target_os = "fuchsia",
57 target_os = "ios",
58 target_os = "linux",
59 target_os = "macos",
60 target_os = "netbsd",
61 target_os = "openbsd",
62 target_os = "tvos",
63 target_os = "visionos",
64 target_os = "watchos"
65 ))]
66 if let Some(flag) = config.reuseport {
67 socket.set_reuse_port(flag)?;
68 }
69 if let Some(flag) = config.broadcast {
70 socket.set_broadcast(flag)?;
71 }
72 if let Some(ttl) = config.ttl {
73 socket.set_ttl(ttl)?;
74 }
75 if let Some(hoplimit) = config.hoplimit {
76 socket.set_unicast_hops_v6(hoplimit)?;
77 }
78 if let Some(timeout) = config.read_timeout {
79 socket.set_read_timeout(Some(timeout))?;
80 }
81 if let Some(timeout) = config.write_timeout {
82 socket.set_write_timeout(Some(timeout))?;
83 }
84 if let Some(size) = config.recv_buffer_size {
85 socket.set_recv_buffer_size(size)?;
86 }
87 if let Some(size) = config.send_buffer_size {
88 socket.set_send_buffer_size(size)?;
89 }
90 if let Some(tos) = config.tos {
91 socket.set_tos(tos)?;
92 }
93 #[cfg(any(
94 target_os = "android",
95 target_os = "dragonfly",
96 target_os = "freebsd",
97 target_os = "fuchsia",
98 target_os = "ios",
99 target_os = "linux",
100 target_os = "macos",
101 target_os = "netbsd",
102 target_os = "openbsd",
103 target_os = "tvos",
104 target_os = "visionos",
105 target_os = "watchos"
106 ))]
107 if let Some(tclass) = config.tclass_v6 {
108 socket.set_tclass_v6(tclass)?;
109 }
110 if let Some(only_v6) = config.only_v6 {
111 socket.set_only_v6(only_v6)?;
112 }
113 if let Some(on) = config.recv_pktinfo {
114 crate::udp::set_recv_pktinfo(&socket, config.socket_family, on)?;
115 }
116
117 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
119 if let Some(iface) = &config.bind_device {
120 socket.bind_device(Some(iface.as_bytes()))?;
121 }
122
123 if let Some(addr) = config.bind_addr {
125 socket.bind(&addr.into())?;
126 }
127
128 Ok(Self { socket })
129 }
130
131 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
133 let socket = Socket::new(domain, sock_type, Some(Protocol::UDP))?;
134 socket.set_nonblocking(false)?;
135 Ok(Self { socket })
136 }
137
138 pub fn v4_dgram() -> io::Result<Self> {
140 Self::new(Domain::IPV4, SockType::DGRAM)
141 }
142
143 pub fn v6_dgram() -> io::Result<Self> {
145 Self::new(Domain::IPV6, SockType::DGRAM)
146 }
147
148 pub fn raw_v4() -> io::Result<Self> {
150 Self::new(Domain::IPV4, SockType::RAW)
151 }
152
153 pub fn raw_v6() -> io::Result<Self> {
155 Self::new(Domain::IPV6, SockType::RAW)
156 }
157
158 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
160 self.socket.send_to(buf, &target.into())
161 }
162
163 #[cfg(unix)]
168 pub fn send_msg(
169 &self,
170 buf: &[u8],
171 target: SocketAddr,
172 meta: Option<&UdpSendMeta>,
173 ) -> io::Result<usize> {
174 use nix::sys::socket::{ControlMessage, MsgFlags, SockaddrIn, SockaddrIn6, sendmsg};
175 use std::io::IoSlice;
176 use std::os::fd::AsRawFd;
177
178 let iov = [IoSlice::new(buf)];
179 let raw_fd = self.socket.as_raw_fd();
180
181 match target {
182 SocketAddr::V4(addr) => {
183 let sockaddr = SockaddrIn::from(addr);
184 #[cfg(any(
185 target_os = "android",
186 target_os = "linux",
187 target_os = "netbsd",
188 target_vendor = "apple"
189 ))]
190 {
191 if let Some(meta) = meta {
192 if meta.source_addr.is_some() || meta.interface_index.is_some() {
193 if let Some(src) = meta.source_addr {
194 if !src.is_ipv4() {
195 return Err(io::Error::new(
196 io::ErrorKind::InvalidInput,
197 "source_addr family does not match target",
198 ));
199 }
200 }
201 let mut pktinfo: libc::in_pktinfo = unsafe { std::mem::zeroed() };
202 if let Some(src) = meta.source_addr.and_then(|ip| match ip {
203 IpAddr::V4(v4) => Some(v4),
204 IpAddr::V6(_) => None,
205 }) {
206 pktinfo.ipi_spec_dst.s_addr = u32::from_ne_bytes(src.octets());
207 }
208 if let Some(ifindex) = meta.interface_index {
209 pktinfo.ipi_ifindex = ifindex.try_into().map_err(|_| {
210 io::Error::new(
211 io::ErrorKind::InvalidInput,
212 "interface_index is out of range for this platform",
213 )
214 })?;
215 }
216 let cmsgs = [ControlMessage::Ipv4PacketInfo(&pktinfo)];
217 return sendmsg(
218 raw_fd,
219 &iov,
220 &cmsgs,
221 MsgFlags::empty(),
222 Some(&sockaddr),
223 )
224 .map_err(|e| io::Error::from_raw_os_error(e as i32));
225 }
226 }
227 }
228 if let Some(meta) = meta {
229 if meta.source_addr.is_some() || meta.interface_index.is_some() {
230 return Err(io::Error::new(
231 io::ErrorKind::Unsupported,
232 "send_msg packet-info metadata is not supported on this platform",
233 ));
234 }
235 }
236 sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr))
237 .map_err(|e| io::Error::from_raw_os_error(e as i32))
238 }
239 SocketAddr::V6(addr) => {
240 let sockaddr = SockaddrIn6::from(addr);
241 #[cfg(any(
242 target_os = "android",
243 target_os = "freebsd",
244 target_os = "linux",
245 target_os = "netbsd",
246 target_vendor = "apple"
247 ))]
248 {
249 if let Some(meta) = meta {
250 if meta.source_addr.is_some() || meta.interface_index.is_some() {
251 if let Some(src) = meta.source_addr {
252 if !src.is_ipv6() {
253 return Err(io::Error::new(
254 io::ErrorKind::InvalidInput,
255 "source_addr family does not match target",
256 ));
257 }
258 }
259 let mut pktinfo: libc::in6_pktinfo = unsafe { std::mem::zeroed() };
260 if let Some(src) = meta.source_addr.and_then(|ip| match ip {
261 IpAddr::V4(_) => None,
262 IpAddr::V6(v6) => Some(v6),
263 }) {
264 pktinfo.ipi6_addr.s6_addr = src.octets();
265 }
266 if let Some(ifindex) = meta.interface_index {
267 pktinfo.ipi6_ifindex = ifindex.try_into().map_err(|_| {
268 io::Error::new(
269 io::ErrorKind::InvalidInput,
270 "interface_index is out of range for this platform",
271 )
272 })?;
273 }
274 let cmsgs = [ControlMessage::Ipv6PacketInfo(&pktinfo)];
275 return sendmsg(
276 raw_fd,
277 &iov,
278 &cmsgs,
279 MsgFlags::empty(),
280 Some(&sockaddr),
281 )
282 .map_err(|e| io::Error::from_raw_os_error(e as i32));
283 }
284 }
285 }
286 if let Some(meta) = meta {
287 if meta.source_addr.is_some() || meta.interface_index.is_some() {
288 return Err(io::Error::new(
289 io::ErrorKind::Unsupported,
290 "send_msg packet-info metadata is not supported on this platform",
291 ));
292 }
293 }
294 sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr))
295 .map_err(|e| io::Error::from_raw_os_error(e as i32))
296 }
297 }
298 }
299
300 #[cfg(not(unix))]
302 pub fn send_msg(
303 &self,
304 _buf: &[u8],
305 _target: SocketAddr,
306 _meta: Option<&UdpSendMeta>,
307 ) -> io::Result<usize> {
308 Err(io::Error::new(
309 io::ErrorKind::Unsupported,
310 "send_msg is only supported on Unix",
311 ))
312 }
313
314 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
316 let buf_maybe = unsafe {
318 std::slice::from_raw_parts_mut(
319 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
320 buf.len(),
321 )
322 };
323
324 let (n, addr) = self.socket.recv_from(buf_maybe)?;
325 let addr = addr
326 .as_socket()
327 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
328
329 Ok((n, addr))
330 }
331
332 #[cfg(unix)]
338 pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<UdpRecvMeta> {
339 use nix::sys::socket::{ControlMessageOwned, MsgFlags, SockaddrStorage, recvmsg};
340 use std::io::IoSliceMut;
341 use std::os::fd::AsRawFd;
342
343 let mut iov = [IoSliceMut::new(buf)];
344 #[cfg(any(
345 target_os = "android",
346 target_os = "fuchsia",
347 target_os = "linux",
348 target_vendor = "apple",
349 target_os = "netbsd"
350 ))]
351 let mut cmsgspace = nix::cmsg_space!(libc::in_pktinfo, libc::in6_pktinfo);
352 #[cfg(all(
353 not(any(
354 target_os = "android",
355 target_os = "fuchsia",
356 target_os = "linux",
357 target_vendor = "apple",
358 target_os = "netbsd"
359 )),
360 any(target_os = "freebsd", target_os = "openbsd")
361 ))]
362 let mut cmsgspace = nix::cmsg_space!(libc::in6_pktinfo);
363 #[cfg(all(
364 not(any(
365 target_os = "android",
366 target_os = "fuchsia",
367 target_os = "linux",
368 target_vendor = "apple",
369 target_os = "netbsd"
370 )),
371 not(any(target_os = "freebsd", target_os = "openbsd"))
372 ))]
373 let mut cmsgspace = nix::cmsg_space!(libc::c_int);
374 let msg = recvmsg::<SockaddrStorage>(
375 self.socket.as_raw_fd(),
376 &mut iov,
377 Some(&mut cmsgspace),
378 MsgFlags::empty(),
379 )
380 .map_err(|e| io::Error::from_raw_os_error(e as i32))?;
381
382 let source_addr = msg
383 .address
384 .and_then(|addr: SockaddrStorage| {
385 if let Some(v4) = addr.as_sockaddr_in() {
386 return Some(SocketAddr::from(*v4));
387 }
388 if let Some(v6) = addr.as_sockaddr_in6() {
389 return Some(SocketAddr::from(*v6));
390 }
391 None
392 })
393 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid source address"))?;
394
395 let mut destination_addr = None;
396 let mut interface_index = None;
397
398 if let Ok(cmsgs) = msg.cmsgs() {
399 for cmsg in cmsgs {
400 match cmsg {
401 #[cfg(any(
402 target_os = "android",
403 target_os = "fuchsia",
404 target_os = "linux",
405 target_vendor = "apple",
406 target_os = "netbsd"
407 ))]
408 ControlMessageOwned::Ipv4PacketInfo(info) => {
409 destination_addr = Some(IpAddr::V4(std::net::Ipv4Addr::from(
410 info.ipi_addr.s_addr.to_ne_bytes(),
411 )));
412 interface_index = Some(info.ipi_ifindex.try_into().map_err(|_| {
413 io::Error::new(
414 io::ErrorKind::InvalidData,
415 "received invalid interface index",
416 )
417 })?);
418 }
419 #[cfg(any(
420 target_os = "android",
421 target_os = "freebsd",
422 target_os = "linux",
423 target_os = "macos",
424 target_os = "ios",
425 target_os = "tvos",
426 target_os = "visionos",
427 target_os = "watchos",
428 target_os = "netbsd",
429 target_os = "openbsd"
430 ))]
431 ControlMessageOwned::Ipv6PacketInfo(info) => {
432 destination_addr =
433 Some(IpAddr::V6(std::net::Ipv6Addr::from(info.ipi6_addr.s6_addr)));
434 interface_index = Some(info.ipi6_ifindex.try_into().map_err(|_| {
435 io::Error::new(
436 io::ErrorKind::InvalidData,
437 "received invalid interface index",
438 )
439 })?);
440 }
441 _ => {}
442 }
443 }
444 }
445
446 Ok(UdpRecvMeta {
447 bytes_read: msg.bytes,
448 source_addr,
449 destination_addr,
450 interface_index,
451 })
452 }
453
454 #[cfg(not(unix))]
456 pub fn recv_msg(&self, _buf: &mut [u8]) -> io::Result<UdpRecvMeta> {
457 Err(io::Error::new(
458 io::ErrorKind::Unsupported,
459 "recv_msg is only supported on Unix",
460 ))
461 }
462
463 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
464 self.socket.set_ttl(ttl)
465 }
466
467 pub fn ttl(&self) -> io::Result<u32> {
468 self.socket.ttl()
469 }
470
471 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
472 self.socket.set_unicast_hops_v6(hops)
473 }
474
475 pub fn hoplimit(&self) -> io::Result<u32> {
476 self.socket.unicast_hops_v6()
477 }
478
479 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
480 self.socket.set_reuse_address(on)
481 }
482
483 pub fn reuseaddr(&self) -> io::Result<bool> {
484 self.socket.reuse_address()
485 }
486
487 #[cfg(any(
488 target_os = "android",
489 target_os = "dragonfly",
490 target_os = "freebsd",
491 target_os = "fuchsia",
492 target_os = "ios",
493 target_os = "linux",
494 target_os = "macos",
495 target_os = "netbsd",
496 target_os = "openbsd",
497 target_os = "tvos",
498 target_os = "visionos",
499 target_os = "watchos"
500 ))]
501 pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
502 self.socket.set_reuse_port(on)
503 }
504
505 #[cfg(any(
506 target_os = "android",
507 target_os = "dragonfly",
508 target_os = "freebsd",
509 target_os = "fuchsia",
510 target_os = "ios",
511 target_os = "linux",
512 target_os = "macos",
513 target_os = "netbsd",
514 target_os = "openbsd",
515 target_os = "tvos",
516 target_os = "visionos",
517 target_os = "watchos"
518 ))]
519 pub fn reuseport(&self) -> io::Result<bool> {
520 self.socket.reuse_port()
521 }
522
523 pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
524 self.socket.set_broadcast(on)
525 }
526
527 pub fn broadcast(&self) -> io::Result<bool> {
528 self.socket.broadcast()
529 }
530
531 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
532 self.socket.set_recv_buffer_size(size)
533 }
534
535 pub fn recv_buffer_size(&self) -> io::Result<usize> {
536 self.socket.recv_buffer_size()
537 }
538
539 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
540 self.socket.set_send_buffer_size(size)
541 }
542
543 pub fn send_buffer_size(&self) -> io::Result<usize> {
544 self.socket.send_buffer_size()
545 }
546
547 pub fn set_tos(&self, tos: u32) -> io::Result<()> {
548 self.socket.set_tos(tos)
549 }
550
551 pub fn tos(&self) -> io::Result<u32> {
552 self.socket.tos()
553 }
554
555 #[cfg(any(
556 target_os = "android",
557 target_os = "dragonfly",
558 target_os = "freebsd",
559 target_os = "fuchsia",
560 target_os = "ios",
561 target_os = "linux",
562 target_os = "macos",
563 target_os = "netbsd",
564 target_os = "openbsd",
565 target_os = "tvos",
566 target_os = "visionos",
567 target_os = "watchos"
568 ))]
569 pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
570 self.socket.set_tclass_v6(tclass)
571 }
572
573 #[cfg(any(
574 target_os = "android",
575 target_os = "dragonfly",
576 target_os = "freebsd",
577 target_os = "fuchsia",
578 target_os = "ios",
579 target_os = "linux",
580 target_os = "macos",
581 target_os = "netbsd",
582 target_os = "openbsd",
583 target_os = "tvos",
584 target_os = "visionos",
585 target_os = "watchos"
586 ))]
587 pub fn tclass_v6(&self) -> io::Result<u32> {
588 self.socket.tclass_v6()
589 }
590
591 pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
592 self.socket.set_only_v6(only_v6)
593 }
594
595 pub fn only_v6(&self) -> io::Result<bool> {
596 self.socket.only_v6()
597 }
598
599 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
600 self.socket.set_keepalive(on)
601 }
602
603 pub fn keepalive(&self) -> io::Result<bool> {
604 self.socket.keepalive()
605 }
606
607 pub fn set_recv_pktinfo_v4(&self, on: bool) -> io::Result<()> {
609 crate::udp::set_recv_pktinfo_v4(&self.socket, on)
610 }
611
612 pub fn set_recv_pktinfo_v6(&self, on: bool) -> io::Result<()> {
614 crate::udp::set_recv_pktinfo_v6(&self.socket, on)
615 }
616
617 pub fn recv_pktinfo_v4(&self) -> io::Result<bool> {
619 crate::udp::recv_pktinfo_v4(&self.socket)
620 }
621
622 pub fn recv_pktinfo_v6(&self) -> io::Result<bool> {
624 crate::udp::recv_pktinfo_v6(&self.socket)
625 }
626
627 pub fn local_addr(&self) -> io::Result<SocketAddr> {
629 self.socket
630 .local_addr()?
631 .as_socket()
632 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
633 }
634
635 pub fn to_std(self) -> io::Result<StdUdpSocket> {
637 Ok(self.socket.into())
638 }
639
640 pub fn from_socket(socket: Socket) -> Self {
642 Self { socket }
643 }
644
645 pub fn socket(&self) -> &Socket {
647 &self.socket
648 }
649
650 pub fn into_socket(self) -> Socket {
652 self.socket
653 }
654
655 #[cfg(unix)]
656 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
657 use std::os::fd::AsRawFd;
658 self.socket.as_raw_fd()
659 }
660
661 #[cfg(windows)]
662 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
663 use std::os::windows::io::AsRawSocket;
664 self.socket.as_raw_socket()
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[test]
673 fn create_v4_socket() {
674 let sock = UdpSocket::v4_dgram().expect("create socket");
675 sock.socket.bind(&"0.0.0.0:0".parse::<SocketAddr>().unwrap().into()).expect("bind");
676 let addr = sock.local_addr().expect("addr");
677 assert!(addr.is_ipv4());
678 }
679}