1use socket2::Type as SockType;
2use std::io;
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use crate::SocketFamily;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpSocketType {
11 Stream,
12 Raw,
13}
14
15impl TcpSocketType {
16 pub fn is_stream(&self) -> bool {
18 matches!(self, TcpSocketType::Stream)
19 }
20
21 pub fn is_raw(&self) -> bool {
23 matches!(self, TcpSocketType::Raw)
24 }
25
26 pub(crate) fn to_sock_type(&self) -> SockType {
28 match self {
29 TcpSocketType::Stream => SockType::STREAM,
30 TcpSocketType::Raw => SockType::RAW,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct TcpConfig {
38 pub socket_family: SocketFamily,
40 pub socket_type: TcpSocketType,
42 pub bind_addr: Option<SocketAddr>,
44 pub nonblocking: bool,
46 pub reuseaddr: Option<bool>,
48 pub reuseport: Option<bool>,
50 pub nodelay: Option<bool>,
52 pub linger: Option<Duration>,
54 pub ttl: Option<u32>,
56 pub hoplimit: Option<u32>,
58 pub read_timeout: Option<Duration>,
60 pub write_timeout: Option<Duration>,
62 pub recv_buffer_size: Option<usize>,
64 pub send_buffer_size: Option<usize>,
66 pub tos: Option<u32>,
68 pub tclass_v6: Option<u32>,
70 pub only_v6: Option<bool>,
72 pub bind_device: Option<String>,
74 pub keepalive: Option<bool>,
76}
77
78impl TcpConfig {
79 pub fn new(socket_family: SocketFamily) -> Self {
81 match socket_family {
82 SocketFamily::IPV4 => Self::v4_stream(),
83 SocketFamily::IPV6 => Self::v6_stream(),
84 }
85 }
86
87 pub fn v4_stream() -> Self {
89 Self {
90 socket_family: SocketFamily::IPV4,
91 socket_type: TcpSocketType::Stream,
92 bind_addr: None,
93 nonblocking: false,
94 reuseaddr: None,
95 reuseport: None,
96 nodelay: None,
97 linger: None,
98 ttl: None,
99 hoplimit: None,
100 read_timeout: None,
101 write_timeout: None,
102 recv_buffer_size: None,
103 send_buffer_size: None,
104 tos: None,
105 tclass_v6: None,
106 only_v6: None,
107 bind_device: None,
108 keepalive: None,
109 }
110 }
111
112 pub fn raw_v4() -> Self {
114 Self {
115 socket_family: SocketFamily::IPV4,
116 socket_type: TcpSocketType::Raw,
117 ..Self::v4_stream()
118 }
119 }
120
121 pub fn v6_stream() -> Self {
123 Self {
124 socket_family: SocketFamily::IPV6,
125 socket_type: TcpSocketType::Stream,
126 ..Self::v4_stream()
127 }
128 }
129
130 pub fn raw_v6() -> Self {
132 Self {
133 socket_family: SocketFamily::IPV6,
134 socket_type: TcpSocketType::Raw,
135 ..Self::v4_stream()
136 }
137 }
138
139 pub fn with_bind(mut self, addr: SocketAddr) -> Self {
142 self.bind_addr = Some(addr);
143 self
144 }
145
146 pub fn with_bind_addr(self, addr: SocketAddr) -> Self {
147 self.with_bind(addr)
148 }
149
150 pub fn with_nonblocking(mut self, flag: bool) -> Self {
151 self.nonblocking = flag;
152 self
153 }
154
155 pub fn with_reuseaddr(mut self, flag: bool) -> Self {
156 self.reuseaddr = Some(flag);
157 self
158 }
159
160 pub fn with_reuseport(mut self, flag: bool) -> Self {
161 self.reuseport = Some(flag);
162 self
163 }
164
165 pub fn with_nodelay(mut self, flag: bool) -> Self {
166 self.nodelay = Some(flag);
167 self
168 }
169
170 pub fn with_linger(mut self, dur: Duration) -> Self {
171 self.linger = Some(dur);
172 self
173 }
174
175 pub fn with_ttl(mut self, ttl: u32) -> Self {
176 self.ttl = Some(ttl);
177 self
178 }
179
180 pub fn with_hoplimit(mut self, hops: u32) -> Self {
181 self.hoplimit = Some(hops);
182 self
183 }
184
185 pub fn with_hop_limit(self, hops: u32) -> Self {
186 self.with_hoplimit(hops)
187 }
188
189 pub fn with_keepalive(mut self, on: bool) -> Self {
190 self.keepalive = Some(on);
191 self
192 }
193
194 pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
195 self.read_timeout = Some(timeout);
196 self
197 }
198
199 pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
200 self.write_timeout = Some(timeout);
201 self
202 }
203
204 pub fn with_recv_buffer_size(mut self, size: usize) -> Self {
205 self.recv_buffer_size = Some(size);
206 self
207 }
208
209 pub fn with_send_buffer_size(mut self, size: usize) -> Self {
210 self.send_buffer_size = Some(size);
211 self
212 }
213
214 pub fn with_tos(mut self, tos: u32) -> Self {
215 self.tos = Some(tos);
216 self
217 }
218
219 pub fn with_tclass_v6(mut self, tclass: u32) -> Self {
220 self.tclass_v6 = Some(tclass);
221 self
222 }
223
224 pub fn with_only_v6(mut self, only_v6: bool) -> Self {
225 self.only_v6 = Some(only_v6);
226 self
227 }
228
229 pub fn with_bind_device(mut self, iface: impl Into<String>) -> Self {
230 self.bind_device = Some(iface.into());
231 self
232 }
233
234 pub fn validate(&self) -> io::Result<()> {
236 if let Some(addr) = self.bind_addr {
237 let addr_family = crate::SocketFamily::from_socket_addr(&addr);
238 if addr_family != self.socket_family {
239 return Err(io::Error::new(
240 io::ErrorKind::InvalidInput,
241 "bind_addr family does not match socket_family",
242 ));
243 }
244 }
245
246 if self.socket_family.is_v4() {
247 if self.hoplimit.is_some() {
248 return Err(io::Error::new(
249 io::ErrorKind::InvalidInput,
250 "hoplimit is only supported for IPv6 TCP sockets",
251 ));
252 }
253 if self.tclass_v6.is_some() {
254 return Err(io::Error::new(
255 io::ErrorKind::InvalidInput,
256 "tclass_v6 is only supported for IPv6 TCP sockets",
257 ));
258 }
259 if self.only_v6.is_some() {
260 return Err(io::Error::new(
261 io::ErrorKind::InvalidInput,
262 "only_v6 is only supported for IPv6 TCP sockets",
263 ));
264 }
265 }
266
267 if self.socket_family.is_v6() && self.ttl.is_some() {
268 return Err(io::Error::new(
269 io::ErrorKind::InvalidInput,
270 "ttl is only supported for IPv4 TCP sockets",
271 ));
272 }
273
274 if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) {
275 return Err(io::Error::new(
276 io::ErrorKind::InvalidInput,
277 "read_timeout must be greater than zero",
278 ));
279 }
280
281 if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) {
282 return Err(io::Error::new(
283 io::ErrorKind::InvalidInput,
284 "write_timeout must be greater than zero",
285 ));
286 }
287
288 if matches!(self.recv_buffer_size, Some(0)) {
289 return Err(io::Error::new(
290 io::ErrorKind::InvalidInput,
291 "recv_buffer_size must be greater than zero",
292 ));
293 }
294
295 if matches!(self.send_buffer_size, Some(0)) {
296 return Err(io::Error::new(
297 io::ErrorKind::InvalidInput,
298 "send_buffer_size must be greater than zero",
299 ));
300 }
301
302 if matches!(self.bind_device.as_deref(), Some("")) {
303 return Err(io::Error::new(
304 io::ErrorKind::InvalidInput,
305 "bind_device must not be empty",
306 ));
307 }
308
309 Ok(())
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn tcp_config_builders() {
319 let addr: SocketAddr = "127.0.0.1:80".parse().unwrap();
320 let cfg = TcpConfig::new(SocketFamily::IPV4)
321 .with_bind_addr(addr)
322 .with_nonblocking(true)
323 .with_reuseaddr(true)
324 .with_reuseport(true)
325 .with_nodelay(true)
326 .with_ttl(10)
327 .with_recv_buffer_size(8192)
328 .with_send_buffer_size(8192)
329 .with_tos(0x10)
330 .with_tclass_v6(0x20);
331
332 assert_eq!(cfg.socket_family, SocketFamily::IPV4);
333 assert_eq!(cfg.socket_type, TcpSocketType::Stream);
334 assert_eq!(cfg.bind_addr, Some(addr));
335 assert!(cfg.nonblocking);
336 assert_eq!(cfg.reuseaddr, Some(true));
337 assert_eq!(cfg.reuseport, Some(true));
338 assert_eq!(cfg.nodelay, Some(true));
339 assert_eq!(cfg.ttl, Some(10));
340 assert_eq!(cfg.recv_buffer_size, Some(8192));
341 assert_eq!(cfg.send_buffer_size, Some(8192));
342 assert_eq!(cfg.tos, Some(0x10));
343 assert_eq!(cfg.tclass_v6, Some(0x20));
344 }
345
346 #[test]
347 fn new_with_ipv6_family_creates_v6_stream() {
348 let cfg = TcpConfig::new(SocketFamily::IPV6);
349 assert_eq!(cfg.socket_family, SocketFamily::IPV6);
350 assert_eq!(cfg.socket_type, TcpSocketType::Stream);
351 }
352
353 #[test]
354 fn tcp_config_validate_rejects_family_mismatch() {
355 let cfg = TcpConfig::v4_stream().with_bind("[::1]:0".parse().unwrap());
356 assert!(cfg.validate().is_err());
357 }
358}