1use anyhow::{Context as _, anyhow};
12use bincode::Encode;
13use bytes::{BufMut, Bytes, BytesMut};
14use secrecy::ExposeSecret;
15use tokio::{
16 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
17 net::{
18 TcpStream,
19 tcp::{OwnedReadHalf, OwnedWriteHalf},
20 },
21};
22use tracing::{error, warn};
23
24use crate::{
25 base::{Constant, Res, SharedSecret, Void},
26 protocol::{BincodeReceive, BincodeSend, ProtocolMessage, ProtocolMessageGuard, ProtocolMessageGuardBuilder},
27 utils::{decrypt_in_place, encrypt_into},
28};
29
30pub trait BincodeSplit {
34 type ReadHalf: BincodeReceive;
35 type WriteHalf: BincodeSend;
36
37 fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf);
42
43 fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf);
48}
49
50pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
54pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
56
57pub struct BuffedStream<R, W> {
68 inner_read: BuffedStreamReadHalf<R>,
70 inner_write: BuffedStreamWriteHalf<W>,
72}
73
74impl<R, W> BuffedStream<R, W> {
77 pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
79 let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
80
81 self.inner_read.shared_secret = Some(secret_clone);
82 self.inner_write.shared_secret = Some(shared_secret);
83
84 self
85 }
86}
87
88impl<R, W> BincodeSplit for BuffedStream<R, W>
89where
90 R: AsyncRead + Unpin,
91 W: AsyncWrite + Unpin,
92{
93 type ReadHalf = BuffedStreamReadHalf<R>;
94 type WriteHalf = BuffedStreamWriteHalf<W>;
95
96 fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf) {
97 (self.inner_read, self.inner_write)
98 }
99
100 fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf) {
101 (&mut self.inner_read, &mut self.inner_write)
102 }
103}
104
105impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
106 fn from(stream: TcpStream) -> Self {
107 let (read, write) = stream.into_split();
108
109 Self {
110 inner_read: BuffedStreamReadHalf::new(read),
111 inner_write: BuffedStreamWriteHalf::new(write),
112 }
113 }
114}
115
116impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
117where
118 T: AsyncRead + AsyncWrite + Unpin,
119{
120 fn from(stream: T) -> Self {
121 let (read, write) = tokio::io::split(stream);
122
123 Self {
124 inner_read: BuffedStreamReadHalf::new(read),
125 inner_write: BuffedStreamWriteHalf::new(write),
126 }
127 }
128}
129
130impl<R, W> BuffedStream<R, W>
131where
132 R: AsyncRead + Unpin,
133 W: AsyncWrite + Unpin,
134{
135 pub fn from_splits(inner_read: R, inner_write: W) -> Self {
136 Self {
137 inner_read: BuffedStreamReadHalf::new(inner_read),
138 inner_write: BuffedStreamWriteHalf::new(inner_write),
139 }
140 }
141}
142
143impl<R> BuffedStream<R, OwnedWriteHalf> {
144 pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
145 &self.inner_write.inner
146 }
147
148 pub fn as_inner_tcp_write_mut(&mut self) -> &mut OwnedWriteHalf {
149 &mut self.inner_write.inner
150 }
151}
152
153impl<W> BuffedStream<OwnedReadHalf, W> {
154 pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
155 &self.inner_read.inner
156 }
157
158 pub fn as_inner_tcp_read_mut(&mut self) -> &mut OwnedReadHalf {
159 &mut self.inner_read.inner
160 }
161}
162
163impl BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
164 pub fn take(self) -> Res<TcpStream> {
165 let read = self.inner_read.take();
166 let write = self.inner_write.take();
167
168 read.reunite(write).context("Failed to reunite read and write halves")
169 }
170}
171
172impl<R, W> BincodeSend for BuffedStream<R, W>
175where
176 R: Unpin,
177 W: AsyncWrite + Unpin,
178{
179 async fn push<E>(&mut self, message: E) -> Void
180 where
181 E: Encode,
182 {
183 self.inner_write.push(message).await?;
184
185 Ok(())
186 }
187
188 async fn close(&mut self) -> Void {
189 self.inner_write.close().await?;
190
191 Ok(())
192 }
193}
194
195impl<R, W> BincodeReceive for BuffedStream<R, W>
196where
197 R: AsyncRead + Unpin,
198 W: Unpin,
199{
200 async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
201 self.inner_read.pull().await
202 }
203}
204
205pub struct BuffedStreamReadHalf<T> {
209 inner: T,
210 shared_secret: Option<SharedSecret>,
211 buffer: BytesMut,
212 decryption_buffer: BytesMut,
213}
214
215impl<T> BuffedStreamReadHalf<T>
216where
217 T: AsyncRead + Unpin,
218{
219 fn new(async_read: T) -> Self {
221 Self {
222 inner: async_read,
223 shared_secret: None,
224 buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
225 decryption_buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
226 }
227 }
228
229 fn take(self) -> T {
231 if !self.buffer.is_empty() {
232 warn!("Buffer was not empty when taking the stream");
233 }
234
235 self.inner
236 }
237}
238
239impl<T> BincodeReceive for BuffedStreamReadHalf<T>
240where
241 T: AsyncRead + Unpin,
242{
243 async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
244 self.buffer.clear();
255 self.decryption_buffer.clear();
256 self.buffer.reserve(Constant::BUFFER_SIZE);
257 self.decryption_buffer.reserve(Constant::BUFFER_SIZE);
258
259 let is_encrypted = match self.inner.read_u8().await {
262 Ok(1) => true,
263 Ok(_) => false,
264 Err(e) => {
265 if e.kind() == std::io::ErrorKind::UnexpectedEof {
266 return Ok(ProtocolMessageGuardBuilder {
267 buffer: Bytes::new(),
268 inner_builder: |_| ProtocolMessage::Shutdown,
269 }
270 .build());
271 } else {
272 return Err(anyhow!("Failed to read encryption flag: {}", e));
273 }
274 }
275 };
276
277 let maybe_encryption_data = if is_encrypted {
280 let mut nonce = [0; Constant::SHARED_SECRET_NONCE_SIZE];
281 self.inner.read_exact(&mut nonce).await.context("Failed to read nonce")?;
282
283 Some(nonce)
284 } else {
285 None
286 };
287
288 let message_size = self.inner.read_u64().await.context("Failed to read size")? as usize;
291
292 if message_size == 0 {
295 let guard = ProtocolMessageGuardBuilder {
296 buffer: Bytes::new(),
297 inner_builder: |_| ProtocolMessage::Shutdown,
298 }
299 .build();
300
301 return Ok(guard);
302 }
303
304 if message_size > Constant::BUFFER_SIZE {
305 return Err(anyhow!("Message size is too large for the buffer during pull"));
306 }
307
308 unsafe { self.buffer.set_len(message_size) };
314 let n = self.inner.read_exact(&mut self.buffer).await?;
315
316 if message_size != n {
317 return Err(anyhow!("Failed to read message: expected {} bytes, got {}", message_size, n));
318 }
319
320 let data_buffer = self.buffer.split().freeze();
323
324 let data_buffer = if let Some(nonce) = maybe_encryption_data {
327 let Some(key) = self.shared_secret.as_ref() else {
328 return Err(anyhow!("Shared secret is not set when receiving encrypted message"));
329 };
330
331 self.decryption_buffer.put(data_buffer);
332
333 decrypt_in_place(key, &nonce, &mut self.decryption_buffer).context("Failed to decrypt message")?;
335
336 self.decryption_buffer.split().freeze()
337 } else {
338 data_buffer
339 };
340
341 Ok(ProtocolMessageGuardBuilder {
344 buffer: data_buffer,
345 inner_builder: |data| match bincode::borrow_decode_from_slice::<ProtocolMessage<'_>, _>(data, Constant::BINCODE_CONFIG) {
346 Ok((message, n)) => {
347 if n != data.len() {
348 error!("Failed to decrypt message: expected {} bytes, got {}", data.len(), n);
349 return ProtocolMessage::Shutdown;
350 }
351
352 message
353 }
354 Err(e) => {
355 error!("Failed to decode message: {}", e);
356 ProtocolMessage::Shutdown
357 }
358 },
359 }
360 .build())
361 }
362}
363
364pub struct BuffedStreamWriteHalf<T> {
366 inner: T,
367 shared_secret: Option<SharedSecret>,
368 buffer: BytesMut,
369}
370
371impl<T> BuffedStreamWriteHalf<T>
372where
373 T: AsyncWrite + Unpin,
374{
375 fn new(async_write: T) -> Self {
377 Self {
378 inner: async_write,
379 shared_secret: None,
380 buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
381 }
382 }
383
384 fn take(self) -> T {
386 self.inner
387 }
388}
389
390impl<T> BincodeSend for BuffedStreamWriteHalf<T>
391where
392 T: AsyncWrite + Unpin,
393{
394 async fn push<E>(&mut self, message: E) -> Void
395 where
396 E: Encode,
397 {
398 self.buffer.clear();
405 assert!(self.buffer.try_reclaim(2 * Constant::BUFFER_SIZE));
406
407 unsafe { self.buffer.set_len(Constant::BUFFER_SIZE) };
413 let n = bincode::encode_into_slice(message, &mut self.buffer, Constant::BINCODE_CONFIG)?;
414 unsafe { self.buffer.set_len(n) };
415
416 let maybe_nonce = if let Some(key) = self.shared_secret.as_ref() {
419 let nonce = encrypt_into(key, &mut self.buffer).context("Encryption failed")?;
421
422 Some(nonce)
423 } else {
424 None
425 };
426
427 let data_length = self.buffer.len();
430
431 if data_length == 0 {
432 return Err(anyhow!("Buffer is empty"));
433 }
434
435 if data_length > Constant::BUFFER_SIZE {
436 return Err(anyhow!("Buffer is too large"));
437 }
438
439 if let Some(nonce) = maybe_nonce {
442 self.inner.write_u8(1).await.context("Failed to write encryption flag")?;
443 self.inner.write_all(&nonce).await.context("Failed to write nonce")?;
444 } else {
445 self.inner.write_u8(0).await.context("Failed to write encryption flag")?;
446 }
447
448 self.inner.write_u64(data_length as u64).await.context("Failed to write size")?;
451
452 self.inner.write_all(&self.buffer).await?;
455
456 self.inner.flush().await.context("Failed to flush stream")?;
459
460 Ok(())
461 }
462
463 async fn close(&mut self) -> Void {
464 self.inner.shutdown().await.context("Failed to close stream")?;
465
466 Ok(())
467 }
468}
469
470#[cfg(test)]
473mod tests {
474 use crate::{
475 protocol::{BincodeReceive, BincodeSend, ProtocolMessage},
476 utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption},
477 };
478
479 #[tokio::test]
480 async fn test_unencrypted_buffed_stream() {
481 let (mut client, mut server) = generate_test_duplex();
482
483 let data = b"Hello, world!";
484
485 client.push(ProtocolMessage::Data(data)).await.unwrap();
486 client.close().await.unwrap();
487
488 let guard = server.pull().await.unwrap();
489 let ProtocolMessage::Data(received) = *guard.message() else {
490 panic!("Failed to receive message");
491 };
492
493 assert_eq!(data, received);
494 }
495
496 #[tokio::test]
497 async fn test_e2e_encrypted_buffed_stream() {
498 let (mut client, mut server) = generate_test_duplex_with_encryption();
499
500 let data = b"Hello, world!";
501
502 client.push(ProtocolMessage::Data(data)).await.unwrap();
503 client.close().await.unwrap();
504
505 let guard = server.pull().await.unwrap();
506 let ProtocolMessage::Data(received) = *guard.message() else {
507 panic!("Failed to receive message");
508 };
509
510 assert_eq!(data, received);
511 }
512
513 #[tokio::test]
514 async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
515 let (mut client, mut server) = generate_test_duplex_with_encryption();
516
517 let data1 = b"Hello, world!";
518 let data2 = b"Hello, wold!";
519
520 client.push(ProtocolMessage::Data(data1)).await.unwrap();
521 client.push(ProtocolMessage::Data(data2)).await.unwrap();
522 client.close().await.unwrap();
523
524 let guard = server.pull().await.unwrap();
525 let ProtocolMessage::Data(received) = *guard.message() else {
526 panic!("Failed to receive message");
527 };
528 assert_eq!(data1, received);
529
530 let guard = server.pull().await.unwrap();
531 let ProtocolMessage::Data(received) = *guard.message() else {
532 panic!("Failed to receive message");
533 };
534 assert_eq!(data2, received);
535 }
536}