#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum WireType {
Varint = 0,
Fixed64 = 1,
LengthDelimited = 2,
StartGroup = 3,
EndGroup = 4,
Fixed32 = 5,
}
impl WireType {
pub fn from_raw(value: u32) -> Option<WireType> {
match value {
0 => Some(WireType::Varint),
1 => Some(WireType::Fixed64),
2 => Some(WireType::LengthDelimited),
3 => Some(WireType::StartGroup),
4 => Some(WireType::EndGroup),
5 => Some(WireType::Fixed32),
_ => None,
}
}
}
#[cfg(test)]
#[inline]
pub(crate) fn make_tag(field_number: u32, wire_type: WireType) -> u32 {
(field_number << 3) | (wire_type as u32)
}
#[inline]
pub fn tag_wire_type(tag: u32) -> Option<WireType> {
WireType::from_raw(tag & 7)
}
#[inline]
pub fn tag_field_number(tag: u32) -> u32 {
tag >> 3
}
pub(crate) fn is_valid_proto_tag(tag: u32) -> bool {
if tag < 8 {
return false;
}
tag_wire_type(tag).is_some()
}
const MAX_VARINT32_LEN: usize = 5;
const MAX_VARINT64_LEN: usize = 10;
fn read_canonical_varint32(data: &[u8], pos: usize) -> Option<(u32, usize)> {
let remaining = &data[pos..];
if remaining.is_empty() {
return None;
}
let mut result: u32 = 0;
for i in 0..MAX_VARINT32_LEN {
if i >= remaining.len() {
return None;
}
let byte = remaining[i];
let low7 = (byte & 0x7F) as u32;
if i == 4 && low7 > 0x0F {
return None;
}
result |= low7 << (7 * i);
if byte < 0x80 {
if i > 0 && byte == 0 {
return None;
}
return Some((result, i + 1));
}
}
None
}
fn skip_canonical_varint64(data: &[u8], pos: usize) -> Option<usize> {
let remaining = &data[pos..];
if remaining.is_empty() {
return None;
}
for i in 0..MAX_VARINT64_LEN {
if i >= remaining.len() {
return None;
}
let byte = remaining[i];
if byte < 0x80 {
if i > 0 && byte == 0 {
return None;
}
return Some(i + 1);
}
}
None
}
pub fn is_proto_message(data: &[u8]) -> bool {
let mut pos: usize = 0;
let mut started_groups: Vec<u32> = Vec::new();
while pos < data.len() {
let (tag, consumed) = match read_canonical_varint32(data, pos) {
Some(v) => v,
None => return false,
};
pos += consumed;
let field_number = tag_field_number(tag);
if field_number == 0 {
return false;
}
let Some(wire_type) = tag_wire_type(tag) else {
return false;
};
match wire_type {
WireType::Varint => {
match skip_canonical_varint64(data, pos) {
Some(n) => pos += n,
None => return false,
}
}
WireType::Fixed32 => {
if pos + 4 > data.len() {
return false;
}
pos += 4;
}
WireType::Fixed64 => {
if pos + 8 > data.len() {
return false;
}
pos += 8;
}
WireType::LengthDelimited => {
let (length, consumed) = match read_canonical_varint32(data, pos) {
Some(v) => v,
None => return false,
};
pos += consumed;
if pos + (length as usize) > data.len() {
return false;
}
pos += length as usize;
}
WireType::StartGroup => {
started_groups.push(field_number);
}
WireType::EndGroup => {
if started_groups.is_empty() || *started_groups.last().unwrap() != field_number {
return false;
}
started_groups.pop();
}
}
}
started_groups.is_empty()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_make_tag_and_decompose() {
let tag = make_tag(1, WireType::Varint);
assert_eq!(tag_wire_type(tag), Some(WireType::Varint));
assert_eq!(tag_field_number(tag), 1);
let tag2 = make_tag(5, WireType::Fixed32);
assert_eq!(tag_wire_type(tag2), Some(WireType::Fixed32));
assert_eq!(tag_field_number(tag2), 5);
}
#[test]
fn test_tag_values() {
assert_eq!(make_tag(1, WireType::Varint), 0x08);
assert_eq!(make_tag(1, WireType::Fixed64), 0x09);
assert_eq!(make_tag(2, WireType::LengthDelimited), 0x12);
}
#[test]
fn test_wire_type_from_raw_invalid() {
assert_eq!(WireType::from_raw(6), None);
assert_eq!(WireType::from_raw(7), None);
assert_eq!(tag_wire_type(6), None); assert_eq!(tag_wire_type(7), None); }
#[test]
fn test_empty_is_valid() {
assert!(is_proto_message(b""));
}
#[test]
fn test_valid_proto_mixed_fields() {
let mut data = Vec::new();
data.push(0x08);
data.push(0x01);
data.push(0x09);
data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
data.push(0x12);
data.push(0x03);
data.extend_from_slice(b"abc");
assert!(is_proto_message(&data));
}
#[test]
fn test_overlong_varint_rejected() {
let data = [0x08, 0x80, 0x80, 0x00];
assert!(!is_proto_message(&data));
}
#[test]
fn test_overlong_varint_tag_rejected() {
let data = [0x80, 0x00];
assert!(!is_proto_message(&data));
}
#[test]
fn test_unclosed_start_group() {
let data = [0x0B];
assert!(!is_proto_message(&data));
}
#[test]
fn test_mismatched_end_group() {
let data = [0x0B, 0x14]; assert!(!is_proto_message(&data));
}
#[test]
fn test_matched_group() {
let data = [0x0B, 0x0C]; assert!(is_proto_message(&data));
}
#[test]
fn test_wire_type_6_rejected() {
let data = [0x0E];
assert!(!is_proto_message(&data));
}
#[test]
fn test_wire_type_7_rejected() {
let data = [0x0F];
assert!(!is_proto_message(&data));
}
#[test]
fn test_field_number_zero_rejected() {
let data = [0x00];
assert!(!is_proto_message(&data));
}
#[test]
fn test_truncated_fixed32() {
let data = [0x0D, 0x00, 0x00];
assert!(!is_proto_message(&data));
}
#[test]
fn test_truncated_fixed64() {
let data = [0x09, 0x00, 0x00, 0x00, 0x00];
assert!(!is_proto_message(&data));
}
#[test]
fn test_length_delimited_overflow() {
let data = [0x12, 0x64, 0x00, 0x00, 0x00];
assert!(!is_proto_message(&data));
}
#[test]
fn test_truncated_varint_at_end() {
let data = [0x88];
assert!(!is_proto_message(&data));
}
#[test]
fn test_valid_fixed32_field() {
let data = [0x0D, 0x01, 0x02, 0x03, 0x04];
assert!(is_proto_message(&data));
}
#[test]
fn test_too_long_varint32_tag() {
let data = [0x80, 0x80, 0x80, 0x80, 0x80, 0x01];
assert!(!is_proto_message(&data));
}
#[test]
fn test_end_group_without_start() {
let data = [0x0C]; assert!(!is_proto_message(&data));
}
#[test]
fn test_nested_groups() {
let data = [0x0B, 0x13, 0x14, 0x0C];
assert!(is_proto_message(&data));
}
#[test]
fn test_varint_max_length_valid() {
let mut data = vec![0x08]; data.extend_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]);
assert!(is_proto_message(&data));
}
#[test]
fn test_varint_11_bytes_rejected() {
let mut data = vec![0x08]; data.extend_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]);
assert!(!is_proto_message(&data));
}
#[test]
fn test_tag_round_trip_all_wire_types() {
let wire_types = [
WireType::Varint,
WireType::Fixed64,
WireType::LengthDelimited,
WireType::StartGroup,
WireType::EndGroup,
WireType::Fixed32,
];
for &wt in &wire_types {
for field in [1u32, 2, 127, 1000, 0x1FFFFFFF] {
let tag = make_tag(field, wt);
assert_eq!(tag_wire_type(tag), Some(wt));
assert_eq!(tag_field_number(tag), field);
}
}
}
#[test]
fn test_unclosed_group_rejected() {
assert!(!is_proto_message(&[0x0B]));
assert!(!is_proto_message(&[0x0B, 0x14]));
assert!(!is_proto_message(&[0x0B, 0x13, 0x14]));
assert!(!is_proto_message(&[0x0B, 0x0C, 0x0C]));
}
#[test]
fn test_invalid_wire_type_rejected() {
assert!(!is_proto_message(&[0x0E])); assert!(!is_proto_message(&[0x0F])); assert!(!is_proto_message(&[0xA6, 0x06])); assert!(!is_proto_message(&[0x08, 0x01, 0x0F])); }
#[test]
fn test_malformed_input_no_panic_truncated() {
let valid = [
0x08, 0x96, 0x01, 0x12, 0x03, b'a', b'b', b'c', 0x0D, 0x01, 0x02, 0x03, 0x04,
];
for i in 1..valid.len() {
let _ = is_proto_message(&valid[..i]);
}
}
#[test]
fn test_malformed_input_no_panic_single_bytes() {
for b in 0u8..=255 {
let _ = is_proto_message(&[b]);
}
}
#[test]
fn test_malformed_input_no_panic_random() {
let cases: &[&[u8]] = &[
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
&[0x00],
&[0x80],
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80],
&[
0x08, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
],
];
for case in cases {
let _ = is_proto_message(case);
}
}
#[test]
fn test_fixed64_field_valid() {
let mut data = vec![0x09];
data.extend_from_slice(&[0x00; 8]);
assert!(is_proto_message(&data));
}
#[test]
fn test_length_delimited_field_valid() {
assert!(is_proto_message(&[0x12, 0x03, b'a', b'b', b'c']));
}
#[test]
fn test_mixed_valid_fields_all_types() {
let mut data = Vec::new();
data.extend_from_slice(&[0x08, 0x96, 0x01]);
data.push(0x11);
data.extend_from_slice(&[0xFF; 8]);
data.push(0x1A);
data.push(0x00);
data.push(0x25);
data.extend_from_slice(&[0xAA; 4]);
data.push(0x2B); data.push(0x2C); assert!(is_proto_message(&data));
}
}