1use crate::error::DecodeError;
6use crate::limits::MAX_VARINT_BYTES;
7use crate::model::Id;
8
9#[derive(Debug, Clone)]
18pub struct Reader<'a> {
19 data: &'a [u8],
20 pos: usize,
21}
22
23impl<'a> Reader<'a> {
24 pub fn new(data: &'a [u8]) -> Self {
26 Self { data, pos: 0 }
27 }
28
29 pub fn position(&self) -> usize {
31 self.pos
32 }
33
34 pub fn remaining(&self) -> &'a [u8] {
36 &self.data[self.pos..]
37 }
38
39 pub fn remaining_len(&self) -> usize {
41 self.data.len() - self.pos
42 }
43
44 pub fn is_empty(&self) -> bool {
46 self.pos >= self.data.len()
47 }
48
49 #[inline]
51 pub fn read_byte(&mut self, context: &'static str) -> Result<u8, DecodeError> {
52 if self.pos >= self.data.len() {
53 return Err(DecodeError::UnexpectedEof { context });
54 }
55 let byte = self.data[self.pos];
56 self.pos += 1;
57 Ok(byte)
58 }
59
60 #[inline]
62 pub fn read_bytes(&mut self, n: usize, context: &'static str) -> Result<&'a [u8], DecodeError> {
63 if self.pos + n > self.data.len() {
64 return Err(DecodeError::UnexpectedEof { context });
65 }
66 let bytes = &self.data[self.pos..self.pos + n];
67 self.pos += n;
68 Ok(bytes)
69 }
70
71 #[inline]
73 pub fn read_id(&mut self, context: &'static str) -> Result<Id, DecodeError> {
74 let bytes = self.read_bytes(16, context)?;
75 Ok(bytes.try_into().unwrap())
77 }
78
79 #[inline]
81 pub fn read_varint(&mut self, context: &'static str) -> Result<u64, DecodeError> {
82 let mut result: u64 = 0;
83 let mut shift = 0;
84
85 for i in 0..MAX_VARINT_BYTES {
86 let byte = self.read_byte(context)?;
87 let value = (byte & 0x7F) as u64;
88
89 if shift >= 64 || (shift == 63 && value > 1) {
91 return Err(DecodeError::VarintOverflow);
92 }
93
94 result |= value << shift;
95
96 if byte & 0x80 == 0 {
97 return Ok(result);
98 }
99 shift += 7;
100
101 if i == MAX_VARINT_BYTES - 1 {
102 return Err(DecodeError::VarintTooLong);
103 }
104 }
105
106 Err(DecodeError::VarintTooLong)
107 }
108
109 pub fn read_signed_varint(&mut self, context: &'static str) -> Result<i64, DecodeError> {
111 let unsigned = self.read_varint(context)?;
112 Ok(zigzag_decode(unsigned))
113 }
114
115 #[inline]
117 pub fn read_string(
118 &mut self,
119 max_len: usize,
120 field: &'static str,
121 ) -> Result<String, DecodeError> {
122 let len = self.read_varint(field)? as usize;
123 if len > max_len {
124 return Err(DecodeError::LengthExceedsLimit {
125 field,
126 len,
127 max: max_len,
128 });
129 }
130 let bytes = self.read_bytes(len, field)?;
131 std::str::from_utf8(bytes)
133 .map(|s| s.to_string())
134 .map_err(|_| DecodeError::InvalidUtf8 { field })
135 }
136
137 #[inline]
139 pub fn read_str(
140 &mut self,
141 max_len: usize,
142 field: &'static str,
143 ) -> Result<&'a str, DecodeError> {
144 let len = self.read_varint(field)? as usize;
145 if len > max_len {
146 return Err(DecodeError::LengthExceedsLimit {
147 field,
148 len,
149 max: max_len,
150 });
151 }
152 let bytes = self.read_bytes(len, field)?;
153 std::str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8 { field })
154 }
155
156 pub fn read_bytes_prefixed(
158 &mut self,
159 max_len: usize,
160 field: &'static str,
161 ) -> Result<Vec<u8>, DecodeError> {
162 let len = self.read_varint(field)? as usize;
163 if len > max_len {
164 return Err(DecodeError::LengthExceedsLimit {
165 field,
166 len,
167 max: max_len,
168 });
169 }
170 let bytes = self.read_bytes(len, field)?;
171 Ok(bytes.to_vec())
172 }
173
174 #[inline]
176 pub fn read_f64(&mut self, context: &'static str) -> Result<f64, DecodeError> {
177 let bytes = self.read_bytes(8, context)?;
178 let value = f64::from_le_bytes(bytes.try_into().unwrap());
180 if value.is_nan() {
181 return Err(DecodeError::FloatIsNan);
182 }
183 Ok(value)
184 }
185
186 #[inline]
188 pub fn read_f64_unchecked(&mut self, context: &'static str) -> Result<f64, DecodeError> {
189 let bytes = self.read_bytes(8, context)?;
190 Ok(f64::from_le_bytes(bytes.try_into().unwrap()))
192 }
193
194 pub fn read_id_vec(
196 &mut self,
197 max_len: usize,
198 field: &'static str,
199 ) -> Result<Vec<Id>, DecodeError> {
200 let count = self.read_varint(field)? as usize;
201 if count > max_len {
202 return Err(DecodeError::LengthExceedsLimit {
203 field,
204 len: count,
205 max: max_len,
206 });
207 }
208 let mut ids = Vec::with_capacity(count);
209 for _ in 0..count {
210 ids.push(self.read_id(field)?);
211 }
212 Ok(ids)
213 }
214}
215
216#[derive(Debug, Clone, Default)]
222pub struct Writer {
223 buf: Vec<u8>,
224}
225
226impl Writer {
227 pub fn new() -> Self {
229 Self { buf: Vec::new() }
230 }
231
232 pub fn with_capacity(capacity: usize) -> Self {
234 Self {
235 buf: Vec::with_capacity(capacity),
236 }
237 }
238
239 pub fn into_bytes(self) -> Vec<u8> {
241 self.buf
242 }
243
244 pub fn as_bytes(&self) -> &[u8] {
246 &self.buf
247 }
248
249 pub fn len(&self) -> usize {
251 self.buf.len()
252 }
253
254 pub fn is_empty(&self) -> bool {
256 self.buf.is_empty()
257 }
258
259 #[inline]
261 pub fn write_byte(&mut self, byte: u8) {
262 self.buf.push(byte);
263 }
264
265 #[inline]
267 pub fn write_bytes(&mut self, bytes: &[u8]) {
268 self.buf.extend_from_slice(bytes);
269 }
270
271 #[inline]
273 pub fn write_id(&mut self, id: &Id) {
274 self.buf.extend_from_slice(id);
275 }
276
277 #[inline]
279 pub fn write_varint(&mut self, mut value: u64) {
280 let mut buf = [0u8; 10]; let mut len = 0;
283 loop {
284 let mut byte = (value & 0x7F) as u8;
285 value >>= 7;
286 if value != 0 {
287 byte |= 0x80;
288 }
289 buf[len] = byte;
290 len += 1;
291 if value == 0 {
292 break;
293 }
294 }
295 self.buf.extend_from_slice(&buf[..len]);
296 }
297
298 pub fn write_signed_varint(&mut self, value: i64) {
300 self.write_varint(zigzag_encode(value));
301 }
302
303 pub fn write_string(&mut self, s: &str) {
305 self.write_varint(s.len() as u64);
306 self.buf.extend_from_slice(s.as_bytes());
307 }
308
309 pub fn write_bytes_prefixed(&mut self, bytes: &[u8]) {
311 self.write_varint(bytes.len() as u64);
312 self.buf.extend_from_slice(bytes);
313 }
314
315 pub fn write_f64(&mut self, value: f64) {
317 self.buf.extend_from_slice(&value.to_le_bytes());
318 }
319
320 pub fn write_id_vec(&mut self, ids: &[Id]) {
322 self.write_varint(ids.len() as u64);
323 for id in ids {
324 self.write_id(id);
325 }
326 }
327}
328
329#[inline]
338pub fn zigzag_encode(n: i64) -> u64 {
339 ((n << 1) ^ (n >> 63)) as u64
340}
341
342#[inline]
344pub fn zigzag_decode(n: u64) -> i64 {
345 ((n >> 1) as i64) ^ (-((n & 1) as i64))
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_zigzag_roundtrip() {
354 for v in [0i64, 1, -1, 127, -128, i64::MAX, i64::MIN] {
355 assert_eq!(zigzag_decode(zigzag_encode(v)), v);
356 }
357 }
358
359 #[test]
360 fn test_zigzag_values() {
361 assert_eq!(zigzag_encode(0), 0);
362 assert_eq!(zigzag_encode(-1), 1);
363 assert_eq!(zigzag_encode(1), 2);
364 assert_eq!(zigzag_encode(-2), 3);
365 assert_eq!(zigzag_encode(2), 4);
366 }
367
368 #[test]
369 fn test_varint_roundtrip() {
370 let test_values = [0u64, 1, 127, 128, 255, 256, 16383, 16384, u64::MAX];
371
372 for v in test_values {
373 let mut writer = Writer::new();
374 writer.write_varint(v);
375
376 let mut reader = Reader::new(writer.as_bytes());
377 let decoded = reader.read_varint("test").unwrap();
378 assert_eq!(v, decoded, "failed for {}", v);
379 }
380 }
381
382 #[test]
383 fn test_signed_varint_roundtrip() {
384 let test_values = [0i64, 1, -1, 127, -128, i64::MAX, i64::MIN];
385
386 for v in test_values {
387 let mut writer = Writer::new();
388 writer.write_signed_varint(v);
389
390 let mut reader = Reader::new(writer.as_bytes());
391 let decoded = reader.read_signed_varint("test").unwrap();
392 assert_eq!(v, decoded, "failed for {}", v);
393 }
394 }
395
396 #[test]
397 fn test_string_roundtrip() {
398 let test_strings = ["", "hello", "hello world", "unicode: \u{1F600}"];
399
400 for s in test_strings {
401 let mut writer = Writer::new();
402 writer.write_string(s);
403
404 let mut reader = Reader::new(writer.as_bytes());
405 let decoded = reader.read_string(1000, "test").unwrap();
406 assert_eq!(s, decoded);
407 }
408 }
409
410 #[test]
411 fn test_id_roundtrip() {
412 let id = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
413
414 let mut writer = Writer::new();
415 writer.write_id(&id);
416
417 let mut reader = Reader::new(writer.as_bytes());
418 let decoded = reader.read_id("test").unwrap();
419 assert_eq!(id, decoded);
420 }
421
422 #[test]
423 fn test_f64_roundtrip() {
424 let test_values = [0.0, 1.0, -1.0, f64::INFINITY, f64::NEG_INFINITY, 3.14159];
425
426 for v in test_values {
427 let mut writer = Writer::new();
428 writer.write_f64(v);
429
430 let mut reader = Reader::new(writer.as_bytes());
431 let decoded = reader.read_f64("test").unwrap();
432 assert_eq!(v, decoded, "failed for {}", v);
433 }
434 }
435
436 #[test]
437 fn test_f64_nan_rejected() {
438 let mut writer = Writer::new();
439 writer.write_f64(f64::NAN);
440
441 let mut reader = Reader::new(writer.as_bytes());
442 let result = reader.read_f64("test");
443 assert!(matches!(result, Err(DecodeError::FloatIsNan)));
444 }
445
446 #[test]
447 fn test_varint_too_long() {
448 let data = [0x80u8; 11];
450 let mut reader = Reader::new(&data);
451 let result = reader.read_varint("test");
452 assert!(matches!(result, Err(DecodeError::VarintTooLong)));
453 }
454
455 #[test]
456 fn test_string_too_long() {
457 let mut writer = Writer::new();
458 writer.write_varint(1000); writer.write_bytes(&[0u8; 1000]);
460
461 let mut reader = Reader::new(writer.as_bytes());
462 let result = reader.read_string(100, "test"); assert!(matches!(
464 result,
465 Err(DecodeError::LengthExceedsLimit { max: 100, .. })
466 ));
467 }
468
469 #[test]
470 fn test_unexpected_eof() {
471 let data = [0u8; 5];
472 let mut reader = Reader::new(&data);
473 let result = reader.read_bytes(10, "test");
474 assert!(matches!(result, Err(DecodeError::UnexpectedEof { .. })));
475 }
476}