#![forbid(unsafe_code)]
use super::buf::{DecodeBuffer, EncodeBuffer};
use super::{Tag, WireError, WireType};
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct DecodeStats {
pub string_alloc_count: usize,
pub string_alloc_bytes: usize,
pub bytes_alloc_count: usize,
pub bytes_alloc_bytes: usize,
pub repeated_resize_count: usize,
pub repeated_element_count: usize,
pub total_alloc_bytes: usize,
pub total_alloc_count: usize,
}
impl DecodeStats {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn record_string(&mut self, bytes: usize) {
self.string_alloc_count += 1;
self.string_alloc_bytes += bytes;
self.total_alloc_count += 1;
self.total_alloc_bytes += bytes;
}
#[inline]
pub fn record_bytes(&mut self, bytes: usize) {
self.bytes_alloc_count += 1;
self.bytes_alloc_bytes += bytes;
self.total_alloc_count += 1;
self.total_alloc_bytes += bytes;
}
#[inline]
pub fn record_repeated_resize(&mut self, element_bytes: usize) {
self.repeated_resize_count += 1;
self.repeated_element_count += 1;
self.total_alloc_count += 1;
self.total_alloc_bytes += element_bytes;
}
#[inline]
pub fn record_repeated_element(&mut self, element_bytes: usize) {
self.repeated_element_count += 1;
self.total_alloc_bytes += element_bytes;
}
pub fn merge(&mut self, other: &DecodeStats) {
self.string_alloc_count += other.string_alloc_count;
self.string_alloc_bytes += other.string_alloc_bytes;
self.bytes_alloc_count += other.bytes_alloc_count;
self.bytes_alloc_bytes += other.bytes_alloc_bytes;
self.repeated_resize_count += other.repeated_resize_count;
self.repeated_element_count += other.repeated_element_count;
self.total_alloc_count += other.total_alloc_count;
self.total_alloc_bytes += other.total_alloc_bytes;
}
pub fn reset(&mut self) {
*self = Self::default();
}
pub fn is_zero(&self) -> bool {
self.total_alloc_count == 0 && self.total_alloc_bytes == 0
}
pub fn summary(&self) -> prost::alloc::string::String {
let total = self.total_alloc_count;
let sc = self.string_alloc_count;
let sb = self.string_alloc_bytes;
let bc = self.bytes_alloc_count;
let bb = self.bytes_alloc_bytes;
let rc = self.repeated_resize_count;
let re = self.repeated_element_count;
let tb = self.total_alloc_bytes;
prost::alloc::format!(
"allocs: {total} (strings={sc}/{sb}B, bytes={bc}/{bb}B, repeated={rc}/elem={re} total_bytes={tb}B)",
)
}
}
impl core::fmt::Display for DecodeStats {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.summary())
}
}
pub struct ProfiledDecodeBuffer<'buf, 'stats> {
inner: DecodeBuffer<'buf>,
stats: &'stats mut DecodeStats,
}
impl<'buf, 'stats> ProfiledDecodeBuffer<'buf, 'stats> {
pub fn new(bytes: &'buf [u8], stats: &'stats mut DecodeStats) -> Self {
Self {
inner: DecodeBuffer::new(bytes),
stats,
}
}
pub fn inner(&self) -> &DecodeBuffer<'buf> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut DecodeBuffer<'buf> {
&mut self.inner
}
pub fn stats(&self) -> &DecodeStats {
self.stats
}
pub fn read_tag(&mut self) -> Result<Tag, WireError> {
self.inner.read_tag()
}
pub fn read_varint(&mut self) -> Result<u64, WireError> {
self.inner.read_varint()
}
pub fn read_varint32(&mut self) -> Result<u32, WireError> {
self.inner.read_varint32()
}
pub fn read_varint_i64(&mut self) -> Result<i64, WireError> {
self.inner.read_varint_i64()
}
pub fn read_varint_i32(&mut self) -> Result<i32, WireError> {
self.inner.read_varint_i32()
}
pub fn read_bool(&mut self) -> Result<bool, WireError> {
self.inner.read_bool()
}
pub fn read_fixed32(&mut self) -> Result<u32, WireError> {
self.inner.read_fixed32()
}
pub fn read_fixed64(&mut self) -> Result<u64, WireError> {
self.inner.read_fixed64()
}
pub fn read_float(&mut self) -> Result<f32, WireError> {
self.inner.read_float()
}
pub fn read_double(&mut self) -> Result<f64, WireError> {
self.inner.read_double()
}
pub fn read_length_delimited(&mut self) -> Result<&'buf [u8], WireError> {
self.inner.read_length_delimited()
}
pub fn read_string(&mut self) -> Result<&'buf str, WireError> {
self.inner.read_string()
}
pub fn skip_field(&mut self, wire_type: WireType) -> Result<(), WireError> {
self.inner.skip_field(wire_type)
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn remaining(&self) -> usize {
self.inner.remaining()
}
#[inline]
pub fn record_string_alloc(&mut self, bytes: usize) {
self.stats.record_string(bytes);
}
#[inline]
pub fn record_bytes_alloc(&mut self, bytes: usize) {
self.stats.record_bytes(bytes);
}
#[inline]
pub fn record_repeated_resize(&mut self, element_bytes: usize) {
self.stats.record_repeated_resize(element_bytes);
}
#[inline]
pub fn record_repeated_element(&mut self, element_bytes: usize) {
self.stats.record_repeated_element(element_bytes);
}
}
#[derive(Debug, Clone)]
pub struct AllocReport {
pub stats: DecodeStats,
pub avg_bytes_per_alloc: usize,
pub heap_fraction_pct: u8,
pub wire_bytes_read: usize,
}
impl AllocReport {
pub fn from_stats(stats: DecodeStats, wire_bytes_read: usize) -> Self {
let avg_bytes_per_alloc = stats
.total_alloc_bytes
.checked_div(stats.total_alloc_count)
.unwrap_or(0);
let heap_fraction_pct = (stats.total_alloc_bytes * 100)
.checked_div(wire_bytes_read)
.map(|frac| frac.min(100) as u8)
.unwrap_or(0u8);
AllocReport {
stats,
avg_bytes_per_alloc,
heap_fraction_pct,
wire_bytes_read,
}
}
pub fn to_wire_bytes(&self) -> prost::alloc::vec::Vec<u8> {
use super::encode_varint;
let s = &self.stats;
let mut out = prost::alloc::vec::Vec::with_capacity(64);
out.push(0x08); encode_varint(s.string_alloc_count as u64, &mut out);
out.push(0x10); encode_varint(s.string_alloc_bytes as u64, &mut out);
out.push(0x18); encode_varint(s.bytes_alloc_count as u64, &mut out);
out.push(0x20); encode_varint(s.bytes_alloc_bytes as u64, &mut out);
out.push(0x28); encode_varint(s.repeated_resize_count as u64, &mut out);
out.push(0x30); encode_varint(s.repeated_element_count as u64, &mut out);
out.push(0x38); encode_varint(s.total_alloc_count as u64, &mut out);
out.push(0x40); encode_varint(s.total_alloc_bytes as u64, &mut out);
out.push(0x48); encode_varint(self.wire_bytes_read as u64, &mut out);
out
}
pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, WireError> {
let mut dec = DecodeBuffer::new(bytes);
let mut s = DecodeStats::new();
let mut wire_bytes_read = 0usize;
while !dec.is_empty() {
let tag = dec.read_tag()?;
match tag.field_number {
1 => s.string_alloc_count = dec.read_varint()? as usize,
2 => s.string_alloc_bytes = dec.read_varint()? as usize,
3 => s.bytes_alloc_count = dec.read_varint()? as usize,
4 => s.bytes_alloc_bytes = dec.read_varint()? as usize,
5 => s.repeated_resize_count = dec.read_varint()? as usize,
6 => s.repeated_element_count = dec.read_varint()? as usize,
7 => s.total_alloc_count = dec.read_varint()? as usize,
8 => s.total_alloc_bytes = dec.read_varint()? as usize,
9 => wire_bytes_read = dec.read_varint()? as usize,
_ => dec.skip_field(tag.wire_type)?,
}
}
Ok(Self::from_stats(s, wire_bytes_read))
}
}
#[derive(Debug, Clone)]
pub struct AllocBudget {
pub max_bytes: usize,
pub max_allocs: usize,
}
impl AllocBudget {
pub fn new(max_bytes: usize, max_allocs: usize) -> Self {
Self {
max_bytes,
max_allocs,
}
}
pub fn check(&self, stats: &DecodeStats) -> Result<(), BudgetExceeded> {
if stats.total_alloc_bytes > self.max_bytes {
return Err(BudgetExceeded::Bytes {
used: stats.total_alloc_bytes,
limit: self.max_bytes,
});
}
if stats.total_alloc_count > self.max_allocs {
return Err(BudgetExceeded::Count {
used: stats.total_alloc_count,
limit: self.max_allocs,
});
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BudgetExceeded {
Bytes {
used: usize,
limit: usize,
},
Count {
used: usize,
limit: usize,
},
}
impl core::fmt::Display for BudgetExceeded {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
BudgetExceeded::Bytes { used, limit } => {
write!(f, "allocation byte budget exceeded: {used} > {limit}")
}
BudgetExceeded::Count { used, limit } => {
write!(f, "allocation count budget exceeded: {used} > {limit}")
}
}
}
}
impl core::error::Error for BudgetExceeded {}
pub trait EncodeAllocProfile {
fn record_alloc(&self, stats: &mut DecodeStats);
}
impl EncodeAllocProfile for EncodeBuffer {
fn record_alloc(&self, stats: &mut DecodeStats) {
stats.record_bytes(self.len());
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wire::{EncodeBuffer, WireType};
fn make_payload() -> prost::alloc::vec::Vec<u8> {
let mut enc = EncodeBuffer::new();
enc.write_tag(1, WireType::Len).expect("tag1");
enc.write_string("hello");
enc.write_tag(2, WireType::Len).expect("tag2");
enc.write_string("world");
enc.write_tag(3, WireType::Varint).expect("tag3");
enc.write_varint(42);
enc.into_vec()
}
#[test]
fn decode_stats_default_is_zero() {
let s = DecodeStats::new();
assert!(s.is_zero());
}
#[test]
fn record_string_increments_counters() {
let mut s = DecodeStats::new();
s.record_string(10);
assert_eq!(s.string_alloc_count, 1);
assert_eq!(s.string_alloc_bytes, 10);
assert_eq!(s.total_alloc_count, 1);
assert_eq!(s.total_alloc_bytes, 10);
}
#[test]
fn record_bytes_increments_counters() {
let mut s = DecodeStats::new();
s.record_bytes(20);
assert_eq!(s.bytes_alloc_count, 1);
assert_eq!(s.bytes_alloc_bytes, 20);
assert_eq!(s.total_alloc_count, 1);
assert_eq!(s.total_alloc_bytes, 20);
}
#[test]
fn record_repeated_resize_increments_counters() {
let mut s = DecodeStats::new();
s.record_repeated_resize(8);
assert_eq!(s.repeated_resize_count, 1);
assert_eq!(s.repeated_element_count, 1);
assert_eq!(s.total_alloc_count, 1);
assert_eq!(s.total_alloc_bytes, 8);
}
#[test]
fn record_repeated_element_no_resize_no_count() {
let mut s = DecodeStats::new();
s.record_repeated_element(4);
assert_eq!(s.repeated_element_count, 1);
assert_eq!(s.repeated_resize_count, 0);
assert_eq!(s.total_alloc_count, 0);
assert_eq!(s.total_alloc_bytes, 4);
}
#[test]
fn merge_adds_counters() {
let mut a = DecodeStats::new();
a.record_string(5);
let mut b = DecodeStats::new();
b.record_bytes(10);
a.merge(&b);
assert_eq!(a.string_alloc_count, 1);
assert_eq!(a.bytes_alloc_count, 1);
assert_eq!(a.total_alloc_count, 2);
assert_eq!(a.total_alloc_bytes, 15);
}
#[test]
fn reset_clears_all() {
let mut s = DecodeStats::new();
s.record_string(100);
s.record_bytes(200);
s.reset();
assert!(s.is_zero());
}
#[test]
fn profiled_buffer_delegates_reads() {
let payload = make_payload();
let mut stats = DecodeStats::new();
let mut prof = ProfiledDecodeBuffer::new(&payload, &mut stats);
let t1 = prof.read_tag().expect("tag1");
assert_eq!(t1.field_number, 1);
let s1 = prof.read_string().expect("str1");
assert_eq!(s1, "hello");
prof.record_string_alloc(s1.len());
let t2 = prof.read_tag().expect("tag2");
assert_eq!(t2.field_number, 2);
let s2 = prof.read_string().expect("str2");
assert_eq!(s2, "world");
prof.record_string_alloc(s2.len());
let t3 = prof.read_tag().expect("tag3");
assert_eq!(t3.field_number, 3);
let v = prof.read_varint().expect("varint");
assert_eq!(v, 42);
assert!(prof.is_empty());
assert_eq!(stats.string_alloc_count, 2);
assert_eq!(stats.string_alloc_bytes, 10);
assert_eq!(stats.total_alloc_count, 2);
assert_eq!(stats.total_alloc_bytes, 10);
}
#[test]
fn alloc_report_avg_and_fraction() {
let mut s = DecodeStats::new();
s.record_string(100);
s.record_bytes(100);
let report = AllocReport::from_stats(s, 400);
assert_eq!(report.avg_bytes_per_alloc, 100);
assert_eq!(report.heap_fraction_pct, 50);
}
#[test]
fn alloc_report_zero_wire_bytes() {
let s = DecodeStats::new();
let report = AllocReport::from_stats(s, 0);
assert_eq!(report.heap_fraction_pct, 0);
assert_eq!(report.avg_bytes_per_alloc, 0);
}
#[test]
fn alloc_report_wire_round_trip() {
let mut s = DecodeStats::new();
s.record_string(30);
s.record_bytes(60);
s.record_repeated_resize(4);
s.record_repeated_element(4);
let report = AllocReport::from_stats(s, 200);
let wire = report.to_wire_bytes();
let decoded = AllocReport::from_wire_bytes(&wire).expect("decode");
assert_eq!(decoded.stats, report.stats);
assert_eq!(decoded.wire_bytes_read, 200);
assert_eq!(decoded.avg_bytes_per_alloc, report.avg_bytes_per_alloc);
}
#[test]
fn budget_ok_within_limits() {
let mut s = DecodeStats::new();
s.record_string(100);
let budget = AllocBudget::new(200, 10);
assert!(budget.check(&s).is_ok());
}
#[test]
fn budget_exceeded_bytes() {
let mut s = DecodeStats::new();
s.record_string(300);
let budget = AllocBudget::new(200, 10);
let err = budget.check(&s).unwrap_err();
assert!(matches!(err, BudgetExceeded::Bytes { .. }));
}
#[test]
fn budget_exceeded_count() {
let mut s = DecodeStats::new();
for _ in 0..5 {
s.record_string(1);
}
let budget = AllocBudget::new(10000, 3);
let err = budget.check(&s).unwrap_err();
assert!(matches!(err, BudgetExceeded::Count { .. }));
}
#[test]
fn encode_alloc_profile_records_len() {
let mut enc = EncodeBuffer::new();
enc.write_varint(42);
let len = enc.len();
let mut stats = DecodeStats::new();
enc.record_alloc(&mut stats);
assert_eq!(stats.bytes_alloc_bytes, len);
assert_eq!(stats.bytes_alloc_count, 1);
}
#[test]
fn decode_stats_summary_no_panic() {
let mut s = DecodeStats::new();
s.record_string(5);
let summary = s.summary();
assert!(!summary.is_empty());
let display_str = prost::alloc::format!("{s}");
assert!(!display_str.is_empty());
}
#[test]
fn budget_exceeded_display() {
let b = BudgetExceeded::Bytes {
used: 300,
limit: 200,
};
let s = prost::alloc::format!("{b}");
assert!(s.contains("300"));
assert!(s.contains("200"));
}
#[test]
fn profiled_buffer_record_bytes_alloc() {
let mut enc = EncodeBuffer::new();
enc.write_tag(1, WireType::Len).expect("tag");
enc.write_length_delimited(&[1, 2, 3]);
let payload = enc.into_vec();
let mut stats = DecodeStats::new();
let mut prof = ProfiledDecodeBuffer::new(&payload, &mut stats);
let t = prof.read_tag().expect("tag");
assert_eq!(t.field_number, 1);
let raw = prof.read_length_delimited().expect("bytes");
let owned = raw.to_vec();
assert_eq!(owned, [1, 2, 3]);
prof.record_bytes_alloc(owned.len());
assert_eq!(stats.bytes_alloc_count, 1);
assert_eq!(stats.bytes_alloc_bytes, 3);
}
#[test]
fn profiled_buffer_fixed_reads() {
let mut enc = EncodeBuffer::new();
enc.write_tag(1, WireType::I32).expect("tag");
enc.write_fixed32(0xDEAD);
enc.write_tag(2, WireType::I64).expect("tag");
enc.write_fixed64(0xCAFE_BABE);
let payload = enc.into_vec();
let mut stats = DecodeStats::new();
let mut prof = ProfiledDecodeBuffer::new(&payload, &mut stats);
let _t1 = prof.read_tag().expect("t1");
let _ = prof.read_fixed32().expect("f32");
let _t2 = prof.read_tag().expect("t2");
let _ = prof.read_fixed64().expect("f64");
assert!(stats.is_zero());
}
#[test]
fn profiled_buffer_remaining_and_position() {
let payload = make_payload();
let total = payload.len();
let mut stats = DecodeStats::new();
let prof = ProfiledDecodeBuffer::new(&payload, &mut stats);
assert_eq!(prof.remaining(), total);
assert!(!prof.is_empty());
}
}