1use std::collections::{BTreeMap, VecDeque};
10
11use sa3p_parser::Instruction;
12use thiserror::Error;
13
14pub const MAGIC: [u8; 4] = *b"SA3P";
15const HEADER_LEN: usize = 13;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum Opcode {
20 VfsRead = 0x01,
21 VfsWriteChunk = 0x02,
22 VfsRename = 0x03,
23 PtySpawn = 0x04,
24 PtyInput = 0x05,
25}
26
27impl Opcode {
28 pub fn from_byte(byte: u8) -> Option<Self> {
29 match byte {
30 0x01 => Some(Self::VfsRead),
31 0x02 => Some(Self::VfsWriteChunk),
32 0x03 => Some(Self::VfsRename),
33 0x04 => Some(Self::PtySpawn),
34 0x05 => Some(Self::PtyInput),
35 _ => None,
36 }
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct Frame {
42 pub stream_id: u32,
43 pub opcode: Opcode,
44 pub payload: Vec<u8>,
45}
46
47impl Frame {
48 pub fn new(stream_id: u32, opcode: Opcode, payload: Vec<u8>) -> Self {
49 Self {
50 stream_id,
51 opcode,
52 payload,
53 }
54 }
55
56 pub fn encode(&self) -> Vec<u8> {
57 let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len());
58 out.extend_from_slice(&MAGIC);
59 out.extend_from_slice(&self.stream_id.to_be_bytes());
60 out.push(self.opcode as u8);
61 out.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
62 out.extend_from_slice(&self.payload);
63 out
64 }
65
66 pub fn decode(bytes: &[u8]) -> Result<Self, ProtocolError> {
67 let (frame, consumed) = decode_frame(bytes)?;
68 if consumed != bytes.len() {
69 return Err(ProtocolError::TrailingBytes {
70 trailing: bytes.len() - consumed,
71 });
72 }
73 Ok(frame)
74 }
75}
76
77#[derive(Debug, Error, PartialEq, Eq)]
78pub enum ProtocolError {
79 #[error("frame too short: need at least {expected} bytes, got {actual}")]
80 FrameTooShort { expected: usize, actual: usize },
81 #[error("invalid magic: got {found:?}")]
82 InvalidMagic { found: [u8; 4] },
83 #[error("unknown opcode: 0x{0:02x}")]
84 UnknownOpcode(u8),
85 #[error("incomplete frame payload: expected {expected} bytes, got {actual}")]
86 IncompletePayload { expected: usize, actual: usize },
87 #[error("trailing bytes after frame: {trailing}")]
88 TrailingBytes { trailing: usize },
89}
90
91pub fn decode_frame(bytes: &[u8]) -> Result<(Frame, usize), ProtocolError> {
92 if bytes.len() < HEADER_LEN {
93 return Err(ProtocolError::FrameTooShort {
94 expected: HEADER_LEN,
95 actual: bytes.len(),
96 });
97 }
98
99 let found_magic = [bytes[0], bytes[1], bytes[2], bytes[3]];
100 if found_magic != MAGIC {
101 return Err(ProtocolError::InvalidMagic { found: found_magic });
102 }
103
104 let stream_id = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
105 let opcode_byte = bytes[8];
106 let opcode = Opcode::from_byte(opcode_byte).ok_or(ProtocolError::UnknownOpcode(opcode_byte))?;
107 let payload_len = u32::from_be_bytes([bytes[9], bytes[10], bytes[11], bytes[12]]) as usize;
108
109 let total_len = HEADER_LEN + payload_len;
110 if bytes.len() < total_len {
111 return Err(ProtocolError::IncompletePayload {
112 expected: payload_len,
113 actual: bytes.len().saturating_sub(HEADER_LEN),
114 });
115 }
116
117 let payload = bytes[HEADER_LEN..total_len].to_vec();
118 Ok((Frame::new(stream_id, opcode, payload), total_len))
119}
120
121#[derive(Debug, Default)]
122pub struct FrameCodec {
123 buffer: Vec<u8>,
124}
125
126impl FrameCodec {
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 pub fn push_bytes(&mut self, bytes: &[u8]) -> Result<Vec<Frame>, ProtocolError> {
132 self.buffer.extend_from_slice(bytes);
133 let mut frames = Vec::new();
134
135 loop {
136 if self.buffer.is_empty() {
137 break;
138 }
139
140 match decode_frame(&self.buffer) {
141 Ok((frame, consumed)) => {
142 self.buffer.drain(..consumed);
143 frames.push(frame);
144 }
145 Err(ProtocolError::FrameTooShort { .. })
146 | Err(ProtocolError::IncompletePayload { .. }) => break,
147 Err(err) => return Err(err),
148 }
149 }
150
151 Ok(frames)
152 }
153
154 pub fn buffered_len(&self) -> usize {
155 self.buffer.len()
156 }
157}
158
159#[derive(Debug, Default)]
160pub struct Multiplexer {
161 queues: BTreeMap<u32, VecDeque<Frame>>,
162}
163
164impl Multiplexer {
165 pub fn new() -> Self {
166 Self::default()
167 }
168
169 pub fn push(&mut self, frame: Frame) {
170 self.queues
171 .entry(frame.stream_id)
172 .or_default()
173 .push_back(frame);
174 }
175
176 pub fn pop_next(&mut self, stream_id: u32) -> Option<Frame> {
177 let queue = self.queues.get_mut(&stream_id)?;
178 let frame = queue.pop_front();
179 if queue.is_empty() {
180 self.queues.remove(&stream_id);
181 }
182 frame
183 }
184
185 pub fn stream_ids(&self) -> Vec<u32> {
186 self.queues.keys().copied().collect()
187 }
188
189 pub fn len_for_stream(&self, stream_id: u32) -> usize {
190 self.queues.get(&stream_id).map(VecDeque::len).unwrap_or(0)
191 }
192}
193
194pub fn frame_from_instruction(stream_id: u32, instruction: &Instruction) -> Option<Frame> {
195 match instruction {
196 Instruction::WriteChunk(bytes) => {
197 Some(Frame::new(stream_id, Opcode::VfsWriteChunk, bytes.clone()))
198 }
199 _ => None,
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn frame_round_trip_encode_decode() {
209 let frame = Frame::new(7, Opcode::PtySpawn, b"cargo test".to_vec());
210 let encoded = frame.encode();
211 let decoded = Frame::decode(&encoded).expect("frame should decode");
212
213 assert_eq!(decoded, frame);
214 }
215
216 #[test]
217 fn decode_rejects_invalid_magic() {
218 let bytes = b"BAD!\0\0\0\x01\x02\0\0\0\x00";
219 let err = Frame::decode(bytes).expect_err("decode should fail");
220
221 assert!(matches!(err, ProtocolError::InvalidMagic { .. }));
222 }
223
224 #[test]
225 fn frame_codec_handles_incremental_feeds() {
226 let frame = Frame::new(1, Opcode::VfsRead, b"payload".to_vec());
227 let bytes = frame.encode();
228
229 let mut codec = FrameCodec::new();
230 let first = codec
231 .push_bytes(&bytes[..5])
232 .expect("partial push should succeed");
233 assert!(first.is_empty());
234 assert_eq!(codec.buffered_len(), 5);
235
236 let second = codec
237 .push_bytes(&bytes[5..])
238 .expect("second push should succeed");
239 assert_eq!(second, vec![frame]);
240 assert_eq!(codec.buffered_len(), 0);
241 }
242
243 #[test]
244 fn multiplexer_isolates_stream_queues() {
245 let mut mux = Multiplexer::new();
246 mux.push(Frame::new(1, Opcode::VfsRead, vec![1]));
247 mux.push(Frame::new(2, Opcode::PtyInput, vec![2]));
248 mux.push(Frame::new(1, Opcode::VfsRename, vec![3]));
249
250 assert_eq!(mux.stream_ids(), vec![1, 2]);
251 assert_eq!(mux.len_for_stream(1), 2);
252
253 let first = mux.pop_next(1).expect("stream 1 frame");
254 assert_eq!(first.payload, vec![1]);
255 let second = mux.pop_next(1).expect("stream 1 second frame");
256 assert_eq!(second.payload, vec![3]);
257 assert_eq!(mux.len_for_stream(1), 0);
258 assert_eq!(mux.len_for_stream(2), 1);
259 }
260
261 #[test]
262 fn instruction_write_chunk_maps_to_opcode_0x02_frame() {
263 let instruction = Instruction::WriteChunk(b"hello".to_vec());
264 let frame = frame_from_instruction(9, &instruction).expect("should map");
265
266 assert_eq!(frame.stream_id, 9);
267 assert_eq!(frame.opcode as u8, 0x02);
268 assert_eq!(frame.payload, b"hello".to_vec());
269 }
270}