Skip to main content

nex_socket/tcp/
config.rs

1use socket2::Type as SockType;
2use std::io;
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use crate::SocketFamily;
7
8/// TCP socket type, either STREAM or RAW.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpSocketType {
11    Stream,
12    Raw,
13}
14
15impl TcpSocketType {
16    /// Returns true if the socket type is STREAM.
17    pub fn is_stream(&self) -> bool {
18        matches!(self, TcpSocketType::Stream)
19    }
20
21    /// Returns true if the socket type is RAW.
22    pub fn is_raw(&self) -> bool {
23        matches!(self, TcpSocketType::Raw)
24    }
25
26    /// Converts the TCP socket type to a `socket2::Type`.
27    pub(crate) fn to_sock_type(&self) -> SockType {
28        match self {
29            TcpSocketType::Stream => SockType::STREAM,
30            TcpSocketType::Raw => SockType::RAW,
31        }
32    }
33}
34
35/// Configuration options for a TCP socket.
36#[derive(Debug, Clone)]
37pub struct TcpConfig {
38    /// The socket family, either IPv4 or IPv6.
39    pub socket_family: SocketFamily,
40    /// The type of TCP socket, either STREAM or RAW.
41    pub socket_type: TcpSocketType,
42    /// Optional address to bind the socket to.
43    pub bind_addr: Option<SocketAddr>,
44    /// Whether the socket should be non-blocking.
45    pub nonblocking: bool,
46    /// Whether to allow address reuse.
47    pub reuseaddr: Option<bool>,
48    /// Whether to allow port reuse (`SO_REUSEPORT`) where supported.
49    pub reuseport: Option<bool>,
50    /// Whether to disable Nagle's algorithm (TCP_NODELAY).
51    pub nodelay: Option<bool>,
52    /// Optional linger duration for the socket.
53    pub linger: Option<Duration>,
54    /// Optional Time-To-Live (TTL) for the socket.
55    pub ttl: Option<u32>,
56    /// Optional Hop Limit for the socket (IPv6).
57    pub hoplimit: Option<u32>,
58    /// Optional read timeout for the socket.
59    pub read_timeout: Option<Duration>,
60    /// Optional write timeout for the socket.
61    pub write_timeout: Option<Duration>,
62    /// Optional receive buffer size in bytes.
63    pub recv_buffer_size: Option<usize>,
64    /// Optional send buffer size in bytes.
65    pub send_buffer_size: Option<usize>,
66    /// Optional IPv4 TOS / DSCP field value.
67    pub tos: Option<u32>,
68    /// Optional IPv6 traffic class value (`IPV6_TCLASS`) where supported.
69    pub tclass_v6: Option<u32>,
70    /// Whether to force IPv6-only behavior on dual-stack sockets.
71    pub only_v6: Option<bool>,
72    /// Optional device to bind the socket to.
73    pub bind_device: Option<String>,
74    /// Whether to enable TCP keepalive.
75    pub keepalive: Option<bool>,
76}
77
78impl TcpConfig {
79    /// Create a STREAM socket for the specified family.
80    pub fn new(socket_family: SocketFamily) -> Self {
81        match socket_family {
82            SocketFamily::IPV4 => Self::v4_stream(),
83            SocketFamily::IPV6 => Self::v6_stream(),
84        }
85    }
86
87    /// Create a STREAM socket for IPv4.
88    pub fn v4_stream() -> Self {
89        Self {
90            socket_family: SocketFamily::IPV4,
91            socket_type: TcpSocketType::Stream,
92            bind_addr: None,
93            nonblocking: false,
94            reuseaddr: None,
95            reuseport: None,
96            nodelay: None,
97            linger: None,
98            ttl: None,
99            hoplimit: None,
100            read_timeout: None,
101            write_timeout: None,
102            recv_buffer_size: None,
103            send_buffer_size: None,
104            tos: None,
105            tclass_v6: None,
106            only_v6: None,
107            bind_device: None,
108            keepalive: None,
109        }
110    }
111
112    /// Create a RAW socket. Requires administrator privileges.
113    pub fn raw_v4() -> Self {
114        Self {
115            socket_family: SocketFamily::IPV4,
116            socket_type: TcpSocketType::Raw,
117            ..Self::v4_stream()
118        }
119    }
120
121    /// Create a STREAM socket for IPv6.
122    pub fn v6_stream() -> Self {
123        Self {
124            socket_family: SocketFamily::IPV6,
125            socket_type: TcpSocketType::Stream,
126            ..Self::v4_stream()
127        }
128    }
129
130    /// Create a RAW socket for IPv6. Requires administrator privileges.
131    pub fn raw_v6() -> Self {
132        Self {
133            socket_family: SocketFamily::IPV6,
134            socket_type: TcpSocketType::Raw,
135            ..Self::v4_stream()
136        }
137    }
138
139    // --- chainable modifiers ---
140
141    pub fn with_bind(mut self, addr: SocketAddr) -> Self {
142        self.bind_addr = Some(addr);
143        self
144    }
145
146    pub fn with_bind_addr(self, addr: SocketAddr) -> Self {
147        self.with_bind(addr)
148    }
149
150    pub fn with_nonblocking(mut self, flag: bool) -> Self {
151        self.nonblocking = flag;
152        self
153    }
154
155    pub fn with_reuseaddr(mut self, flag: bool) -> Self {
156        self.reuseaddr = Some(flag);
157        self
158    }
159
160    pub fn with_reuseport(mut self, flag: bool) -> Self {
161        self.reuseport = Some(flag);
162        self
163    }
164
165    pub fn with_nodelay(mut self, flag: bool) -> Self {
166        self.nodelay = Some(flag);
167        self
168    }
169
170    pub fn with_linger(mut self, dur: Duration) -> Self {
171        self.linger = Some(dur);
172        self
173    }
174
175    pub fn with_ttl(mut self, ttl: u32) -> Self {
176        self.ttl = Some(ttl);
177        self
178    }
179
180    pub fn with_hoplimit(mut self, hops: u32) -> Self {
181        self.hoplimit = Some(hops);
182        self
183    }
184
185    pub fn with_hop_limit(self, hops: u32) -> Self {
186        self.with_hoplimit(hops)
187    }
188
189    pub fn with_keepalive(mut self, on: bool) -> Self {
190        self.keepalive = Some(on);
191        self
192    }
193
194    pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
195        self.read_timeout = Some(timeout);
196        self
197    }
198
199    pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
200        self.write_timeout = Some(timeout);
201        self
202    }
203
204    pub fn with_recv_buffer_size(mut self, size: usize) -> Self {
205        self.recv_buffer_size = Some(size);
206        self
207    }
208
209    pub fn with_send_buffer_size(mut self, size: usize) -> Self {
210        self.send_buffer_size = Some(size);
211        self
212    }
213
214    pub fn with_tos(mut self, tos: u32) -> Self {
215        self.tos = Some(tos);
216        self
217    }
218
219    pub fn with_tclass_v6(mut self, tclass: u32) -> Self {
220        self.tclass_v6 = Some(tclass);
221        self
222    }
223
224    pub fn with_only_v6(mut self, only_v6: bool) -> Self {
225        self.only_v6 = Some(only_v6);
226        self
227    }
228
229    pub fn with_bind_device(mut self, iface: impl Into<String>) -> Self {
230        self.bind_device = Some(iface.into());
231        self
232    }
233
234    /// Validate the configuration before socket creation.
235    pub fn validate(&self) -> io::Result<()> {
236        if let Some(addr) = self.bind_addr {
237            let addr_family = crate::SocketFamily::from_socket_addr(&addr);
238            if addr_family != self.socket_family {
239                return Err(io::Error::new(
240                    io::ErrorKind::InvalidInput,
241                    "bind_addr family does not match socket_family",
242                ));
243            }
244        }
245
246        if self.socket_family.is_v4() {
247            if self.hoplimit.is_some() {
248                return Err(io::Error::new(
249                    io::ErrorKind::InvalidInput,
250                    "hoplimit is only supported for IPv6 TCP sockets",
251                ));
252            }
253            if self.tclass_v6.is_some() {
254                return Err(io::Error::new(
255                    io::ErrorKind::InvalidInput,
256                    "tclass_v6 is only supported for IPv6 TCP sockets",
257                ));
258            }
259            if self.only_v6.is_some() {
260                return Err(io::Error::new(
261                    io::ErrorKind::InvalidInput,
262                    "only_v6 is only supported for IPv6 TCP sockets",
263                ));
264            }
265        }
266
267        if self.socket_family.is_v6() && self.ttl.is_some() {
268            return Err(io::Error::new(
269                io::ErrorKind::InvalidInput,
270                "ttl is only supported for IPv4 TCP sockets",
271            ));
272        }
273
274        if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) {
275            return Err(io::Error::new(
276                io::ErrorKind::InvalidInput,
277                "read_timeout must be greater than zero",
278            ));
279        }
280
281        if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) {
282            return Err(io::Error::new(
283                io::ErrorKind::InvalidInput,
284                "write_timeout must be greater than zero",
285            ));
286        }
287
288        if matches!(self.recv_buffer_size, Some(0)) {
289            return Err(io::Error::new(
290                io::ErrorKind::InvalidInput,
291                "recv_buffer_size must be greater than zero",
292            ));
293        }
294
295        if matches!(self.send_buffer_size, Some(0)) {
296            return Err(io::Error::new(
297                io::ErrorKind::InvalidInput,
298                "send_buffer_size must be greater than zero",
299            ));
300        }
301
302        if matches!(self.bind_device.as_deref(), Some("")) {
303            return Err(io::Error::new(
304                io::ErrorKind::InvalidInput,
305                "bind_device must not be empty",
306            ));
307        }
308
309        Ok(())
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn tcp_config_builders() {
319        let addr: SocketAddr = "127.0.0.1:80".parse().unwrap();
320        let cfg = TcpConfig::new(SocketFamily::IPV4)
321            .with_bind_addr(addr)
322            .with_nonblocking(true)
323            .with_reuseaddr(true)
324            .with_reuseport(true)
325            .with_nodelay(true)
326            .with_ttl(10)
327            .with_recv_buffer_size(8192)
328            .with_send_buffer_size(8192)
329            .with_tos(0x10)
330            .with_tclass_v6(0x20);
331
332        assert_eq!(cfg.socket_family, SocketFamily::IPV4);
333        assert_eq!(cfg.socket_type, TcpSocketType::Stream);
334        assert_eq!(cfg.bind_addr, Some(addr));
335        assert!(cfg.nonblocking);
336        assert_eq!(cfg.reuseaddr, Some(true));
337        assert_eq!(cfg.reuseport, Some(true));
338        assert_eq!(cfg.nodelay, Some(true));
339        assert_eq!(cfg.ttl, Some(10));
340        assert_eq!(cfg.recv_buffer_size, Some(8192));
341        assert_eq!(cfg.send_buffer_size, Some(8192));
342        assert_eq!(cfg.tos, Some(0x10));
343        assert_eq!(cfg.tclass_v6, Some(0x20));
344    }
345
346    #[test]
347    fn new_with_ipv6_family_creates_v6_stream() {
348        let cfg = TcpConfig::new(SocketFamily::IPV6);
349        assert_eq!(cfg.socket_family, SocketFamily::IPV6);
350        assert_eq!(cfg.socket_type, TcpSocketType::Stream);
351    }
352
353    #[test]
354    fn tcp_config_validate_rejects_family_mismatch() {
355        let cfg = TcpConfig::v4_stream().with_bind("[::1]:0".parse().unwrap());
356        assert!(cfg.validate().is_err());
357    }
358}