ntex_server/net/
socket.rs

1use std::{fmt, io, net};
2
3use ntex_net::{self as rt, Io};
4
5use super::Token;
6
7#[derive(Debug)]
8pub struct Connection {
9    pub(crate) io: Stream,
10    pub(crate) token: Token,
11}
12
13pub enum Listener {
14    Tcp(net::TcpListener),
15    #[cfg(unix)]
16    Uds(std::os::unix::net::UnixListener),
17}
18
19impl fmt::Debug for Listener {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match *self {
22            Listener::Tcp(ref lst) => write!(f, "{:?}", lst),
23            #[cfg(unix)]
24            Listener::Uds(ref lst) => write!(f, "{:?}", lst),
25        }
26    }
27}
28
29impl fmt::Display for Listener {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match *self {
32            Listener::Tcp(ref lst) => write!(f, "{}", lst.local_addr().ok().unwrap()),
33            #[cfg(unix)]
34            Listener::Uds(ref lst) => {
35                write!(f, "{:?}", lst.local_addr().ok().unwrap())
36            }
37        }
38    }
39}
40
41pub(crate) enum SocketAddr {
42    Tcp(net::SocketAddr),
43    #[cfg(unix)]
44    Uds(std::os::unix::net::SocketAddr),
45}
46
47impl fmt::Display for SocketAddr {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match *self {
50            SocketAddr::Tcp(ref addr) => write!(f, "{}", addr),
51            #[cfg(unix)]
52            SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr),
53        }
54    }
55}
56
57impl fmt::Debug for SocketAddr {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        match *self {
60            SocketAddr::Tcp(ref addr) => write!(f, "{:?}", addr),
61            #[cfg(unix)]
62            SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr),
63        }
64    }
65}
66
67impl Listener {
68    pub(super) fn from_tcp(lst: net::TcpListener) -> Self {
69        let _ = lst.set_nonblocking(true);
70        Listener::Tcp(lst)
71    }
72
73    #[cfg(unix)]
74    pub(super) fn from_uds(lst: std::os::unix::net::UnixListener) -> Self {
75        let _ = lst.set_nonblocking(true);
76        Listener::Uds(lst)
77    }
78
79    pub(crate) fn local_addr(&self) -> SocketAddr {
80        match self {
81            Listener::Tcp(lst) => SocketAddr::Tcp(lst.local_addr().unwrap()),
82            #[cfg(unix)]
83            Listener::Uds(lst) => SocketAddr::Uds(lst.local_addr().unwrap()),
84        }
85    }
86
87    pub(crate) fn accept(&self) -> io::Result<Option<Stream>> {
88        match *self {
89            Listener::Tcp(ref lst) => {
90                lst.accept().map(|(stream, _)| Some(Stream::Tcp(stream)))
91            }
92            #[cfg(unix)]
93            Listener::Uds(ref lst) => {
94                lst.accept().map(|(stream, _)| Some(Stream::Uds(stream)))
95            }
96        }
97    }
98
99    pub(crate) fn remove_source(&self) {
100        match *self {
101            Listener::Tcp(_) => (),
102            #[cfg(unix)]
103            Listener::Uds(ref lst) => {
104                // cleanup file path
105                if let Ok(addr) = lst.local_addr() {
106                    if let Some(path) = addr.as_pathname() {
107                        let _ = std::fs::remove_file(path);
108                    }
109                }
110            }
111        }
112    }
113}
114
115#[cfg(unix)]
116mod listener_impl {
117    use super::*;
118    use std::os::fd::{AsFd, BorrowedFd};
119    use std::os::unix::io::{AsRawFd, RawFd};
120
121    impl AsFd for Listener {
122        fn as_fd(&self) -> BorrowedFd<'_> {
123            match *self {
124                Listener::Tcp(ref lst) => lst.as_fd(),
125                Listener::Uds(ref lst) => lst.as_fd(),
126            }
127        }
128    }
129
130    impl AsRawFd for Listener {
131        fn as_raw_fd(&self) -> RawFd {
132            match *self {
133                Listener::Tcp(ref lst) => lst.as_raw_fd(),
134                Listener::Uds(ref lst) => lst.as_raw_fd(),
135            }
136        }
137    }
138}
139
140#[cfg(windows)]
141mod listener_impl {
142    use super::*;
143    use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, RawSocket};
144
145    impl AsSocket for Listener {
146        fn as_socket(&self) -> BorrowedSocket<'_> {
147            match *self {
148                Listener::Tcp(ref lst) => lst.as_socket(),
149            }
150        }
151    }
152
153    impl AsRawSocket for Listener {
154        fn as_raw_socket(&self) -> RawSocket {
155            match *self {
156                Listener::Tcp(ref lst) => lst.as_raw_socket(),
157            }
158        }
159    }
160}
161
162#[derive(Debug)]
163pub enum Stream {
164    Tcp(net::TcpStream),
165    #[cfg(unix)]
166    Uds(std::os::unix::net::UnixStream),
167}
168
169impl TryFrom<Stream> for Io {
170    type Error = io::Error;
171
172    fn try_from(sock: Stream) -> Result<Self, Self::Error> {
173        match sock {
174            Stream::Tcp(stream) => rt::from_tcp_stream(stream),
175            #[cfg(unix)]
176            Stream::Uds(stream) => rt::from_unix_stream(stream),
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn socket_addr() {
187        use socket2::{Domain, SockAddr, Socket, Type};
188
189        let addr = SocketAddr::Tcp("127.0.0.1:8080".parse().unwrap());
190        assert!(format!("{:?}", addr).contains("127.0.0.1:8080"));
191        assert_eq!(format!("{}", addr), "127.0.0.1:8080");
192
193        let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap();
194        let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
195        socket.set_reuse_address(true).unwrap();
196        socket.bind(&SockAddr::from(addr)).unwrap();
197        let lst = Listener::Tcp(net::TcpListener::from(socket));
198        assert!(format!("{:?}", lst).contains("TcpListener"));
199        assert!(format!("{}", lst).contains("127.0.0.1"));
200    }
201
202    #[test]
203    #[cfg(unix)]
204    fn uds() {
205        use std::os::unix::net::UnixListener;
206
207        let _ = std::fs::remove_file("/tmp/sock.xxxxx");
208        if let Ok(lst) = UnixListener::bind("/tmp/sock.xxxxx") {
209            let addr = lst.local_addr().expect("Couldn't get local address");
210            let a = SocketAddr::Uds(addr);
211            assert!(format!("{:?}", a).contains("/tmp/sock.xxxxx"));
212            assert!(format!("{}", a).contains("/tmp/sock.xxxxx"));
213
214            let lst = Listener::Uds(lst);
215            assert!(format!("{:?}", lst).contains("/tmp/sock.xxxxx"));
216            assert!(format!("{}", lst).contains("/tmp/sock.xxxxx"));
217        }
218    }
219}