1use std::collections::HashMap;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use socket2::Domain;
7use socket2::Protocol;
8use socket2::Type;
9
10static CONNS: std::sync::OnceLock<std::sync::Mutex<Connections>> =
13 std::sync::OnceLock::new();
14
15#[derive(Default)]
17struct Connections {
18 tcp: HashMap<SocketAddr, Arc<TcpConnection>>,
19}
20
21pub struct TcpConnection {
24 #[cfg(unix)]
26 sock: std::os::fd::OwnedFd,
27 #[cfg(not(unix))]
28 sock: std::os::windows::io::OwnedSocket,
29 key: SocketAddr,
30}
31
32impl TcpConnection {
33 pub fn start(key: SocketAddr) -> std::io::Result<Self> {
35 let listener = bind_socket_and_listen(key, false)?;
36 let sock = listener.into();
37
38 Ok(Self { sock, key })
39 }
40
41 fn listener(&self) -> std::io::Result<tokio::net::TcpListener> {
42 let listener = std::net::TcpListener::from(self.sock.try_clone()?);
43 let listener = tokio::net::TcpListener::from_std(listener)?;
44 Ok(listener)
45 }
46}
47
48pub struct TcpListener {
50 listener: Option<tokio::net::TcpListener>,
51 conn: Option<Arc<TcpConnection>>,
52}
53
54const REUSE_PORT_LOAD_BALANCES: bool =
56 cfg!(any(target_os = "android", target_os = "linux"));
57
58impl TcpListener {
59 pub fn bind(
79 socket_addr: SocketAddr,
80 reuse_port: bool,
81 ) -> std::io::Result<Self> {
82 if REUSE_PORT_LOAD_BALANCES && reuse_port {
83 Self::bind_load_balanced(socket_addr)
84 } else {
85 Self::bind_direct(socket_addr, reuse_port)
86 }
87 }
88
89 pub fn bind_direct(
92 socket_addr: SocketAddr,
93 reuse_port: bool,
94 ) -> std::io::Result<Self> {
95 let listener = bind_socket_and_listen(socket_addr, reuse_port)?;
97 Ok(Self {
98 listener: Some(tokio::net::TcpListener::from_std(listener)?),
99 conn: None,
100 })
101 }
102
103 pub fn bind_load_balanced(socket_addr: SocketAddr) -> std::io::Result<Self> {
105 let tcp = &mut CONNS.get_or_init(Default::default).lock().unwrap().tcp;
106 if let Some(conn) = tcp.get(&socket_addr) {
107 let listener = Some(conn.listener()?);
108 return Ok(Self {
109 listener,
110 conn: Some(conn.clone()),
111 });
112 }
113 let conn = Arc::new(TcpConnection::start(socket_addr)?);
114 let listener = Some(conn.listener()?);
115 tcp.insert(socket_addr, conn.clone());
116 Ok(Self {
117 listener,
118 conn: Some(conn),
119 })
120 }
121
122 pub async fn accept(
123 &self,
124 ) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
125 let (tcp, addr) = self.listener.as_ref().unwrap().accept().await?;
126 Ok((tcp, addr))
127 }
128
129 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
130 self.listener.as_ref().unwrap().local_addr()
131 }
132}
133
134impl Drop for TcpListener {
135 fn drop(&mut self) {
136 if let Some(conn) = self.conn.take() {
138 let mut tcp = CONNS.get().unwrap().lock().unwrap();
139 if Arc::strong_count(&conn) == 2 {
140 tcp.tcp.remove(&conn.key);
141 debug_assert_eq!(Arc::strong_count(&conn), 1);
143 drop(conn);
144 }
145 }
146 }
147}
148
149#[allow(unused_variables)]
151fn bind_socket_and_listen(
152 socket_addr: SocketAddr,
153 reuse_port: bool,
154) -> Result<std::net::TcpListener, std::io::Error> {
155 let socket = if socket_addr.is_ipv4() {
156 socket2::Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?
157 } else {
158 socket2::Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?
159 };
160 #[cfg(not(windows))]
161 if REUSE_PORT_LOAD_BALANCES && reuse_port {
162 socket.set_reuse_port(true)?;
163 }
164 #[cfg(not(windows))]
165 socket.set_reuse_address(true)?;
171 socket.set_nonblocking(true)?;
172 socket.bind(&socket_addr.into())?;
173 socket.listen(511)?;
175 let listener = socket.into();
176 Ok(listener)
177}