1use anyhow::{Context as _, anyhow};
12use bincode::Encode;
13use bytes::{BufMut, Bytes, BytesMut};
14use secrecy::ExposeSecret;
15use tokio::{
16 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
17 net::{
18 TcpStream,
19 tcp::{OwnedReadHalf, OwnedWriteHalf},
20 },
21};
22use tracing::{error, warn};
23
24use crate::{
25 base::{Constant, Res, SharedSecret, Void},
26 protocol::{BincodeReceive, BincodeSend, ProtocolMessage, ProtocolMessageGuard, ProtocolMessageGuardBuilder},
27 utils::{decrypt_in_place, encrypt_into},
28};
29
30pub trait BincodeSplit {
34 type ReadHalf: BincodeReceive;
35 type WriteHalf: BincodeSend;
36
37 fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf);
42
43 fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf);
48}
49
50pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
54pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
56
57pub struct BuffedStream<R, W> {
68 inner_read: BuffedStreamReadHalf<R>,
70 inner_write: BuffedStreamWriteHalf<W>,
72}
73
74impl<R, W> BuffedStream<R, W> {
77 pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
79 let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
80
81 self.inner_read.shared_secret = Some(secret_clone);
82 self.inner_write.shared_secret = Some(shared_secret);
83
84 self
85 }
86}
87
88impl<R, W> BincodeSplit for BuffedStream<R, W>
89where
90 R: AsyncRead + Unpin,
91 W: AsyncWrite + Unpin,
92{
93 type ReadHalf = BuffedStreamReadHalf<R>;
94 type WriteHalf = BuffedStreamWriteHalf<W>;
95
96 fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf) {
97 (self.inner_read, self.inner_write)
98 }
99
100 fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf) {
101 (&mut self.inner_read, &mut self.inner_write)
102 }
103}
104
105impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
106 fn from(stream: TcpStream) -> Self {
107 let (read, write) = stream.into_split();
108
109 Self {
110 inner_read: BuffedStreamReadHalf::new(read),
111 inner_write: BuffedStreamWriteHalf::new(write),
112 }
113 }
114}
115
116impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
117where
118 T: AsyncRead + AsyncWrite + Unpin,
119{
120 fn from(stream: T) -> Self {
121 let (read, write) = tokio::io::split(stream);
122
123 Self {
124 inner_read: BuffedStreamReadHalf::new(read),
125 inner_write: BuffedStreamWriteHalf::new(write),
126 }
127 }
128}
129
130impl<R, W> BuffedStream<R, W>
131where
132 R: AsyncRead + Unpin,
133 W: AsyncWrite + Unpin,
134{
135 pub fn from_splits(inner_read: R, inner_write: W) -> Self {
136 Self {
137 inner_read: BuffedStreamReadHalf::new(inner_read),
138 inner_write: BuffedStreamWriteHalf::new(inner_write),
139 }
140 }
141}
142
143impl<R> BuffedStream<R, OwnedWriteHalf> {
144 pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
145 &self.inner_write.inner
146 }
147
148 pub fn as_inner_tcp_write_mut(&mut self) -> &mut OwnedWriteHalf {
149 &mut self.inner_write.inner
150 }
151}
152
153impl<W> BuffedStream<OwnedReadHalf, W> {
154 pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
155 &self.inner_read.inner
156 }
157
158 pub fn as_inner_tcp_read_mut(&mut self) -> &mut OwnedReadHalf {
159 &mut self.inner_read.inner
160 }
161}
162
163impl BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
164 pub fn take(self) -> Res<TcpStream> {
165 let read = self.inner_read.take();
166 let write = self.inner_write.take();
167
168 read.reunite(write).context("Failed to reunite read and write halves")
169 }
170}
171
172impl<R, W> BincodeSend for BuffedStream<R, W>
175where
176 R: Unpin,
177 W: AsyncWrite + Unpin,
178{
179 async fn push<E>(&mut self, message: E) -> Void
180 where
181 E: Encode,
182 {
183 self.inner_write.push(message).await?;
184
185 Ok(())
186 }
187
188 async fn close(&mut self) -> Void {
189 self.inner_write.close().await?;
190
191 Ok(())
192 }
193}
194
195impl<R, W> BincodeReceive for BuffedStream<R, W>
196where
197 R: AsyncRead + Unpin,
198 W: Unpin,
199{
200 async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
201 self.inner_read.pull().await
202 }
203}
204
205pub struct BuffedStreamReadHalf<T> {
209 inner: T,
210 shared_secret: Option<SharedSecret>,
211 buffer: BytesMut,
212 decryption_buffer: BytesMut,
213}
214
215impl<T> BuffedStreamReadHalf<T>
216where
217 T: AsyncRead + Unpin,
218{
219 fn new(async_read: T) -> Self {
221 Self {
222 inner: async_read,
223 shared_secret: None,
224 buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
225 decryption_buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
226 }
227 }
228
229 fn take(self) -> T {
231 if !self.buffer.is_empty() {
232 warn!("Buffer was not empty when taking the stream");
233 }
234
235 self.inner
236 }
237}
238
239impl<T> BincodeReceive for BuffedStreamReadHalf<T>
240where
241 T: AsyncRead + Unpin,
242{
243 async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
244 self.buffer.clear();
255 self.decryption_buffer.clear();
256 self.buffer.reserve(Constant::BUFFER_SIZE);
257 self.decryption_buffer.reserve(Constant::BUFFER_SIZE);
258
259 let is_encrypted = self.inner.read_u8().await.context("Failed to read encryption flag")? == 1;
262
263 let maybe_encryption_data = if is_encrypted {
266 let mut nonce = [0; Constant::SHARED_SECRET_NONCE_SIZE];
267 self.inner.read_exact(&mut nonce).await.context("Failed to read nonce")?;
268
269 Some(nonce)
270 } else {
271 None
272 };
273
274 let message_size = self.inner.read_u64().await.context("Failed to read size")? as usize;
277
278 if message_size == 0 {
281 let guard = ProtocolMessageGuardBuilder {
282 buffer: Bytes::new(),
283 inner_builder: |_| ProtocolMessage::Shutdown,
284 }
285 .build();
286
287 return Ok(guard);
288 }
289
290 if message_size > Constant::BUFFER_SIZE {
291 return Err(anyhow!("Message size is too large for the buffer during pull"));
292 }
293
294 unsafe { self.buffer.set_len(message_size) };
300 let n = self.inner.read_exact(&mut self.buffer).await?;
301
302 if message_size != n {
303 return Err(anyhow!("Failed to read message: expected {} bytes, got {}", message_size, n));
304 }
305
306 let data_buffer = self.buffer.split().freeze();
309
310 let data_buffer = if let Some(nonce) = maybe_encryption_data {
313 let Some(key) = self.shared_secret.as_ref() else {
314 return Err(anyhow!("Shared secret is not set when receiving encrypted message"));
315 };
316
317 self.decryption_buffer.put(data_buffer);
318
319 decrypt_in_place(key, &nonce, &mut self.decryption_buffer).context("Failed to decrypt message")?;
321
322 self.decryption_buffer.split().freeze()
323 } else {
324 data_buffer
325 };
326
327 Ok(ProtocolMessageGuardBuilder {
330 buffer: data_buffer,
331 inner_builder: |data| match bincode::borrow_decode_from_slice::<ProtocolMessage<'_>, _>(data, Constant::BINCODE_CONFIG) {
332 Ok((message, n)) => {
333 if n != data.len() {
334 error!("Failed to decrypt message: expected {} bytes, got {}", data.len(), n);
335 return ProtocolMessage::Shutdown;
336 }
337
338 message
339 }
340 Err(e) => {
341 error!("Failed to decode message: {}", e);
342 ProtocolMessage::Shutdown
343 }
344 },
345 }
346 .build())
347 }
348}
349
350pub struct BuffedStreamWriteHalf<T> {
352 inner: T,
353 shared_secret: Option<SharedSecret>,
354 buffer: BytesMut,
355}
356
357impl<T> BuffedStreamWriteHalf<T>
358where
359 T: AsyncWrite + Unpin,
360{
361 fn new(async_write: T) -> Self {
363 Self {
364 inner: async_write,
365 shared_secret: None,
366 buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
367 }
368 }
369
370 fn take(self) -> T {
372 self.inner
373 }
374}
375
376impl<T> BincodeSend for BuffedStreamWriteHalf<T>
377where
378 T: AsyncWrite + Unpin,
379{
380 async fn push<E>(&mut self, message: E) -> Void
381 where
382 E: Encode,
383 {
384 self.buffer.clear();
391 assert!(self.buffer.try_reclaim(2 * Constant::BUFFER_SIZE));
392
393 unsafe { self.buffer.set_len(Constant::BUFFER_SIZE) };
399 let n = bincode::encode_into_slice(message, &mut self.buffer, Constant::BINCODE_CONFIG)?;
400 unsafe { self.buffer.set_len(n) };
401
402 let maybe_nonce = if let Some(key) = self.shared_secret.as_ref() {
403 let nonce = encrypt_into(key, &mut self.buffer).context("Encryption failed")?;
405
406 Some(nonce)
407 } else {
408 None
409 };
410
411 let data_length = self.buffer.len();
414
415 if data_length == 0 {
416 return Err(anyhow!("Buffer is empty"));
417 }
418
419 if data_length > Constant::BUFFER_SIZE {
420 return Err(anyhow!("Buffer is too large"));
421 }
422
423 if let Some(nonce) = maybe_nonce {
426 self.inner.write_u8(1).await.context("Failed to write encryption flag")?;
427 self.inner.write_all(&nonce).await.context("Failed to write nonce")?;
428 } else {
429 self.inner.write_u8(0).await.context("Failed to write encryption flag")?;
430 }
431
432 self.inner.write_u64(data_length as u64).await.context("Failed to write size")?;
435
436 self.inner.write_all(&self.buffer).await?;
439
440 self.inner.flush().await.context("Failed to flush stream")?;
443
444 Ok(())
445 }
446
447 async fn close(&mut self) -> Void {
448 self.inner.shutdown().await.context("Failed to close stream")?;
449
450 Ok(())
451 }
452}
453
454#[cfg(test)]
457mod tests {
458 use crate::{
459 protocol::{BincodeReceive, BincodeSend, ProtocolMessage},
460 utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption},
461 };
462
463 #[tokio::test]
464 async fn test_unencrypted_buffed_stream() {
465 let (mut client, mut server) = generate_test_duplex();
466
467 let data = b"Hello, world!";
468
469 client.push(ProtocolMessage::Data(data)).await.unwrap();
470 client.close().await.unwrap();
471
472 let guard = server.pull().await.unwrap();
473 let ProtocolMessage::Data(received) = *guard.message() else {
474 panic!("Failed to receive message");
475 };
476
477 assert_eq!(data, received);
478 }
479
480 #[tokio::test]
481 async fn test_e2e_encrypted_buffed_stream() {
482 let (mut client, mut server) = generate_test_duplex_with_encryption();
483
484 let data = b"Hello, world!";
485
486 client.push(ProtocolMessage::Data(data)).await.unwrap();
487 client.close().await.unwrap();
488
489 let guard = server.pull().await.unwrap();
490 let ProtocolMessage::Data(received) = *guard.message() else {
491 panic!("Failed to receive message");
492 };
493
494 assert_eq!(data, received);
495 }
496
497 #[tokio::test]
498 async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
499 let (mut client, mut server) = generate_test_duplex_with_encryption();
500
501 let data1 = b"Hello, world!";
502 let data2 = b"Hello, wold!";
503
504 client.push(ProtocolMessage::Data(data1)).await.unwrap();
505 client.push(ProtocolMessage::Data(data2)).await.unwrap();
506 client.close().await.unwrap();
507
508 let guard = server.pull().await.unwrap();
509 let ProtocolMessage::Data(received) = *guard.message() else {
510 panic!("Failed to receive message");
511 };
512 assert_eq!(data1, received);
513
514 let guard = server.pull().await.unwrap();
515 let ProtocolMessage::Data(received) = *guard.message() else {
516 panic!("Failed to receive message");
517 };
518 assert_eq!(data2, received);
519 }
520}