1use std::{
2 fmt::Display,
3 io::{Read, Write},
4};
5
6mod rustls_client {
7 #[link(wasm_import_module = "rustls_client")]
8 extern "C" {
9 pub fn default_config() -> i32;
10 pub fn new_codec(config: i32, server_ptr: i32, server_len: i32) -> i32;
11 pub fn codec_is_handshaking(codec_id: i32) -> i32;
12 pub fn codec_wants(codec_id: i32) -> i32;
13 pub fn delete_codec(codec_id: i32) -> i32;
14 pub fn send_close_notify(codec_id: i32) -> i32;
15 pub fn process_new_packets(codec_id: i32, io_state_ptr: i32) -> i32;
16 pub fn write_tls(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
17 pub fn write_raw(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
18 pub fn read_tls(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
19 pub fn read_raw(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
20 }
21}
22
23#[derive(Debug, Clone)]
24#[repr(C)]
25pub struct TlsIoState {
26 pub tls_bytes_to_write: u32,
27 pub plaintext_bytes_to_read: u32,
28 pub peer_has_closed: bool,
29}
30
31#[derive(Debug, Clone, Copy)]
33pub enum TlsError {
34 ParamError,
35 InappropriateMessage,
36 InappropriateHandshakeMessage,
37 CorruptMessage,
38 CorruptMessagePayload,
39 NoCertificatesPresented,
40 UnsupportedNameType,
41 DecryptError,
42 EncryptError,
43 PeerIncompatibleError,
44 PeerMisbehavedError,
45 AlertReceived,
46 InvalidCertificateEncoding,
47 InvalidCertificateSignatureType,
48 InvalidCertificateSignature,
49 InvalidCertificateData,
50 InvalidSct,
51 General,
52 FailedToGetCurrentTime,
53 FailedToGetRandomBytes,
54 HandshakeNotComplete,
55 PeerSentOversizedRecord,
56 NoApplicationProtocol,
57 BadMaxFragmentSize,
58 IOWouldBlock,
59 IO,
60}
61
62impl Into<TlsError> for i32 {
63 fn into(self) -> TlsError {
64 match self {
65 -1 => TlsError::ParamError,
66 -2 => TlsError::InappropriateMessage,
67 -3 => TlsError::InappropriateHandshakeMessage,
68 -4 => TlsError::CorruptMessage,
69 -5 => TlsError::CorruptMessagePayload,
70 -6 => TlsError::NoCertificatesPresented,
71 -7 => TlsError::UnsupportedNameType,
72 -8 => TlsError::DecryptError,
73 -9 => TlsError::EncryptError,
74 -10 => TlsError::PeerIncompatibleError,
75 -11 => TlsError::PeerMisbehavedError,
76 -12 => TlsError::AlertReceived,
77 -13 => TlsError::InvalidCertificateEncoding,
78 -14 => TlsError::InvalidCertificateSignatureType,
79 -15 => TlsError::InvalidCertificateSignature,
80 -16 => TlsError::InvalidCertificateData,
81 -17 => TlsError::InvalidSct,
82 -18 => TlsError::General,
83 -19 => TlsError::FailedToGetCurrentTime,
84 -20 => TlsError::FailedToGetRandomBytes,
85 -21 => TlsError::HandshakeNotComplete,
86 -22 => TlsError::PeerSentOversizedRecord,
87 -23 => TlsError::NoApplicationProtocol,
88 -24 => TlsError::BadMaxFragmentSize,
89 -25 => TlsError::IOWouldBlock,
90 -26 => TlsError::IO,
91 _ => TlsError::ParamError,
92 }
93 }
94}
95
96impl Display for TlsError {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 write!(f, "{:?}", self)
99 }
100}
101
102impl std::error::Error for TlsError {}
103
104impl From<TlsError> for std::io::Error {
105 fn from(value: TlsError) -> Self {
106 if let TlsError::IOWouldBlock = value {
107 std::io::ErrorKind::WouldBlock.into()
108 } else {
109 std::io::Error::new(std::io::ErrorKind::InvalidInput, value)
110 }
111 }
112}
113pub struct ClientConfig {
114 id: i32,
115}
116
117impl ClientConfig {
118 pub fn new_codec<S: AsRef<str>>(&self, server_name: S) -> Result<TlsClientCodec, TlsError> {
119 unsafe {
120 let server_name = server_name.as_ref();
121 let server_ptr = server_name.as_ptr();
122 let server_len = server_name.len();
123 let id = rustls_client::new_codec(self.id, server_ptr as i32, server_len as i32);
124 if id < 0 {
125 return Err(id.into());
126 }
127 Ok(TlsClientCodec {
128 id,
129 read_buf: VecBuffer::new(vec![0; 1024 * 4]),
130 write_buf: VecBuffer::new(vec![0; 1024 * 4]),
131 })
132 }
133 }
134}
135
136impl Default for ClientConfig {
137 fn default() -> Self {
138 let id = unsafe { rustls_client::default_config() };
139 Self { id }
140 }
141}
142
143#[derive(Debug)]
144pub struct VecBuffer {
145 buf: Vec<u8>,
146 pub used: usize,
147 pub filled: usize,
148}
149
150impl VecBuffer {
151 pub fn new(buf: Vec<u8>) -> Self {
152 Self {
153 buf,
154 used: 0,
155 filled: 0,
156 }
157 }
158 pub fn from_reader<R: Read>(&mut self, rd: &mut R) -> std::io::Result<()> {
159 let n = rd.read(self.mut_rest_buf())?;
160 self.filled += n;
161 Ok(())
162 }
163
164 pub fn mut_rest_buf(&mut self) -> &mut [u8] {
165 &mut self.buf[self.filled..]
166 }
167
168 pub fn get_available_buf(&self) -> &[u8] {
169 &self.buf[self.used..self.filled]
170 }
171
172 pub fn write_to(
173 &mut self,
174 f: &mut dyn FnMut(&[u8]) -> std::io::Result<usize>,
175 ) -> std::io::Result<usize> {
176 let n = f(self.get_available_buf())?;
177 self.used += n;
178 self.clear();
179 Ok(n)
180 }
181
182 pub fn read_from(
183 &mut self,
184 f: &mut dyn FnMut(&mut [u8]) -> std::io::Result<usize>,
185 ) -> std::io::Result<usize> {
186 let n = f(self.mut_rest_buf())?;
187 self.filled += n;
188 Ok(n)
189 }
190
191 pub fn clear(&mut self) {
192 if self.used == self.filled {
193 self.used = 0;
194 self.filled = 0;
195 }
196 }
197}
198
199#[derive(Debug)]
200pub struct TlsClientCodec {
201 id: i32,
202 pub read_buf: VecBuffer,
203 pub write_buf: VecBuffer,
204}
205
206#[derive(Debug)]
207pub struct WantsResult {
208 pub wants_read: bool,
209 pub wants_write: bool,
210}
211
212impl TlsClientCodec {
213 pub fn is_handshaking(&self) -> bool {
214 unsafe { rustls_client::codec_is_handshaking(self.id) > 0 }
215 }
216
217 pub fn wants(&self) -> WantsResult {
219 unsafe {
220 let i = rustls_client::codec_wants(self.id);
221 WantsResult {
222 wants_read: i & 0b01 > 0,
223 wants_write: i & 0b010 > 0,
224 }
225 }
226 }
227
228 pub fn send_close_notify(&mut self) -> Result<(), TlsError> {
229 unsafe {
230 let e = rustls_client::send_close_notify(self.id);
231 if e < 0 {
232 Err(e.into())
233 } else {
234 Ok(())
235 }
236 }
237 }
238
239 pub fn process_new_packets(&mut self) -> Result<TlsIoState, TlsError> {
240 unsafe {
241 let mut io_state = TlsIoState {
242 tls_bytes_to_write: 0,
243 plaintext_bytes_to_read: 0,
244 peer_has_closed: false,
245 };
246 let e = rustls_client::process_new_packets(
247 self.id,
248 (&mut io_state) as *mut _ as usize as i32,
249 );
250 if e < 0 {
251 Err(e.into())
252 } else {
253 Ok(io_state)
254 }
255 }
256 }
257
258 pub fn write_tls(&mut self, tls_buf: &mut [u8]) -> Result<usize, TlsError> {
259 unsafe {
260 let e =
261 rustls_client::write_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
262 if e < 0 {
263 Err(e.into())
264 } else {
265 Ok(e as usize)
266 }
267 }
268 }
269
270 pub fn write_raw(&mut self, raw_buf: &[u8]) -> Result<usize, TlsError> {
271 unsafe {
272 let e =
273 rustls_client::write_raw(self.id, raw_buf.as_ptr() as i32, raw_buf.len() as i32);
274 if e < 0 {
275 Err(e.into())
276 } else {
277 Ok(e as usize)
278 }
279 }
280 }
281
282 pub fn read_tls(&mut self, tls_buf: &[u8]) -> Result<usize, TlsError> {
283 unsafe {
284 let e = rustls_client::read_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
285 if e < 0 {
286 Err(e.into())
287 } else {
288 Ok(e as usize)
289 }
290 }
291 }
292
293 pub fn read_raw(&mut self, raw_buf: &mut [u8]) -> Result<usize, TlsError> {
294 unsafe {
295 let e = rustls_client::read_raw(self.id, raw_buf.as_ptr() as i32, raw_buf.len() as i32);
296 if e < 0 {
297 Err(e.into())
298 } else {
299 Ok(e as usize)
300 }
301 }
302 }
303}
304
305impl TlsClientCodec {
306 pub fn read_tls_from_io<R: Read>(&mut self, io: &mut R) -> std::io::Result<usize> {
307 self.read_buf.from_reader(io)?;
308 let id = self.id;
309 self.read_buf.write_to(&mut |tls_buf| unsafe {
310 let e = rustls_client::read_tls(id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
311 if e < 0 {
312 let tls_err: TlsError = e.into();
313 Err(tls_err.into())
314 } else {
315 Ok(e as usize)
316 }
317 })
318 }
319
320 pub fn write_tls_to_io<W: Write>(&mut self, io: &mut W) -> std::io::Result<usize> {
321 let id = self.id;
322 self.write_buf.read_from(&mut |tls_buf| unsafe {
323 let e = rustls_client::write_tls(id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
324 if e < 0 {
325 let tls_err: TlsError = e.into();
326 Err(tls_err.into())
327 } else {
328 Ok(e as usize)
329 }
330 })?;
331 self.write_buf.write_to(&mut |buf| io.write(buf))
332 }
333
334 pub fn poll_write_tls_to_io(
335 &mut self,
336 f: &mut dyn FnMut(&[u8]) -> std::task::Poll<std::io::Result<usize>>,
337 ) -> std::task::Poll<std::io::Result<usize>> {
338 let n = self.write_buf.read_from(&mut |tls_buf| unsafe {
339 let e =
340 rustls_client::write_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
341 if e < 0 {
342 let tls_err: TlsError = e.into();
343 Err(tls_err.into())
344 } else {
345 Ok(e as usize)
346 }
347 });
348 if let Err(e) = n {
349 return std::task::Poll::Ready(Err(e));
350 }
351 let n = f(self.write_buf.get_available_buf());
352 if let std::task::Poll::Ready(Ok(n)) = &n {
353 self.write_buf.used += *n;
354 self.write_buf.clear();
355 }
356 n
357 }
358
359 pub fn poll_read_tls_from_io(
360 &mut self,
361 f: &mut dyn FnMut(&mut [u8]) -> std::task::Poll<std::io::Result<usize>>,
362 ) -> std::task::Poll<std::io::Result<usize>> {
363 let n = f(self.read_buf.mut_rest_buf());
364 if let std::task::Poll::Ready(Ok(n)) = &n {
365 self.read_buf.filled += *n;
366 let r = self.read_buf.write_to(&mut |tls_buf| unsafe {
367 let e =
368 rustls_client::read_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
369 if e < 0 {
370 let tls_err: TlsError = e.into();
371 Err(tls_err.into())
372 } else {
373 Ok(e as usize)
374 }
375 });
376 std::task::Poll::Ready(r)
377 } else {
378 n
379 }
380 }
381}
382
383impl Drop for TlsClientCodec {
384 fn drop(&mut self) {
385 unsafe { rustls_client::delete_codec(self.id) };
386 }
387}
388
389pub fn complete_io<T>(codec: &mut TlsClientCodec, io: &mut T) -> std::io::Result<(usize, usize)>
390where
391 T: std::io::Read + std::io::Write,
392{
393 let until_handshaked = codec.is_handshaking();
394 let mut eof = false;
395 let mut wrlen = 0;
396 let mut rdlen = 0;
397 let mut buf = [0u8; 1024 * 4];
398
399 loop {
400 while codec.wants().wants_write {
401 let n = codec.write_tls_to_io(io)?;
402 wrlen += n;
403 }
404
405 if !until_handshaked && wrlen > 0 {
406 return Ok((rdlen, wrlen));
407 }
408
409 if !eof && codec.wants().wants_read {
410 match codec.read_tls_from_io(io) {
411 Ok(0) => {
412 eof = true;
413 }
414 Ok(n) => {
415 rdlen += n;
416 }
417 Err(err) => return Err(err.into()),
418 };
419 }
420
421 match codec.process_new_packets() {
422 Ok(_) => {}
423 Err(e) => {
424 let n = codec.write_tls(&mut buf)?;
428 let _ignored = io.write_all(&buf[0..n]);
429
430 return Err(e.into());
431 }
432 };
433
434 match (eof, until_handshaked, codec.is_handshaking()) {
435 (_, true, false) => return Ok((rdlen, wrlen)),
436 (_, false, _) => return Ok((rdlen, wrlen)),
437 (true, true, true) => {
438 return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))
439 }
440 (..) => {}
441 }
442 }
443}