1use std::{
12 pin::Pin,
13 task::{Context, Poll, ready},
14};
15
16use async_bincode::{
17 AsyncDestination,
18 tokio::{AsyncBincodeReader, AsyncBincodeWriter},
19};
20use futures::{Sink, Stream};
21use secrecy::ExposeSecret;
22use tokio::{
23 io::{AsyncRead, AsyncWrite, DuplexStream, ReadHalf, SimplexStream, WriteHalf},
24 net::{
25 TcpStream,
26 tcp::{OwnedReadHalf, OwnedWriteHalf},
27 },
28};
29
30use crate::{
31 base::{Constant, SharedSecret},
32 protocol::{ProtocolMessage, ProtocolMessageWrapper},
33 utils::{decrypt, encrypt},
34};
35
36macro_rules! pinned_inner {
40 ($self:ident) => {
41 Pin::new(&mut $self.inner)
42 };
43}
44
45macro_rules! take_pinned_inner {
47 ($self:ident) => {
48 Pin::new(&mut $self.get_mut().inner)
49 };
50}
51
52macro_rules! take_pinned_inner_read {
54 ($self:ident) => {
55 Pin::new(&mut $self.get_mut().inner_read)
56 };
57}
58
59macro_rules! take_pinned_inner_write {
61 ($self:ident) => {
62 Pin::new(&mut $self.get_mut().inner_write)
63 };
64}
65
66macro_rules! pinned_read_stream {
68 ($self:ident) => {
69 Pin::new($self.read_stream.as_mut().unwrap())
70 };
71}
72
73macro_rules! take_pinned_read_stream {
75 ($self:ident) => {
76 Pin::new($self.get_mut().read_stream.as_mut().unwrap())
77 };
78}
79
80pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
83pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
84
85pub struct BuffedStream<R, W> {
99 inner_read: BuffedStreamReadHalf<R>,
101 inner_write: BuffedStreamWriteHalf<W>,
103}
104
105impl<R, W> BuffedStream<R, W> {
108 pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
110 let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
111
112 self.inner_read.shared_secret = Some(secret_clone);
113 self.inner_read.read_stream = Some(SimplexStream::new_unsplit(Constant::BUFFER_SIZE));
114 self.inner_write.shared_secret = Some(shared_secret);
115
116 self
117 }
118
119 pub fn into_split(self) -> (BuffedStreamReadHalf<R>, BuffedStreamWriteHalf<W>) {
124 (self.inner_read, self.inner_write)
125 }
126}
127
128impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
129 fn from(stream: TcpStream) -> Self {
130 let (read, write) = stream.into_split();
131
132 Self {
133 inner_read: BuffedStreamReadHalf::new(read),
134 inner_write: BuffedStreamWriteHalf::new(write),
135 }
136 }
137}
138
139impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
140where
141 T: AsyncRead + AsyncWrite + Unpin,
142{
143 fn from(stream: T) -> Self {
144 let (read, write) = tokio::io::split(stream);
145
146 Self {
147 inner_read: BuffedStreamReadHalf::new(read),
148 inner_write: BuffedStreamWriteHalf::new(write),
149 }
150 }
151}
152
153impl<R, W> BuffedStream<R, W>
154where
155 R: AsyncRead + Unpin,
156 W: AsyncWrite + Unpin,
157{
158 pub fn new(inner_read: R, inner_write: W) -> Self {
159 Self {
160 inner_read: BuffedStreamReadHalf::new(inner_read),
161 inner_write: BuffedStreamWriteHalf::new(inner_write),
162 }
163 }
164}
165
166impl<R> BuffedStream<R, OwnedWriteHalf> {
167 pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
168 self.inner_write.inner.get_ref()
169 }
170}
171
172impl<W> BuffedStream<OwnedReadHalf, W> {
173 pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
174 self.inner_read.inner.get_ref()
175 }
176}
177
178impl<R, W> Stream for BuffedStream<R, W>
181where
182 R: AsyncRead + Unpin,
183 W: Unpin,
184{
185 type Item = std::io::Result<ProtocolMessage>;
186
187 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
188 take_pinned_inner_read!(self).poll_next(cx)
189 }
190}
191
192impl<R, W> Sink<ProtocolMessage> for BuffedStream<R, W>
193where
194 R: Unpin,
195 W: AsyncWrite + Unpin,
196{
197 type Error = std::io::Error;
198
199 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200 take_pinned_inner_write!(self).poll_ready(cx)
201 }
202
203 fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
204 take_pinned_inner_write!(self).start_send(item)
205 }
206
207 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208 futures::Sink::<ProtocolMessage>::poll_flush(take_pinned_inner_write!(self), cx)
209 }
210
211 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212 take_pinned_inner_write!(self).poll_close(cx)
213 }
214}
215
216impl<R, W> AsyncRead for BuffedStream<R, W>
217where
218 R: AsyncRead + Unpin,
219 W: Unpin,
220{
221 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
222 take_pinned_inner_read!(self).poll_read(cx, buf)
223 }
224}
225
226impl<R, W> AsyncWrite for BuffedStream<R, W>
227where
228 R: Unpin,
229 W: AsyncWrite + Unpin,
230{
231 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
232 take_pinned_inner_write!(self).poll_write(cx, buf)
233 }
234
235 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
236 AsyncWrite::poll_flush(take_pinned_inner_write!(self), cx)
237 }
238
239 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
240 take_pinned_inner_write!(self).poll_shutdown(cx)
241 }
242}
243
244pub struct BuffedStreamReadHalf<T> {
247 inner: AsyncBincodeReader<T, ProtocolMessageWrapper>,
248 shared_secret: Option<SharedSecret>,
249 read_stream: Option<SimplexStream>,
250}
251
252impl<T> BuffedStreamReadHalf<T>
253where
254 T: AsyncRead + Unpin,
255{
256 fn new(stream: T) -> Self {
257 Self {
258 inner: AsyncBincodeReader::from(stream),
259 shared_secret: None,
260 read_stream: None,
261 }
262 }
263}
264
265impl<T> Stream for BuffedStreamReadHalf<T>
266where
267 T: AsyncRead + Unpin,
268{
269 type Item = std::io::Result<ProtocolMessage>;
270
271 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272 let key = self.shared_secret.as_ref().map(|s| SharedSecret::init_with(|| *s.expose_secret()));
274
275 match take_pinned_inner!(self).poll_next(cx) {
276 Poll::Ready(Some(Ok(wrapper))) => match wrapper {
277 ProtocolMessageWrapper::Plain(message) => Poll::Ready(Some(Ok(message))),
278 ProtocolMessageWrapper::Encrypted { nonce, data } => {
279 let Some(key) = key else {
280 return Poll::Ready(Some(Err(std::io::Error::new(
281 std::io::ErrorKind::InvalidData,
282 "Received encrypted message without shared secret on this end",
283 ))));
284 };
285
286 let Ok(decrypted_data) = decrypt(&key, &data, &nonce) else {
287 return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))));
288 };
289
290 let Ok(message) = bincode::deserialize::<ProtocolMessage>(&decrypted_data) else {
291 return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize decrypted data"))));
292 };
293
294 Poll::Ready(Some(Ok(message)))
295 }
296 },
297 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(std::io::Error::new(
298 std::io::ErrorKind::InvalidData,
299 format!("Error on bincode reading during stream next: {}", e),
300 )))),
301 Poll::Ready(None) => Poll::Ready(None),
302 Poll::Pending => Poll::Pending,
303 }
304 }
305}
306
307impl<T> AsyncRead for BuffedStreamReadHalf<T>
308where
309 T: AsyncRead + Unpin,
310{
311 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
312 if self.shared_secret.is_none() {
318 return Pin::new(self.inner.get_mut()).poll_read(cx, buf);
319 }
320
321 let result = self.as_mut().poll_next(cx);
325
326 match result {
327 Poll::Ready(Some(Ok(message))) => {
328 let ProtocolMessage::Data(data) = message else {
329 return Poll::Ready(Err(std::io::Error::new(
330 std::io::ErrorKind::InvalidData,
331 "Received non-data message during `poll_read`, which shouldn't happen",
332 )));
333 };
334
335 let written = ready!(pinned_read_stream!(self).poll_write(cx, &data)?);
337
338 if written < data.len() {
340 return Poll::Ready(Err(std::io::Error::new(
341 std::io::ErrorKind::InvalidData,
342 "Decryption stream buffer overflow (shouldn't happen unless there is a mismatched buffer size between client and server)",
343 )));
344 }
345
346 ready!(pinned_read_stream!(self).poll_flush(cx)?);
348 }
349 Poll::Ready(Some(Err(e))) => {
350 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Error on bincode reading during pump: {}", e))));
353 }
354 Poll::Ready(None) => {
355 ready!(pinned_read_stream!(self).poll_shutdown(cx)?);
360 }
361 Poll::Pending => {
362 }
365 }
366
367 take_pinned_read_stream!(self).poll_read(cx, buf)
370 }
371}
372
373pub struct BuffedStreamWriteHalf<T> {
374 inner: AsyncBincodeWriter<T, ProtocolMessageWrapper, AsyncDestination>,
375 shared_secret: Option<SharedSecret>,
376}
377
378impl<T> BuffedStreamWriteHalf<T>
379where
380 T: AsyncWrite + Unpin,
381{
382 fn new(stream: T) -> Self {
383 Self {
384 inner: AsyncBincodeWriter::from(stream).for_async(),
385 shared_secret: None,
386 }
387 }
388}
389
390impl<T> Sink<ProtocolMessage> for BuffedStreamWriteHalf<T>
391where
392 T: AsyncWrite + Unpin,
393{
394 type Error = std::io::Error;
395
396 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
397 take_pinned_inner!(self)
398 .poll_ready(cx)
399 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
400 }
401
402 fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
403 if let Some(key) = self.shared_secret.as_ref() {
404 let encrypted_data = encrypt(
405 key,
406 &bincode::serialize(&item).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to serialize message"))?,
407 )
408 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Encryption failed"))?;
409
410 let message = ProtocolMessageWrapper::Encrypted {
411 nonce: encrypted_data.nonce,
412 data: encrypted_data.data,
413 };
414
415 take_pinned_inner!(self)
416 .start_send(message)
417 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write encrypted packet: {}", e)))?;
418
419 return Ok(());
420 }
421
422 take_pinned_inner!(self)
423 .start_send(ProtocolMessageWrapper::Plain(item))
424 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write plain packet: {}", e)))
425 }
426
427 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428 futures::Sink::<ProtocolMessageWrapper>::poll_flush(take_pinned_inner!(self), cx)
429 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
430 }
431
432 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
433 take_pinned_inner!(self)
434 .poll_close(cx)
435 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
436 }
437}
438
439impl<T> AsyncWrite for BuffedStreamWriteHalf<T>
440where
441 T: AsyncWrite + Unpin,
442{
443 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
444 if self.shared_secret.is_none() {
450 return Pin::new(self.inner.get_mut()).poll_write(cx, buf);
451 }
452
453 let max_size = Constant::BUFFER_SIZE - Constant::ENCRYPTION_OVERHEAD;
455 let amt = std::cmp::min(buf.len(), max_size);
456 let buf = &buf[..amt];
457
458 let message = ProtocolMessage::Data(buf.to_vec());
459
460 self.as_mut()
462 .start_send(message)
463 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to write encrypted packet"))?;
464
465 Poll::Ready(Ok(buf.len()))
468 }
469
470 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
471 pinned_inner!(self)
472 .poll_flush(cx)
473 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
474 }
475
476 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
477 pinned_inner!(self)
478 .poll_close(cx)
479 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
480 }
481}
482
483#[cfg(test)]
486mod tests {
487 use futures::future::join_all;
488 use tokio::io::{AsyncReadExt, AsyncWriteExt};
489
490 use crate::utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption};
491
492 #[tokio::test]
493 async fn test_unencrypted_buffed_stream() {
494 let (mut client, mut server) = generate_test_duplex();
495
496 let data = b"Hello, world!";
497
498 client.write_all(data).await.unwrap();
499 client.shutdown().await.unwrap();
500
501 let mut received = Vec::new();
502 server.read_to_end(&mut received).await.unwrap();
503
504 assert_eq!(data, &received[..]);
505 }
506
507 #[tokio::test]
508 async fn test_e2e_encrypted_buffed_stream() {
509 let (mut client, mut server) = generate_test_duplex_with_encryption();
510
511 let data = b"Hello, world!";
512
513 client.write_all(data).await.unwrap();
514 client.shutdown().await.unwrap();
515
516 let mut received = Vec::new();
517 server.read_to_end(&mut received).await.unwrap();
518
519 assert_eq!(data, &received[..]);
520 }
521
522 #[tokio::test]
523 async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
524 let (mut client, mut server) = generate_test_duplex_with_encryption();
525
526 let data1 = b"Hello, world!";
527 let data2 = b"Hello, world!";
528
529 client.write_all(data1).await.unwrap();
530 client.write_all(data2).await.unwrap();
531 client.shutdown().await.unwrap();
532
533 let mut received = Vec::new();
534 server.read_to_end(&mut received).await.unwrap();
535
536 assert_eq!(data1.len() + data2.len(), received.len());
537 }
538
539 #[tokio::test]
540 async fn test_e2e_encrypted_buffed_stream_with_large_data() {
541 let (mut client, mut server) = generate_test_duplex_with_encryption();
542
543 let data = b"Hello, world!";
544 let data = data.repeat(10000);
545
546 let data_clone = data.clone();
547
548 let write_task = tokio::spawn(async move {
549 client.write_all(&data_clone).await.unwrap();
550 client.shutdown().await.unwrap();
551 });
552
553 let read_task = tokio::spawn(async move {
554 let mut received = Vec::new();
555 server.read_to_end(&mut received).await.unwrap();
556 assert_eq!(data.len(), received.len());
557 });
558
559 join_all([write_task, read_task]).await.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
560 }
561}