1use std::{
12 pin::Pin,
13 task::{Context, Poll, ready},
14};
15
16use async_bincode::{
17 AsyncDestination,
18 tokio::{AsyncBincodeReader, AsyncBincodeWriter},
19};
20use futures::{Sink, Stream};
21use secrecy::ExposeSecret;
22use tokio::{
23 io::{AsyncRead, AsyncWrite, DuplexStream, ReadHalf, SimplexStream, WriteHalf},
24 net::{
25 TcpStream,
26 tcp::{OwnedReadHalf, OwnedWriteHalf},
27 },
28};
29
30use crate::{
31 base::{Constant, SharedSecret},
32 protocol::{ProtocolMessage, ProtocolMessageWrapper},
33 utils::{decrypt, encrypt},
34};
35
36macro_rules! pinned_inner {
40 ($self:ident) => {
41 Pin::new(&mut $self.inner)
42 };
43}
44
45macro_rules! take_pinned_inner {
47 ($self:ident) => {
48 Pin::new(&mut $self.get_mut().inner)
49 };
50}
51
52macro_rules! take_pinned_inner_read {
54 ($self:ident) => {
55 Pin::new(&mut $self.get_mut().inner_read)
56 };
57}
58
59macro_rules! take_pinned_inner_write {
61 ($self:ident) => {
62 Pin::new(&mut $self.get_mut().inner_write)
63 };
64}
65
66macro_rules! pinned_read_stream {
68 ($self:ident) => {
69 Pin::new($self.read_stream.as_mut().unwrap())
70 };
71}
72
73macro_rules! take_pinned_read_stream {
75 ($self:ident) => {
76 Pin::new($self.get_mut().read_stream.as_mut().unwrap())
77 };
78}
79
80pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
83pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
84
85pub struct BuffedStream<R, W> {
99 inner_read: BuffedStreamReadHalf<R>,
100 inner_write: BuffedStreamWriteHalf<W>,
101}
102
103impl<R, W> BuffedStream<R, W> {
106 pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
108 let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
109
110 self.inner_read.shared_secret = Some(secret_clone);
111 self.inner_read.read_stream = Some(SimplexStream::new_unsplit(Constant::BUFFER_SIZE));
112 self.inner_write.shared_secret = Some(shared_secret);
113
114 self
115 }
116
117 pub fn into_split(self) -> (BuffedStreamReadHalf<R>, BuffedStreamWriteHalf<W>) {
118 (self.inner_read, self.inner_write)
119 }
120}
121
122impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
123 fn from(stream: TcpStream) -> Self {
124 let (read, write) = stream.into_split();
125
126 Self {
127 inner_read: BuffedStreamReadHalf::new(read),
128 inner_write: BuffedStreamWriteHalf::new(write),
129 }
130 }
131}
132
133impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
134where
135 T: AsyncRead + AsyncWrite + Unpin,
136{
137 fn from(stream: T) -> Self {
138 let (read, write) = tokio::io::split(stream);
139
140 Self {
141 inner_read: BuffedStreamReadHalf::new(read),
142 inner_write: BuffedStreamWriteHalf::new(write),
143 }
144 }
145}
146
147impl<R, W> BuffedStream<R, W>
148where
149 R: AsyncRead + Unpin,
150 W: AsyncWrite + Unpin,
151{
152 pub fn new(inner_read: R, inner_write: W) -> Self {
153 Self {
154 inner_read: BuffedStreamReadHalf::new(inner_read),
155 inner_write: BuffedStreamWriteHalf::new(inner_write),
156 }
157 }
158}
159
160impl<R> BuffedStream<R, OwnedWriteHalf> {
161 pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
162 self.inner_write.inner.get_ref()
163 }
164}
165
166impl<W> BuffedStream<OwnedReadHalf, W> {
167 pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
168 self.inner_read.inner.get_ref()
169 }
170}
171
172impl<R, W> Stream for BuffedStream<R, W>
175where
176 R: AsyncRead + Unpin,
177 W: Unpin,
178{
179 type Item = std::io::Result<ProtocolMessage>;
180
181 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182 take_pinned_inner_read!(self).poll_next(cx)
183 }
184}
185
186impl<R, W> Sink<ProtocolMessage> for BuffedStream<R, W>
187where
188 R: Unpin,
189 W: AsyncWrite + Unpin,
190{
191 type Error = std::io::Error;
192
193 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
194 take_pinned_inner_write!(self).poll_ready(cx)
195 }
196
197 fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
198 take_pinned_inner_write!(self).start_send(item)
199 }
200
201 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202 futures::Sink::<ProtocolMessage>::poll_flush(take_pinned_inner_write!(self), cx)
203 }
204
205 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206 take_pinned_inner_write!(self).poll_close(cx)
207 }
208}
209
210impl<R, W> AsyncRead for BuffedStream<R, W>
211where
212 R: AsyncRead + Unpin,
213 W: Unpin,
214{
215 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
216 take_pinned_inner_read!(self).poll_read(cx, buf)
217 }
218}
219
220impl<R, W> AsyncWrite for BuffedStream<R, W>
221where
222 R: Unpin,
223 W: AsyncWrite + Unpin,
224{
225 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
226 take_pinned_inner_write!(self).poll_write(cx, buf)
227 }
228
229 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
230 AsyncWrite::poll_flush(take_pinned_inner_write!(self), cx)
231 }
232
233 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
234 take_pinned_inner_write!(self).poll_shutdown(cx)
235 }
236}
237
238pub struct BuffedStreamReadHalf<T> {
241 inner: AsyncBincodeReader<T, ProtocolMessageWrapper>,
242 shared_secret: Option<SharedSecret>,
243 read_stream: Option<SimplexStream>,
244}
245
246impl<T> BuffedStreamReadHalf<T>
247where
248 T: AsyncRead + Unpin,
249{
250 fn new(stream: T) -> Self {
251 Self {
252 inner: AsyncBincodeReader::from(stream),
253 shared_secret: None,
254 read_stream: None,
255 }
256 }
257}
258
259impl<T> Stream for BuffedStreamReadHalf<T>
260where
261 T: AsyncRead + Unpin,
262{
263 type Item = std::io::Result<ProtocolMessage>;
264
265 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266 let key = self.shared_secret.as_ref().map(|s| SharedSecret::init_with(|| *s.expose_secret()));
268
269 match take_pinned_inner!(self).poll_next(cx) {
270 Poll::Ready(Some(Ok(wrapper))) => match wrapper {
271 ProtocolMessageWrapper::Plain(message) => Poll::Ready(Some(Ok(message))),
272 ProtocolMessageWrapper::Encrypted { nonce, data } => {
273 let Some(key) = key else {
274 return Poll::Ready(Some(Err(std::io::Error::new(
275 std::io::ErrorKind::InvalidData,
276 "Received encrypted message without shared secret on this end",
277 ))));
278 };
279
280 let Ok(decrypted_data) = decrypt(&key, &data, &nonce) else {
281 return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))));
282 };
283
284 let Ok(message) = bincode::deserialize::<ProtocolMessage>(&decrypted_data) else {
285 return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize decrypted data"))));
286 };
287
288 Poll::Ready(Some(Ok(message)))
289 }
290 },
291 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(std::io::Error::new(
292 std::io::ErrorKind::InvalidData,
293 format!("Error on bincode reading during stream next: {}", e),
294 )))),
295 Poll::Ready(None) => Poll::Ready(None),
296 Poll::Pending => Poll::Pending,
297 }
298 }
299}
300
301impl<T> AsyncRead for BuffedStreamReadHalf<T>
302where
303 T: AsyncRead + Unpin,
304{
305 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
306 if self.shared_secret.is_none() {
312 return Pin::new(self.inner.get_mut()).poll_read(cx, buf);
313 }
314
315 let result = self.as_mut().poll_next(cx);
319
320 match result {
321 Poll::Ready(Some(Ok(message))) => {
322 let ProtocolMessage::Data(data) = message else {
323 return Poll::Ready(Err(std::io::Error::new(
324 std::io::ErrorKind::InvalidData,
325 "Received non-data message during `poll_read`, which shouldn't happen",
326 )));
327 };
328
329 let written = ready!(pinned_read_stream!(self).poll_write(cx, &data)?);
331
332 if written < data.len() {
334 return Poll::Ready(Err(std::io::Error::new(
335 std::io::ErrorKind::InvalidData,
336 "Decryption stream buffer overflow (shouldn't happen unless there is a mismatched buffer size between client and server)",
337 )));
338 }
339
340 ready!(pinned_read_stream!(self).poll_flush(cx)?);
342 }
343 Poll::Ready(Some(Err(e))) => {
344 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Error on bincode reading during pump: {}", e))));
347 }
348 Poll::Ready(None) => {
349 ready!(pinned_read_stream!(self).poll_shutdown(cx)?);
354 }
355 Poll::Pending => {
356 }
359 }
360
361 take_pinned_read_stream!(self).poll_read(cx, buf)
364 }
365}
366
367pub struct BuffedStreamWriteHalf<T> {
368 inner: AsyncBincodeWriter<T, ProtocolMessageWrapper, AsyncDestination>,
369 shared_secret: Option<SharedSecret>,
370}
371
372impl<T> BuffedStreamWriteHalf<T>
373where
374 T: AsyncWrite + Unpin,
375{
376 fn new(stream: T) -> Self {
377 Self {
378 inner: AsyncBincodeWriter::from(stream).for_async(),
379 shared_secret: None,
380 }
381 }
382}
383
384impl<T> Sink<ProtocolMessage> for BuffedStreamWriteHalf<T>
385where
386 T: AsyncWrite + Unpin,
387{
388 type Error = std::io::Error;
389
390 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
391 take_pinned_inner!(self)
392 .poll_ready(cx)
393 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
394 }
395
396 fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
397 if let Some(key) = self.shared_secret.as_ref() {
398 let encrypted_data = encrypt(
399 key,
400 &bincode::serialize(&item).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to serialize message"))?,
401 )
402 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Encryption failed"))?;
403
404 let message = ProtocolMessageWrapper::Encrypted {
405 nonce: encrypted_data.nonce,
406 data: encrypted_data.data,
407 };
408
409 take_pinned_inner!(self)
410 .start_send(message)
411 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write encrypted packet: {}", e)))?;
412
413 return Ok(());
414 }
415
416 take_pinned_inner!(self)
417 .start_send(ProtocolMessageWrapper::Plain(item))
418 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write plain packet: {}", e)))
419 }
420
421 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
422 futures::Sink::<ProtocolMessageWrapper>::poll_flush(take_pinned_inner!(self), cx)
423 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
424 }
425
426 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
427 take_pinned_inner!(self)
428 .poll_close(cx)
429 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
430 }
431}
432
433impl<T> AsyncWrite for BuffedStreamWriteHalf<T>
434where
435 T: AsyncWrite + Unpin,
436{
437 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
438 if self.shared_secret.is_none() {
444 return Pin::new(self.inner.get_mut()).poll_write(cx, buf);
445 }
446
447 let max_size = Constant::BUFFER_SIZE - Constant::ENCRYPTION_OVERHEAD;
449 let amt = std::cmp::min(buf.len(), max_size);
450 let buf = &buf[..amt];
451
452 let message = ProtocolMessage::Data(buf.to_vec());
453
454 self.as_mut()
456 .start_send(message)
457 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to write encrypted packet"))?;
458
459 Poll::Ready(Ok(buf.len()))
462 }
463
464 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
465 pinned_inner!(self)
466 .poll_flush(cx)
467 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
468 }
469
470 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
471 pinned_inner!(self)
472 .poll_close(cx)
473 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
474 }
475}
476
477#[cfg(test)]
480mod tests {
481 use futures::future::join_all;
482 use tokio::io::{AsyncReadExt, AsyncWriteExt};
483
484 use crate::utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption};
485
486 #[tokio::test]
487 async fn test_unencrypted_buffed_stream() {
488 let (mut client, mut server) = generate_test_duplex();
489
490 let data = b"Hello, world!";
491
492 client.write_all(data).await.unwrap();
493 client.shutdown().await.unwrap();
494
495 let mut received = Vec::new();
496 server.read_to_end(&mut received).await.unwrap();
497
498 assert_eq!(data, &received[..]);
499 }
500
501 #[tokio::test]
502 async fn test_e2e_encrypted_buffed_stream() {
503 let (mut client, mut server) = generate_test_duplex_with_encryption();
504
505 let data = b"Hello, world!";
506
507 client.write_all(data).await.unwrap();
508 client.shutdown().await.unwrap();
509
510 let mut received = Vec::new();
511 server.read_to_end(&mut received).await.unwrap();
512
513 assert_eq!(data, &received[..]);
514 }
515
516 #[tokio::test]
517 async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
518 let (mut client, mut server) = generate_test_duplex_with_encryption();
519
520 let data1 = b"Hello, world!";
521 let data2 = b"Hello, world!";
522
523 client.write_all(data1).await.unwrap();
524 client.write_all(data2).await.unwrap();
525 client.shutdown().await.unwrap();
526
527 let mut received = Vec::new();
528 server.read_to_end(&mut received).await.unwrap();
529
530 assert_eq!(data1.len() + data2.len(), received.len());
531 }
532
533 #[tokio::test]
534 async fn test_e2e_encrypted_buffed_stream_with_large_data() {
535 let (mut client, mut server) = generate_test_duplex_with_encryption();
536
537 let data = b"Hello, world!";
538 let data = data.repeat(10000);
539
540 let data_clone = data.clone();
541
542 let write_task = tokio::spawn(async move {
543 client.write_all(&data_clone).await.unwrap();
544 client.shutdown().await.unwrap();
545 });
546
547 let read_task = tokio::spawn(async move {
548 let mut received = Vec::new();
549 server.read_to_end(&mut received).await.unwrap();
550 assert_eq!(data.len(), received.len());
551 });
552
553 join_all([write_task, read_task]).await.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
554 }
555}