1use thiserror::Error;
10
11pub const MAX_NESTING_DEPTH: usize = 100;
16
17#[derive(Debug, Error)]
18pub enum Error {
19 #[error("truncated varint")]
20 TruncatedVarint,
21 #[error("varint exceeds 10 bytes")]
22 VarintTooLong,
23 #[error("truncated fixed32")]
24 TruncatedFixed32,
25 #[error("truncated fixed64")]
26 TruncatedFixed64,
27 #[error("truncated length-delimited")]
28 TruncatedLengthDelim,
29 #[error("invalid tag: field number 0 at offset {0}")]
30 InvalidTag(usize),
31 #[error("unknown wire type {0}")]
32 UnknownWireType(u8),
33 #[error("group wire types are not supported")]
34 GroupNotSupported,
35 #[error("invalid utf-8 in string field: {0}")]
36 InvalidUtf8(#[from] std::string::FromUtf8Error),
37 #[error("nested message exceeds buffer")]
38 NestedExceedsBuffer,
39 #[error("message overran (pos={pos}, end={end})")]
40 Overrun { pos: usize, end: usize },
41 #[error("nesting depth exceeds MaxNestingDepth ({0})")]
42 DepthExceeded(usize),
43}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u8)]
49pub enum WireType {
50 Varint = 0,
51 Fixed64 = 1,
52 LengthDelimited = 2,
53 StartGroup = 3,
54 EndGroup = 4,
55 Fixed32 = 5,
56}
57
58impl WireType {
59 pub fn from_u8(v: u8) -> Result<Self> {
60 match v {
61 0 => Ok(Self::Varint),
62 1 => Ok(Self::Fixed64),
63 2 => Ok(Self::LengthDelimited),
64 3 => Ok(Self::StartGroup),
65 4 => Ok(Self::EndGroup),
66 5 => Ok(Self::Fixed32),
67 _ => Err(Error::UnknownWireType(v)),
68 }
69 }
70}
71
72#[derive(Debug, Default)]
73pub struct Writer {
74 buf: Vec<u8>,
75}
76
77impl Writer {
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn with_capacity(cap: usize) -> Self {
83 Self {
84 buf: Vec::with_capacity(cap),
85 }
86 }
87
88 pub fn finish(self) -> Vec<u8> {
89 self.buf
90 }
91
92 pub fn len(&self) -> usize {
93 self.buf.len()
94 }
95
96 pub fn is_empty(&self) -> bool {
97 self.buf.is_empty()
98 }
99
100 pub fn raw(&mut self, b: &[u8]) {
102 self.buf.extend_from_slice(b);
103 }
104
105 pub fn varint(&mut self, mut v: u64) {
107 while v >= 0x80 {
108 self.buf.push(((v & 0x7f) as u8) | 0x80);
109 v >>= 7;
110 }
111 self.buf.push(v as u8);
112 }
113
114 pub fn varint_i32(&mut self, v: i32) {
117 self.varint(v as i64 as u64);
118 }
119
120 pub fn varint_i64(&mut self, v: i64) {
122 self.varint(v as u64);
123 }
124
125 pub fn zigzag32(&mut self, v: i32) {
127 let u = ((v << 1) ^ (v >> 31)) as u32;
128 self.varint(u as u64);
129 }
130
131 pub fn zigzag64(&mut self, v: i64) {
133 let u = ((v << 1) ^ (v >> 63)) as u64;
134 self.varint(u);
135 }
136
137 pub fn fixed32(&mut self, v: u32) {
139 self.buf.extend_from_slice(&v.to_le_bytes());
140 }
141
142 pub fn fixed64(&mut self, v: u64) {
144 self.buf.extend_from_slice(&v.to_le_bytes());
145 }
146
147 pub fn float(&mut self, v: f32) {
149 self.buf.extend_from_slice(&v.to_le_bytes());
150 }
151
152 pub fn double(&mut self, v: f64) {
154 self.buf.extend_from_slice(&v.to_le_bytes());
155 }
156
157 pub fn string(&mut self, v: &str) {
159 self.bytes(v.as_bytes());
160 }
161
162 pub fn bytes(&mut self, v: &[u8]) {
164 self.varint(v.len() as u64);
165 self.raw(v);
166 }
167
168 pub fn tag(&mut self, field_number: u32, wire_type: WireType) {
173 assert!(
174 (1..=0x1fff_ffff).contains(&field_number),
175 "field number out of range: {field_number}"
176 );
177 self.varint(((field_number as u64) << 3) | (wire_type as u64));
178 }
179}
180
181pub struct Reader<'a> {
182 pub(crate) data: &'a [u8],
183 pub pos: usize,
184 pub(crate) depth: usize,
189}
190
191impl<'a> Reader<'a> {
192 pub fn new(data: &'a [u8]) -> Self {
193 Self {
194 data,
195 pos: 0,
196 depth: 0,
197 }
198 }
199
200 pub fn data(&self) -> &'a [u8] {
201 self.data
202 }
203
204 pub fn eof(&self) -> bool {
205 self.pos >= self.data.len()
206 }
207
208 pub fn remaining(&self) -> usize {
209 self.data.len().saturating_sub(self.pos)
210 }
211
212 pub fn varint(&mut self) -> Result<u64> {
214 let mut result: u64 = 0;
215 let mut shift = 0u32;
216 for i in 0..10 {
217 if self.pos >= self.data.len() {
218 return Err(Error::TruncatedVarint);
219 }
220 let byte = self.data[self.pos];
221 self.pos += 1;
222 result |= ((byte & 0x7f) as u64) << shift;
223 if byte & 0x80 == 0 {
224 return Ok(result);
225 }
226 shift += 7;
227 if i == 9 {
228 return Err(Error::VarintTooLong);
229 }
230 }
231 Err(Error::VarintTooLong)
232 }
233
234 pub fn zigzag32(&mut self) -> Result<i32> {
236 let u = self.varint()? as u32;
237 Ok(((u >> 1) as i32) ^ -((u & 1) as i32))
238 }
239
240 pub fn zigzag64(&mut self) -> Result<i64> {
242 let u = self.varint()?;
243 Ok(((u >> 1) as i64) ^ -((u & 1) as i64))
244 }
245
246 pub fn fixed32(&mut self) -> Result<u32> {
247 if self.pos + 4 > self.data.len() {
248 return Err(Error::TruncatedFixed32);
249 }
250 let v = u32::from_le_bytes(self.data[self.pos..self.pos + 4].try_into().unwrap());
251 self.pos += 4;
252 Ok(v)
253 }
254
255 pub fn fixed64(&mut self) -> Result<u64> {
256 if self.pos + 8 > self.data.len() {
257 return Err(Error::TruncatedFixed64);
258 }
259 let v = u64::from_le_bytes(self.data[self.pos..self.pos + 8].try_into().unwrap());
260 self.pos += 8;
261 Ok(v)
262 }
263
264 pub fn float(&mut self) -> Result<f32> {
265 Ok(f32::from_bits(self.fixed32()?))
266 }
267
268 pub fn double(&mut self) -> Result<f64> {
269 Ok(f64::from_bits(self.fixed64()?))
270 }
271
272 pub fn bytes(&mut self) -> Result<Vec<u8>> {
274 Ok(self.bytes_view()?.to_vec())
275 }
276
277 pub fn bytes_view(&mut self) -> Result<&'a [u8]> {
285 let len = self.read_length()?;
286 let end = self.pos + len;
287 let view = &self.data[self.pos..end];
288 self.pos = end;
289 Ok(view)
290 }
291
292 fn read_length(&mut self) -> Result<usize> {
297 let len = self.varint()?;
298 let len = usize::try_from(len).map_err(|_| Error::TruncatedLengthDelim)?;
299 let end = self
300 .pos
301 .checked_add(len)
302 .ok_or(Error::TruncatedLengthDelim)?;
303 if end > self.data.len() {
304 return Err(Error::TruncatedLengthDelim);
305 }
306 Ok(len)
307 }
308
309 pub fn string(&mut self) -> Result<String> {
311 let bytes = self.bytes_view()?.to_vec();
312 Ok(String::from_utf8(bytes)?)
313 }
314
315 pub fn tag(&mut self) -> Result<(u32, WireType)> {
317 let t = self.varint()?;
318 let wire_type = WireType::from_u8((t & 0x7) as u8)?;
319 let field_number = (t >> 3) as u32;
320 if field_number == 0 {
321 return Err(Error::InvalidTag(self.pos));
322 }
323 Ok((field_number, wire_type))
324 }
325
326 pub fn skip(&mut self, wire_type: WireType) -> Result<()> {
328 match wire_type {
329 WireType::Varint => {
330 self.varint()?;
331 }
332 WireType::Fixed64 => {
333 if self.pos + 8 > self.data.len() {
334 return Err(Error::TruncatedFixed64);
335 }
336 self.pos += 8;
337 }
338 WireType::LengthDelimited => {
339 let len = self.read_length()?;
340 self.pos += len;
341 }
342 WireType::Fixed32 => {
343 if self.pos + 4 > self.data.len() {
344 return Err(Error::TruncatedFixed32);
345 }
346 self.pos += 4;
347 }
348 WireType::StartGroup | WireType::EndGroup => {
349 return Err(Error::GroupNotSupported);
350 }
351 }
352 Ok(())
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 fn round_trip_varint(v: u64) -> u64 {
361 let mut w = Writer::new();
362 w.varint(v);
363 let bytes = w.finish();
364 let mut r = Reader::new(&bytes);
365 let out = r.varint().unwrap();
366 assert!(r.eof());
367 out
368 }
369
370 #[test]
371 fn varint_encodes_zero_as_single_byte() {
372 let mut w = Writer::new();
373 w.varint(0);
374 assert_eq!(w.finish(), vec![0]);
375 }
376
377 #[test]
378 fn varint_round_trips_small_numbers() {
379 for v in [0u64, 1, 127, 128, 255, 256, 16383, 16384] {
380 assert_eq!(round_trip_varint(v), v);
381 }
382 }
383
384 #[test]
385 fn varint_round_trips_up_to_i64_max() {
386 let v = i64::MAX as u64;
387 assert_eq!(round_trip_varint(v), v);
388 }
389
390 #[test]
391 fn varint_round_trips_full_uint64_range() {
392 for v in [0u64, 1, 0x80, 0xff, 0xffff, 0xffff_ffff, u64::MAX] {
393 assert_eq!(round_trip_varint(v), v);
394 }
395 }
396
397 #[test]
398 fn varint_encodes_150_as_canonical_proto_example() {
399 let mut w = Writer::new();
400 w.varint(150);
401 assert_eq!(w.finish(), vec![0x96, 0x01]);
402 }
403
404 #[test]
405 fn zigzag32_matches_proto3_spec() {
406 let cases: &[(i32, u32)] = &[
407 (0, 0),
408 (-1, 1),
409 (1, 2),
410 (-2, 3),
411 (2147483647, 4294967294),
412 (-2147483648, 4294967295),
413 ];
414 for &(signed, encoded) in cases {
415 let mut w = Writer::new();
416 w.zigzag32(signed);
417 let bytes = w.finish();
418 let mut r = Reader::new(&bytes);
419 assert_eq!(r.varint().unwrap() as u32, encoded);
420
421 let mut r2 = Reader::new(&bytes);
422 assert_eq!(r2.zigzag32().unwrap(), signed);
423 }
424 }
425
426 #[test]
427 fn zigzag64_round_trips_boundary_values() {
428 for v in [0i64, -1, 1, -2, i64::MAX, i64::MIN] {
429 let mut w = Writer::new();
430 w.zigzag64(v);
431 let bytes = w.finish();
432 let mut r = Reader::new(&bytes);
433 assert_eq!(r.zigzag64().unwrap(), v);
434 }
435 }
436
437 #[test]
438 fn fixed32_round_trips() {
439 for v in [0u32, 1, 0x7fff_ffff, 0xffff_ffff] {
440 let mut w = Writer::new();
441 w.fixed32(v);
442 let bytes = w.finish();
443 let mut r = Reader::new(&bytes);
444 assert_eq!(r.fixed32().unwrap(), v);
445 }
446 }
447
448 #[test]
449 fn fixed64_round_trips_uint64() {
450 for v in [0u64, 1, 0xffff_ffff, u64::MAX] {
451 let mut w = Writer::new();
452 w.fixed64(v);
453 let bytes = w.finish();
454 let mut r = Reader::new(&bytes);
455 assert_eq!(r.fixed64().unwrap(), v);
456 }
457 }
458
459 #[test]
460 fn float_and_double_round_trip() {
461 let mut w = Writer::new();
462 w.float(2.5);
463 w.double(std::f64::consts::PI);
464 let bytes = w.finish();
465 let mut r = Reader::new(&bytes);
466 assert!((r.float().unwrap() - 2.5).abs() < 1e-5);
467 assert_eq!(r.double().unwrap(), std::f64::consts::PI);
468 }
469
470 #[test]
471 fn utf8_strings_round_trip() {
472 let mut w = Writer::new();
473 w.string("héllo, 世界");
474 let bytes = w.finish();
475 let mut r = Reader::new(&bytes);
476 assert_eq!(r.string().unwrap(), "héllo, 世界");
477 }
478
479 #[test]
480 fn bytes_round_trip() {
481 let mut w = Writer::new();
482 w.bytes(&[0xde, 0xad, 0xbe, 0xef]);
483 let bytes = w.finish();
484 let mut r = Reader::new(&bytes);
485 assert_eq!(r.bytes().unwrap(), vec![0xde, 0xad, 0xbe, 0xef]);
486 }
487
488 #[test]
489 fn tag_for_field_1_varint_is_0x08() {
490 let mut w = Writer::new();
491 w.tag(1, WireType::Varint);
492 assert_eq!(w.finish(), vec![0x08]);
493 }
494
495 #[test]
496 fn tag_decodes_back_to_field_number_and_wire_type() {
497 let mut w = Writer::new();
498 w.tag(15, WireType::LengthDelimited);
499 let bytes = w.finish();
500 let mut r = Reader::new(&bytes);
501 assert_eq!(r.tag().unwrap(), (15, WireType::LengthDelimited));
502 }
503
504 #[test]
505 fn skip_handles_each_wire_type() {
506 let mut w = Writer::new();
507 w.tag(1, WireType::Varint);
508 w.varint(150);
509 w.tag(2, WireType::Fixed32);
510 w.fixed32(0xdead_beef);
511 w.tag(3, WireType::Fixed64);
512 w.fixed64(0xdead_beef_cafe_babe);
513 w.tag(4, WireType::LengthDelimited);
514 w.string("skip me");
515 w.tag(5, WireType::Varint);
516 w.varint(7);
517
518 let bytes = w.finish();
519 let mut r = Reader::new(&bytes);
520 let mut keep5: Option<u64> = None;
521 while !r.eof() {
522 let (num, wt) = r.tag().unwrap();
523 if num == 5 {
524 keep5 = Some(r.varint().unwrap());
525 } else {
526 r.skip(wt).unwrap();
527 }
528 }
529 assert_eq!(keep5, Some(7));
530 }
531
532 #[test]
533 fn truncated_varint_is_rejected() {
534 let mut r = Reader::new(&[0x80]);
535 assert!(matches!(r.varint(), Err(Error::TruncatedVarint)));
536 }
537
538 #[test]
539 fn length_prefix_max_varint_does_not_overflow() {
540 let mut bytes = vec![0x0a];
545 bytes.extend_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]);
547 let mut r = Reader::new(&bytes);
548 let (_, _) = r.tag().unwrap();
549 assert!(matches!(r.bytes_view(), Err(Error::TruncatedLengthDelim)));
550 }
551
552 #[test]
553 fn length_prefix_overflow_during_skip_does_not_panic() {
554 let mut bytes = vec![0x0a];
555 bytes.extend_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]);
556 let mut r = Reader::new(&bytes);
557 let (_, wt) = r.tag().unwrap();
558 assert!(matches!(r.skip(wt), Err(Error::TruncatedLengthDelim)));
559 }
560}