futures_loco_protocol/
secure.rs1use std::{
2 io::{self, Cursor, ErrorKind, Read},
3 mem,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use futures_io::{AsyncRead, AsyncWrite};
9use loco_protocol::secure::{
10 client::{LocoClientSecureLayer, ReadState as LayerReadState, rsa::RsaPublicKey},
11 SecurePacket,
12};
13use rand::RngCore;
14
15pub use loco_protocol::secure::client::rsa;
16
17pin_project_lite::pin_project! {
18 #[derive(Debug)]
19 pub struct LocoSecureStream<T> {
20 read_state: ReadState,
21 write_state: WriteState,
22
23 layer: LocoClientSecureLayer,
24
25 #[pin]
26 inner: T,
27 }
28}
29
30impl<T> LocoSecureStream<T> {
31 pub const MAX_IO_SIZE: u64 = 16 * 1024 * 1024;
32
33 pub fn new(rsa_key: RsaPublicKey, inner: T) -> Self {
34 let mut key = [0_u8; 16];
35 rand::thread_rng().fill_bytes(&mut key);
36
37 Self {
38 read_state: ReadState::Pending,
39 write_state: WriteState::Initial(rsa_key),
40
41 layer: LocoClientSecureLayer::new(key),
42
43 inner,
44 }
45 }
46
47 pub fn inner(&self) -> &T {
48 &self.inner
49 }
50
51 pub fn inner_mut(&mut self) -> &mut T {
52 &mut self.inner
53 }
54
55 pub fn into_inner(self) -> T {
56 self.inner
57 }
58}
59
60impl<T: AsyncRead> AsyncRead for LocoSecureStream<T> {
61 fn poll_read(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &mut [u8],
65 ) -> Poll<io::Result<usize>> {
66 let mut this = self.project();
67
68 loop {
69 match mem::replace(this.read_state, ReadState::Corrupted) {
70 ReadState::Pending => {
71 if let Some(packet) = this.layer.read() {
72 *this.read_state = ReadState::Reading(Cursor::new(packet.data));
73 } else {
74 if let LayerReadState::Header(header) = this.layer.read_state() {
75 if header.size as u64 - 16 > Self::MAX_IO_SIZE {
76 *this.read_state = ReadState::PacketTooLarge;
77 continue;
78 }
79 }
80
81 let mut read_buf = [0_u8; 1024];
82
83 *this.read_state = ReadState::Pending;
84
85 let read = ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
86 if read == 0 {
87 *this.read_state = ReadState::Done;
88 continue;
89 }
90
91 this.layer.read_buffer.extend(&read_buf[..read]);
92 }
93 }
94
95 ReadState::Reading(mut cursor) => {
96 let read = cursor.read(buf)?;
97
98 *this.read_state = if cursor.position() as usize == cursor.get_ref().len() {
99 ReadState::Pending
100 } else {
101 ReadState::Reading(cursor)
102 };
103
104 break Poll::Ready(Ok(read));
105 }
106
107 ReadState::PacketTooLarge => {
108 *this.read_state = ReadState::PacketTooLarge;
109
110 break Poll::Ready(Err(io::Error::new(
111 ErrorKind::InvalidData,
112 "packet is too large",
113 )));
114 }
115
116 ReadState::Done => break Poll::Ready(Err(ErrorKind::UnexpectedEof.into())),
117
118 ReadState::Corrupted => unreachable!(),
119 }
120 }
121 }
122}
123
124impl<T: AsyncWrite> AsyncWrite for LocoSecureStream<T> {
125 fn poll_write(
126 self: Pin<&mut Self>,
127 cx: &mut Context<'_>,
128 buf: &[u8],
129 ) -> Poll<io::Result<usize>> {
130 let mut this = self.project();
131
132 loop {
133 match mem::replace(this.write_state, WriteState::Corrupted) {
134 WriteState::Initial(key) => {
135 this.layer.handshake(&key);
136
137 *this.write_state = WriteState::Pending;
138 }
139
140 WriteState::Pending => {
141 let data = if buf.len() as u64 > Self::MAX_IO_SIZE {
142 &buf[..Self::MAX_IO_SIZE as usize]
143 } else {
144 buf
145 };
146
147 let mut iv = [0_u8; 16];
148 rand::thread_rng().fill_bytes(&mut iv);
149
150 *this.write_state = WriteState::Writing(data.len());
151 this.layer.send(SecurePacket { iv, data });
152 }
153
154 WriteState::Writing(size) => {
155 let write_buffer = &mut this.layer.write_buffer;
156
157 loop {
158 let slice = {
159 let slices = write_buffer.as_slices();
160
161 if !slices.0.is_empty() {
162 slices.0
163 } else {
164 slices.1
165 }
166 };
167
168 match this.inner.as_mut().poll_write(cx, slice)? {
169 Poll::Ready(written) => {
170 write_buffer.drain(..written);
171 }
172
173 Poll::Pending => {
174 *this.write_state = WriteState::Writing(size);
175 return Poll::Pending;
176 }
177 }
178
179 if write_buffer.is_empty() {
180 *this.write_state = WriteState::Pending;
181 return Poll::Ready(Ok(size));
182 }
183 }
184 }
185
186 WriteState::Corrupted => unreachable!(),
187 }
188 }
189 }
190
191 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192 self.project().inner.poll_flush(cx)
193 }
194
195 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196 self.project().inner.poll_close(cx)
197 }
198}
199
200#[derive(Debug)]
201enum ReadState {
202 Pending,
203 Reading(Cursor<Box<[u8]>>),
204 PacketTooLarge,
205 Done,
206 Corrupted,
207}
208
209#[derive(Debug)]
210enum WriteState {
211 Initial(RsaPublicKey),
212 Pending,
213 Writing(usize),
214 Corrupted,
215}