futures_loco_protocol/
secure.rs

1use std::{
2    io::{self, Cursor, ErrorKind, Read},
3    mem,
4    pin::Pin,
5    task::{ready, Context, Poll},
6};
7
8use futures_io::{AsyncRead, AsyncWrite};
9use loco_protocol::secure::{
10    client::{LocoClientSecureLayer, ReadState as LayerReadState, rsa::RsaPublicKey},
11    SecurePacket,
12};
13use rand::RngCore;
14
15pub use loco_protocol::secure::client::rsa;
16
17pin_project_lite::pin_project! {
18    #[derive(Debug)]
19    pub struct LocoSecureStream<T> {
20        read_state: ReadState,
21        write_state: WriteState,
22
23        layer: LocoClientSecureLayer,
24
25        #[pin]
26        inner: T,
27    }
28}
29
30impl<T> LocoSecureStream<T> {
31    pub const MAX_IO_SIZE: u64 = 16 * 1024 * 1024;
32
33    pub fn new(rsa_key: RsaPublicKey, inner: T) -> Self {
34        let mut key = [0_u8; 16];
35        rand::thread_rng().fill_bytes(&mut key);
36
37        Self {
38            read_state: ReadState::Pending,
39            write_state: WriteState::Initial(rsa_key),
40
41            layer: LocoClientSecureLayer::new(key),
42
43            inner,
44        }
45    }
46
47    pub fn inner(&self) -> &T {
48        &self.inner
49    }
50
51    pub fn inner_mut(&mut self) -> &mut T {
52        &mut self.inner
53    }
54
55    pub fn into_inner(self) -> T {
56        self.inner
57    }
58}
59
60impl<T: AsyncRead> AsyncRead for LocoSecureStream<T> {
61    fn poll_read(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &mut [u8],
65    ) -> Poll<io::Result<usize>> {
66        let mut this = self.project();
67
68        loop {
69            match mem::replace(this.read_state, ReadState::Corrupted) {
70                ReadState::Pending => {
71                    if let Some(packet) = this.layer.read() {
72                        *this.read_state = ReadState::Reading(Cursor::new(packet.data));
73                    } else {
74                        if let LayerReadState::Header(header) = this.layer.read_state() {
75                            if header.size as u64 - 16 > Self::MAX_IO_SIZE {
76                                *this.read_state = ReadState::PacketTooLarge;
77                                continue;
78                            }
79                        }
80
81                        let mut read_buf = [0_u8; 1024];
82
83                        *this.read_state = ReadState::Pending;
84
85                        let read = ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
86                        if read == 0 {
87                            *this.read_state = ReadState::Done;
88                            continue;
89                        }
90
91                        this.layer.read_buffer.extend(&read_buf[..read]);
92                    }
93                }
94
95                ReadState::Reading(mut cursor) => {
96                    let read = cursor.read(buf)?;
97
98                    *this.read_state = if cursor.position() as usize == cursor.get_ref().len() {
99                        ReadState::Pending
100                    } else {
101                        ReadState::Reading(cursor)
102                    };
103
104                    break Poll::Ready(Ok(read));
105                }
106
107                ReadState::PacketTooLarge => {
108                    *this.read_state = ReadState::PacketTooLarge;
109
110                    break Poll::Ready(Err(io::Error::new(
111                        ErrorKind::InvalidData,
112                        "packet is too large",
113                    )));
114                }
115
116                ReadState::Done => break Poll::Ready(Err(ErrorKind::UnexpectedEof.into())),
117
118                ReadState::Corrupted => unreachable!(),
119            }
120        }
121    }
122}
123
124impl<T: AsyncWrite> AsyncWrite for LocoSecureStream<T> {
125    fn poll_write(
126        self: Pin<&mut Self>,
127        cx: &mut Context<'_>,
128        buf: &[u8],
129    ) -> Poll<io::Result<usize>> {
130        let mut this = self.project();
131
132        loop {
133            match mem::replace(this.write_state, WriteState::Corrupted) {
134                WriteState::Initial(key) => {
135                    this.layer.handshake(&key);
136
137                    *this.write_state = WriteState::Pending;
138                }
139
140                WriteState::Pending => {
141                    let data = if buf.len() as u64 > Self::MAX_IO_SIZE {
142                        &buf[..Self::MAX_IO_SIZE as usize]
143                    } else {
144                        buf
145                    };
146
147                    let mut iv = [0_u8; 16];
148                    rand::thread_rng().fill_bytes(&mut iv);
149
150                    *this.write_state = WriteState::Writing(data.len());
151                    this.layer.send(SecurePacket { iv, data });
152                }
153
154                WriteState::Writing(size) => {
155                    let write_buffer = &mut this.layer.write_buffer;
156
157                    loop {
158                        let slice = {
159                            let slices = write_buffer.as_slices();
160
161                            if !slices.0.is_empty() {
162                                slices.0
163                            } else {
164                                slices.1
165                            }
166                        };
167
168                        match this.inner.as_mut().poll_write(cx, slice)? {
169                            Poll::Ready(written) => {
170                                write_buffer.drain(..written);
171                            }
172
173                            Poll::Pending => {
174                                *this.write_state = WriteState::Writing(size);
175                                return Poll::Pending;
176                            }
177                        }
178
179                        if write_buffer.is_empty() {
180                            *this.write_state = WriteState::Pending;
181                            return Poll::Ready(Ok(size));
182                        }
183                    }
184                }
185
186                WriteState::Corrupted => unreachable!(),
187            }
188        }
189    }
190
191    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192        self.project().inner.poll_flush(cx)
193    }
194
195    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196        self.project().inner.poll_close(cx)
197    }
198}
199
200#[derive(Debug)]
201enum ReadState {
202    Pending,
203    Reading(Cursor<Box<[u8]>>),
204    PacketTooLarge,
205    Done,
206    Corrupted,
207}
208
209#[derive(Debug)]
210enum WriteState {
211    Initial(RsaPublicKey),
212    Pending,
213    Writing(usize),
214    Corrupted,
215}