1use std::{io, net::SocketAddr, time::Duration};
2
3use socket2::Type as SockType;
4
5use crate::SocketFamily;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum UdpSocketType {
10 Dgram,
11 Raw,
12}
13
14impl UdpSocketType {
15 pub fn is_dgram(&self) -> bool {
17 matches!(self, UdpSocketType::Dgram)
18 }
19
20 pub fn is_raw(&self) -> bool {
22 matches!(self, UdpSocketType::Raw)
23 }
24
25 pub(crate) fn to_sock_type(&self) -> SockType {
27 match self {
28 UdpSocketType::Dgram => SockType::DGRAM,
29 UdpSocketType::Raw => SockType::RAW,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct UdpConfig {
37 pub socket_family: SocketFamily,
39 pub socket_type: UdpSocketType,
41 pub bind_addr: Option<SocketAddr>,
43 pub reuseaddr: Option<bool>,
45 pub reuseport: Option<bool>,
47 pub broadcast: Option<bool>,
49 pub ttl: Option<u32>,
51 pub hoplimit: Option<u32>,
53 pub read_timeout: Option<Duration>,
55 pub write_timeout: Option<Duration>,
57 pub recv_buffer_size: Option<usize>,
59 pub send_buffer_size: Option<usize>,
61 pub tos: Option<u32>,
63 pub tclass_v6: Option<u32>,
65 pub recv_pktinfo: Option<bool>,
67 pub only_v6: Option<bool>,
69 pub bind_device: Option<String>,
71}
72
73impl Default for UdpConfig {
74 fn default() -> Self {
75 Self {
76 socket_family: SocketFamily::IPV4,
77 socket_type: UdpSocketType::Dgram,
78 bind_addr: None,
79 reuseaddr: None,
80 reuseport: None,
81 broadcast: None,
82 ttl: None,
83 hoplimit: None,
84 read_timeout: None,
85 write_timeout: None,
86 recv_buffer_size: None,
87 send_buffer_size: None,
88 tos: None,
89 tclass_v6: None,
90 recv_pktinfo: None,
91 only_v6: None,
92 bind_device: None,
93 }
94 }
95}
96
97impl UdpConfig {
98 pub fn new() -> Self {
100 Self::default()
101 }
102
103 pub fn new_with_family(socket_family: SocketFamily) -> Self {
105 Self {
106 socket_family,
107 ..Self::default()
108 }
109 }
110
111 pub fn with_socket_family(mut self, socket_family: SocketFamily) -> Self {
113 self.socket_family = socket_family;
114 self
115 }
116
117 pub fn with_bind_addr(mut self, addr: SocketAddr) -> Self {
119 self.bind_addr = Some(addr);
120 self
121 }
122
123 pub fn with_bind(self, addr: SocketAddr) -> Self {
125 self.with_bind_addr(addr)
126 }
127
128 pub fn with_reuseaddr(mut self, on: bool) -> Self {
130 self.reuseaddr = Some(on);
131 self
132 }
133
134 pub fn with_reuseport(mut self, on: bool) -> Self {
136 self.reuseport = Some(on);
137 self
138 }
139
140 pub fn with_broadcast(mut self, on: bool) -> Self {
142 self.broadcast = Some(on);
143 self
144 }
145
146 pub fn with_ttl(mut self, ttl: u32) -> Self {
148 self.ttl = Some(ttl);
149 self
150 }
151
152 pub fn with_hoplimit(mut self, hops: u32) -> Self {
154 self.hoplimit = Some(hops);
155 self
156 }
157
158 pub fn with_hop_limit(self, hops: u32) -> Self {
160 self.with_hoplimit(hops)
161 }
162
163 pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
165 self.read_timeout = Some(timeout);
166 self
167 }
168
169 pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
171 self.write_timeout = Some(timeout);
172 self
173 }
174
175 pub fn with_recv_buffer_size(mut self, size: usize) -> Self {
177 self.recv_buffer_size = Some(size);
178 self
179 }
180
181 pub fn with_send_buffer_size(mut self, size: usize) -> Self {
183 self.send_buffer_size = Some(size);
184 self
185 }
186
187 pub fn with_tos(mut self, tos: u32) -> Self {
189 self.tos = Some(tos);
190 self
191 }
192
193 pub fn with_tclass_v6(mut self, tclass: u32) -> Self {
195 self.tclass_v6 = Some(tclass);
196 self
197 }
198
199 pub fn with_recv_pktinfo(mut self, on: bool) -> Self {
201 self.recv_pktinfo = Some(on);
202 self
203 }
204
205 pub fn with_only_v6(mut self, only_v6: bool) -> Self {
207 self.only_v6 = Some(only_v6);
208 self
209 }
210
211 pub fn with_bind_device(mut self, iface: impl Into<String>) -> Self {
213 self.bind_device = Some(iface.into());
214 self
215 }
216
217 pub fn validate(&self) -> io::Result<()> {
219 if let Some(addr) = self.bind_addr {
220 let addr_family = crate::SocketFamily::from_socket_addr(&addr);
221 if addr_family != self.socket_family {
222 return Err(io::Error::new(
223 io::ErrorKind::InvalidInput,
224 "bind_addr family does not match socket_family",
225 ));
226 }
227 }
228
229 if self.socket_family.is_v4() {
230 if self.hoplimit.is_some() {
231 return Err(io::Error::new(
232 io::ErrorKind::InvalidInput,
233 "hoplimit is only supported for IPv6 UDP sockets",
234 ));
235 }
236 if self.tclass_v6.is_some() {
237 return Err(io::Error::new(
238 io::ErrorKind::InvalidInput,
239 "tclass_v6 is only supported for IPv6 UDP sockets",
240 ));
241 }
242 if self.only_v6.is_some() {
243 return Err(io::Error::new(
244 io::ErrorKind::InvalidInput,
245 "only_v6 is only supported for IPv6 UDP sockets",
246 ));
247 }
248 }
249
250 if self.socket_family.is_v6() {
251 if self.ttl.is_some() {
252 return Err(io::Error::new(
253 io::ErrorKind::InvalidInput,
254 "ttl is only supported for IPv4 UDP sockets",
255 ));
256 }
257 if self.broadcast.is_some() {
258 return Err(io::Error::new(
259 io::ErrorKind::InvalidInput,
260 "broadcast is only supported for IPv4 UDP sockets",
261 ));
262 }
263 }
264
265 if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) {
266 return Err(io::Error::new(
267 io::ErrorKind::InvalidInput,
268 "read_timeout must be greater than zero",
269 ));
270 }
271
272 if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) {
273 return Err(io::Error::new(
274 io::ErrorKind::InvalidInput,
275 "write_timeout must be greater than zero",
276 ));
277 }
278
279 if matches!(self.recv_buffer_size, Some(0)) {
280 return Err(io::Error::new(
281 io::ErrorKind::InvalidInput,
282 "recv_buffer_size must be greater than zero",
283 ));
284 }
285
286 if matches!(self.send_buffer_size, Some(0)) {
287 return Err(io::Error::new(
288 io::ErrorKind::InvalidInput,
289 "send_buffer_size must be greater than zero",
290 ));
291 }
292
293 if matches!(self.bind_device.as_deref(), Some("")) {
294 return Err(io::Error::new(
295 io::ErrorKind::InvalidInput,
296 "bind_device must not be empty",
297 ));
298 }
299
300 Ok(())
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn udp_config_default_values() {
310 let cfg = UdpConfig::default();
311 assert!(cfg.bind_addr.is_none());
312 assert!(cfg.reuseaddr.is_none());
313 assert!(cfg.reuseport.is_none());
314 assert!(cfg.broadcast.is_none());
315 assert!(cfg.ttl.is_none());
316 assert!(cfg.recv_buffer_size.is_none());
317 assert!(cfg.send_buffer_size.is_none());
318 assert!(cfg.tos.is_none());
319 assert!(cfg.tclass_v6.is_none());
320 assert!(cfg.recv_pktinfo.is_none());
321 assert!(cfg.only_v6.is_none());
322 assert!(cfg.bind_device.is_none());
323 }
324
325 #[test]
326 fn udp_config_with_family_builder() {
327 let cfg =
328 UdpConfig::new_with_family(SocketFamily::IPV6).with_bind("[::1]:0".parse().unwrap());
329 assert_eq!(cfg.socket_family, SocketFamily::IPV6);
330 assert!(cfg.bind_addr.is_some());
331 }
332
333 #[test]
334 fn udp_config_validate_rejects_ipv6_broadcast() {
335 let cfg = UdpConfig::new_with_family(SocketFamily::IPV6).with_broadcast(true);
336 assert!(cfg.validate().is_err());
337 }
338}