Skip to main content

phantom_protocol/transport/
framing.rs

1//! Zero-Copy TCP Framing Pipeline
2//!
3//! Проблема: старый подход делал `data.clone()` + `encrypt_in_place()` + `write(len)` + `write(data)` = 2 syscalls + 1 clone.
4//! TLS 1.3 (rustls) делает всё за 1 внутренний write.
5//!
6//! Решение: prepend 4-byte length header → encrypt payload in-place → single write_all().
7
8use crate::crypto::adaptive_crypto::{CryptoSession, AEAD_OVERHEAD};
9
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::TcpStream;
12
13/// Frame header size: 4 bytes for payload length (u32 BE)
14pub const FRAME_HEADER_SIZE: usize = 4;
15
16/// Maximum frame payload size (before encryption)
17pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024; // 64 KB
18
19/// Zero-copy frame writer — encrypts and writes in a single syscall
20pub struct FrameWriter;
21
22impl Default for FrameWriter {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl FrameWriter {
29    /// Create a new frame writer
30    pub fn new() -> Self {
31        Self
32    }
33
34    /// Threshold size (bytes) above which we use tokio::task::spawn_blocking for encryption
35    pub const SPAWN_BLOCKING_THRESHOLD: usize = 256 * 1024; // 256 KB
36
37    /// Write a single message as one or more frames: [len:4][encrypted_payload + tag:16]
38    ///
39    /// If data exceeds MAX_FRAME_PAYLOAD, it is split into multiple frames.
40    /// If total data exceeds SPAWN_BLOCKING_THRESHOLD, encryption is offloaded to spawn_blocking.
41    #[inline]
42    pub async fn write_frame(
43        &self,
44        stream: &mut TcpStream,
45        session: &CryptoSession,
46        data: &[u8],
47    ) -> Result<usize, FrameError> {
48        if data.is_empty() {
49            return Ok(0);
50        }
51
52        let total_len = data.len();
53        let num_chunks = total_len.div_ceil(MAX_FRAME_PAYLOAD);
54
55        // Calculate total buffer size needed for all chunks
56        let total_cap = num_chunks * (FRAME_HEADER_SIZE + AEAD_OVERHEAD) + total_len;
57        let mut batch_buf = Vec::with_capacity(total_cap);
58
59        if total_len > Self::SPAWN_BLOCKING_THRESHOLD {
60            // Offload encryption to blocking thread pool
61            let session = session.clone();
62            let data = data.to_vec();
63
64            batch_buf = tokio::task::spawn_blocking(move || {
65                let mut buf = Vec::with_capacity(total_cap);
66                for chunk in data.chunks(MAX_FRAME_PAYLOAD) {
67                    let frame_start = buf.len();
68                    let ct_len = chunk.len() + AEAD_OVERHEAD;
69                    let len_bytes = (ct_len as u32).to_be_bytes();
70                    // Length placeholder
71                    buf.extend_from_slice(&len_bytes);
72                    // Payload
73                    buf.extend_from_slice(chunk);
74                    // Encrypt in-place at offset
75                    session
76                        .encrypt_in_place_offset(
77                            &len_bytes,
78                            &mut buf,
79                            frame_start + FRAME_HEADER_SIZE,
80                        )
81                        .map_err(|_| FrameError::EncryptFailed)?;
82                }
83                Ok::<Vec<u8>, FrameError>(buf)
84            })
85            .await
86            .map_err(|_| FrameError::EncryptFailed)??;
87        } else {
88            // Synchronous encryption (on current Tokio worker)
89            for chunk in data.chunks(MAX_FRAME_PAYLOAD) {
90                let frame_start = batch_buf.len();
91                let ct_len = chunk.len() + AEAD_OVERHEAD;
92                let len_bytes = (ct_len as u32).to_be_bytes();
93                // Length placeholder
94                batch_buf.extend_from_slice(&len_bytes);
95                // Payload
96                batch_buf.extend_from_slice(chunk);
97                // Encrypt in-place at offset
98                session
99                    .encrypt_in_place_offset(
100                        &len_bytes,
101                        &mut batch_buf,
102                        frame_start + FRAME_HEADER_SIZE,
103                    )
104                    .map_err(|_| FrameError::EncryptFailed)?;
105            }
106        }
107
108        // Single syscall write for all chunks
109        stream.write_all(&batch_buf).await.map_err(FrameError::Io)?;
110
111        Ok(total_len)
112    }
113
114    /// Write multiple frames in a batch (TCP write coalescing).
115    /// Accumulates all frames into a single buffer → one write_all().
116    #[inline]
117    pub async fn write_frames_batch(
118        &self,
119        stream: &mut TcpStream,
120        session: &CryptoSession,
121        payloads: &[&[u8]],
122    ) -> Result<usize, FrameError> {
123        if payloads.is_empty() {
124            return Ok(0);
125        }
126
127        // Calculate total buffer size needed
128        let total_size: usize = payloads
129            .iter()
130            .map(|p| FRAME_HEADER_SIZE + p.len() + AEAD_OVERHEAD)
131            .sum();
132
133        let mut batch_buf = Vec::with_capacity(total_size);
134        let mut total_payload = 0usize;
135
136        for payload in payloads {
137            let frame_start = batch_buf.len();
138            let ct_len = payload.len() + AEAD_OVERHEAD;
139            let len_bytes = (ct_len as u32).to_be_bytes();
140
141            // Length placeholder
142            batch_buf.extend_from_slice(&len_bytes);
143            // Payload
144            batch_buf.extend_from_slice(payload);
145
146            // Encrypt in-place at offset
147            let encrypt_start = frame_start + FRAME_HEADER_SIZE;
148            session
149                .encrypt_in_place_offset(&len_bytes, &mut batch_buf, encrypt_start)
150                .map_err(|_| FrameError::EncryptFailed)?;
151
152            total_payload += payload.len();
153        }
154
155        // Single write for all frames
156        stream.write_all(&batch_buf).await.map_err(FrameError::Io)?;
157
158        Ok(total_payload)
159    }
160}
161
162/// Zero-copy frame reader — reads and decrypts from TCP stream
163pub struct FrameReader {
164    /// Internal read buffer
165    header_buf: [u8; FRAME_HEADER_SIZE],
166}
167
168impl Default for FrameReader {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl FrameReader {
175    pub fn new() -> Self {
176        Self {
177            header_buf: [0u8; FRAME_HEADER_SIZE],
178        }
179    }
180
181    /// Read a single frame: reads `[len:4]`, then reads `[encrypted_payload]`, decrypts in-place.
182    /// Returns decrypted plaintext as `Vec<u8>`.
183    #[inline]
184    pub async fn read_frame(
185        &mut self,
186        stream: &mut TcpStream,
187        session: &CryptoSession,
188    ) -> Result<Vec<u8>, FrameError> {
189        // Read length header
190        stream
191            .read_exact(&mut self.header_buf)
192            .await
193            .map_err(FrameError::Io)?;
194
195        let ct_len = u32::from_be_bytes(self.header_buf) as usize;
196
197        if ct_len > MAX_FRAME_PAYLOAD + AEAD_OVERHEAD {
198            return Err(FrameError::FrameTooLarge(ct_len));
199        }
200
201        // Read ciphertext
202        let mut ct = vec![0u8; ct_len];
203        stream.read_exact(&mut ct).await.map_err(FrameError::Io)?;
204
205        // Decrypt in-place
206        // Offload to spawn_blocking if frame is large
207        if ct_len > FrameWriter::SPAWN_BLOCKING_THRESHOLD {
208            let session = session.clone();
209            let header_buf = self.header_buf; // Copy for closure
210            ct = tokio::task::spawn_blocking(move || {
211                let pt = session
212                    .decrypt_in_place(&header_buf, &mut ct)
213                    .map_err(|_| FrameError::DecryptFailed)?;
214                let pt_len = pt.len();
215                ct.truncate(pt_len);
216                Ok::<Vec<u8>, FrameError>(ct)
217            })
218            .await
219            .map_err(|_| FrameError::DecryptFailed)??;
220        } else {
221            let pt = session
222                .decrypt_in_place(&self.header_buf, &mut ct)
223                .map_err(|_| FrameError::DecryptFailed)?;
224            let pt_len = pt.len();
225            ct.truncate(pt_len);
226        }
227
228        Ok(ct)
229    }
230}
231
232/// Frame errors
233#[derive(Debug)]
234pub enum FrameError {
235    Io(std::io::Error),
236    EncryptFailed,
237    DecryptFailed,
238    FrameTooLarge(usize),
239}
240
241impl std::fmt::Display for FrameError {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        match self {
244            Self::Io(e) => write!(f, "Frame I/O error: {}", e),
245            Self::EncryptFailed => write!(f, "Frame encryption failed"),
246            Self::DecryptFailed => write!(f, "Frame decryption / auth failed"),
247            Self::FrameTooLarge(n) => write!(f, "Frame too large: {} bytes", n),
248        }
249    }
250}
251
252impl std::error::Error for FrameError {}
253
254impl From<std::io::Error> for FrameError {
255    fn from(e: std::io::Error) -> Self {
256        Self::Io(e)
257    }
258}
259
260// ─── Adaptive Padding ───────────────────────────────────────────────────────
261
262/// Padding profile — mimics real-world traffic distributions.
263#[derive(Debug, Clone, Copy, PartialEq, Eq)]
264pub enum PaddingProfile {
265    None,
266    DtlsSrtp,
267    HttpsTls,
268    FixedMtu,
269}
270
271const DTLS_SRTP_BUCKETS: &[usize] = &[64, 128, 256, 512, 1024, 1200];
272const HTTPS_TLS_BUCKETS: &[usize] = &[128, 256, 512, 1024, 2048, 4096, 8192];
273const DEFAULT_MTU: usize = 1400;
274
275pub fn adaptive_pad_size(payload_len: usize, profile: PaddingProfile) -> usize {
276    match profile {
277        PaddingProfile::None => payload_len,
278        PaddingProfile::DtlsSrtp => pad_to_bucket(payload_len, DTLS_SRTP_BUCKETS),
279        PaddingProfile::HttpsTls => pad_to_bucket(payload_len, HTTPS_TLS_BUCKETS),
280        PaddingProfile::FixedMtu => {
281            if payload_len <= DEFAULT_MTU {
282                DEFAULT_MTU
283            } else {
284                payload_len
285            }
286        }
287    }
288}
289
290pub fn apply_adaptive_padding(payload: &[u8], profile: PaddingProfile) -> Vec<u8> {
291    let orig_len = payload.len();
292    let padded_size = adaptive_pad_size(orig_len + 2, profile);
293
294    let mut buf = Vec::with_capacity(padded_size);
295    buf.extend_from_slice(&(orig_len as u16).to_be_bytes());
296    buf.extend_from_slice(payload);
297    buf.resize(padded_size, 0);
298    buf
299}
300
301pub fn strip_adaptive_padding(padded: &[u8]) -> Option<&[u8]> {
302    if padded.len() < 2 {
303        return None;
304    }
305    let orig_len = u16::from_be_bytes([padded[0], padded[1]]) as usize;
306    if 2 + orig_len > padded.len() {
307        return None;
308    }
309    Some(&padded[2..2 + orig_len])
310}
311
312fn pad_to_bucket(payload_len: usize, buckets: &[usize]) -> usize {
313    for &bucket in buckets {
314        if payload_len <= bucket {
315            return bucket;
316        }
317    }
318    payload_len
319}
320
321impl FrameWriter {
322    pub async fn write_frame_padded(
323        &self,
324        stream: &mut TcpStream,
325        session: &CryptoSession,
326        data: &[u8],
327        profile: PaddingProfile,
328    ) -> Result<usize, FrameError> {
329        let padded = apply_adaptive_padding(data, profile);
330        self.write_frame(stream, session, &padded).await
331    }
332}
333
334impl FrameReader {
335    pub async fn read_frame_padded(
336        &mut self,
337        stream: &mut TcpStream,
338        session: &CryptoSession,
339    ) -> Result<Vec<u8>, FrameError> {
340        let padded = self.read_frame(stream, session).await?;
341        match strip_adaptive_padding(&padded) {
342            Some(payload) => Ok(payload.to_vec()),
343            None => Ok(padded),
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use std::sync::Arc;
352    use tokio::net::TcpListener;
353
354    #[tokio::test]
355    async fn frame_round_trip() {
356        let secret = [0xABu8; 32];
357        let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
358        let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
359
360        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
361        let addr = listener.local_addr().unwrap();
362
363        let ss2 = ss.clone();
364        let handle = tokio::spawn(async move {
365            let (mut tcp, _) = listener.accept().await.unwrap();
366            let mut reader = FrameReader::new();
367            let data = reader.read_frame(&mut tcp, &ss2).await.unwrap();
368            assert_eq!(&data, b"Hello, zero-copy framing!");
369        });
370
371        let mut tcp = TcpStream::connect(addr).await.unwrap();
372        let writer = FrameWriter::new();
373        writer
374            .write_frame(&mut tcp, &cs, b"Hello, zero-copy framing!")
375            .await
376            .unwrap();
377
378        handle.await.unwrap();
379    }
380
381    #[tokio::test]
382    async fn large_message_round_trip() {
383        let secret = [0x12u8; 32];
384        let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
385        let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
386
387        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
388        let addr = listener.local_addr().unwrap();
389
390        let original_data = vec![0x42u8; 1024 * 1024]; // 1MB
391        let data_clone = original_data.clone();
392
393        let ss2 = ss.clone();
394        let handle = tokio::spawn(async move {
395            let (mut tcp, _) = listener.accept().await.unwrap();
396            let mut reader = FrameReader::new();
397            let mut received_data = Vec::new();
398
399            let num_chunks = (data_clone.len() + MAX_FRAME_PAYLOAD - 1) / MAX_FRAME_PAYLOAD;
400            for _ in 0..num_chunks {
401                let chunk = reader.read_frame(&mut tcp, &ss2).await.unwrap();
402                received_data.extend_from_slice(&chunk);
403            }
404            assert_eq!(received_data, data_clone);
405        });
406
407        let mut tcp = TcpStream::connect(addr).await.unwrap();
408        let writer = FrameWriter::new();
409        writer
410            .write_frame(&mut tcp, &cs, &original_data)
411            .await
412            .unwrap();
413
414        handle.await.unwrap();
415    }
416
417    #[tokio::test]
418    async fn frame_batch_round_trip() {
419        let secret = [0xCDu8; 32];
420        let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
421        let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
422
423        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
424        let addr = listener.local_addr().unwrap();
425
426        let ss2 = ss.clone();
427        let handle = tokio::spawn(async move {
428            let (mut tcp, _) = listener.accept().await.unwrap();
429            let mut reader = FrameReader::new();
430            let d1 = reader.read_frame(&mut tcp, &ss2).await.unwrap();
431            let d2 = reader.read_frame(&mut tcp, &ss2).await.unwrap();
432            let d3 = reader.read_frame(&mut tcp, &ss2).await.unwrap();
433            assert_eq!(&d1, b"Frame 1");
434            assert_eq!(&d2, b"Frame 2");
435            assert_eq!(&d3, b"Frame 3");
436        });
437
438        let mut tcp = TcpStream::connect(addr).await.unwrap();
439        let writer = FrameWriter::new();
440        let payloads: Vec<&[u8]> = vec![b"Frame 1", b"Frame 2", b"Frame 3"];
441        writer
442            .write_frames_batch(&mut tcp, &cs, &payloads)
443            .await
444            .unwrap();
445
446        handle.await.unwrap();
447    }
448
449    #[test]
450    fn test_adaptive_padding_dtls() {
451        let padded_size = adaptive_pad_size(52, PaddingProfile::DtlsSrtp);
452        assert_eq!(padded_size, 64);
453        assert_eq!(adaptive_pad_size(100, PaddingProfile::DtlsSrtp), 128);
454        assert_eq!(adaptive_pad_size(1000, PaddingProfile::DtlsSrtp), 1024);
455    }
456
457    #[test]
458    fn test_padding_roundtrip() {
459        let original = b"Hello, adaptive padding!";
460        let padded = apply_adaptive_padding(original, PaddingProfile::DtlsSrtp);
461        assert!(padded.len() >= 64);
462        let stripped = strip_adaptive_padding(&padded).unwrap();
463        assert_eq!(stripped, original);
464    }
465
466    #[tokio::test]
467    async fn frame_padded_round_trip() {
468        let secret = [0xEFu8; 32];
469        let cs = Arc::new(CryptoSession::from_shared_secret(&secret).unwrap());
470        let ss = Arc::new(CryptoSession::from_shared_secret_peer(&secret).unwrap());
471
472        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
473        let addr = listener.local_addr().unwrap();
474
475        let ss2 = ss.clone();
476        let handle = tokio::spawn(async move {
477            let (mut tcp, _) = listener.accept().await.unwrap();
478            let mut reader = FrameReader::new();
479            let data = reader.read_frame_padded(&mut tcp, &ss2).await.unwrap();
480            assert_eq!(&data, b"Padded message!");
481        });
482
483        let mut tcp = TcpStream::connect(addr).await.unwrap();
484        let writer = FrameWriter::new();
485        writer
486            .write_frame_padded(&mut tcp, &cs, b"Padded message!", PaddingProfile::DtlsSrtp)
487            .await
488            .unwrap();
489
490        handle.await.unwrap();
491    }
492}