ntex_server/net/
socket.rs

1use std::{fmt, io, net};
2
3use ntex_io::Io;
4use ntex_service::cfg::SharedCfg;
5
6use super::Token;
7
8#[derive(Debug)]
9pub enum Stream {
10    Tcp(net::TcpStream),
11    #[cfg(unix)]
12    Uds(std::os::unix::net::UnixStream),
13}
14
15impl Stream {
16    pub(crate) fn convert(self, cfg: SharedCfg) -> Result<Io, io::Error> {
17        match self {
18            Stream::Tcp(stream) => ntex_net::from_tcp_stream(stream, cfg),
19            #[cfg(unix)]
20            Stream::Uds(stream) => ntex_net::from_unix_stream(stream, cfg),
21        }
22    }
23}
24
25#[derive(Debug)]
26pub struct Connection {
27    pub(crate) io: Stream,
28    pub(crate) token: Token,
29}
30
31pub enum Listener {
32    Tcp(net::TcpListener),
33    #[cfg(unix)]
34    Uds(std::os::unix::net::UnixListener),
35}
36
37impl fmt::Debug for Listener {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match *self {
40            Listener::Tcp(ref lst) => write!(f, "{lst:?}"),
41            #[cfg(unix)]
42            Listener::Uds(ref lst) => write!(f, "{lst:?}"),
43        }
44    }
45}
46
47impl fmt::Display for Listener {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match *self {
50            Listener::Tcp(ref lst) => write!(f, "{}", lst.local_addr().ok().unwrap()),
51            #[cfg(unix)]
52            Listener::Uds(ref lst) => {
53                write!(f, "{:?}", lst.local_addr().ok().unwrap())
54            }
55        }
56    }
57}
58
59pub(crate) enum SocketAddr {
60    Tcp(net::SocketAddr),
61    #[cfg(unix)]
62    Uds(std::os::unix::net::SocketAddr),
63}
64
65impl fmt::Display for SocketAddr {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        match *self {
68            SocketAddr::Tcp(ref addr) => write!(f, "{addr}"),
69            #[cfg(unix)]
70            SocketAddr::Uds(ref addr) => write!(f, "{addr:?}"),
71        }
72    }
73}
74
75impl fmt::Debug for SocketAddr {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match *self {
78            SocketAddr::Tcp(ref addr) => write!(f, "{addr:?}"),
79            #[cfg(unix)]
80            SocketAddr::Uds(ref addr) => write!(f, "{addr:?}"),
81        }
82    }
83}
84
85impl Listener {
86    pub(super) fn from_tcp(lst: net::TcpListener) -> Self {
87        let _ = lst.set_nonblocking(true);
88        Listener::Tcp(lst)
89    }
90
91    #[cfg(unix)]
92    pub(super) fn from_uds(lst: std::os::unix::net::UnixListener) -> Self {
93        let _ = lst.set_nonblocking(true);
94        Listener::Uds(lst)
95    }
96
97    pub(crate) fn local_addr(&self) -> SocketAddr {
98        match self {
99            Listener::Tcp(lst) => SocketAddr::Tcp(lst.local_addr().unwrap()),
100            #[cfg(unix)]
101            Listener::Uds(lst) => SocketAddr::Uds(lst.local_addr().unwrap()),
102        }
103    }
104
105    pub(crate) fn accept(&self) -> io::Result<Option<Stream>> {
106        match *self {
107            Listener::Tcp(ref lst) => {
108                lst.accept().map(|(stream, _)| Some(Stream::Tcp(stream)))
109            }
110            #[cfg(unix)]
111            Listener::Uds(ref lst) => {
112                lst.accept().map(|(stream, _)| Some(Stream::Uds(stream)))
113            }
114        }
115    }
116
117    pub(crate) fn remove_source(&self) {
118        match *self {
119            Listener::Tcp(_) => (),
120            #[cfg(unix)]
121            Listener::Uds(ref lst) => {
122                // cleanup file path
123                if let Ok(addr) = lst.local_addr() {
124                    if let Some(path) = addr.as_pathname() {
125                        let _ = std::fs::remove_file(path);
126                    }
127                }
128            }
129        }
130    }
131}
132
133#[cfg(unix)]
134mod listener_impl {
135    use super::*;
136    use std::os::fd::{AsFd, BorrowedFd};
137    use std::os::unix::io::{AsRawFd, RawFd};
138
139    impl AsFd for Listener {
140        fn as_fd(&self) -> BorrowedFd<'_> {
141            match *self {
142                Listener::Tcp(ref lst) => lst.as_fd(),
143                Listener::Uds(ref lst) => lst.as_fd(),
144            }
145        }
146    }
147
148    impl AsRawFd for Listener {
149        fn as_raw_fd(&self) -> RawFd {
150            match *self {
151                Listener::Tcp(ref lst) => lst.as_raw_fd(),
152                Listener::Uds(ref lst) => lst.as_raw_fd(),
153            }
154        }
155    }
156}
157
158#[cfg(windows)]
159mod listener_impl {
160    use super::*;
161    use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, RawSocket};
162
163    impl AsSocket for Listener {
164        fn as_socket(&self) -> BorrowedSocket<'_> {
165            match *self {
166                Listener::Tcp(ref lst) => lst.as_socket(),
167            }
168        }
169    }
170
171    impl AsRawSocket for Listener {
172        fn as_raw_socket(&self) -> RawSocket {
173            match *self {
174                Listener::Tcp(ref lst) => lst.as_raw_socket(),
175            }
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn socket_addr() {
186        use socket2::{Domain, SockAddr, Socket, Type};
187
188        let addr = SocketAddr::Tcp("127.0.0.1:8080".parse().unwrap());
189        assert!(format!("{addr:?}").contains("127.0.0.1:8080"));
190        assert_eq!(format!("{addr}"), "127.0.0.1:8080");
191
192        let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap();
193        let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
194        socket.set_reuse_address(true).unwrap();
195        socket.bind(&SockAddr::from(addr)).unwrap();
196        let lst = Listener::Tcp(net::TcpListener::from(socket));
197        assert!(format!("{lst:?}").contains("TcpListener"));
198        assert!(format!("{lst}").contains("127.0.0.1"));
199    }
200
201    #[test]
202    #[cfg(unix)]
203    fn uds() {
204        use std::os::unix::net::UnixListener;
205
206        let _ = std::fs::remove_file("/tmp/sock.xxxxx");
207        if let Ok(lst) = UnixListener::bind("/tmp/sock.xxxxx") {
208            let addr = lst.local_addr().expect("Couldn't get local address");
209            let a = SocketAddr::Uds(addr);
210            assert!(format!("{a:?}").contains("/tmp/sock.xxxxx"));
211            assert!(format!("{a}").contains("/tmp/sock.xxxxx"));
212
213            let lst = Listener::Uds(lst);
214            assert!(format!("{lst:?}").contains("/tmp/sock.xxxxx"));
215            assert!(format!("{lst}").contains("/tmp/sock.xxxxx"));
216        }
217    }
218}