1use crate::wire::{Error, Reader, Result, WireType, Writer, MAX_NESTING_DEPTH};
23
24pub trait Message: Sized + Default {
27 fn encode_to(&self, w: &mut Writer);
30
31 fn merge_field(
34 &mut self,
35 field_number: u32,
36 wire_type: WireType,
37 r: &mut Reader<'_>,
38 ) -> Result<()>;
39}
40
41pub fn marshal<M: Message>(value: &M) -> Vec<u8> {
43 let mut w = Writer::new();
44 value.encode_to(&mut w);
45 w.finish()
46}
47
48pub fn unmarshal<M: Message>(data: &[u8]) -> Result<M> {
50 let mut r = Reader::new(data);
51 let mut msg = M::default();
52 while !r.eof() {
53 let (num, wt) = r.tag()?;
54 msg.merge_field(num, wt, &mut r)?;
55 }
56 Ok(msg)
57}
58
59pub fn write_message<M: Message>(w: &mut Writer, field_number: u32, msg: &M) {
61 let mut inner = Writer::new();
62 msg.encode_to(&mut inner);
63 let bytes = inner.finish();
64 w.tag(field_number, WireType::LengthDelimited);
65 w.varint(bytes.len() as u64);
66 w.raw(&bytes);
67}
68
69pub fn read_message<M: Message>(r: &mut Reader<'_>) -> Result<M> {
81 let len = r.varint()?;
82 let len = usize::try_from(len).map_err(|_| Error::NestedExceedsBuffer)?;
83 let end = r.pos.checked_add(len).ok_or(Error::NestedExceedsBuffer)?;
84 if end > r.data().len() {
85 return Err(Error::NestedExceedsBuffer);
86 }
87 if r.depth >= MAX_NESTING_DEPTH {
88 return Err(Error::DepthExceeded(MAX_NESTING_DEPTH));
89 }
90 r.depth += 1;
91 let result = (|| -> Result<M> {
92 let mut msg = M::default();
93 while r.pos < end {
94 let (num, wt) = r.tag()?;
95 msg.merge_field(num, wt, r)?;
96 }
97 if r.pos != end {
98 return Err(Error::Overrun { pos: r.pos, end });
99 }
100 Ok(msg)
101 })();
102 r.depth -= 1;
103 result
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use std::collections::BTreeMap;
110
111 #[derive(Debug, Default, Clone, PartialEq)]
116 struct Inner {
117 name: String,
118 value: i32,
119 }
120
121 impl Message for Inner {
122 fn encode_to(&self, w: &mut Writer) {
123 if !self.name.is_empty() {
124 w.tag(1, WireType::LengthDelimited);
125 w.string(&self.name);
126 }
127 if self.value != 0 {
128 w.tag(2, WireType::Varint);
129 w.varint_i32(self.value);
130 }
131 }
132
133 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
134 match num {
135 1 => self.name = r.string()?,
136 2 => self.value = r.varint()? as i32,
137 _ => r.skip(wt)?,
138 }
139 Ok(())
140 }
141 }
142
143 #[derive(Debug, Default, Clone, PartialEq)]
144 struct Outer {
145 title: String,
146 count: u32,
147 score: f64,
148 active: bool,
149 data: Vec<u8>,
150 items: Vec<Inner>,
151 signed: i64,
152 small_f: f32,
153 }
154
155 impl Message for Outer {
156 fn encode_to(&self, w: &mut Writer) {
157 if !self.title.is_empty() {
158 w.tag(1, WireType::LengthDelimited);
159 w.string(&self.title);
160 }
161 if self.count != 0 {
162 w.tag(2, WireType::Varint);
163 w.varint(self.count as u64);
164 }
165 if self.score != 0.0 {
166 w.tag(3, WireType::Fixed64);
167 w.double(self.score);
168 }
169 if self.active {
170 w.tag(4, WireType::Varint);
171 w.varint(1);
172 }
173 if !self.data.is_empty() {
174 w.tag(5, WireType::LengthDelimited);
175 w.bytes(&self.data);
176 }
177 for item in &self.items {
178 write_message(w, 6, item);
179 }
180 if self.signed != 0 {
181 w.tag(8, WireType::Varint);
182 w.varint_i64(self.signed);
183 }
184 if self.small_f != 0.0 {
185 w.tag(9, WireType::Fixed32);
186 w.float(self.small_f);
187 }
188 }
189
190 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
191 match num {
192 1 => self.title = r.string()?,
193 2 => self.count = r.varint()? as u32,
194 3 => self.score = r.double()?,
195 4 => self.active = r.varint()? != 0,
196 5 => self.data = r.bytes()?,
197 6 => self.items.push(read_message(r)?),
198 8 => self.signed = r.varint()? as i64,
199 9 => self.small_f = r.float()?,
200 _ => r.skip(wt)?,
201 }
202 Ok(())
203 }
204 }
205
206 #[test]
207 fn populated_message_round_trip() {
208 let orig = Outer {
209 title: "hello".into(),
210 count: 42,
211 score: 3.125,
212 active: true,
213 data: vec![0xde, 0xad],
214 items: vec![
215 Inner {
216 name: "a".into(),
217 value: 1,
218 },
219 Inner {
220 name: "b".into(),
221 value: -7,
222 },
223 ],
224 signed: -12345,
225 small_f: 2.5,
226 };
227 let bytes = marshal(&orig);
228 let got: Outer = unmarshal(&bytes).unwrap();
229 assert_eq!(got, orig);
230 }
231
232 #[test]
233 fn all_zero_message_marshals_to_empty_bytes() {
234 let bytes = marshal(&Outer::default());
235 assert!(bytes.is_empty());
236 }
237
238 #[test]
239 fn empty_bytes_unmarshal_to_default() {
240 let got: Outer = unmarshal(&[]).unwrap();
241 assert_eq!(got, Outer::default());
242 }
243
244 #[test]
245 fn unknown_fields_are_skipped() {
246 #[derive(Debug, Default, PartialEq)]
247 struct Big {
248 a: String,
249 b: String,
250 c: String,
251 }
252 impl Message for Big {
253 fn encode_to(&self, w: &mut Writer) {
254 if !self.a.is_empty() {
255 w.tag(1, WireType::LengthDelimited);
256 w.string(&self.a);
257 }
258 if !self.b.is_empty() {
259 w.tag(2, WireType::LengthDelimited);
260 w.string(&self.b);
261 }
262 if !self.c.is_empty() {
263 w.tag(3, WireType::LengthDelimited);
264 w.string(&self.c);
265 }
266 }
267 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
268 match num {
269 1 => self.a = r.string()?,
270 2 => self.b = r.string()?,
271 3 => self.c = r.string()?,
272 _ => r.skip(wt)?,
273 }
274 Ok(())
275 }
276 }
277 #[derive(Debug, Default, PartialEq)]
278 struct Small {
279 a: String,
280 }
281 impl Message for Small {
282 fn encode_to(&self, w: &mut Writer) {
283 if !self.a.is_empty() {
284 w.tag(1, WireType::LengthDelimited);
285 w.string(&self.a);
286 }
287 }
288 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
289 match num {
290 1 => self.a = r.string()?,
291 _ => r.skip(wt)?,
292 }
293 Ok(())
294 }
295 }
296
297 let bytes = marshal(&Big {
298 a: "aa".into(),
299 b: "bb".into(),
300 c: "cc".into(),
301 });
302 let got: Small = unmarshal(&bytes).unwrap();
303 assert_eq!(got.a, "aa");
304 }
305
306 #[derive(Debug, Default, Clone, PartialEq)]
307 struct Wrap {
308 inner: Option<Inner>,
309 }
310
311 impl Message for Wrap {
312 fn encode_to(&self, w: &mut Writer) {
313 if let Some(ref i) = self.inner {
314 write_message(w, 1, i);
315 }
316 }
317 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
318 match num {
319 1 => self.inner = Some(read_message(r)?),
320 _ => r.skip(wt)?,
321 }
322 Ok(())
323 }
324 }
325
326 #[test]
327 fn singular_nested_none_omits_tag() {
328 let bytes = marshal(&Wrap { inner: None });
329 assert!(bytes.is_empty());
330 let got: Wrap = unmarshal(&bytes).unwrap();
331 assert!(got.inner.is_none());
332 }
333
334 #[test]
335 fn singular_nested_populated_round_trips() {
336 let bytes = marshal(&Wrap {
337 inner: Some(Inner {
338 name: "x".into(),
339 value: 9,
340 }),
341 });
342 let got: Wrap = unmarshal(&bytes).unwrap();
343 assert_eq!(
344 got.inner,
345 Some(Inner {
346 name: "x".into(),
347 value: 9
348 })
349 );
350 }
351
352 #[test]
353 fn singular_nested_empty_emits_zero_length_blob() {
354 let bytes = marshal(&Wrap {
356 inner: Some(Inner::default()),
357 });
358 assert_eq!(bytes, vec![0x0a, 0x00]);
359 let got: Wrap = unmarshal(&bytes).unwrap();
360 assert_eq!(got.inner, Some(Inner::default()));
361 }
362
363 #[derive(Debug, Default, Clone, PartialEq)]
364 struct WithStringMap {
365 meta: BTreeMap<String, String>,
366 }
367
368 impl Message for WithStringMap {
369 fn encode_to(&self, w: &mut Writer) {
370 for (k, v) in &self.meta {
371 let mut inner = Writer::new();
372 if !k.is_empty() {
373 inner.tag(1, WireType::LengthDelimited);
374 inner.string(k);
375 }
376 if !v.is_empty() {
377 inner.tag(2, WireType::LengthDelimited);
378 inner.string(v);
379 }
380 let bytes = inner.finish();
381 w.tag(1, WireType::LengthDelimited);
382 w.varint(bytes.len() as u64);
383 w.raw(&bytes);
384 }
385 }
386 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
387 match num {
388 1 => {
389 let len = r.varint()? as usize;
390 let end = r.pos + len;
391 let mut k = String::new();
392 let mut v = String::new();
393 while r.pos < end {
394 let (n, w) = r.tag()?;
395 match n {
396 1 => k = r.string()?,
397 2 => v = r.string()?,
398 _ => r.skip(w)?,
399 }
400 }
401 self.meta.insert(k, v);
402 }
403 _ => r.skip(wt)?,
404 }
405 Ok(())
406 }
407 }
408
409 #[test]
410 fn map_string_string_round_trips() {
411 let mut meta = BTreeMap::new();
412 meta.insert("a".into(), "1".into());
413 meta.insert("b".into(), "2".into());
414 meta.insert("key with space".into(), "v".into());
415 let bytes = marshal(&WithStringMap { meta: meta.clone() });
416 let got: WithStringMap = unmarshal(&bytes).unwrap();
417 assert_eq!(got.meta, meta);
418 }
419
420 #[test]
421 fn map_string_string_empty_produces_empty_bytes() {
422 let bytes = marshal(&WithStringMap::default());
423 assert!(bytes.is_empty());
424 let got: WithStringMap = unmarshal(&bytes).unwrap();
425 assert!(got.meta.is_empty());
426 }
427
428 #[derive(Debug, Default, Clone, PartialEq)]
429 struct WithIntMap {
430 codes: BTreeMap<i32, String>,
431 }
432
433 impl Message for WithIntMap {
434 fn encode_to(&self, w: &mut Writer) {
435 for (k, v) in &self.codes {
436 let mut inner = Writer::new();
437 if *k != 0 {
438 inner.tag(1, WireType::Varint);
439 inner.varint_i32(*k);
440 }
441 if !v.is_empty() {
442 inner.tag(2, WireType::LengthDelimited);
443 inner.string(v);
444 }
445 let bytes = inner.finish();
446 w.tag(1, WireType::LengthDelimited);
447 w.varint(bytes.len() as u64);
448 w.raw(&bytes);
449 }
450 }
451 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
452 match num {
453 1 => {
454 let len = r.varint()? as usize;
455 let end = r.pos + len;
456 let mut k: i32 = 0;
457 let mut v = String::new();
458 while r.pos < end {
459 let (n, w) = r.tag()?;
460 match n {
461 1 => k = r.varint()? as i32,
462 2 => v = r.string()?,
463 _ => r.skip(w)?,
464 }
465 }
466 self.codes.insert(k, v);
467 }
468 _ => r.skip(wt)?,
469 }
470 Ok(())
471 }
472 }
473
474 #[test]
475 fn map_int32_string_round_trips() {
476 let mut codes = BTreeMap::new();
477 codes.insert(404, "Not Found".into());
478 codes.insert(500, "Internal".into());
479 let bytes = marshal(&WithIntMap {
480 codes: codes.clone(),
481 });
482 let got: WithIntMap = unmarshal(&bytes).unwrap();
483 assert_eq!(got.codes, codes);
484 }
485
486 #[derive(Debug, Default, Clone, PartialEq)]
495 struct SignedI32 {
496 v: i32,
497 }
498 impl Message for SignedI32 {
499 fn encode_to(&self, w: &mut Writer) {
500 if self.v != 0 {
501 w.tag(1, WireType::Varint);
502 w.varint_i32(self.v);
503 }
504 }
505 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
506 match num {
507 1 => self.v = r.varint()? as i32,
508 _ => r.skip(wt)?,
509 }
510 Ok(())
511 }
512 }
513
514 #[test]
515 fn proto3_int32_negative_sign_extends_to_10_byte_varint() {
516 let bytes = marshal(&SignedI32 { v: -1 });
519 assert_eq!(
520 bytes,
521 vec![0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]
522 );
523 let got: SignedI32 = unmarshal(&bytes).unwrap();
524 assert_eq!(got.v, -1);
525 }
526
527 #[derive(Debug, Default, Clone, PartialEq)]
528 struct ZigzagI32 {
529 v: i32,
530 }
531 impl Message for ZigzagI32 {
532 fn encode_to(&self, w: &mut Writer) {
533 if self.v != 0 {
534 w.tag(1, WireType::Varint);
535 w.zigzag32(self.v);
536 }
537 }
538 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
539 match num {
540 1 => self.v = r.zigzag32()?,
541 _ => r.skip(wt)?,
542 }
543 Ok(())
544 }
545 }
546
547 #[derive(Default, Debug)]
557 struct Tree {
558 child: Option<Box<Tree>>,
559 }
560
561 impl Message for Tree {
562 fn encode_to(&self, w: &mut Writer) {
563 if let Some(c) = &self.child {
564 write_message(w, 1, c.as_ref());
565 }
566 }
567 fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
568 match num {
569 1 => self.child = Some(Box::new(read_message(r)?)),
570 _ => r.skip(wt)?,
571 }
572 Ok(())
573 }
574 }
575
576 fn build_tree_bytes(depth: usize) -> Vec<u8> {
578 let mut payload: Vec<u8> = Vec::new(); for _ in 0..depth {
580 let mut framed = Vec::new();
581 framed.push(0x0a); let mut len = payload.len() as u64;
583 while len >= 0x80 {
584 framed.push(((len & 0x7f) as u8) | 0x80);
585 len >>= 7;
586 }
587 framed.push(len as u8);
588 framed.extend_from_slice(&payload);
589 payload = framed;
590 }
591 payload
592 }
593
594 #[test]
595 fn deep_submessage_at_limit_is_accepted() {
596 let bytes = build_tree_bytes(100);
600 let _: Tree = unmarshal(&bytes).unwrap();
601 }
602
603 #[test]
604 fn deep_submessage_past_limit_returns_depth_exceeded() {
605 let bytes = build_tree_bytes(200);
607 let res: Result<Tree> = unmarshal(&bytes);
608 assert!(
609 matches!(res, Err(Error::DepthExceeded(100))),
610 "got {:?}",
611 res
612 );
613 }
614
615 #[test]
616 fn deep_submessage_at_extreme_depth_rejects_without_stack_overflow() {
617 let bytes = build_tree_bytes(100_000);
620 let res: Result<Tree> = unmarshal(&bytes);
621 assert!(
622 matches!(res, Err(Error::DepthExceeded(100))),
623 "got {:?}",
624 res
625 );
626 }
627
628 #[test]
629 fn sint32_zigzag_is_compact_for_negative_values() {
630 let bytes = marshal(&ZigzagI32 { v: -1 });
634 assert_eq!(bytes, vec![0x08, 0x01]);
635 let got: ZigzagI32 = unmarshal(&bytes).unwrap();
636 assert_eq!(got.v, -1);
637 }
638}