Skip to main content

pipa/http/
conn.rs

1use std::io::{self, Read, Write};
2use std::net::TcpStream;
3use std::os::unix::io::{AsRawFd, RawFd};
4use std::sync::mpsc;
5use std::thread;
6use std::time::Duration;
7
8pub enum Connection {
9    Plain(TcpStream),
10    Tls {
11        tls: rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
12    },
13}
14
15impl Connection {
16    pub fn connect_async(
17        host: String,
18        port: u16,
19        use_tls: bool,
20    ) -> Result<mpsc::Receiver<Result<Self, String>>, String> {
21        Self::connect_async_with_roots(host, port, use_tls, Vec::new())
22    }
23
24    pub fn connect_async_with_roots(
25        host: String,
26        port: u16,
27        use_tls: bool,
28        extra_roots: Vec<Vec<u8>>,
29    ) -> Result<mpsc::Receiver<Result<Self, String>>, String> {
30        let (tx, rx) = mpsc::channel();
31        thread::spawn(move || {
32            let result = Self::connect_blocking(&host, port, use_tls, &extra_roots);
33            let _ = tx.send(result);
34        });
35        Ok(rx)
36    }
37
38    fn connect_blocking(
39        host: &str,
40        port: u16,
41        use_tls: bool,
42        extra_roots: &[Vec<u8>],
43    ) -> Result<Self, String> {
44        let addr = format!("{host}:{port}");
45        let stream =
46            TcpStream::connect(&addr).map_err(|e| format!("connect to {addr} failed: {e}"))?;
47        stream
48            .set_read_timeout(None)
49            .map_err(|e| format!("set_read_timeout failed: {e}"))?;
50        stream
51            .set_write_timeout(None)
52            .map_err(|e| format!("set_write_timeout failed: {e}"))?;
53        stream
54            .set_nonblocking(true)
55            .map_err(|e| format!("set_nonblocking failed: {e}"))?;
56
57        if use_tls {
58            Self::wrap_tls(host, stream, extra_roots)
59        } else {
60            Ok(Connection::Plain(stream))
61        }
62    }
63
64    fn wrap_tls(host: &str, stream: TcpStream, extra_roots: &[Vec<u8>]) -> Result<Self, String> {
65        let mut root_certs =
66            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
67        for cert_der in extra_roots {
68            root_certs
69                .add(cert_der.clone().into())
70                .map_err(|e| format!("add root cert failed: {e}"))?;
71        }
72        let config = rustls::ClientConfig::builder()
73            .with_root_certificates(root_certs)
74            .with_no_client_auth();
75        let server_name = rustls::pki_types::ServerName::try_from(host)
76            .map_err(|e| format!("invalid server name: {e}"))?
77            .to_owned();
78        let tls_conn = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
79            .map_err(|e| format!("tls handshake failed: {e}"))?;
80        let tls = rustls::StreamOwned::new(tls_conn, stream);
81        Ok(Connection::Tls { tls })
82    }
83
84    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), String> {
85        match self {
86            Connection::Plain(stream) => stream
87                .set_nonblocking(nonblocking)
88                .map_err(|e| format!("set_nonblocking failed: {e}")),
89            Connection::Tls { tls } => tls
90                .sock
91                .set_nonblocking(nonblocking)
92                .map_err(|e| format!("set_nonblocking failed: {e}")),
93        }
94    }
95
96    pub fn raw_fd(&self) -> RawFd {
97        match self {
98            Connection::Plain(stream) => stream.as_raw_fd(),
99            Connection::Tls { tls } => tls.sock.as_raw_fd(),
100        }
101    }
102
103    pub fn is_tls(&self) -> bool {
104        matches!(self, Connection::Tls { .. })
105    }
106
107    pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<(), String> {
108        match self {
109            Connection::Plain(stream) => stream
110                .set_read_timeout(dur)
111                .map_err(|e| format!("set_read_timeout failed: {e}")),
112            Connection::Tls { tls } => tls
113                .sock
114                .set_read_timeout(dur)
115                .map_err(|e| format!("set_read_timeout failed: {e}")),
116        }
117    }
118
119    pub fn try_clone(&self) -> Result<Self, String> {
120        match self {
121            Connection::Plain(stream) => stream
122                .try_clone()
123                .map(Connection::Plain)
124                .map_err(|e| format!("try_clone failed: {e}")),
125            Connection::Tls { .. } => Err("cannot clone TLS connection".into()),
126        }
127    }
128}
129
130impl Read for Connection {
131    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
132        match self {
133            Connection::Plain(stream) => stream.read(buf),
134            Connection::Tls { tls } => tls.read(buf),
135        }
136    }
137}
138
139impl Write for Connection {
140    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
141        match self {
142            Connection::Plain(stream) => stream.write(buf),
143            Connection::Tls { tls } => tls.write(buf),
144        }
145    }
146
147    fn flush(&mut self) -> io::Result<()> {
148        match self {
149            Connection::Plain(stream) => stream.flush(),
150            Connection::Tls { tls } => tls.sock.flush(),
151        }
152    }
153}
154
155impl AsRawFd for Connection {
156    fn as_raw_fd(&self) -> RawFd {
157        self.raw_fd()
158    }
159}