Skip to main content

apfsds_transport/
frame_codec.rs

1//! Frame codec for encoding/decoding ProxyFrames over WebSocket
2
3use apfsds_obfuscation::{PaddingStrategy, XorMask, compress_if_needed, decompress};
4use apfsds_protocol::ProxyFrame;
5use thiserror::Error;
6use tracing::trace;
7
8#[derive(Error, Debug)]
9pub enum CodecError {
10    #[error("Serialization failed: {0}")]
11    SerializationFailed(String),
12
13    #[error("Deserialization failed: {0}")]
14    DeserializationFailed(String),
15
16    #[error("Compression failed: {0}")]
17    CompressionFailed(String),
18
19    #[error("Decompression failed: {0}")]
20    DecompressionFailed(String),
21
22    #[error("Invalid frame format")]
23    InvalidFrameFormat,
24}
25
26/// Frame codec for encoding/decoding ProxyFrames
27pub struct FrameCodec {
28    xor_mask: XorMask,
29    padding: PaddingStrategy,
30    compression_enabled: bool,
31}
32
33impl FrameCodec {
34    /// Create a new codec with the given session key
35    pub fn new(session_key: u64) -> Self {
36        Self {
37            xor_mask: XorMask::new(session_key),
38            padding: PaddingStrategy::default(),
39            compression_enabled: true,
40        }
41    }
42
43    /// Create without compression
44    pub fn without_compression(session_key: u64) -> Self {
45        Self {
46            xor_mask: XorMask::new(session_key),
47            padding: PaddingStrategy::default(),
48            compression_enabled: false,
49        }
50    }
51
52    /// Encode a ProxyFrame for transmission
53    pub fn encode(&self, frame: &ProxyFrame) -> Result<Vec<u8>, CodecError> {
54        // 1. Serialize with rkyv
55        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(frame)
56            .map_err(|e| CodecError::SerializationFailed(e.to_string()))?
57            .to_vec();
58
59        trace!("Serialized frame: {} bytes", bytes.len());
60
61        // 2. Compress if needed
62        let (data, compressed) = if self.compression_enabled {
63            compress_if_needed(&bytes).map_err(|e| CodecError::CompressionFailed(e.to_string()))?
64        } else {
65            (bytes, false)
66        };
67
68        trace!(
69            "After compression: {} bytes (compressed: {})",
70            data.len(),
71            compressed
72        );
73
74        // 3. XOR mask
75        let masked = self.xor_mask.apply(&data);
76
77        // 4. Add padding
78        let mut padded = self.padding.pad(&masked);
79
80        // 5. Prepend flags byte (bit 0 = compressed)
81        let flags = if compressed { 0x01 } else { 0x00 };
82        padded.insert(0, flags);
83
84        trace!("Final encoded size: {} bytes", padded.len());
85
86        Ok(padded)
87    }
88
89    /// Decode a ProxyFrame from received data
90    pub fn decode(&self, data: &[u8]) -> Result<ProxyFrame, CodecError> {
91        if data.is_empty() {
92            return Err(CodecError::InvalidFrameFormat);
93        }
94
95        // 1. Extract flags byte
96        let flags = data[0];
97        let compressed = (flags & 0x01) != 0;
98        let remaining = &data[1..];
99
100        trace!(
101            "Decoding frame: {} bytes, compressed: {}",
102            data.len(),
103            compressed
104        );
105
106        // 2. Remove padding
107        let unpadded = PaddingStrategy::unpad(remaining).ok_or(CodecError::InvalidFrameFormat)?;
108
109        // 3. XOR unmask
110        let unmasked = self.xor_mask.apply(&unpadded);
111
112        // 4. Decompress if needed
113        let bytes = if compressed {
114            decompress(&unmasked).map_err(|e| CodecError::DecompressionFailed(e.to_string()))?
115        } else {
116            unmasked
117        };
118
119        // 5. Deserialize with rkyv
120        let archived =
121            rkyv::access::<apfsds_protocol::ArchivedProxyFrame, rkyv::rancor::Error>(&bytes)
122                .map_err(|e| CodecError::DeserializationFailed(e.to_string()))?;
123
124        let frame: ProxyFrame = rkyv::deserialize::<ProxyFrame, rkyv::rancor::Error>(archived)
125            .map_err(|e| CodecError::DeserializationFailed(e.to_string()))?;
126
127        Ok(frame)
128    }
129
130    /// Encode frame as binary WebSocket message
131    pub fn encode_to_message(
132        &self,
133        frame: &ProxyFrame,
134    ) -> Result<tokio_tungstenite::tungstenite::Message, CodecError> {
135        let bytes = self.encode(frame)?;
136        Ok(tokio_tungstenite::tungstenite::Message::Binary(
137            bytes.into(),
138        ))
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_encode_decode_roundtrip() {
148        let codec = FrameCodec::new(12345);
149
150        let frame = ProxyFrame::new_data(
151            42,
152            ProxyFrame::ipv4_to_mapped([192, 168, 1, 1]),
153            8080,
154            vec![1, 2, 3, 4, 5],
155        );
156
157        let encoded = codec.encode(&frame).unwrap();
158        let decoded = codec.decode(&encoded).unwrap();
159
160        assert_eq!(frame.conn_id, decoded.conn_id);
161        assert_eq!(frame.rport, decoded.rport);
162        assert_eq!(frame.payload, decoded.payload);
163    }
164
165    #[test]
166    fn test_large_payload_compression() {
167        let codec = FrameCodec::new(12345);
168
169        // Large payload that should be compressed
170        let payload: Vec<u8> = (0..2000).map(|i| (i % 256) as u8).collect();
171        let frame = ProxyFrame::new_data(1, [0; 16], 443, payload.clone());
172
173        let encoded = codec.encode(&frame).unwrap();
174
175        // Check that compression flag is set
176        assert_eq!(encoded[0] & 0x01, 0x01);
177
178        let decoded = codec.decode(&encoded).unwrap();
179        assert_eq!(frame.payload, decoded.payload);
180    }
181
182    #[test]
183    fn test_without_compression() {
184        let codec = FrameCodec::without_compression(12345);
185
186        let payload: Vec<u8> = (0..2000).map(|i| (i % 256) as u8).collect();
187        let frame = ProxyFrame::new_data(1, [0; 16], 443, payload);
188
189        let encoded = codec.encode(&frame).unwrap();
190
191        // Check that compression flag is NOT set
192        assert_eq!(encoded[0] & 0x01, 0x00);
193    }
194}