mod helpers;
mod packed;
mod varint;
use std::cell::Cell;
use std::collections::HashMap;
use std::sync::Arc;
use prost_reflect::{Cardinality, ExtensionDescriptor, FieldDescriptor, Kind, MessageDescriptor};
use crate::helpers::{
decode_double, decode_fixed32, decode_fixed64, decode_float, decode_sfixed32, decode_sfixed64,
};
use crate::helpers::{
parse_varint, parse_wiretag, WiretagResult, WT_END_GROUP, WT_I32, WT_I64, WT_LEN,
WT_START_GROUP, WT_VARINT,
};
use crate::schema::ParsedSchema;
use crate::serialize::common::{
format_double_protoc, format_fixed32_protoc, format_fixed64_protoc, format_float_protoc,
format_sfixed32_protoc, format_sfixed64_protoc, format_wire_fixed32_protoc,
format_wire_fixed64_protoc,
};
use helpers::{
render_group_field, render_invalid, render_invalid_tag_type, render_len_field, render_scalar,
render_truncated_bytes, ScalarCtx,
};
use varint::{decode_varint_typed, render_varint_field, VarintKind};
const PROTOTEXT_MAGIC: &[u8] = b"#@ prototext:";
pub(super) enum FieldOrExt {
Field(FieldDescriptor),
Ext(ExtensionDescriptor),
}
impl FieldOrExt {
pub(super) fn kind(&self) -> Kind {
match self {
FieldOrExt::Field(f) => f.kind(),
FieldOrExt::Ext(e) => e.kind(),
}
}
pub(super) fn cardinality(&self) -> Cardinality {
match self {
FieldOrExt::Field(f) => f.cardinality(),
FieldOrExt::Ext(e) => e.cardinality(),
}
}
pub(super) fn is_group(&self) -> bool {
match self {
FieldOrExt::Field(f) => f.is_group(),
FieldOrExt::Ext(_) => false,
}
}
pub(super) fn is_packed(&self) -> bool {
match self {
FieldOrExt::Field(f) => f.is_packed(),
FieldOrExt::Ext(_) => false,
}
}
pub(super) fn display_name(&self) -> String {
match self {
FieldOrExt::Field(f) => f.name().to_owned(),
FieldOrExt::Ext(e) => format!("[{}]", e.full_name()),
}
}
#[allow(dead_code)]
pub(super) fn as_field(&self) -> Option<&FieldDescriptor> {
match self {
FieldOrExt::Field(f) => Some(f),
FieldOrExt::Ext(_) => None,
}
}
}
thread_local! {
pub(super) static CBL_START: Cell<usize> = const { Cell::new(0) };
pub(super) static ANNOTATIONS: Cell<bool> = const { Cell::new(false) };
pub(super) static INDENT_SIZE: Cell<usize> = const { Cell::new(2) };
pub(super) static LEVEL: Cell<usize> = const { Cell::new(0) };
}
pub(super) struct LevelGuard;
impl Drop for LevelGuard {
fn drop(&mut self) {
LEVEL.with(|l| l.set(l.get() - 1));
}
}
pub(super) fn enter_level() -> LevelGuard {
LEVEL.with(|l| l.set(l.get() + 1));
LevelGuard
}
pub fn is_prototext_text(data: &[u8]) -> bool {
data.starts_with(PROTOTEXT_MAGIC)
}
pub fn decode_and_render(
buf: &[u8],
schema: Option<&ParsedSchema>,
annotations: bool,
indent_size: usize,
) -> Vec<u8> {
let capacity = buf.len() * 8;
let mut out = Vec::with_capacity(capacity);
out.extend_from_slice(b"#@ prototext: protoc\n");
CBL_START.with(|c| c.set(out.len()));
ANNOTATIONS.with(|c| c.set(annotations));
INDENT_SIZE.with(|c| c.set(indent_size));
LEVEL.with(|c| c.set(0));
let all_descriptors: Option<HashMap<String, Arc<MessageDescriptor>>> =
schema.map(|s| build_descriptor_map(s));
let all_schemas = all_descriptors.as_ref();
let root_desc: Option<MessageDescriptor> = schema.and_then(|s| s.root_descriptor());
render_message(buf, 0, None, root_desc.as_ref(), all_schemas, &mut out);
#[cfg(debug_assertions)]
{
let actual = out.len();
if actual < capacity {
eprintln!(
"[render_text] truncate: input_len={} capacity={} actual={} ratio={:.2}x",
buf.len(),
capacity,
actual,
actual as f64 / buf.len().max(1) as f64
);
}
}
out
}
fn build_descriptor_map(schema: &ParsedSchema) -> HashMap<String, Arc<MessageDescriptor>> {
schema
.pool()
.all_messages()
.map(|msg| (msg.full_name().to_string(), Arc::new(msg)))
.collect()
}
pub(super) fn render_message(
buf: &[u8],
start: usize,
my_group: Option<u64>,
schema: Option<&MessageDescriptor>,
all_schemas: Option<&HashMap<String, Arc<MessageDescriptor>>>,
out: &mut Vec<u8>,
) -> (usize, Option<WiretagResult>) {
let buflen = buf.len();
let mut pos = start;
loop {
if pos == buflen {
return (pos, None);
}
let tag = parse_wiretag(buf, pos);
if let Some(ref wtag_gar) = tag.wtag_gar {
render_invalid_tag_type(wtag_gar, out);
return (buflen, None);
}
let field_number = tag.wfield.unwrap();
let wire_type = tag.wtype.unwrap();
let tag_ohb = tag.wfield_ohb;
let tag_oor = tag.wfield_oor.is_some();
pos = tag.next_pos;
let field_schema: Option<FieldOrExt> = schema.and_then(|s| {
if let Some(f) = s.get_field(field_number as u32) {
Some(FieldOrExt::Field(f))
} else {
s.get_extension(field_number as u32).map(FieldOrExt::Ext)
}
});
match wire_type {
WT_VARINT => {
let vr = parse_varint(buf, pos);
if let Some(ref varint_gar) = vr.varint_gar {
render_invalid(
field_number,
field_schema.as_ref(),
tag_ohb,
tag_oor,
"INVALID_VARINT",
varint_gar,
out,
);
return (buflen, None);
}
pos = vr.next_pos;
let val_ohb = vr.varint_ohb;
let val = vr.varint.unwrap();
let (content_kind, typed_val) = if let Some(ref fs) = field_schema {
decode_varint_typed(val, fs)
} else {
(VarintKind::Wire, val)
};
render_varint_field(
field_number,
field_schema.as_ref(),
tag_ohb,
tag_oor,
val_ohb,
content_kind,
typed_val,
out,
);
}
WT_I64 => {
if pos + 8 > buflen {
let raw = &buf[pos..];
render_invalid(
field_number,
field_schema.as_ref(),
tag_ohb,
tag_oor,
"INVALID_FIXED64",
raw,
out,
);
return (buflen, None);
}
let data = &buf[pos..pos + 8];
pos += 8;
let is_mismatch;
let mut nan_bits: Option<u64> = None;
let value_str = if let Some(ref fs) = field_schema {
match fs.kind() {
Kind::Double => {
is_mismatch = false;
let v = decode_double(data);
if v.is_nan() {
let bits = v.to_bits();
if bits != f64::NAN.to_bits() {
nan_bits = Some(bits);
}
}
format_double_protoc(v)
}
Kind::Fixed64 => {
is_mismatch = false;
format_fixed64_protoc(decode_fixed64(data))
}
Kind::Sfixed64 => {
is_mismatch = false;
format_sfixed64_protoc(decode_sfixed64(data))
}
_ => {
is_mismatch = true;
format_wire_fixed64_protoc(decode_fixed64(data))
} }
} else {
is_mismatch = false;
format_wire_fixed64_protoc(decode_fixed64(data)) };
render_scalar(
&ScalarCtx {
field_number,
field_schema: field_schema.as_ref(),
tag_ohb,
tag_oor,
len_ohb: None,
wire_type_name: "fixed64",
nan_bits,
},
&value_str,
is_mismatch,
out,
);
}
WT_LEN => {
let lr = parse_varint(buf, pos);
if let Some(ref varint_gar) = lr.varint_gar {
render_invalid(
field_number,
field_schema.as_ref(),
tag_ohb,
tag_oor,
"INVALID_LEN",
varint_gar,
out,
);
return (buflen, None);
}
let len_ohb = lr.varint_ohb;
pos = lr.next_pos;
let length = lr.varint.unwrap() as usize;
if pos + length > buflen {
let missing = (length - (buflen - pos)) as u64;
let raw = &buf[pos..];
render_truncated_bytes(
field_number,
tag_ohb,
tag_oor,
len_ohb,
missing,
raw,
out,
);
return (buflen, None);
}
let data = &buf[pos..pos + length];
pos += length;
render_len_field(
field_number,
field_schema.as_ref(),
all_schemas,
tag_ohb,
tag_oor,
len_ohb,
data,
out,
);
}
WT_START_GROUP => {
render_group_field(
buf,
&mut pos,
field_number,
field_schema.as_ref(),
all_schemas,
tag_ohb,
tag_oor,
out,
);
}
WT_END_GROUP => {
if my_group.is_none() {
let raw = &buf[pos..];
render_invalid(
field_number,
field_schema.as_ref(),
tag_ohb,
tag_oor,
"INVALID_GROUP_END",
raw,
out,
);
return (buflen, None);
}
return (pos, Some(tag));
}
WT_I32 => {
if pos + 4 > buflen {
let raw = &buf[pos..];
render_invalid(
field_number,
field_schema.as_ref(),
tag_ohb,
tag_oor,
"INVALID_FIXED32",
raw,
out,
);
return (buflen, None);
}
let data = &buf[pos..pos + 4];
pos += 4;
let is_mismatch;
let mut nan_bits: Option<u64> = None;
let value_str = if let Some(ref fs) = field_schema {
match fs.kind() {
Kind::Float => {
is_mismatch = false;
let v = decode_float(data);
if v.is_nan() {
let bits = v.to_bits();
if bits != f32::NAN.to_bits() {
nan_bits = Some(bits as u64);
}
}
format_float_protoc(v)
}
Kind::Fixed32 => {
is_mismatch = false;
format_fixed32_protoc(decode_fixed32(data))
}
Kind::Sfixed32 => {
is_mismatch = false;
format_sfixed32_protoc(decode_sfixed32(data))
}
_ => {
is_mismatch = true;
format_wire_fixed32_protoc(decode_fixed32(data))
} }
} else {
is_mismatch = false;
format_wire_fixed32_protoc(decode_fixed32(data)) };
render_scalar(
&ScalarCtx {
field_number,
field_schema: field_schema.as_ref(),
tag_ohb,
tag_oor,
len_ohb: None,
wire_type_name: "fixed32",
nan_bits,
},
&value_str,
is_mismatch,
out,
);
}
_ => unreachable!("wire type > 5 caught by parse_wiretag"),
}
}
}