Skip to main content

pipa/http/
conn.rs

1use std::io::{self, ErrorKind, Read, Write};
2use std::net::{SocketAddr, TcpStream};
3use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
4use std::sync::Arc;
5
6#[derive(Debug)]
7pub enum Connection {
8    Plain(TcpStream),
9    Tls {
10        tls: rustls::ClientConnection,
11        stream: TcpStream,
12    },
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum IoHint {
17    Read,
18    Write,
19    ReadWrite,
20    Ready,
21}
22
23impl Connection {
24    pub fn connect_nonblocking(host: &str, port: u16) -> Result<TcpStream, String> {
25        let addr = format!("{host}:{port}");
26        let addrs: Vec<SocketAddr> = std::net::ToSocketAddrs::to_socket_addrs(&addr)
27            .map_err(|e| format!("dns resolve failed for {host}:{port}: {e}"))?
28            .collect();
29
30        for addr in addrs {
31            let domain = if addr.is_ipv6() {
32                libc::AF_INET6
33            } else {
34                libc::AF_INET
35            };
36            let sock = unsafe {
37                libc::socket(
38                    domain,
39                    libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC,
40                    0,
41                )
42            };
43            if sock < 0 {
44                continue;
45            }
46
47            let (addr_ptr, addr_len) = socket_addr_to_raw(&addr);
48            let ret = unsafe { libc::connect(sock, addr_ptr, addr_len) };
49
50            if ret == 0 {
51                return Ok(unsafe { TcpStream::from_raw_fd(sock) });
52            }
53
54            let errno = unsafe { *libc::__errno_location() };
55            if errno == libc::EINPROGRESS {
56                return Ok(unsafe { TcpStream::from_raw_fd(sock) });
57            }
58
59            unsafe { libc::close(sock) };
60        }
61
62        Err(format!("connect to {host}:{port} failed"))
63    }
64
65    pub fn check_connect(stream: &TcpStream) -> Result<(), String> {
66        let mut err: i32 = 0;
67        let mut err_len: u32 = std::mem::size_of::<i32>() as u32;
68        let ret = unsafe {
69            libc::getsockopt(
70                stream.as_raw_fd(),
71                libc::SOL_SOCKET,
72                libc::SO_ERROR,
73                &mut err as *mut _ as *mut _,
74                &mut err_len,
75            )
76        };
77        if ret < 0 {
78            return Err("getsockopt failed".into());
79        }
80        if err != 0 {
81            return Err(format!("connect failed: errno {err}"));
82        }
83        Ok(())
84    }
85
86    pub fn start_tls(
87        host: &str,
88        stream: TcpStream,
89        extra_roots: &[Vec<u8>],
90    ) -> Result<Self, String> {
91        let mut root_certs =
92            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
93        for cert_der in extra_roots {
94            root_certs
95                .add(cert_der.clone().into())
96                .map_err(|e| format!("add root cert failed: {e}"))?;
97        }
98        let config = rustls::ClientConfig::builder()
99            .with_root_certificates(root_certs)
100            .with_no_client_auth();
101        let server_name = rustls::pki_types::ServerName::try_from(host)
102            .map_err(|e| format!("invalid server name: {e}"))?
103            .to_owned();
104        let tls_conn = rustls::ClientConnection::new(Arc::new(config), server_name)
105            .map_err(|e| format!("tls init failed: {e}"))?;
106        Ok(Connection::Tls {
107            tls: tls_conn,
108            stream,
109        })
110    }
111
112    pub fn tls_handshake_step(&mut self) -> Result<IoHint, String> {
113        match self {
114            Connection::Plain(_) => Ok(IoHint::Ready),
115            Connection::Tls { tls, stream } => {
116                if !tls.is_handshaking() {
117                    return Ok(IoHint::Ready);
118                }
119
120                let mut need_read = false;
121                let mut need_write = false;
122
123                if tls.wants_read() {
124                    match tls.read_tls(stream) {
125                        Ok(_) => {
126                            tls.process_new_packets()
127                                .map_err(|e| format!("tls process error: {e}"))?;
128                        }
129                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
130                            need_read = true;
131                        }
132                        Err(e) => return Err(format!("tls read error: {e}")),
133                    }
134                }
135
136                if tls.wants_write() {
137                    match tls.write_tls(stream) {
138                        Ok(_) => {}
139                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
140                            need_write = true;
141                        }
142                        Err(e) => return Err(format!("tls write error: {e}")),
143                    }
144                }
145
146                if !tls.is_handshaking() {
147                    Ok(IoHint::Ready)
148                } else {
149                    match (need_read, need_write) {
150                        (true, true) => Ok(IoHint::ReadWrite),
151                        (true, false) => Ok(IoHint::Read),
152                        (false, true) => Ok(IoHint::Write),
153                        (false, false) => {
154                            if tls.wants_read() && tls.wants_write() {
155                                Ok(IoHint::ReadWrite)
156                            } else if tls.wants_read() {
157                                Ok(IoHint::Read)
158                            } else {
159                                Ok(IoHint::Write)
160                            }
161                        }
162                    }
163                }
164            }
165        }
166    }
167
168    pub fn tls_wants_read(&self) -> bool {
169        match self {
170            Connection::Plain(_) => false,
171            Connection::Tls { tls, .. } => tls.wants_read() || tls.is_handshaking(),
172        }
173    }
174
175    pub fn tls_wants_write(&self) -> bool {
176        match self {
177            Connection::Plain(_) => false,
178            Connection::Tls { tls, .. } => tls.wants_write() || tls.is_handshaking(),
179        }
180    }
181
182    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), String> {
183        match self {
184            Connection::Plain(stream) => stream
185                .set_nonblocking(nonblocking)
186                .map_err(|e| format!("set_nonblocking failed: {e}")),
187            Connection::Tls { stream, .. } => stream
188                .set_nonblocking(nonblocking)
189                .map_err(|e| format!("set_nonblocking failed: {e}")),
190        }
191    }
192
193    pub fn raw_fd(&self) -> RawFd {
194        match self {
195            Connection::Plain(stream) => stream.as_raw_fd(),
196            Connection::Tls { stream, .. } => stream.as_raw_fd(),
197        }
198    }
199
200    pub fn is_tls(&self) -> bool {
201        matches!(self, Connection::Tls { .. })
202    }
203
204    pub fn set_read_timeout(&self, dur: Option<std::time::Duration>) -> Result<(), String> {
205        let stream = match self {
206            Connection::Plain(s) => s,
207            Connection::Tls { stream, .. } => stream,
208        };
209        stream
210            .set_read_timeout(dur)
211            .map_err(|e| format!("set_read_timeout failed: {e}"))
212    }
213}
214
215impl Read for Connection {
216    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
217        match self {
218            Connection::Plain(stream) => stream.read(buf),
219            Connection::Tls { tls, stream } => loop {
220                match tls.read_tls(stream) {
221                    Ok(0) => {
222                        tls.process_new_packets()
223                            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
224                        return tls.reader().read(buf);
225                    }
226                    Ok(_) => {
227                        tls.process_new_packets()
228                            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
229                        match tls.reader().read(buf) {
230                            Ok(n) => return Ok(n),
231                            Err(e) if e.kind() == ErrorKind::WouldBlock => continue,
232                            Err(e) => return Err(e),
233                        }
234                    }
235                    Err(e) if e.kind() == ErrorKind::WouldBlock => match tls.reader().read(buf) {
236                        Ok(n) => return Ok(n),
237                        Err(e2) if e2.kind() == ErrorKind::WouldBlock => return Err(e),
238                        Err(e2) => return Err(e2),
239                    },
240                    Err(e) => return Err(e),
241                }
242            },
243        }
244    }
245}
246
247impl Write for Connection {
248    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
249        match self {
250            Connection::Plain(stream) => stream.write(buf),
251            Connection::Tls { tls, stream } => {
252                let n = tls.writer().write(buf)?;
253                let _ = tls.write_tls(stream);
254                Ok(n)
255            }
256        }
257    }
258
259    fn flush(&mut self) -> io::Result<()> {
260        match self {
261            Connection::Plain(stream) => stream.flush(),
262            Connection::Tls { stream, .. } => stream.flush(),
263        }
264    }
265}
266
267impl AsRawFd for Connection {
268    fn as_raw_fd(&self) -> RawFd {
269        self.raw_fd()
270    }
271}
272
273fn socket_addr_to_raw(addr: &SocketAddr) -> (*const libc::sockaddr, u32) {
274    match addr {
275        SocketAddr::V4(v4) => {
276            let raw: libc::sockaddr_in = libc::sockaddr_in {
277                sin_family: libc::AF_INET as u16,
278                sin_port: v4.port().to_be(),
279                sin_addr: libc::in_addr {
280                    s_addr: u32::from_ne_bytes(v4.ip().octets()),
281                },
282                sin_zero: [0; 8],
283            };
284            (
285                &raw as *const _ as *const libc::sockaddr,
286                std::mem::size_of::<libc::sockaddr_in>() as u32,
287            )
288        }
289        SocketAddr::V6(v6) => {
290            let raw = libc::sockaddr_in6 {
291                sin6_family: libc::AF_INET6 as u16,
292                sin6_port: v6.port().to_be(),
293                sin6_flowinfo: v6.flowinfo(),
294                sin6_addr: libc::in6_addr {
295                    s6_addr: v6.ip().octets(),
296                },
297                sin6_scope_id: v6.scope_id(),
298            };
299            (
300                &raw as *const _ as *const libc::sockaddr,
301                std::mem::size_of::<libc::sockaddr_in6>() as u32,
302            )
303        }
304    }
305}