phantom_protocol/transport/
framing.rs1use crate::crypto::adaptive_crypto::{CryptoSession, AEAD_OVERHEAD};
9
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::TcpStream;
12
13pub const FRAME_HEADER_SIZE: usize = 4;
15
16pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024; pub struct FrameWriter;
21
22impl Default for FrameWriter {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl FrameWriter {
29 pub fn new() -> Self {
31 Self
32 }
33
34 pub const SPAWN_BLOCKING_THRESHOLD: usize = 256 * 1024; #[inline]
42 pub async fn write_frame(
43 &self,
44 stream: &mut TcpStream,
45 session: &CryptoSession,
46 data: &[u8],
47 ) -> Result<usize, FrameError> {
48 if data.is_empty() {
49 return Ok(0);
50 }
51
52 let total_len = data.len();
53 let num_chunks = total_len.div_ceil(MAX_FRAME_PAYLOAD);
54
55 let total_cap = num_chunks * (FRAME_HEADER_SIZE + AEAD_OVERHEAD) + total_len;
57 let mut batch_buf = Vec::with_capacity(total_cap);
58
59 if total_len > Self::SPAWN_BLOCKING_THRESHOLD {
60 let session = session.clone();
62 let data = data.to_vec();
63
64 batch_buf = tokio::task::spawn_blocking(move || {
65 let mut buf = Vec::with_capacity(total_cap);
66 for chunk in data.chunks(MAX_FRAME_PAYLOAD) {
67 let frame_start = buf.len();
68 let ct_len = chunk.len() + AEAD_OVERHEAD;
69 let len_bytes = (ct_len as u32).to_be_bytes();
70 buf.extend_from_slice(&len_bytes);
72 buf.extend_from_slice(chunk);
74 session
76 .encrypt_in_place_offset(
77 &len_bytes,
78 &mut buf,
79 frame_start + FRAME_HEADER_SIZE,
80 )
81 .map_err(|_| FrameError::EncryptFailed)?;
82 }
83 Ok::<Vec<u8>, FrameError>(buf)
84 })
85 .await
86 .map_err(|_| FrameError::EncryptFailed)??;
87 } else {
88 for chunk in data.chunks(MAX_FRAME_PAYLOAD) {
90 let frame_start = batch_buf.len();
91 let ct_len = chunk.len() + AEAD_OVERHEAD;
92 let len_bytes = (ct_len as u32).to_be_bytes();
93 batch_buf.extend_from_slice(&len_bytes);
95 batch_buf.extend_from_slice(chunk);
97 session
99 .encrypt_in_place_offset(
100 &len_bytes,
101 &mut batch_buf,
102 frame_start + FRAME_HEADER_SIZE,
103 )
104 .map_err(|_| FrameError::EncryptFailed)?;
105 }
106 }
107
108 stream.write_all(&batch_buf).await.map_err(FrameError::Io)?;
110
111 Ok(total_len)
112 }
113
114 #[inline]
117 pub async fn write_frames_batch(
118 &self,
119 stream: &mut TcpStream,
120 session: &CryptoSession,
121 payloads: &[&[u8]],
122 ) -> Result<usize, FrameError> {
123 if payloads.is_empty() {
124 return Ok(0);
125 }
126
127 let total_size: usize = payloads
129 .iter()
130 .map(|p| FRAME_HEADER_SIZE + p.len() + AEAD_OVERHEAD)
131 .sum();
132
133 let mut batch_buf = Vec::with_capacity(total_size);
134 let mut total_payload = 0usize;
135
136 for payload in payloads {
137 let frame_start = batch_buf.len();
138 let ct_len = payload.len() + AEAD_OVERHEAD;
139 let len_bytes = (ct_len as u32).to_be_bytes();
140
141 batch_buf.extend_from_slice(&len_bytes);
143 batch_buf.extend_from_slice(payload);
145
146 let encrypt_start = frame_start + FRAME_HEADER_SIZE;
148 session
149 .encrypt_in_place_offset(&len_bytes, &mut batch_buf, encrypt_start)
150 .map_err(|_| FrameError::EncryptFailed)?;
151
152 total_payload += payload.len();
153 }
154
155 stream.write_all(&batch_buf).await.map_err(FrameError::Io)?;
157
158 Ok(total_payload)
159 }
160}
161
162pub struct FrameReader {
164 header_buf: [u8; FRAME_HEADER_SIZE],
166}
167
168impl Default for FrameReader {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl FrameReader {
175 pub fn new() -> Self {
176 Self {
177 header_buf: [0u8; FRAME_HEADER_SIZE],
178 }
179 }
180
181 #[inline]
184 pub async fn read_frame(
185 &mut self,
186 stream: &mut TcpStream,
187 session: &CryptoSession,
188 ) -> Result<Vec<u8>, FrameError> {
189 stream
191 .read_exact(&mut self.header_buf)
192 .await
193 .map_err(FrameError::Io)?;
194
195 let ct_len = u32::from_be_bytes(self.header_buf) as usize;
196
197 if ct_len > MAX_FRAME_PAYLOAD + AEAD_OVERHEAD {
198 return Err(FrameError::FrameTooLarge(ct_len));
199 }
200
201 let mut ct = vec![0u8; ct_len];
203 stream.read_exact(&mut ct).await.map_err(FrameError::Io)?;
204
205 if ct_len > FrameWriter::SPAWN_BLOCKING_THRESHOLD {
208 let session = session.clone();
209 let header_buf = self.header_buf; ct = tokio::task::spawn_blocking(move || {
211 let pt = session
212 .decrypt_in_place(&header_buf, &mut ct)
213 .map_err(|_| FrameError::DecryptFailed)?;
214 let pt_len = pt.len();
215 ct.truncate(pt_len);
216 Ok::<Vec<u8>, FrameError>(ct)
217 })
218 .await
219 .map_err(|_| FrameError::DecryptFailed)??;
220 } else {
221 let pt = session
222 .decrypt_in_place(&self.header_buf, &mut ct)
223 .map_err(|_| FrameError::DecryptFailed)?;
224 let pt_len = pt.len();
225 ct.truncate(pt_len);
226 }
227
228 Ok(ct)
229 }
230}
231
232#[derive(Debug)]
234pub enum FrameError {
235 Io(std::io::Error),
236 EncryptFailed,
237 DecryptFailed,
238 FrameTooLarge(usize),
239}
240
241impl std::fmt::Display for FrameError {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 match self {
244 Self::Io(e) => write!(f, "Frame I/O error: {}", e),
245 Self::EncryptFailed => write!(f, "Frame encryption failed"),
246 Self::DecryptFailed => write!(f, "Frame decryption / auth failed"),
247 Self::FrameTooLarge(n) => write!(f, "Frame too large: {} bytes", n),
248 }
249 }
250}
251
252impl std::error::Error for FrameError {}
253
254impl From<std::io::Error> for FrameError {
255 fn from(e: std::io::Error) -> Self {
256 Self::Io(e)
257 }
258}
259
260#[derive(Debug, Clone, Copy, PartialEq, Eq)]
264pub enum PaddingProfile {
265 None,
266 DtlsSrtp,
267 HttpsTls,
268 FixedMtu,
269}
270
271const DTLS_SRTP_BUCKETS: &[usize] = &[64, 128, 256, 512, 1024, 1200];
272const HTTPS_TLS_BUCKETS: &[usize] = &[128, 256, 512, 1024, 2048, 4096, 8192];
273const DEFAULT_MTU: usize = 1400;
274
275pub fn adaptive_pad_size(payload_len: usize, profile: PaddingProfile) -> usize {
276 match profile {
277 PaddingProfile::None => payload_len,
278 PaddingProfile::DtlsSrtp => pad_to_bucket(payload_len, DTLS_SRTP_BUCKETS),
279 PaddingProfile::HttpsTls => pad_to_bucket(payload_len, HTTPS_TLS_BUCKETS),
280 PaddingProfile::FixedMtu => {
281 if payload_len <= DEFAULT_MTU {
282 DEFAULT_MTU
283 } else {
284 payload_len
285 }
286 }
287 }
288}
289
290pub fn apply_adaptive_padding(payload: &[u8], profile: PaddingProfile) -> Vec<u8> {
291 let orig_len = payload.len();
292 let padded_size = adaptive_pad_size(orig_len + 2, profile);
293
294 let mut buf = Vec::with_capacity(padded_size);
295 buf.extend_from_slice(&(orig_len as u16).to_be_bytes());
296 buf.extend_from_slice(payload);
297 buf.resize(padded_size, 0);
298 buf
299}
300
301pub fn strip_adaptive_padding(padded: &[u8]) -> Option<&[u8]> {
302 if padded.len() < 2 {
303 return None;
304 }
305 let orig_len = u16::from_be_bytes([padded[0], padded[1]]) as usize;
306 if 2 + orig_len > padded.len() {
307 return None;
308 }
309 Some(&padded[2..2 + orig_len])
310}
311
312fn pad_to_bucket(payload_len: usize, buckets: &[usize]) -> usize {
313 for &bucket in buckets {
314 if payload_len <= bucket {
315 return bucket;
316 }
317 }
318 payload_len
319}
320
321impl FrameWriter {
322 pub async fn write_frame_padded(
323 &self,
324 stream: &mut TcpStream,
325 session: &CryptoSession,
326 data: &[u8],
327 profile: PaddingProfile,
328 ) -> Result<usize, FrameError> {
329 let padded = apply_adaptive_padding(data, profile);
330 self.write_frame(stream, session, &padded).await
331 }
332}
333
334impl FrameReader {
335 pub async fn read_frame_padded(
336 &mut self,
337 stream: &mut TcpStream,
338 session: &CryptoSession,
339 ) -> Result<Vec<u8>, FrameError> {
340 let padded = self.read_frame(stream, session).await?;
341 match strip_adaptive_padding(&padded) {
342 Some(payload) => Ok(payload.to_vec()),
343 None => Ok(padded),
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::sync::Arc;
352 use tokio::net::TcpListener;
353
354 #[tokio::test]
355 async fn frame_round_trip() {
356 let secret = [0xABu8; 32];
357 let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
358 let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
359
360 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
361 let addr = listener.local_addr().unwrap();
362
363 let ss2 = ss.clone();
364 let handle = tokio::spawn(async move {
365 let (mut tcp, _) = listener.accept().await.unwrap();
366 let mut reader = FrameReader::new();
367 let data = reader.read_frame(&mut tcp, &ss2).await.unwrap();
368 assert_eq!(&data, b"Hello, zero-copy framing!");
369 });
370
371 let mut tcp = TcpStream::connect(addr).await.unwrap();
372 let writer = FrameWriter::new();
373 writer
374 .write_frame(&mut tcp, &cs, b"Hello, zero-copy framing!")
375 .await
376 .unwrap();
377
378 handle.await.unwrap();
379 }
380
381 #[tokio::test]
382 async fn large_message_round_trip() {
383 let secret = [0x12u8; 32];
384 let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
385 let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
386
387 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
388 let addr = listener.local_addr().unwrap();
389
390 let original_data = vec![0x42u8; 1024 * 1024]; let data_clone = original_data.clone();
392
393 let ss2 = ss.clone();
394 let handle = tokio::spawn(async move {
395 let (mut tcp, _) = listener.accept().await.unwrap();
396 let mut reader = FrameReader::new();
397 let mut received_data = Vec::new();
398
399 let num_chunks = (data_clone.len() + MAX_FRAME_PAYLOAD - 1) / MAX_FRAME_PAYLOAD;
400 for _ in 0..num_chunks {
401 let chunk = reader.read_frame(&mut tcp, &ss2).await.unwrap();
402 received_data.extend_from_slice(&chunk);
403 }
404 assert_eq!(received_data, data_clone);
405 });
406
407 let mut tcp = TcpStream::connect(addr).await.unwrap();
408 let writer = FrameWriter::new();
409 writer
410 .write_frame(&mut tcp, &cs, &original_data)
411 .await
412 .unwrap();
413
414 handle.await.unwrap();
415 }
416
417 #[tokio::test]
418 async fn frame_batch_round_trip() {
419 let secret = [0xCDu8; 32];
420 let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
421 let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
422
423 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
424 let addr = listener.local_addr().unwrap();
425
426 let ss2 = ss.clone();
427 let handle = tokio::spawn(async move {
428 let (mut tcp, _) = listener.accept().await.unwrap();
429 let mut reader = FrameReader::new();
430 let d1 = reader.read_frame(&mut tcp, &ss2).await.unwrap();
431 let d2 = reader.read_frame(&mut tcp, &ss2).await.unwrap();
432 let d3 = reader.read_frame(&mut tcp, &ss2).await.unwrap();
433 assert_eq!(&d1, b"Frame 1");
434 assert_eq!(&d2, b"Frame 2");
435 assert_eq!(&d3, b"Frame 3");
436 });
437
438 let mut tcp = TcpStream::connect(addr).await.unwrap();
439 let writer = FrameWriter::new();
440 let payloads: Vec<&[u8]> = vec![b"Frame 1", b"Frame 2", b"Frame 3"];
441 writer
442 .write_frames_batch(&mut tcp, &cs, &payloads)
443 .await
444 .unwrap();
445
446 handle.await.unwrap();
447 }
448
449 #[test]
450 fn test_adaptive_padding_dtls() {
451 let padded_size = adaptive_pad_size(52, PaddingProfile::DtlsSrtp);
452 assert_eq!(padded_size, 64);
453 assert_eq!(adaptive_pad_size(100, PaddingProfile::DtlsSrtp), 128);
454 assert_eq!(adaptive_pad_size(1000, PaddingProfile::DtlsSrtp), 1024);
455 }
456
457 #[test]
458 fn test_padding_roundtrip() {
459 let original = b"Hello, adaptive padding!";
460 let padded = apply_adaptive_padding(original, PaddingProfile::DtlsSrtp);
461 assert!(padded.len() >= 64);
462 let stripped = strip_adaptive_padding(&padded).unwrap();
463 assert_eq!(stripped, original);
464 }
465
466 #[tokio::test]
467 async fn frame_padded_round_trip() {
468 let secret = [0xEFu8; 32];
469 let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
470 let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
471
472 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
473 let addr = listener.local_addr().unwrap();
474
475 let ss2 = ss.clone();
476 let handle = tokio::spawn(async move {
477 let (mut tcp, _) = listener.accept().await.unwrap();
478 let mut reader = FrameReader::new();
479 let data = reader.read_frame_padded(&mut tcp, &ss2).await.unwrap();
480 assert_eq!(&data, b"Padded message!");
481 });
482
483 let mut tcp = TcpStream::connect(addr).await.unwrap();
484 let writer = FrameWriter::new();
485 writer
486 .write_frame_padded(&mut tcp, &cs, b"Padded message!", PaddingProfile::DtlsSrtp)
487 .await
488 .unwrap();
489
490 handle.await.unwrap();
491 }
492}