1use super::itc::{Event, Id, Stamp};
28use std::error::Error;
29use std::fmt;
30
31const FORMAT_VERSION: u8 = 1;
32const VARINT_CONTINUATION_BIT: u8 = 0x80;
33const VARINT_PAYLOAD_MASK: u8 = 0x7f;
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum CodecError {
38 EmptyInput,
40 UnsupportedVersion(u8),
42 UnexpectedEof,
44 VarintOverflow,
46 InvalidLength,
48 TrailingBytes,
50 TrailingBits,
52}
53
54impl fmt::Display for CodecError {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 Self::EmptyInput => write!(f, "input is empty"),
58 Self::UnsupportedVersion(version) => {
59 write!(f, "unsupported compact ITC version: {version}")
60 }
61 Self::UnexpectedEof => write!(f, "unexpected end of input"),
62 Self::VarintOverflow => write!(f, "varint overflow"),
63 Self::InvalidLength => write!(f, "invalid length prefix"),
64 Self::TrailingBytes => write!(f, "trailing bytes after decode"),
65 Self::TrailingBits => write!(f, "trailing bits after decode"),
66 }
67 }
68}
69
70impl Error for CodecError {}
71
72impl Stamp {
73 #[must_use]
75 pub fn serialize_compact(&self) -> Vec<u8> {
76 let mut id_bits = BitWriter::default();
77 encode_id(&self.id, &mut id_bits);
78 let (id_bytes, id_bit_len) = id_bits.into_parts();
79
80 let mut event_bits = BitWriter::default();
81 let mut event_values = Vec::new();
82 encode_event(&self.event, &mut event_bits, &mut event_values);
83 let (event_bit_bytes, event_bit_len) = event_bits.into_parts();
84
85 let mut out = Vec::with_capacity(
86 1 + id_bytes.len() + event_bit_bytes.len() + event_values.len() + 12,
87 );
88 out.push(FORMAT_VERSION);
89 encode_usize_varint(id_bit_len, &mut out);
90 out.extend_from_slice(&id_bytes);
91 encode_usize_varint(event_bit_len, &mut out);
92 out.extend_from_slice(&event_bit_bytes);
93 encode_usize_varint(event_values.len(), &mut out);
94 out.extend_from_slice(&event_values);
95 out
96 }
97
98 pub fn deserialize_compact(input: &[u8]) -> Result<Self, CodecError> {
105 if input.is_empty() {
106 return Err(CodecError::EmptyInput);
107 }
108
109 let version = input[0];
110 if version != FORMAT_VERSION {
111 return Err(CodecError::UnsupportedVersion(version));
112 }
113
114 let mut cursor = 1usize;
115
116 let id_bit_len = decode_usize_varint(input, &mut cursor)?;
117 let id_byte_len = bytes_for_bits(id_bit_len);
118 let id_bytes = take_slice(input, &mut cursor, id_byte_len)?;
119
120 let event_bit_len = decode_usize_varint(input, &mut cursor)?;
121 let event_bit_byte_len = bytes_for_bits(event_bit_len);
122 let event_bits = take_slice(input, &mut cursor, event_bit_byte_len)?;
123
124 let event_values_len = decode_usize_varint(input, &mut cursor)?;
125 let event_values = take_slice(input, &mut cursor, event_values_len)?;
126
127 if cursor != input.len() {
128 return Err(CodecError::TrailingBytes);
129 }
130
131 let mut id_reader = BitReader::new(id_bytes, id_bit_len)?;
132 let id = decode_id(&mut id_reader)?;
133 if !id_reader.is_exhausted() {
134 return Err(CodecError::TrailingBits);
135 }
136
137 let mut event_reader = BitReader::new(event_bits, event_bit_len)?;
138 let mut event_cursor = 0usize;
139 let event = decode_event(&mut event_reader, event_values, &mut event_cursor)?;
140 if !event_reader.is_exhausted() {
141 return Err(CodecError::TrailingBits);
142 }
143 if event_cursor != event_values.len() {
144 return Err(CodecError::TrailingBytes);
145 }
146
147 Ok(Self::new(id, event).normalize())
148 }
149}
150
151#[derive(Default)]
152struct BitWriter {
153 bytes: Vec<u8>,
154 bit_len: usize,
155}
156
157impl BitWriter {
158 fn push_bit(&mut self, bit: bool) {
159 let byte_index = self.bit_len / 8;
160 let bit_offset = 7 - (self.bit_len % 8);
161
162 if byte_index == self.bytes.len() {
163 self.bytes.push(0);
164 }
165
166 if bit {
167 self.bytes[byte_index] |= 1u8 << bit_offset;
168 }
169
170 self.bit_len += 1;
171 }
172
173 fn into_parts(self) -> (Vec<u8>, usize) {
174 (self.bytes, self.bit_len)
175 }
176}
177
178struct BitReader<'a> {
179 bytes: &'a [u8],
180 bit_len: usize,
181 cursor: usize,
182}
183
184impl<'a> BitReader<'a> {
185 fn new(bytes: &'a [u8], bit_len: usize) -> Result<Self, CodecError> {
186 let total_bits = bytes
187 .len()
188 .checked_mul(8)
189 .ok_or(CodecError::InvalidLength)?;
190 if bit_len > total_bits {
191 return Err(CodecError::InvalidLength);
192 }
193
194 Ok(Self {
195 bytes,
196 bit_len,
197 cursor: 0,
198 })
199 }
200
201 fn read_bit(&mut self) -> Result<bool, CodecError> {
202 if self.cursor >= self.bit_len {
203 return Err(CodecError::UnexpectedEof);
204 }
205
206 let byte_index = self.cursor / 8;
207 let bit_offset = 7 - (self.cursor % 8);
208 let bit = ((self.bytes[byte_index] >> bit_offset) & 1u8) == 1u8;
209 self.cursor += 1;
210 Ok(bit)
211 }
212
213 const fn is_exhausted(&self) -> bool {
214 self.cursor == self.bit_len
215 }
216}
217
218const fn bytes_for_bits(bit_len: usize) -> usize {
219 bit_len.div_ceil(8)
220}
221
222fn take_slice<'a>(input: &'a [u8], cursor: &mut usize, len: usize) -> Result<&'a [u8], CodecError> {
223 let end = cursor.checked_add(len).ok_or(CodecError::InvalidLength)?;
224 if end > input.len() {
225 return Err(CodecError::UnexpectedEof);
226 }
227
228 let slice = &input[*cursor..end];
229 *cursor = end;
230 Ok(slice)
231}
232
233fn encode_id(id: &Id, out: &mut BitWriter) {
234 match id {
235 Id::Zero => {
236 out.push_bit(false);
237 out.push_bit(false);
238 }
239 Id::One => {
240 out.push_bit(false);
241 out.push_bit(true);
242 }
243 Id::Branch(left, right) => {
244 out.push_bit(true);
245 encode_id(left, out);
246 encode_id(right, out);
247 }
248 }
249}
250
251fn decode_id(bits: &mut BitReader<'_>) -> Result<Id, CodecError> {
252 let is_branch = bits.read_bit()?;
253 if !is_branch {
254 let leaf_is_one = bits.read_bit()?;
255 return Ok(if leaf_is_one { Id::one() } else { Id::zero() });
256 }
257
258 let left = decode_id(bits)?;
259 let right = decode_id(bits)?;
260 Ok(Id::branch(left, right))
261}
262
263fn encode_event(event: &Event, bit_out: &mut BitWriter, value_out: &mut Vec<u8>) {
264 match event {
265 Event::Leaf(value) => {
266 bit_out.push_bit(false);
267 encode_u32_varint(*value, value_out);
268 }
269 Event::Branch(base, left, right) => {
270 bit_out.push_bit(true);
271 encode_u32_varint(*base, value_out);
272 encode_event(left, bit_out, value_out);
273 encode_event(right, bit_out, value_out);
274 }
275 }
276}
277
278fn decode_event(
279 bits: &mut BitReader<'_>,
280 values: &[u8],
281 value_cursor: &mut usize,
282) -> Result<Event, CodecError> {
283 let is_branch = bits.read_bit()?;
284 let value = decode_u32_varint(values, value_cursor)?;
285
286 if !is_branch {
287 return Ok(Event::leaf(value));
288 }
289
290 let left = decode_event(bits, values, value_cursor)?;
291 let right = decode_event(bits, values, value_cursor)?;
292 Ok(Event::branch(value, left, right))
293}
294
295fn encode_usize_varint(mut value: usize, out: &mut Vec<u8>) {
296 loop {
297 let low_bits = value & usize::from(VARINT_PAYLOAD_MASK);
298 let [mut byte, ..] = low_bits.to_le_bytes();
299
300 value >>= 7;
301 if value != 0 {
302 byte |= VARINT_CONTINUATION_BIT;
303 out.push(byte);
304 continue;
305 }
306
307 out.push(byte);
308 break;
309 }
310}
311
312fn decode_usize_varint(input: &[u8], cursor: &mut usize) -> Result<usize, CodecError> {
313 let mut value = 0usize;
314 let mut shift = 0u32;
315
316 loop {
317 if *cursor >= input.len() {
318 return Err(CodecError::UnexpectedEof);
319 }
320
321 let byte = input[*cursor];
322 *cursor += 1;
323
324 let payload = usize::from(byte & VARINT_PAYLOAD_MASK);
325 let shifted = payload
326 .checked_shl(shift)
327 .ok_or(CodecError::VarintOverflow)?;
328 value = value
329 .checked_add(shifted)
330 .ok_or(CodecError::VarintOverflow)?;
331
332 if (byte & VARINT_CONTINUATION_BIT) == 0 {
333 return Ok(value);
334 }
335
336 shift = shift.checked_add(7).ok_or(CodecError::VarintOverflow)?;
337 if shift >= usize::BITS {
338 return Err(CodecError::VarintOverflow);
339 }
340 }
341}
342
343fn encode_u32_varint(mut value: u32, out: &mut Vec<u8>) {
344 loop {
345 let low_bits = value & u32::from(VARINT_PAYLOAD_MASK);
346 let [mut byte, ..] = low_bits.to_le_bytes();
347
348 value >>= 7;
349 if value != 0 {
350 byte |= VARINT_CONTINUATION_BIT;
351 out.push(byte);
352 continue;
353 }
354
355 out.push(byte);
356 break;
357 }
358}
359
360fn decode_u32_varint(input: &[u8], cursor: &mut usize) -> Result<u32, CodecError> {
361 let mut value = 0u32;
362 let mut shift = 0u32;
363
364 loop {
365 if *cursor >= input.len() {
366 return Err(CodecError::UnexpectedEof);
367 }
368
369 let byte = input[*cursor];
370 *cursor += 1;
371
372 let payload = u32::from(byte & VARINT_PAYLOAD_MASK);
373 let shifted = payload
374 .checked_shl(shift)
375 .ok_or(CodecError::VarintOverflow)?;
376 value = value
377 .checked_add(shifted)
378 .ok_or(CodecError::VarintOverflow)?;
379
380 if (byte & VARINT_CONTINUATION_BIT) == 0 {
381 return Ok(value);
382 }
383
384 shift = shift.checked_add(7).ok_or(CodecError::VarintOverflow)?;
385 if shift >= u32::BITS {
386 return Err(CodecError::VarintOverflow);
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use proptest::prelude::*;
395
396 #[test]
397 fn compact_roundtrip_seed_stamp() {
398 let stamp = Stamp::seed();
399 let bytes = stamp.serialize_compact();
400 let decoded = Stamp::deserialize_compact(&bytes);
401 assert_eq!(decoded, Ok(stamp));
402 }
403
404 #[test]
405 fn compact_roundtrip_complex_stamp() {
406 let stamp = sample_eight_agent_stamp();
407 let bytes = stamp.serialize_compact();
408 let decoded = Stamp::deserialize_compact(&bytes);
409 assert_eq!(decoded, Ok(stamp));
410 }
411
412 #[test]
413 fn compact_single_agent_size_stays_small() {
414 let stamp = Stamp::seed();
415 let bytes = stamp.serialize_compact();
416 assert!(
417 bytes.len() <= 20,
418 "single-agent compact stamp too large: {} bytes",
419 bytes.len()
420 );
421 }
422
423 #[test]
424 fn compact_eight_agent_size_stays_under_target() {
425 let stamp = sample_eight_agent_stamp();
426 let bytes = stamp.serialize_compact();
427 assert!(
428 bytes.len() <= 50,
429 "8-agent compact stamp too large: {} bytes",
430 bytes.len()
431 );
432 }
433
434 #[test]
435 fn rejects_unknown_version() {
436 let err = Stamp::deserialize_compact(&[99]);
437 assert_eq!(err, Err(CodecError::UnsupportedVersion(99)));
438 }
439
440 #[test]
441 fn rejects_trailing_bytes() {
442 let mut bytes = Stamp::seed().serialize_compact();
443 bytes.push(0);
444 let err = Stamp::deserialize_compact(&bytes);
445 assert_eq!(err, Err(CodecError::TrailingBytes));
446 }
447
448 proptest! {
449 #[test]
450 fn random_stamps_roundtrip(stamp in arb_stamp()) {
451 let bytes = stamp.serialize_compact();
452 let decoded = Stamp::deserialize_compact(&bytes);
453 prop_assert_eq!(decoded, Ok(stamp));
454 }
455 }
456
457 fn sample_eight_agent_stamp() -> Stamp {
458 let id = Id::branch(
459 Id::branch(
460 Id::branch(Id::one(), Id::zero()),
461 Id::branch(Id::zero(), Id::one()),
462 ),
463 Id::branch(
464 Id::branch(Id::one(), Id::zero()),
465 Id::branch(Id::zero(), Id::one()),
466 ),
467 );
468
469 let event = Event::branch(
470 1,
471 Event::branch(
472 0,
473 Event::branch(0, Event::leaf(3), Event::leaf(1)),
474 Event::branch(1, Event::leaf(2), Event::leaf(0)),
475 ),
476 Event::branch(
477 0,
478 Event::branch(2, Event::leaf(1), Event::leaf(0)),
479 Event::branch(0, Event::leaf(4), Event::leaf(2)),
480 ),
481 );
482
483 Stamp::new(id, event).normalize()
484 }
485
486 fn arb_stamp() -> impl Strategy<Value = Stamp> {
487 (arb_id(), arb_event()).prop_map(|(id, event)| Stamp::new(id, event).normalize())
488 }
489
490 fn arb_id() -> impl Strategy<Value = Id> {
491 let leaf = prop_oneof![Just(Id::zero()), Just(Id::one())];
492 leaf.prop_recursive(4, 64, 2, |inner| {
493 (inner.clone(), inner).prop_map(|(left, right)| Id::branch(left, right))
494 })
495 }
496
497 fn arb_event() -> impl Strategy<Value = Event> {
498 let leaf = (0u32..=25).prop_map(Event::leaf);
499 leaf.prop_recursive(4, 128, 2, |inner| {
500 (0u32..=10, inner.clone(), inner)
501 .prop_map(|(base, left, right)| Event::branch(base, left, right))
502 })
503 }
504}