Skip to main content

sa3p_protocol/
lib.rs

1//! Binary framing and multiplexing primitives for SA3P transport boundaries.
2//!
3//! Frame layout:
4//! `[MAGIC 4B][STREAM_ID 4B][OPCODE 1B][PAYLOAD_LEN 4B][PAYLOAD]`
5//!
6//! This crate includes incremental decoding (`FrameCodec`) and stream queues
7//! (`Multiplexer`) for socket or vsock transports.
8
9use std::collections::{BTreeMap, VecDeque};
10
11use sa3p_parser::Instruction;
12use thiserror::Error;
13
14pub const MAGIC: [u8; 4] = *b"SA3P";
15const HEADER_LEN: usize = 13;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum Opcode {
20    VfsRead = 0x01,
21    VfsWriteChunk = 0x02,
22    VfsRename = 0x03,
23    PtySpawn = 0x04,
24    PtyInput = 0x05,
25}
26
27impl Opcode {
28    pub fn from_byte(byte: u8) -> Option<Self> {
29        match byte {
30            0x01 => Some(Self::VfsRead),
31            0x02 => Some(Self::VfsWriteChunk),
32            0x03 => Some(Self::VfsRename),
33            0x04 => Some(Self::PtySpawn),
34            0x05 => Some(Self::PtyInput),
35            _ => None,
36        }
37    }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct Frame {
42    pub stream_id: u32,
43    pub opcode: Opcode,
44    pub payload: Vec<u8>,
45}
46
47impl Frame {
48    pub fn new(stream_id: u32, opcode: Opcode, payload: Vec<u8>) -> Self {
49        Self {
50            stream_id,
51            opcode,
52            payload,
53        }
54    }
55
56    pub fn encode(&self) -> Vec<u8> {
57        let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len());
58        out.extend_from_slice(&MAGIC);
59        out.extend_from_slice(&self.stream_id.to_be_bytes());
60        out.push(self.opcode as u8);
61        out.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
62        out.extend_from_slice(&self.payload);
63        out
64    }
65
66    pub fn decode(bytes: &[u8]) -> Result<Self, ProtocolError> {
67        let (frame, consumed) = decode_frame(bytes)?;
68        if consumed != bytes.len() {
69            return Err(ProtocolError::TrailingBytes {
70                trailing: bytes.len() - consumed,
71            });
72        }
73        Ok(frame)
74    }
75}
76
77#[derive(Debug, Error, PartialEq, Eq)]
78pub enum ProtocolError {
79    #[error("frame too short: need at least {expected} bytes, got {actual}")]
80    FrameTooShort { expected: usize, actual: usize },
81    #[error("invalid magic: got {found:?}")]
82    InvalidMagic { found: [u8; 4] },
83    #[error("unknown opcode: 0x{0:02x}")]
84    UnknownOpcode(u8),
85    #[error("incomplete frame payload: expected {expected} bytes, got {actual}")]
86    IncompletePayload { expected: usize, actual: usize },
87    #[error("trailing bytes after frame: {trailing}")]
88    TrailingBytes { trailing: usize },
89}
90
91pub fn decode_frame(bytes: &[u8]) -> Result<(Frame, usize), ProtocolError> {
92    if bytes.len() < HEADER_LEN {
93        return Err(ProtocolError::FrameTooShort {
94            expected: HEADER_LEN,
95            actual: bytes.len(),
96        });
97    }
98
99    let found_magic = [bytes[0], bytes[1], bytes[2], bytes[3]];
100    if found_magic != MAGIC {
101        return Err(ProtocolError::InvalidMagic { found: found_magic });
102    }
103
104    let stream_id = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
105    let opcode_byte = bytes[8];
106    let opcode = Opcode::from_byte(opcode_byte).ok_or(ProtocolError::UnknownOpcode(opcode_byte))?;
107    let payload_len = u32::from_be_bytes([bytes[9], bytes[10], bytes[11], bytes[12]]) as usize;
108
109    let total_len = HEADER_LEN + payload_len;
110    if bytes.len() < total_len {
111        return Err(ProtocolError::IncompletePayload {
112            expected: payload_len,
113            actual: bytes.len().saturating_sub(HEADER_LEN),
114        });
115    }
116
117    let payload = bytes[HEADER_LEN..total_len].to_vec();
118    Ok((Frame::new(stream_id, opcode, payload), total_len))
119}
120
121#[derive(Debug, Default)]
122pub struct FrameCodec {
123    buffer: Vec<u8>,
124}
125
126impl FrameCodec {
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    pub fn push_bytes(&mut self, bytes: &[u8]) -> Result<Vec<Frame>, ProtocolError> {
132        self.buffer.extend_from_slice(bytes);
133        let mut frames = Vec::new();
134
135        loop {
136            if self.buffer.is_empty() {
137                break;
138            }
139
140            match decode_frame(&self.buffer) {
141                Ok((frame, consumed)) => {
142                    self.buffer.drain(..consumed);
143                    frames.push(frame);
144                }
145                Err(ProtocolError::FrameTooShort { .. })
146                | Err(ProtocolError::IncompletePayload { .. }) => break,
147                Err(err) => return Err(err),
148            }
149        }
150
151        Ok(frames)
152    }
153
154    pub fn buffered_len(&self) -> usize {
155        self.buffer.len()
156    }
157}
158
159#[derive(Debug, Default)]
160pub struct Multiplexer {
161    queues: BTreeMap<u32, VecDeque<Frame>>,
162}
163
164impl Multiplexer {
165    pub fn new() -> Self {
166        Self::default()
167    }
168
169    pub fn push(&mut self, frame: Frame) {
170        self.queues
171            .entry(frame.stream_id)
172            .or_default()
173            .push_back(frame);
174    }
175
176    pub fn pop_next(&mut self, stream_id: u32) -> Option<Frame> {
177        let queue = self.queues.get_mut(&stream_id)?;
178        let frame = queue.pop_front();
179        if queue.is_empty() {
180            self.queues.remove(&stream_id);
181        }
182        frame
183    }
184
185    pub fn stream_ids(&self) -> Vec<u32> {
186        self.queues.keys().copied().collect()
187    }
188
189    pub fn len_for_stream(&self, stream_id: u32) -> usize {
190        self.queues.get(&stream_id).map(VecDeque::len).unwrap_or(0)
191    }
192}
193
194pub fn frame_from_instruction(stream_id: u32, instruction: &Instruction) -> Option<Frame> {
195    match instruction {
196        Instruction::WriteChunk(bytes) => {
197            Some(Frame::new(stream_id, Opcode::VfsWriteChunk, bytes.clone()))
198        }
199        _ => None,
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn frame_round_trip_encode_decode() {
209        let frame = Frame::new(7, Opcode::PtySpawn, b"cargo test".to_vec());
210        let encoded = frame.encode();
211        let decoded = Frame::decode(&encoded).expect("frame should decode");
212
213        assert_eq!(decoded, frame);
214    }
215
216    #[test]
217    fn decode_rejects_invalid_magic() {
218        let bytes = b"BAD!\0\0\0\x01\x02\0\0\0\x00";
219        let err = Frame::decode(bytes).expect_err("decode should fail");
220
221        assert!(matches!(err, ProtocolError::InvalidMagic { .. }));
222    }
223
224    #[test]
225    fn frame_codec_handles_incremental_feeds() {
226        let frame = Frame::new(1, Opcode::VfsRead, b"payload".to_vec());
227        let bytes = frame.encode();
228
229        let mut codec = FrameCodec::new();
230        let first = codec
231            .push_bytes(&bytes[..5])
232            .expect("partial push should succeed");
233        assert!(first.is_empty());
234        assert_eq!(codec.buffered_len(), 5);
235
236        let second = codec
237            .push_bytes(&bytes[5..])
238            .expect("second push should succeed");
239        assert_eq!(second, vec![frame]);
240        assert_eq!(codec.buffered_len(), 0);
241    }
242
243    #[test]
244    fn multiplexer_isolates_stream_queues() {
245        let mut mux = Multiplexer::new();
246        mux.push(Frame::new(1, Opcode::VfsRead, vec![1]));
247        mux.push(Frame::new(2, Opcode::PtyInput, vec![2]));
248        mux.push(Frame::new(1, Opcode::VfsRename, vec![3]));
249
250        assert_eq!(mux.stream_ids(), vec![1, 2]);
251        assert_eq!(mux.len_for_stream(1), 2);
252
253        let first = mux.pop_next(1).expect("stream 1 frame");
254        assert_eq!(first.payload, vec![1]);
255        let second = mux.pop_next(1).expect("stream 1 second frame");
256        assert_eq!(second.payload, vec![3]);
257        assert_eq!(mux.len_for_stream(1), 0);
258        assert_eq!(mux.len_for_stream(2), 1);
259    }
260
261    #[test]
262    fn instruction_write_chunk_maps_to_opcode_0x02_frame() {
263        let instruction = Instruction::WriteChunk(b"hello".to_vec());
264        let frame = frame_from_instruction(9, &instruction).expect("should map");
265
266        assert_eq!(frame.stream_id, 9);
267        assert_eq!(frame.opcode as u8, 0x02);
268        assert_eq!(frame.payload, b"hello".to_vec());
269    }
270}