1use std::{
4 fmt::{self, Display, Formatter},
5 io,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use futures_core::{ready, Future};
13use pin_project_lite::pin_project;
14use tokio::{
15 io::{AsyncRead, AsyncWrite, ReadBuf},
16 time::{Instant, Sleep},
17};
18
19use crate::{
20 context::Ctx,
21 crypto::{
22 cipher::{Cipher, Method},
23 hkdf_sha1, Nonce,
24 },
25 net::{buf::OwnedReadBuf, constants::MAXIMUM_PAYLOAD_SIZE, poll_read_exact},
26};
27
28pub struct TcpStream<T> {
30 inner_stream: T,
31
32 cipher_method: Method,
33 cipher_key: Vec<u8>,
34
35 enc_cipher: Option<Cipher>,
36 dec_cipher: Option<Cipher>,
37
38 enc_nonce: Nonce,
39 dec_nonce: Nonce,
40
41 incoming_salt: Option<Vec<u8>>, read_state: ReadState,
44 write_state: WriteState,
45
46 in_payload: Vec<u8>, out_payload: Vec<u8>, read_buf: OwnedReadBuf,
50
51 ctx: Arc<Ctx>,
52}
53
54impl<T> TcpStream<T> {
55 pub fn new(inner_stream: T, cipher_method: Method, cipher_key: &[u8], ctx: Arc<Ctx>) -> Self {
57 TcpStream {
58 inner_stream,
59 cipher_method,
60 cipher_key: cipher_key.to_owned(),
61 enc_cipher: None,
62 dec_cipher: None,
63 enc_nonce: Nonce::new(cipher_method.iv_size()),
64 dec_nonce: Nonce::new(cipher_method.iv_size()),
65 incoming_salt: None,
66 read_state: ReadState::ReadSalt,
67 write_state: WriteState::WriteSalt,
68 in_payload: Vec::new(),
69 out_payload: Vec::new(),
70 read_buf: OwnedReadBuf::new(),
71 ctx: ctx.clone(),
72 }
73 }
74}
75
76impl<T> TcpStream<T> {
77 fn encrypt(&mut self, plaintext: &[u8]) -> io::Result<Vec<u8>> {
78 match self
79 .enc_cipher
80 .as_ref()
81 .expect("no salt received")
82 .encrypt(&self.enc_nonce, plaintext)
83 {
84 Ok(data) => {
85 self.enc_nonce.increment();
86 Ok(data)
87 }
88 Err(_) => Err(io::Error::new(io::ErrorKind::Other, Error::Encryption)),
89 }
90 }
91
92 fn decrypt(&mut self, ciphertext: &[u8]) -> io::Result<Vec<u8>> {
93 match self
94 .dec_cipher
95 .as_ref()
96 .expect("no salt received")
97 .decrypt(&self.dec_nonce, ciphertext)
98 {
99 Ok(data) => {
100 self.dec_nonce.increment();
101 Ok(data)
102 }
103 Err(_) => Err(io::Error::new(io::ErrorKind::Other, Error::Decryption)),
104 }
105 }
106}
107
108impl<T> TcpStream<T>
109where
110 T: AsyncRead + Unpin,
111{
112 fn poll_read_decrypt_helper(
113 &mut self,
114 cx: &mut Context<'_>,
115 buf: &mut ReadBuf<'_>,
116 ) -> Poll<io::Result<()>> {
117 let res = ready!(self.poll_read_decrypt(cx, buf));
118
119 if let Err(e) = res {
120 if e.kind() != io::ErrorKind::UnexpectedEof {
121 return Err(e).into();
122 }
123 }
124
125 Ok(()).into()
126 }
127
128 fn poll_read_decrypt(
129 &mut self,
130 cx: &mut Context<'_>,
131 buf: &mut ReadBuf<'_>,
132 ) -> Poll<io::Result<()>> {
133 loop {
134 match self.read_state {
135 ReadState::ReadSalt => {
136 ready!(self.poll_read_salt(cx))?;
137 self.read_state = ReadState::ReadLength;
138 }
139 ReadState::ReadLength => {
140 let len = ready!(self.poll_read_length(cx))?;
141 self.read_state = ReadState::ReadPayload(len);
142 }
143 ReadState::ReadPayload(payload_len) => {
144 self.in_payload = ready!(self.poll_read_payload(cx, payload_len))?;
145 self.read_state = ReadState::ReadPayloadOut;
146 }
147 ReadState::ReadPayloadOut => {
148 let buf_remaining = buf.remaining();
149 let payload_len = self.in_payload.len();
150
151 if buf_remaining >= payload_len {
152 buf.put_slice(&self.in_payload);
153 self.read_state = ReadState::ReadLength;
154 } else {
155 let (data, remaining) = self.in_payload.split_at(buf_remaining);
156 buf.put_slice(data);
157 self.in_payload = remaining.to_owned();
158 self.read_state = ReadState::ReadPayloadOut;
159 }
160
161 return Ok(()).into();
162 }
163 }
164 }
165 }
166
167 fn poll_read_salt(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168 if self.dec_cipher.is_none() {
169 let mut salt = vec![0u8; self.cipher_method.salt_size()];
170 ready!(poll_read_exact(
171 &mut self.inner_stream,
172 &mut self.read_buf,
173 cx,
174 &mut salt
175 ))?;
176
177 self.incoming_salt = Some(salt.clone());
178
179 let mut subkey = vec![0u8; self.cipher_method.key_size()];
180 hkdf_sha1(&self.cipher_key, &salt, &mut subkey);
181
182 let cipher = Cipher::new(self.cipher_method, &mut subkey);
183 self.dec_cipher.replace(cipher);
184 }
185
186 Ok(()).into()
187 }
188
189 fn poll_read_length(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
190 let mut buf = vec![0u8; 2 + self.cipher_method.tag_size()];
191 ready!(poll_read_exact(
192 &mut self.inner_stream,
193 &mut self.read_buf,
194 cx,
195 &mut buf
196 ))?;
197
198 let len = self.decrypt(&buf)?;
199 let len = [len[0], len[1]];
200 let payload_len = (u16::from_be_bytes(len) as usize) & MAXIMUM_PAYLOAD_SIZE;
201
202 if let Some(salt) = self.incoming_salt.take() {
203 if !self.ctx.check_replay(&salt) {
204 return Err(io::Error::new(io::ErrorKind::Other, Error::DuplicateSalt)).into();
205 }
206 }
207
208 Ok(payload_len).into()
209 }
210
211 fn poll_read_payload(
212 &mut self,
213 cx: &mut Context<'_>,
214 payload_len: usize,
215 ) -> Poll<io::Result<Vec<u8>>> {
216 let mut buf = vec![0u8; payload_len + self.cipher_method.tag_size()];
217 ready!(poll_read_exact(
218 &mut self.inner_stream,
219 &mut self.read_buf,
220 cx,
221 &mut buf
222 ))?;
223 let payload = self.decrypt(&buf)?;
224
225 Ok(payload).into()
226 }
227}
228
229impl<T> TcpStream<T>
230where
231 T: AsyncWrite + Unpin,
232{
233 fn poll_write_encrypt(
234 &mut self,
235 cx: &mut Context<'_>,
236 payload: &[u8],
237 ) -> Poll<io::Result<usize>> {
238 loop {
239 match self.write_state {
240 WriteState::WriteSalt => {
241 ready!(self.poll_write_salt(cx))?;
242 self.write_state = WriteState::WriteLength;
243 }
244 WriteState::WriteLength => {
245 ready!(self.poll_write_length(cx, payload))?;
246 self.write_state = WriteState::WritePayload;
247 }
248 WriteState::WritePayload => {
249 ready!(self.poll_write_payload(cx, payload))?;
250 self.write_state = WriteState::WritePayloadOut;
251 }
252 WriteState::WritePayloadOut => {
253 while !self.out_payload.is_empty() {
254 let nwrite = ready!(
255 Pin::new(&mut self.inner_stream).poll_write(cx, &self.out_payload)
256 )?;
257
258 self.out_payload = self.out_payload[nwrite..].to_vec();
259 }
260
261 self.write_state = WriteState::WriteLength;
262
263 let length = usize::min(payload.len(), MAXIMUM_PAYLOAD_SIZE);
264 return Ok(length).into();
265 }
266 }
267 }
268 }
269
270 fn poll_write_salt(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
271 use rand::prelude::*;
272
273 if self.enc_cipher.is_none() {
274 let mut salt = vec![0u8; self.cipher_method.salt_size()];
275 let mut rng = StdRng::from_entropy();
276 rng.fill_bytes(&mut salt);
277
278 let mut subkey = vec![0u8; self.cipher_method.key_size()];
279 hkdf_sha1(&self.cipher_key, &salt, &mut subkey);
280
281 let cipher = Cipher::new(self.cipher_method, &mut subkey);
282 self.enc_cipher.replace(cipher);
283
284 self.out_payload.append(&mut salt);
285 }
286
287 Ok(()).into()
288 }
289
290 fn poll_write_length(&mut self, _cx: &mut Context<'_>, payload: &[u8]) -> Poll<io::Result<()>> {
291 let length = usize::min(payload.len(), MAXIMUM_PAYLOAD_SIZE);
292 let len = (length as u16).to_be_bytes();
293
294 let mut buf = self.encrypt(&len)?;
295 self.out_payload.append(&mut buf);
296
297 Ok(()).into()
298 }
299
300 fn poll_write_payload(
301 &mut self,
302 _cx: &mut Context<'_>,
303 payload: &[u8],
304 ) -> Poll<io::Result<()>> {
305 let length = usize::min(payload.len(), MAXIMUM_PAYLOAD_SIZE);
306
307 let mut buf = self.encrypt(&payload[..length])?;
308 self.out_payload.append(&mut buf);
309
310 Ok(()).into()
311 }
312}
313
314impl<T> AsyncRead for TcpStream<T>
315where
316 T: AsyncRead + Unpin,
317{
318 fn poll_read(
319 self: Pin<&mut Self>,
320 cx: &mut Context<'_>,
321 buf: &mut ReadBuf<'_>,
322 ) -> Poll<io::Result<()>> {
323 self.get_mut().poll_read_decrypt_helper(cx, buf)
324 }
325}
326
327impl<T> AsyncWrite for TcpStream<T>
328where
329 T: AsyncWrite + Unpin,
330{
331 fn poll_write(
332 self: Pin<&mut Self>,
333 cx: &mut Context<'_>,
334 buf: &[u8],
335 ) -> Poll<io::Result<usize>> {
336 self.get_mut().poll_write_encrypt(cx, buf)
337 }
338
339 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
340 let inner_stream = &mut self.get_mut().inner_stream;
341 Pin::new(inner_stream).poll_flush(cx)
342 }
343
344 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
345 let inner_stream = &mut self.get_mut().inner_stream;
346 Pin::new(inner_stream).poll_shutdown(cx)
347 }
348}
349
350pin_project! {
351 pub struct TimeoutStream<T> {
355 #[pin]
356 inner_stream: T,
357
358 duration: Duration,
359 sleep: Option<Sleep>,
360 }
361}
362
363impl<T> TimeoutStream<T> {
364 pub fn new(inner_stream: T, duration: Duration) -> Self {
366 TimeoutStream {
367 inner_stream,
368 duration,
369 sleep: None,
370 }
371 }
372}
373
374impl<T> TimeoutStream<T> {
375 fn check_timeout(sleep: Pin<&mut Sleep>, cx: &mut Context<'_>) -> io::Result<()> {
376 match sleep.poll(cx) {
377 Poll::Ready(_) => Err(io::ErrorKind::TimedOut.into()),
378 Poll::Pending => Ok(()),
379 }
380 }
381
382 fn reset_timeout(sleep: Pin<&mut Sleep>, duration: Duration) {
383 sleep.reset(Instant::now() + duration);
384 }
385}
386
387impl<T> AsyncRead for TimeoutStream<T>
388where
389 T: AsyncRead,
390{
391 fn poll_read(
392 self: Pin<&mut Self>,
393 cx: &mut Context<'_>,
394 buf: &mut ReadBuf<'_>,
395 ) -> Poll<io::Result<()>> {
396 let this = self.project();
397 let ret = this.inner_stream.poll_read(cx, buf);
398
399 let sleep = unsafe {
400 Pin::new_unchecked(
401 this.sleep
402 .get_or_insert(tokio::time::sleep_until(Instant::now() + *this.duration)),
403 )
404 };
405
406 match ret {
407 Poll::Ready(_) => Self::reset_timeout(sleep, *this.duration),
408 Poll::Pending => Self::check_timeout(sleep, cx)?,
409 }
410
411 ret
412 }
413}
414
415impl<T> AsyncWrite for TimeoutStream<T>
416where
417 T: AsyncWrite,
418{
419 fn poll_write(
420 self: Pin<&mut Self>,
421 cx: &mut Context<'_>,
422 buf: &[u8],
423 ) -> Poll<Result<usize, io::Error>> {
424 let this = self.project();
425 let ret = this.inner_stream.poll_write(cx, buf);
426
427 let sleep = unsafe {
428 Pin::new_unchecked(
429 this.sleep
430 .get_or_insert(tokio::time::sleep_until(Instant::now() + *this.duration)),
431 )
432 };
433
434 match ret {
435 Poll::Ready(_) => Self::reset_timeout(sleep, *this.duration),
436 Poll::Pending => Self::check_timeout(sleep, cx)?,
437 }
438
439 ret
440 }
441
442 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
443 let this = self.project();
444 this.inner_stream.poll_flush(cx)
445 }
446
447 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
448 let this = self.project();
449 this.inner_stream.poll_shutdown(cx)
450 }
451}
452
453#[derive(Debug)]
455pub enum Error {
456 Encryption,
458
459 Decryption,
461
462 DuplicateSalt,
464}
465
466impl Display for Error {
467 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
468 match self {
469 Error::Encryption => write!(f, "encryption error"),
470 Error::Decryption => write!(f, "decryption error"),
471 Error::DuplicateSalt => write!(f, "duplicate salt received, possible replay attack"),
472 }
473 }
474}
475
476impl std::error::Error for Error {}
477
478enum ReadState {
479 ReadSalt,
480 ReadLength,
481 ReadPayload(usize),
482 ReadPayloadOut,
483}
484
485enum WriteState {
486 WriteSalt,
487 WriteLength,
488 WritePayload,
489 WritePayloadOut,
490}
491
492