1use std::{
2    fmt::Display,
3    io::{Read, Write},
4};
5
6mod rustls_client {
7    #[link(wasm_import_module = "rustls_client")]
8    extern "C" {
9        pub fn default_config() -> i32;
10        pub fn new_codec(config: i32, server_ptr: i32, server_len: i32) -> i32;
11        pub fn codec_is_handshaking(codec_id: i32) -> i32;
12        pub fn codec_wants(codec_id: i32) -> i32;
13        pub fn delete_codec(codec_id: i32) -> i32;
14        pub fn send_close_notify(codec_id: i32) -> i32;
15        pub fn process_new_packets(codec_id: i32, io_state_ptr: i32) -> i32;
16        pub fn write_tls(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
17        pub fn write_raw(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
18        pub fn read_tls(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
19        pub fn read_raw(codec_id: i32, buf_ptr: i32, buf_len: i32) -> i32;
20    }
21}
22
23#[derive(Debug, Clone)]
24#[repr(C)]
25pub struct TlsIoState {
26    pub tls_bytes_to_write: u32,
27    pub plaintext_bytes_to_read: u32,
28    pub peer_has_closed: bool,
29}
30
31#[derive(Debug, Clone, Copy)]
33pub enum TlsError {
34    ParamError,
35    InappropriateMessage,
36    InappropriateHandshakeMessage,
37    CorruptMessage,
38    CorruptMessagePayload,
39    NoCertificatesPresented,
40    UnsupportedNameType,
41    DecryptError,
42    EncryptError,
43    PeerIncompatibleError,
44    PeerMisbehavedError,
45    AlertReceived,
46    InvalidCertificateEncoding,
47    InvalidCertificateSignatureType,
48    InvalidCertificateSignature,
49    InvalidCertificateData,
50    InvalidSct,
51    General,
52    FailedToGetCurrentTime,
53    FailedToGetRandomBytes,
54    HandshakeNotComplete,
55    PeerSentOversizedRecord,
56    NoApplicationProtocol,
57    BadMaxFragmentSize,
58    IOWouldBlock,
59    IO,
60}
61
62impl Into<TlsError> for i32 {
63    fn into(self) -> TlsError {
64        match self {
65            -1 => TlsError::ParamError,
66            -2 => TlsError::InappropriateMessage,
67            -3 => TlsError::InappropriateHandshakeMessage,
68            -4 => TlsError::CorruptMessage,
69            -5 => TlsError::CorruptMessagePayload,
70            -6 => TlsError::NoCertificatesPresented,
71            -7 => TlsError::UnsupportedNameType,
72            -8 => TlsError::DecryptError,
73            -9 => TlsError::EncryptError,
74            -10 => TlsError::PeerIncompatibleError,
75            -11 => TlsError::PeerMisbehavedError,
76            -12 => TlsError::AlertReceived,
77            -13 => TlsError::InvalidCertificateEncoding,
78            -14 => TlsError::InvalidCertificateSignatureType,
79            -15 => TlsError::InvalidCertificateSignature,
80            -16 => TlsError::InvalidCertificateData,
81            -17 => TlsError::InvalidSct,
82            -18 => TlsError::General,
83            -19 => TlsError::FailedToGetCurrentTime,
84            -20 => TlsError::FailedToGetRandomBytes,
85            -21 => TlsError::HandshakeNotComplete,
86            -22 => TlsError::PeerSentOversizedRecord,
87            -23 => TlsError::NoApplicationProtocol,
88            -24 => TlsError::BadMaxFragmentSize,
89            -25 => TlsError::IOWouldBlock,
90            -26 => TlsError::IO,
91            _ => TlsError::ParamError,
92        }
93    }
94}
95
96impl Display for TlsError {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{:?}", self)
99    }
100}
101
102impl std::error::Error for TlsError {}
103
104impl From<TlsError> for std::io::Error {
105    fn from(value: TlsError) -> Self {
106        if let TlsError::IOWouldBlock = value {
107            std::io::ErrorKind::WouldBlock.into()
108        } else {
109            std::io::Error::new(std::io::ErrorKind::InvalidInput, value)
110        }
111    }
112}
113pub struct ClientConfig {
114    id: i32,
115}
116
117impl ClientConfig {
118    pub fn new_codec<S: AsRef<str>>(&self, server_name: S) -> Result<TlsClientCodec, TlsError> {
119        unsafe {
120            let server_name = server_name.as_ref();
121            let server_ptr = server_name.as_ptr();
122            let server_len = server_name.len();
123            let id = rustls_client::new_codec(self.id, server_ptr as i32, server_len as i32);
124            if id < 0 {
125                return Err(id.into());
126            }
127            Ok(TlsClientCodec {
128                id,
129                read_buf: VecBuffer::new(vec![0; 1024 * 4]),
130                write_buf: VecBuffer::new(vec![0; 1024 * 4]),
131            })
132        }
133    }
134}
135
136impl Default for ClientConfig {
137    fn default() -> Self {
138        let id = unsafe { rustls_client::default_config() };
139        Self { id }
140    }
141}
142
143#[derive(Debug)]
144pub struct VecBuffer {
145    buf: Vec<u8>,
146    pub used: usize,
147    pub filled: usize,
148}
149
150impl VecBuffer {
151    pub fn new(buf: Vec<u8>) -> Self {
152        Self {
153            buf,
154            used: 0,
155            filled: 0,
156        }
157    }
158    pub fn from_reader<R: Read>(&mut self, rd: &mut R) -> std::io::Result<()> {
159        let n = rd.read(self.mut_rest_buf())?;
160        self.filled += n;
161        Ok(())
162    }
163
164    pub fn mut_rest_buf(&mut self) -> &mut [u8] {
165        &mut self.buf[self.filled..]
166    }
167
168    pub fn get_available_buf(&self) -> &[u8] {
169        &self.buf[self.used..self.filled]
170    }
171
172    pub fn write_to(
173        &mut self,
174        f: &mut dyn FnMut(&[u8]) -> std::io::Result<usize>,
175    ) -> std::io::Result<usize> {
176        let n = f(self.get_available_buf())?;
177        self.used += n;
178        self.clear();
179        Ok(n)
180    }
181
182    pub fn read_from(
183        &mut self,
184        f: &mut dyn FnMut(&mut [u8]) -> std::io::Result<usize>,
185    ) -> std::io::Result<usize> {
186        let n = f(self.mut_rest_buf())?;
187        self.filled += n;
188        Ok(n)
189    }
190
191    pub fn clear(&mut self) {
192        if self.used == self.filled {
193            self.used = 0;
194            self.filled = 0;
195        }
196    }
197}
198
199#[derive(Debug)]
200pub struct TlsClientCodec {
201    id: i32,
202    pub read_buf: VecBuffer,
203    pub write_buf: VecBuffer,
204}
205
206#[derive(Debug)]
207pub struct WantsResult {
208    pub wants_read: bool,
209    pub wants_write: bool,
210}
211
212impl TlsClientCodec {
213    pub fn is_handshaking(&self) -> bool {
214        unsafe { rustls_client::codec_is_handshaking(self.id) > 0 }
215    }
216
217    pub fn wants(&self) -> WantsResult {
219        unsafe {
220            let i = rustls_client::codec_wants(self.id);
221            WantsResult {
222                wants_read: i & 0b01 > 0,
223                wants_write: i & 0b010 > 0,
224            }
225        }
226    }
227
228    pub fn send_close_notify(&mut self) -> Result<(), TlsError> {
229        unsafe {
230            let e = rustls_client::send_close_notify(self.id);
231            if e < 0 {
232                Err(e.into())
233            } else {
234                Ok(())
235            }
236        }
237    }
238
239    pub fn process_new_packets(&mut self) -> Result<TlsIoState, TlsError> {
240        unsafe {
241            let mut io_state = TlsIoState {
242                tls_bytes_to_write: 0,
243                plaintext_bytes_to_read: 0,
244                peer_has_closed: false,
245            };
246            let e = rustls_client::process_new_packets(
247                self.id,
248                (&mut io_state) as *mut _ as usize as i32,
249            );
250            if e < 0 {
251                Err(e.into())
252            } else {
253                Ok(io_state)
254            }
255        }
256    }
257
258    pub fn write_tls(&mut self, tls_buf: &mut [u8]) -> Result<usize, TlsError> {
259        unsafe {
260            let e =
261                rustls_client::write_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
262            if e < 0 {
263                Err(e.into())
264            } else {
265                Ok(e as usize)
266            }
267        }
268    }
269
270    pub fn write_raw(&mut self, raw_buf: &[u8]) -> Result<usize, TlsError> {
271        unsafe {
272            let e =
273                rustls_client::write_raw(self.id, raw_buf.as_ptr() as i32, raw_buf.len() as i32);
274            if e < 0 {
275                Err(e.into())
276            } else {
277                Ok(e as usize)
278            }
279        }
280    }
281
282    pub fn read_tls(&mut self, tls_buf: &[u8]) -> Result<usize, TlsError> {
283        unsafe {
284            let e = rustls_client::read_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
285            if e < 0 {
286                Err(e.into())
287            } else {
288                Ok(e as usize)
289            }
290        }
291    }
292
293    pub fn read_raw(&mut self, raw_buf: &mut [u8]) -> Result<usize, TlsError> {
294        unsafe {
295            let e = rustls_client::read_raw(self.id, raw_buf.as_ptr() as i32, raw_buf.len() as i32);
296            if e < 0 {
297                Err(e.into())
298            } else {
299                Ok(e as usize)
300            }
301        }
302    }
303}
304
305impl TlsClientCodec {
306    pub fn read_tls_from_io<R: Read>(&mut self, io: &mut R) -> std::io::Result<usize> {
307        self.read_buf.from_reader(io)?;
308        let id = self.id;
309        self.read_buf.write_to(&mut |tls_buf| unsafe {
310            let e = rustls_client::read_tls(id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
311            if e < 0 {
312                let tls_err: TlsError = e.into();
313                Err(tls_err.into())
314            } else {
315                Ok(e as usize)
316            }
317        })
318    }
319
320    pub fn write_tls_to_io<W: Write>(&mut self, io: &mut W) -> std::io::Result<usize> {
321        let id = self.id;
322        self.write_buf.read_from(&mut |tls_buf| unsafe {
323            let e = rustls_client::write_tls(id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
324            if e < 0 {
325                let tls_err: TlsError = e.into();
326                Err(tls_err.into())
327            } else {
328                Ok(e as usize)
329            }
330        })?;
331        self.write_buf.write_to(&mut |buf| io.write(buf))
332    }
333
334    pub fn poll_write_tls_to_io(
335        &mut self,
336        f: &mut dyn FnMut(&[u8]) -> std::task::Poll<std::io::Result<usize>>,
337    ) -> std::task::Poll<std::io::Result<usize>> {
338        let n = self.write_buf.read_from(&mut |tls_buf| unsafe {
339            let e =
340                rustls_client::write_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
341            if e < 0 {
342                let tls_err: TlsError = e.into();
343                Err(tls_err.into())
344            } else {
345                Ok(e as usize)
346            }
347        });
348        if let Err(e) = n {
349            return std::task::Poll::Ready(Err(e));
350        }
351        let n = f(self.write_buf.get_available_buf());
352        if let std::task::Poll::Ready(Ok(n)) = &n {
353            self.write_buf.used += *n;
354            self.write_buf.clear();
355        }
356        n
357    }
358
359    pub fn poll_read_tls_from_io(
360        &mut self,
361        f: &mut dyn FnMut(&mut [u8]) -> std::task::Poll<std::io::Result<usize>>,
362    ) -> std::task::Poll<std::io::Result<usize>> {
363        let n = f(self.read_buf.mut_rest_buf());
364        if let std::task::Poll::Ready(Ok(n)) = &n {
365            self.read_buf.filled += *n;
366            let r = self.read_buf.write_to(&mut |tls_buf| unsafe {
367                let e =
368                    rustls_client::read_tls(self.id, tls_buf.as_ptr() as i32, tls_buf.len() as i32);
369                if e < 0 {
370                    let tls_err: TlsError = e.into();
371                    Err(tls_err.into())
372                } else {
373                    Ok(e as usize)
374                }
375            });
376            std::task::Poll::Ready(r)
377        } else {
378            n
379        }
380    }
381}
382
383impl Drop for TlsClientCodec {
384    fn drop(&mut self) {
385        unsafe { rustls_client::delete_codec(self.id) };
386    }
387}
388
389pub fn complete_io<T>(codec: &mut TlsClientCodec, io: &mut T) -> std::io::Result<(usize, usize)>
390where
391    T: std::io::Read + std::io::Write,
392{
393    let until_handshaked = codec.is_handshaking();
394    let mut eof = false;
395    let mut wrlen = 0;
396    let mut rdlen = 0;
397    let mut buf = [0u8; 1024 * 4];
398
399    loop {
400        while codec.wants().wants_write {
401            let n = codec.write_tls_to_io(io)?;
402            wrlen += n;
403        }
404
405        if !until_handshaked && wrlen > 0 {
406            return Ok((rdlen, wrlen));
407        }
408
409        if !eof && codec.wants().wants_read {
410            match codec.read_tls_from_io(io) {
411                Ok(0) => {
412                    eof = true;
413                }
414                Ok(n) => {
415                    rdlen += n;
416                }
417                Err(err) => return Err(err.into()),
418            };
419        }
420
421        match codec.process_new_packets() {
422            Ok(_) => {}
423            Err(e) => {
424                let n = codec.write_tls(&mut buf)?;
428                let _ignored = io.write_all(&buf[0..n]);
429
430                return Err(e.into());
431            }
432        };
433
434        match (eof, until_handshaked, codec.is_handshaking()) {
435            (_, true, false) => return Ok((rdlen, wrlen)),
436            (_, false, _) => return Ok((rdlen, wrlen)),
437            (true, true, true) => {
438                return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))
439            }
440            (..) => {}
441        }
442    }
443}