1use libc::{
18 __errno_location,
19 __u32,
20 c_int,
21 c_void,
22 close as libc_close,
23 connect as libc_connect,
24 in6_addr,
25 in_addr,
26 setsockopt,
27 sockaddr,
28 sockaddr_in,
29 sockaddr_in6,
30 socket,
31 AF_INET,
32 AF_INET6,
33 EINPROGRESS,
34 SOCK_STREAM,
35 SOL_TCP,
36 TCP_QUEUE_SEQ,
38 TCP_REPAIR,
39 TCP_REPAIR_QUEUE,
40};
41
42static TCP_SEND_QUEUE: __u32 = 2;
45
46use std::io::{Error, ErrorKind, Result};
47use std::mem::size_of;
48use std::net::SocketAddr;
49use std::net::ToSocketAddrs;
50use std::os::unix::io::FromRawFd;
51
52use log::debug;
53
54#[cfg(feature = "async")]
55use async_io;
56#[cfg(feature = "async")]
57use async_std;
58
59#[derive(Clone, Copy, PartialEq)]
60pub enum Family {
61 V4,
62 V6,
63}
64
65fn family_matches(socket_addr: &SocketAddr, family: Option<Family>) -> bool {
66 if let Some(f) = family {
67 if f == Family::V4 && !socket_addr.is_ipv4() {
68 return false;
69 }
70 if f == Family::V6 && !socket_addr.is_ipv6() {
71 return false;
72 }
73 }
74 true
75}
76
77pub fn connect<A: ToSocketAddrs>(
78 sequence_no: u32,
79 addr: A,
80 force_family: Option<Family>,
81) -> Result<std::net::TcpStream> {
82 unsafe {
83 let socket_addrs = addr.to_socket_addrs()?;
85
86 let mut maybe_err = None;
88 for socket_addr in socket_addrs {
89 if !family_matches(&socket_addr, force_family) {
90 debug!("skipping {}, not of requested family", socket_addr);
91 continue;
92 }
93 debug!("Trying to connect to {}", socket_addr);
94
95 let sock = create_socket(family_of(&socket_addr), sequence_no)?;
96
97 match connect_socket(sock, socket_addr, true) {
98 Ok(()) => {
99 debug!("Connected to {}", socket_addr);
100 return Ok(std::net::TcpStream::from_raw_fd(sock));
102 }
103 Err(e) => maybe_err = Some(e),
104 }
105
106 libc_close(sock);
107 }
108
109 if let Some(e) = maybe_err {
110 return Err(e);
111 }
112 }
113 Err(Error::new(
114 ErrorKind::AddrNotAvailable,
115 "No address entries for hostname",
116 ))
117}
118
119#[cfg(feature = "async")]
120pub async fn connect_async<A: async_std::net::ToSocketAddrs>(
121 sequence_no: u32,
122 socket_addr: A,
123 force_family: Option<Family>,
124) -> Result<async_std::net::TcpStream> {
125 let socket_addrs = socket_addr.to_socket_addrs().await?;
126
127 unsafe {
128 let mut maybe_err = None;
130 for socket_addr in socket_addrs {
131 if !family_matches(&socket_addr, force_family) {
132 debug!("skipping {}, not of requested family", socket_addr);
133 continue;
134 }
135 debug!("Trying to connect to {}", socket_addr);
136
137 let sock = create_socket(family_of(&socket_addr), sequence_no)?;
138
139 match connect_socket(sock, socket_addr, true) {
140 Ok(()) => {
141 let stream = match async_io::Async::new(std::net::TcpStream::from_raw_fd(sock))
142 {
143 Ok(s) => s,
144 Err(e) => {
145 maybe_err = Some(e);
146 continue;
147 }
148 };
149 match stream.writable().await {
150 Ok(_) => {
151 match stream.get_ref().take_error()? {
152 None => {
153 debug!("Connected to {}", socket_addr);
154 return Ok(stream.into_inner()?.into());
157 }
158 Some(e) => {
159 maybe_err = Some(e);
160 continue;
161 }
162 }
163 }
164 Err(e) => maybe_err = Some(e),
165 }
166 }
167 Err(e) => maybe_err = Some(e),
168 }
169
170 libc_close(sock);
171 }
172
173 if let Some(e) = maybe_err {
174 return Err(e);
175 }
176 }
177
178 Err(Error::new(
179 ErrorKind::AddrNotAvailable,
180 "No address entries for hostname",
181 ))
182}
183
184unsafe fn connect_socket(
185 sock: c_int,
186 socket_addr: std::net::SocketAddr,
187 blocking: bool,
188) -> Result<()> {
189 match socket_addr {
190 SocketAddr::V4(v4addr) => {
191 let octets = v4addr.ip().octets();
192 let u32_addr: u32 = (octets[0] as u32)
193 | (octets[1] as u32) << 8
194 | (octets[2] as u32) << 16
195 | (octets[3] as u32) << 24;
196 let saddr = sockaddr_in {
197 sin_family: AF_INET as u16,
198 sin_port: v4addr.port().to_be(),
199 sin_addr: in_addr { s_addr: u32_addr },
200 sin_zero: [0; 8],
201 };
202 let result = libc_connect(
203 sock,
204 &saddr as *const sockaddr_in as *const sockaddr,
205 size_of::<sockaddr_in>() as u32,
206 );
207 if result < 0 && (blocking || (*__errno_location()) != EINPROGRESS) {
208 return Err(Error::last_os_error());
209 }
210 }
211 SocketAddr::V6(v6addr) => {
212 let saddr = sockaddr_in6 {
213 sin6_family: AF_INET6 as u16,
214 sin6_port: v6addr.port().to_be(),
215 sin6_flowinfo: 0,
216 sin6_addr: in6_addr {
217 s6_addr: v6addr.ip().octets(),
218 },
219 sin6_scope_id: 0,
220 };
221 let result = libc_connect(
222 sock,
223 &saddr as *const sockaddr_in6 as *const sockaddr,
224 size_of::<sockaddr_in6>() as u32,
225 );
226 if result < 0 && (blocking || (*__errno_location()) != EINPROGRESS) {
227 return Err(Error::last_os_error());
228 }
229 }
230 }
231
232 Ok(())
233}
234
235unsafe fn sso_tcp_wrapper(sock: c_int, cmd: c_int, data: u32) -> Result<()> {
236 let dataptr = &data as *const __u32 as *const c_void;
237 if setsockopt(sock, SOL_TCP, cmd, dataptr, 4) < 0 {
238 return Err(Error::last_os_error());
239 }
240 Ok(())
241}
242
243unsafe fn create_socket(family: c_int, sequence_no: u32) -> Result<c_int> {
244 let sock: c_int = socket(family, SOCK_STREAM, 0);
246 if sock < 0 {
247 return Err(Error::last_os_error());
248 }
249
250 sso_tcp_wrapper(sock, TCP_REPAIR, 1)?;
252 sso_tcp_wrapper(sock, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE)?;
254 sso_tcp_wrapper(sock, TCP_QUEUE_SEQ, sequence_no)?;
256 sso_tcp_wrapper(sock, TCP_REPAIR, 0)?;
258
259 Ok(sock)
260}
261
262fn family_of(socket_addr: &std::net::SocketAddr) -> c_int {
263 match socket_addr {
264 SocketAddr::V4(_) => AF_INET,
265 SocketAddr::V6(_) => AF_INET6,
266 }
267}