1use crate::error::WireError;
2use crate::varint::{decode_varint, encode_varint};
3
4#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
12pub struct BlockFlags(u8);
13
14impl BlockFlags {
15 pub const NONE: Self = Self(0);
16 pub const HAS_SUMMARY: Self = Self(0b0000_0001);
17 pub const COMPRESSED: Self = Self(0b0000_0010);
18 pub const IS_REFERENCE: Self = Self(0b0000_0100);
19
20 pub fn from_raw(raw: u8) -> Self {
21 Self(raw)
22 }
23
24 pub fn raw(self) -> u8 {
25 self.0
26 }
27
28 pub fn has_summary(self) -> bool {
29 self.0 & Self::HAS_SUMMARY.0 != 0
30 }
31
32 pub fn is_compressed(self) -> bool {
33 self.0 & Self::COMPRESSED.0 != 0
34 }
35
36 pub fn is_reference(self) -> bool {
37 self.0 & Self::IS_REFERENCE.0 != 0
38 }
39}
40
41pub mod block_type {
46 pub const CODE: u8 = 0x01;
47 pub const CONVERSATION: u8 = 0x02;
48 pub const FILE_TREE: u8 = 0x03;
49 pub const TOOL_RESULT: u8 = 0x04;
50 pub const DOCUMENT: u8 = 0x05;
51 pub const STRUCTURED_DATA: u8 = 0x06;
52 pub const DIFF: u8 = 0x07;
53 pub const ANNOTATION: u8 = 0x08;
54 pub const EMBEDDING_REF: u8 = 0x09;
55 pub const IMAGE: u8 = 0x0A;
56 pub const EXTENSION: u8 = 0xFE;
57 pub const END: u8 = 0xFF;
58}
59
60#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct BlockFrame {
72 pub block_type: u8,
74
75 pub flags: BlockFlags,
77
78 pub body: Vec<u8>,
80}
81
82const MAX_VARINT_LEN: usize = 10;
84
85impl BlockFrame {
86 pub fn write_to(&self, w: &mut impl std::io::Write) -> Result<usize, WireError> {
98 let mut bytes_written = 0;
99 let mut varint_buf = [0u8; MAX_VARINT_LEN];
100
101 let n = encode_varint(u64::from(self.block_type), &mut varint_buf);
103 w.write_all(&varint_buf[..n])?;
104 bytes_written += n;
105
106 w.write_all(&[self.flags.raw()])?;
108 bytes_written += 1;
109
110 let n = encode_varint(self.body.len() as u64, &mut varint_buf);
112 w.write_all(&varint_buf[..n])?;
113 bytes_written += n;
114
115 w.write_all(&self.body)?;
117 bytes_written += self.body.len();
118
119 Ok(bytes_written)
120 }
121
122 pub fn read_from(buf: &[u8]) -> Result<Option<(Self, usize)>, WireError> {
135 let mut cursor = 0;
136
137 let (block_type_raw, n) = decode_varint(
139 buf.get(cursor..)
140 .ok_or(WireError::UnexpectedEof { offset: cursor })?,
141 )?;
142 cursor += n;
143
144 if block_type_raw == u64::from(block_type::END) {
148 return Ok(None);
149 }
150
151 let block_type = u8::try_from(block_type_raw).map_err(|_| {
152 WireError::InvalidBlockType {
153 raw: block_type_raw,
154 }
155 })?;
156
157 let flags_byte = *buf
159 .get(cursor)
160 .ok_or(WireError::UnexpectedEof { offset: cursor })?;
161 cursor += 1;
162 let flags = BlockFlags::from_raw(flags_byte);
163
164 let (content_len, n) = decode_varint(
166 buf.get(cursor..)
167 .ok_or(WireError::UnexpectedEof { offset: cursor })?,
168 )?;
169 cursor += n;
170
171 let content_len_usize = usize::try_from(content_len).map_err(|_| {
173 WireError::UnexpectedEof {
174 offset: cursor, }
176 })?;
177
178 let body_end = match cursor.checked_add(content_len_usize) {
180 Some(end) => end,
181 None => {
182 return Err(WireError::UnexpectedEof {
183 offset: buf.len(),
184 })
185 }
186 };
187 if buf.len() < body_end {
188 return Err(WireError::UnexpectedEof { offset: buf.len() });
189 }
190 let body = buf[cursor..body_end].to_vec();
191 cursor = body_end;
192
193 Ok(Some((
194 Self {
195 block_type,
196 flags,
197 body,
198 },
199 cursor,
200 )))
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 fn write_frame(frame: &BlockFrame) -> Vec<u8> {
210 let mut buf = Vec::new();
211 frame.write_to(&mut buf).unwrap();
212 buf
213 }
214
215 #[test]
216 fn roundtrip_code_block() {
217 let frame = BlockFrame {
218 block_type: block_type::CODE,
219 flags: BlockFlags::NONE,
220 body: b"fn main() {}".to_vec(),
221 };
222 let bytes = write_frame(&frame);
223 let (parsed, consumed) = BlockFrame::read_from(&bytes).unwrap().unwrap();
224 assert_eq!(parsed, frame);
225 assert_eq!(consumed, bytes.len());
226 }
227
228 #[test]
229 fn roundtrip_with_flags() {
230 let frame = BlockFrame {
231 block_type: block_type::TOOL_RESULT,
232 flags: BlockFlags::from_raw(
233 BlockFlags::HAS_SUMMARY.raw() | BlockFlags::COMPRESSED.raw(),
234 ),
235 body: vec![0xDE, 0xAD, 0xBE, 0xEF],
236 };
237 let bytes = write_frame(&frame);
238 let (parsed, _) = BlockFrame::read_from(&bytes).unwrap().unwrap();
239 assert!(parsed.flags.has_summary());
240 assert!(parsed.flags.is_compressed());
241 assert!(!parsed.flags.is_reference());
242 assert_eq!(parsed.body, vec![0xDE, 0xAD, 0xBE, 0xEF]);
243 }
244
245 #[test]
246 fn roundtrip_empty_body() {
247 let frame = BlockFrame {
248 block_type: block_type::ANNOTATION,
249 flags: BlockFlags::NONE,
250 body: vec![],
251 };
252 let bytes = write_frame(&frame);
253 let (parsed, consumed) = BlockFrame::read_from(&bytes).unwrap().unwrap();
254 assert_eq!(parsed, frame);
255 assert_eq!(consumed, bytes.len());
256 }
257
258 #[test]
259 fn roundtrip_large_body() {
260 let frame = BlockFrame {
262 block_type: block_type::CODE,
263 flags: BlockFlags::NONE,
264 body: vec![0xAB; 10_000],
265 };
266 let bytes = write_frame(&frame);
267 let (parsed, consumed) = BlockFrame::read_from(&bytes).unwrap().unwrap();
268 assert_eq!(parsed.body.len(), 10_000);
269 assert_eq!(consumed, bytes.len());
270 }
271
272 #[test]
273 fn end_block_returns_none() {
274 let mut buf = Vec::new();
276 let mut varint_buf = [0u8; 10];
277 let n = encode_varint(u64::from(block_type::END), &mut varint_buf);
278 buf.extend_from_slice(&varint_buf[..n]);
279
280 let result = BlockFrame::read_from(&buf).unwrap();
281 assert!(result.is_none());
282 }
283
284 #[test]
285 fn read_truncated_body() {
286 let frame = BlockFrame {
288 block_type: block_type::CODE,
289 flags: BlockFlags::NONE,
290 body: vec![0xFF; 100],
291 };
292 let full_bytes = write_frame(&frame);
293
294 let truncated = &full_bytes[..full_bytes.len() - 95];
296 let result = BlockFrame::read_from(truncated);
297 assert!(matches!(result, Err(WireError::UnexpectedEof { .. })));
298 }
299
300 #[test]
301 fn multiple_frames_sequential() {
302 let frame1 = BlockFrame {
304 block_type: block_type::CODE,
305 flags: BlockFlags::NONE,
306 body: b"first".to_vec(),
307 };
308 let frame2 = BlockFrame {
309 block_type: block_type::CONVERSATION,
310 flags: BlockFlags::NONE,
311 body: b"second".to_vec(),
312 };
313
314 let mut buf = Vec::new();
315 frame1.write_to(&mut buf).unwrap();
316 frame2.write_to(&mut buf).unwrap();
317
318 let (parsed1, consumed1) = BlockFrame::read_from(&buf).unwrap().unwrap();
320 assert_eq!(parsed1, frame1);
321
322 let (parsed2, consumed2) = BlockFrame::read_from(&buf[consumed1..]).unwrap().unwrap();
324 assert_eq!(parsed2, frame2);
325 assert_eq!(consumed1 + consumed2, buf.len());
326 }
327
328 #[test]
329 fn all_block_types_roundtrip() {
330 let types = [
331 block_type::CODE,
332 block_type::CONVERSATION,
333 block_type::FILE_TREE,
334 block_type::TOOL_RESULT,
335 block_type::DOCUMENT,
336 block_type::STRUCTURED_DATA,
337 block_type::DIFF,
338 block_type::ANNOTATION,
339 block_type::EMBEDDING_REF,
340 block_type::IMAGE,
341 block_type::EXTENSION,
342 ];
343 for &bt in &types {
344 let frame = BlockFrame {
345 block_type: bt,
346 flags: BlockFlags::NONE,
347 body: vec![bt], };
349 let bytes = write_frame(&frame);
350 let (parsed, _) = BlockFrame::read_from(&bytes).unwrap().unwrap();
351 assert_eq!(parsed.block_type, bt, "failed for block type {bt:#04X}");
352 assert_eq!(parsed.body, vec![bt]);
353 }
354 }
355}