const DEFAULT_MAX_NESTING_DEPTH: u32 = 64;
#[derive(Debug, Clone)]
pub struct DepthLimitedDecoder {
max_depth: u32,
}
impl DepthLimitedDecoder {
pub fn new(max_depth: u32) -> Self {
Self { max_depth }
}
pub fn max_depth(&self) -> u32 {
self.max_depth
}
pub fn decode<M: prost::Message + Default>(&self, buf: &[u8]) -> Result<M, DepthLimitError> {
let measured_depth = measure_wire_depth(buf);
if measured_depth > self.max_depth {
tracing::warn!(
measured_depth = measured_depth,
max_depth = self.max_depth,
"Protobuf message nesting depth exceeded limit"
);
return Err(DepthLimitError::ExceededMaxDepth {
depth: measured_depth,
limit: self.max_depth,
});
}
M::decode(buf).map_err(DepthLimitError::DecodeError)
}
}
impl Default for DepthLimitedDecoder {
fn default() -> Self {
Self {
max_depth: DEFAULT_MAX_NESTING_DEPTH,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum DepthLimitError {
#[error("protobuf nesting depth {depth} exceeds limit of {limit}")]
ExceededMaxDepth {
depth: u32,
limit: u32,
},
#[error("protobuf decode error: {0}")]
DecodeError(#[from] prost::DecodeError),
}
fn measure_wire_depth(buf: &[u8]) -> u32 {
if buf.is_empty() {
return 0;
}
measure_depth_recursive(buf, 0)
}
fn measure_depth_recursive(buf: &[u8], current_depth: u32) -> u32 {
let mut max_depth = current_depth;
let mut pos = 0;
while pos < buf.len() {
let (tag, bytes_read) = match decode_varint(&buf[pos..]) {
Some(v) => v,
None => break,
};
pos += bytes_read;
let wire_type = (tag & 0x07) as u8;
match wire_type {
0 => {
match decode_varint(&buf[pos..]) {
Some((_, n)) => pos += n,
None => break,
}
}
1 => {
pos += 8;
if pos > buf.len() {
break;
}
}
2 => {
let (length, bytes_read) = match decode_varint(&buf[pos..]) {
Some(v) => v,
None => break,
};
pos += bytes_read;
let length = length as usize;
if pos + length > buf.len() {
break;
}
let sub_buf = &buf[pos..pos + length];
let sub_depth = measure_depth_recursive(sub_buf, current_depth + 1);
if sub_depth > max_depth {
max_depth = sub_depth;
}
pos += length;
}
3 => {
break;
}
4 => break,
5 => {
pos += 4;
if pos > buf.len() {
break;
}
}
_ => break,
}
}
max_depth
}
fn decode_varint(buf: &[u8]) -> Option<(u64, usize)> {
let mut value: u64 = 0;
let mut shift = 0u32;
for (i, &byte) in buf.iter().enumerate() {
if shift >= 64 {
return None;
}
value |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
return Some((value, i + 1));
}
shift += 7;
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use prost::Message;
use rstest::rstest;
#[rstest]
fn default_depth_limit_is_64() {
let decoder = DepthLimitedDecoder::default();
assert_eq!(decoder.max_depth(), DEFAULT_MAX_NESTING_DEPTH);
assert_eq!(decoder.max_depth(), 64);
}
#[rstest]
#[case(1)]
#[case(32)]
#[case(64)]
#[case(128)]
#[case(256)]
fn custom_depth_limit(#[case] limit: u32) {
let decoder = DepthLimitedDecoder::new(limit);
assert_eq!(decoder.max_depth(), limit);
}
#[rstest]
fn decode_empty_message_succeeds() {
let decoder = DepthLimitedDecoder::default();
let empty = crate::proto::common::Empty {};
let encoded = empty.encode_to_vec();
let result = decoder.decode::<crate::proto::common::Empty>(&encoded);
assert!(result.is_ok());
}
#[rstest]
fn decode_simple_message_succeeds() {
let decoder = DepthLimitedDecoder::default();
let timestamp = crate::proto::common::Timestamp {
seconds: 1_000_000,
nanos: 500_000,
};
let encoded = timestamp.encode_to_vec();
let result = decoder.decode::<crate::proto::common::Timestamp>(&encoded);
assert!(result.is_ok());
let decoded = result.unwrap();
assert_eq!(decoded.seconds, 1_000_000);
assert_eq!(decoded.nanos, 500_000);
}
#[rstest]
fn decode_nested_message_within_limit_succeeds() {
let decoder = DepthLimitedDecoder::new(10);
let event = crate::proto::graphql::SubscriptionEvent {
id: "test".to_string(),
event_type: "update".to_string(),
payload: Some(crate::proto::graphql::GraphQlResponse {
data: Some("{}".to_string()),
errors: vec![],
extensions: None,
}),
timestamp: Some(crate::proto::common::Timestamp {
seconds: 100,
nanos: 0,
}),
};
let encoded = event.encode_to_vec();
let result = decoder.decode::<crate::proto::graphql::SubscriptionEvent>(&encoded);
assert!(result.is_ok());
}
#[rstest]
fn decode_rejects_message_exceeding_depth_limit() {
let decoder = DepthLimitedDecoder::new(0);
let batch = crate::proto::common::BatchResult {
success_count: 1,
failure_count: 1,
errors: vec![crate::proto::common::Error {
code: "500".to_string(),
message: "fail".to_string(),
metadata: Default::default(),
}],
};
let encoded = batch.encode_to_vec();
let result = decoder.decode::<crate::proto::common::BatchResult>(&encoded);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, DepthLimitError::ExceededMaxDepth { .. }),
"Expected ExceededMaxDepth error, got: {err:?}"
);
}
#[rstest]
fn depth_limit_error_display_message() {
let error = DepthLimitError::ExceededMaxDepth {
depth: 100,
limit: 64,
};
let message = error.to_string();
assert_eq!(message, "protobuf nesting depth 100 exceeds limit of 64");
}
#[rstest]
fn measure_empty_buffer_returns_zero() {
let buf: &[u8] = &[];
let depth = measure_wire_depth(buf);
assert_eq!(depth, 0);
}
#[rstest]
fn decoder_clone_preserves_limit() {
let decoder = DepthLimitedDecoder::new(42);
let cloned = decoder.clone();
assert_eq!(cloned.max_depth(), 42);
}
}