1use crate::error::Error as HttpError;
4
5#[cfg(feature = "wasmedge_rustls")]
6use std::io;
7#[cfg(not(feature = "wasmedge_rustls"))]
8use std::{
9 fs::File,
10 io::{self, BufReader},
11 path::Path,
12};
13
14#[cfg(feature = "native-tls")]
15use std::io::prelude::*;
16
17#[cfg(feature = "rust-tls")]
18use crate::error::ParseErr;
19
20#[cfg(not(any(
21 feature = "native-tls",
22 feature = "rust-tls",
23 feature = "wasmedge_rustls"
24)))]
25compile_error!("one of the `native-tls` or `rust-tls` features must be enabled");
26
27pub struct Conn<S: io::Read + io::Write> {
30 #[cfg(feature = "native-tls")]
31 stream: native_tls::TlsStream<S>,
32
33 #[cfg(feature = "rust-tls")]
34 stream: rustls::StreamOwned<rustls::ClientSession, S>,
35 #[cfg(feature = "wasmedge_rustls")]
36 stream: wasmedge_rustls_api::stream::StreamOwned<wasmedge_rustls_api::TlsClientCodec, S>,
37}
38
39impl<S: io::Read + io::Write> io::Read for Conn<S> {
40 fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
41 let len = self.stream.read(buf);
42
43 #[cfg(any(feature = "rust-tls", feature = "wasmedge_rustls"))]
44 {
45 if let Err(ref e) = len {
52 if io::ErrorKind::ConnectionAborted == e.kind() {
53 return Ok(0);
54 }
55 }
56 }
57
58 len
59 }
60}
61
62impl<S: io::Read + io::Write> io::Write for Conn<S> {
63 fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
64 self.stream.write(buf)
65 }
66 fn flush(&mut self) -> Result<(), io::Error> {
67 self.stream.flush()
68 }
69}
70
71pub struct Config {
73 #[cfg(feature = "native-tls")]
74 extra_root_certs: Vec<native_tls::Certificate>,
75 #[cfg(feature = "rust-tls")]
76 client_config: std::sync::Arc<rustls::ClientConfig>,
77 #[cfg(feature = "wasmedge_rustls")]
78 client_config: std::sync::Arc<wasmedge_rustls_api::ClientConfig>,
79}
80
81impl Default for Config {
82 #[cfg(feature = "native-tls")]
83 fn default() -> Self {
84 Config {
85 extra_root_certs: vec![],
86 }
87 }
88
89 #[cfg(feature = "rust-tls")]
90 fn default() -> Self {
91 let mut config = rustls::ClientConfig::new();
92 config
93 .root_store
94 .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
95
96 Config {
97 client_config: std::sync::Arc::new(config),
98 }
99 }
100
101 #[cfg(feature = "wasmedge_rustls")]
102 fn default() -> Self {
103 Config {
104 client_config: std::sync::Arc::new(Default::default()),
105 }
106 }
107}
108
109impl Config {
110 #[cfg(feature = "native-tls")]
111 pub fn add_root_cert_file_pem(&mut self, file_path: &Path) -> Result<&mut Self, HttpError> {
112 let f = File::open(file_path)?;
113 let f = BufReader::new(f);
114 let mut pem_crt = vec![];
115 for line in f.lines() {
116 let line = line?;
117 let is_end_cert = line.contains("-----END");
118 pem_crt.append(&mut line.into_bytes());
119 pem_crt.push(b'\n');
120 if is_end_cert {
121 let crt = native_tls::Certificate::from_pem(&pem_crt)?;
122 self.extra_root_certs.push(crt);
123 pem_crt.clear();
124 }
125 }
126 Ok(self)
127 }
128
129 #[cfg(feature = "native-tls")]
130 pub fn connect<H, S>(&self, hostname: H, stream: S) -> Result<Conn<S>, HttpError>
131 where
132 H: AsRef<str>,
133 S: io::Read + io::Write,
134 {
135 let mut connector_builder = native_tls::TlsConnector::builder();
136 for crt in self.extra_root_certs.iter() {
137 connector_builder.add_root_certificate((*crt).clone());
138 }
139 let connector = connector_builder.build()?;
140 let stream = connector.connect(hostname.as_ref(), stream)?;
141
142 Ok(Conn { stream })
143 }
144
145 #[cfg(feature = "rust-tls")]
146 pub fn add_root_cert_file_pem(&mut self, file_path: &Path) -> Result<&mut Self, HttpError> {
147 let f = File::open(file_path)?;
148 let mut f = BufReader::new(f);
149 let config = std::sync::Arc::make_mut(&mut self.client_config);
150 let _ = config
151 .root_store
152 .add_pem_file(&mut f)
153 .map_err(|_| HttpError::from(ParseErr::Invalid))?;
154 Ok(self)
155 }
156
157 #[cfg(feature = "rust-tls")]
158 pub fn connect<H, S>(&self, hostname: H, stream: S) -> Result<Conn<S>, HttpError>
159 where
160 H: AsRef<str>,
161 S: io::Read + io::Write,
162 {
163 use rustls::{ClientSession, StreamOwned};
164
165 let session = ClientSession::new(
166 &self.client_config,
167 webpki::DNSNameRef::try_from_ascii_str(hostname.as_ref())
168 .map_err(|_| HttpError::Tls)?,
169 );
170 let stream = StreamOwned::new(session, stream);
171
172 Ok(Conn { stream })
173 }
174
175 #[cfg(feature = "wasmedge_rustls")]
176 pub fn connect<H, S>(&self, hostname: H, stream: S) -> Result<Conn<S>, HttpError>
177 where
178 H: AsRef<str>,
179 S: io::Read + io::Write,
180 {
181 use wasmedge_rustls_api::stream::StreamOwned;
182
183 let session = self
184 .client_config
185 .new_codec(hostname)
186 .map_err(|_| HttpError::Tls)?;
187
188 let stream = StreamOwned::new(session, stream);
189
190 Ok(Conn { stream })
191 }
192}