ntex_server/net/
socket.rs

1use std::{fmt, io, net};
2
3use ntex_net::{self as rt, Io};
4
5use super::Token;
6
7#[derive(Debug)]
8pub enum Stream {
9    Tcp(net::TcpStream),
10    #[cfg(unix)]
11    Uds(std::os::unix::net::UnixStream),
12}
13
14impl TryFrom<Stream> for Io {
15    type Error = io::Error;
16
17    fn try_from(sock: Stream) -> Result<Self, Self::Error> {
18        match sock {
19            Stream::Tcp(stream) => rt::from_tcp_stream(stream),
20            #[cfg(unix)]
21            Stream::Uds(stream) => rt::from_unix_stream(stream),
22        }
23    }
24}
25
26#[derive(Debug)]
27pub struct Connection {
28    pub(crate) io: Stream,
29    pub(crate) token: Token,
30}
31
32pub enum Listener {
33    Tcp(net::TcpListener),
34    #[cfg(unix)]
35    Uds(std::os::unix::net::UnixListener),
36}
37
38impl fmt::Debug for Listener {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match *self {
41            Listener::Tcp(ref lst) => write!(f, "{:?}", lst),
42            #[cfg(unix)]
43            Listener::Uds(ref lst) => write!(f, "{:?}", lst),
44        }
45    }
46}
47
48impl fmt::Display for Listener {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match *self {
51            Listener::Tcp(ref lst) => write!(f, "{}", lst.local_addr().ok().unwrap()),
52            #[cfg(unix)]
53            Listener::Uds(ref lst) => {
54                write!(f, "{:?}", lst.local_addr().ok().unwrap())
55            }
56        }
57    }
58}
59
60pub(crate) enum SocketAddr {
61    Tcp(net::SocketAddr),
62    #[cfg(unix)]
63    Uds(std::os::unix::net::SocketAddr),
64}
65
66impl fmt::Display for SocketAddr {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match *self {
69            SocketAddr::Tcp(ref addr) => write!(f, "{}", addr),
70            #[cfg(unix)]
71            SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr),
72        }
73    }
74}
75
76impl fmt::Debug for SocketAddr {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match *self {
79            SocketAddr::Tcp(ref addr) => write!(f, "{:?}", addr),
80            #[cfg(unix)]
81            SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr),
82        }
83    }
84}
85
86impl Listener {
87    pub(super) fn from_tcp(lst: net::TcpListener) -> Self {
88        let _ = lst.set_nonblocking(true);
89        Listener::Tcp(lst)
90    }
91
92    #[cfg(unix)]
93    pub(super) fn from_uds(lst: std::os::unix::net::UnixListener) -> Self {
94        let _ = lst.set_nonblocking(true);
95        Listener::Uds(lst)
96    }
97
98    pub(crate) fn local_addr(&self) -> SocketAddr {
99        match self {
100            Listener::Tcp(lst) => SocketAddr::Tcp(lst.local_addr().unwrap()),
101            #[cfg(unix)]
102            Listener::Uds(lst) => SocketAddr::Uds(lst.local_addr().unwrap()),
103        }
104    }
105
106    pub(crate) fn accept(&self) -> io::Result<Option<Stream>> {
107        match *self {
108            Listener::Tcp(ref lst) => {
109                lst.accept().map(|(stream, _)| Some(Stream::Tcp(stream)))
110            }
111            #[cfg(unix)]
112            Listener::Uds(ref lst) => {
113                lst.accept().map(|(stream, _)| Some(Stream::Uds(stream)))
114            }
115        }
116    }
117
118    pub(crate) fn remove_source(&self) {
119        match *self {
120            Listener::Tcp(_) => (),
121            #[cfg(unix)]
122            Listener::Uds(ref lst) => {
123                // cleanup file path
124                if let Ok(addr) = lst.local_addr() {
125                    if let Some(path) = addr.as_pathname() {
126                        let _ = std::fs::remove_file(path);
127                    }
128                }
129            }
130        }
131    }
132}
133
134#[cfg(unix)]
135mod listener_impl {
136    use super::*;
137    use std::os::fd::{AsFd, BorrowedFd};
138    use std::os::unix::io::{AsRawFd, RawFd};
139
140    impl AsFd for Listener {
141        fn as_fd(&self) -> BorrowedFd<'_> {
142            match *self {
143                Listener::Tcp(ref lst) => lst.as_fd(),
144                Listener::Uds(ref lst) => lst.as_fd(),
145            }
146        }
147    }
148
149    impl AsRawFd for Listener {
150        fn as_raw_fd(&self) -> RawFd {
151            match *self {
152                Listener::Tcp(ref lst) => lst.as_raw_fd(),
153                Listener::Uds(ref lst) => lst.as_raw_fd(),
154            }
155        }
156    }
157}
158
159#[cfg(windows)]
160mod listener_impl {
161    use super::*;
162    use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, RawSocket};
163
164    impl AsSocket for Listener {
165        fn as_socket(&self) -> BorrowedSocket<'_> {
166            match *self {
167                Listener::Tcp(ref lst) => lst.as_socket(),
168            }
169        }
170    }
171
172    impl AsRawSocket for Listener {
173        fn as_raw_socket(&self) -> RawSocket {
174            match *self {
175                Listener::Tcp(ref lst) => lst.as_raw_socket(),
176            }
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn socket_addr() {
187        use socket2::{Domain, SockAddr, Socket, Type};
188
189        let addr = SocketAddr::Tcp("127.0.0.1:8080".parse().unwrap());
190        assert!(format!("{:?}", addr).contains("127.0.0.1:8080"));
191        assert_eq!(format!("{}", addr), "127.0.0.1:8080");
192
193        let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap();
194        let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
195        socket.set_reuse_address(true).unwrap();
196        socket.bind(&SockAddr::from(addr)).unwrap();
197        let lst = Listener::Tcp(net::TcpListener::from(socket));
198        assert!(format!("{:?}", lst).contains("TcpListener"));
199        assert!(format!("{}", lst).contains("127.0.0.1"));
200    }
201
202    #[test]
203    #[cfg(unix)]
204    fn uds() {
205        use std::os::unix::net::UnixListener;
206
207        let _ = std::fs::remove_file("/tmp/sock.xxxxx");
208        if let Ok(lst) = UnixListener::bind("/tmp/sock.xxxxx") {
209            let addr = lst.local_addr().expect("Couldn't get local address");
210            let a = SocketAddr::Uds(addr);
211            assert!(format!("{:?}", a).contains("/tmp/sock.xxxxx"));
212            assert!(format!("{}", a).contains("/tmp/sock.xxxxx"));
213
214            let lst = Listener::Uds(lst);
215            assert!(format!("{:?}", lst).contains("/tmp/sock.xxxxx"));
216            assert!(format!("{}", lst).contains("/tmp/sock.xxxxx"));
217        }
218    }
219}