http_req/
tls.rs

1//!secure connection over TLS
2
3use crate::error::Error as HttpError;
4
5#[cfg(feature = "wasmedge_rustls")]
6use std::io;
7#[cfg(not(feature = "wasmedge_rustls"))]
8use std::{
9    fs::File,
10    io::{self, BufReader},
11    path::Path,
12};
13
14#[cfg(feature = "native-tls")]
15use std::io::prelude::*;
16
17#[cfg(feature = "rust-tls")]
18use crate::error::ParseErr;
19
20#[cfg(not(any(
21    feature = "native-tls",
22    feature = "rust-tls",
23    feature = "wasmedge_rustls"
24)))]
25compile_error!("one of the `native-tls` or `rust-tls` features must be enabled");
26
27///wrapper around TLS Stream,
28///depends on selected TLS library
29pub struct Conn<S: io::Read + io::Write> {
30    #[cfg(feature = "native-tls")]
31    stream: native_tls::TlsStream<S>,
32
33    #[cfg(feature = "rust-tls")]
34    stream: rustls::StreamOwned<rustls::ClientSession, S>,
35    #[cfg(feature = "wasmedge_rustls")]
36    stream: wasmedge_rustls_api::stream::StreamOwned<wasmedge_rustls_api::TlsClientCodec, S>,
37}
38
39impl<S: io::Read + io::Write> io::Read for Conn<S> {
40    fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
41        let len = self.stream.read(buf);
42
43        #[cfg(any(feature = "rust-tls", feature = "wasmedge_rustls"))]
44        {
45            // TODO: this api returns ConnectionAborted with a "..CloseNotify.." string.
46            // TODO: we should work out if self.stream.sess exposes enough information
47            // TODO: to not read in this situation, and return EOF directly.
48            // TODO: c.f. the checks in the implementation. connection_at_eof() doesn't
49            // TODO: seem to be exposed. The implementation:
50            // TODO: https://github.com/ctz/rustls/blob/f93c325ce58f2f1e02f09bcae6c48ad3f7bde542/src/session.rs#L789-L792
51            if let Err(ref e) = len {
52                if io::ErrorKind::ConnectionAborted == e.kind() {
53                    return Ok(0);
54                }
55            }
56        }
57
58        len
59    }
60}
61
62impl<S: io::Read + io::Write> io::Write for Conn<S> {
63    fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
64        self.stream.write(buf)
65    }
66    fn flush(&mut self) -> Result<(), io::Error> {
67        self.stream.flush()
68    }
69}
70
71///client configuration
72pub struct Config {
73    #[cfg(feature = "native-tls")]
74    extra_root_certs: Vec<native_tls::Certificate>,
75    #[cfg(feature = "rust-tls")]
76    client_config: std::sync::Arc<rustls::ClientConfig>,
77    #[cfg(feature = "wasmedge_rustls")]
78    client_config: std::sync::Arc<wasmedge_rustls_api::ClientConfig>,
79}
80
81impl Default for Config {
82    #[cfg(feature = "native-tls")]
83    fn default() -> Self {
84        Config {
85            extra_root_certs: vec![],
86        }
87    }
88
89    #[cfg(feature = "rust-tls")]
90    fn default() -> Self {
91        let mut config = rustls::ClientConfig::new();
92        config
93            .root_store
94            .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
95
96        Config {
97            client_config: std::sync::Arc::new(config),
98        }
99    }
100
101    #[cfg(feature = "wasmedge_rustls")]
102    fn default() -> Self {
103        Config {
104            client_config: std::sync::Arc::new(Default::default()),
105        }
106    }
107}
108
109impl Config {
110    #[cfg(feature = "native-tls")]
111    pub fn add_root_cert_file_pem(&mut self, file_path: &Path) -> Result<&mut Self, HttpError> {
112        let f = File::open(file_path)?;
113        let f = BufReader::new(f);
114        let mut pem_crt = vec![];
115        for line in f.lines() {
116            let line = line?;
117            let is_end_cert = line.contains("-----END");
118            pem_crt.append(&mut line.into_bytes());
119            pem_crt.push(b'\n');
120            if is_end_cert {
121                let crt = native_tls::Certificate::from_pem(&pem_crt)?;
122                self.extra_root_certs.push(crt);
123                pem_crt.clear();
124            }
125        }
126        Ok(self)
127    }
128
129    #[cfg(feature = "native-tls")]
130    pub fn connect<H, S>(&self, hostname: H, stream: S) -> Result<Conn<S>, HttpError>
131    where
132        H: AsRef<str>,
133        S: io::Read + io::Write,
134    {
135        let mut connector_builder = native_tls::TlsConnector::builder();
136        for crt in self.extra_root_certs.iter() {
137            connector_builder.add_root_certificate((*crt).clone());
138        }
139        let connector = connector_builder.build()?;
140        let stream = connector.connect(hostname.as_ref(), stream)?;
141
142        Ok(Conn { stream })
143    }
144
145    #[cfg(feature = "rust-tls")]
146    pub fn add_root_cert_file_pem(&mut self, file_path: &Path) -> Result<&mut Self, HttpError> {
147        let f = File::open(file_path)?;
148        let mut f = BufReader::new(f);
149        let config = std::sync::Arc::make_mut(&mut self.client_config);
150        let _ = config
151            .root_store
152            .add_pem_file(&mut f)
153            .map_err(|_| HttpError::from(ParseErr::Invalid))?;
154        Ok(self)
155    }
156
157    #[cfg(feature = "rust-tls")]
158    pub fn connect<H, S>(&self, hostname: H, stream: S) -> Result<Conn<S>, HttpError>
159    where
160        H: AsRef<str>,
161        S: io::Read + io::Write,
162    {
163        use rustls::{ClientSession, StreamOwned};
164
165        let session = ClientSession::new(
166            &self.client_config,
167            webpki::DNSNameRef::try_from_ascii_str(hostname.as_ref())
168                .map_err(|_| HttpError::Tls)?,
169        );
170        let stream = StreamOwned::new(session, stream);
171
172        Ok(Conn { stream })
173    }
174
175    #[cfg(feature = "wasmedge_rustls")]
176    pub fn connect<H, S>(&self, hostname: H, stream: S) -> Result<Conn<S>, HttpError>
177    where
178        H: AsRef<str>,
179        S: io::Read + io::Write,
180    {
181        use wasmedge_rustls_api::stream::StreamOwned;
182
183        let session = self
184            .client_config
185            .new_codec(hostname)
186            .map_err(|_| HttpError::Tls)?;
187
188        let stream = StreamOwned::new(session, stream);
189
190        Ok(Conn { stream })
191    }
192}