use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use buffa::Message;
use buffa_descriptor::generated::descriptor::{FileDescriptorProto, FileDescriptorSet};
use buffa_descriptor::{DescriptorPool, PoolError};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ReflectionError {
#[error("failed to decode FileDescriptorSet: {0}")]
Decode(#[from] buffa::DecodeError),
#[error("invalid descriptor set: {0}")]
Pool(#[from] PoolError),
#[error("malformed FileDescriptorSet framing at byte {offset}")]
MalformedFraming {
offset: usize,
},
#[error("FileDescriptorProto at index {index} has no name")]
UnnamedFile {
index: usize,
},
#[error("FileDescriptorSet framing yields {framed} files but decoding yields {decoded}")]
CountMismatch {
framed: usize,
decoded: usize,
},
#[error("cannot add to a descriptor pool with outstanding references")]
SharedPool,
}
pub(crate) enum Answer {
Files(Vec<Vec<u8>>),
ExtensionNumbers {
base_type: String,
numbers: Vec<i32>,
},
Services(Vec<String>),
NotFound(String),
}
pub struct Reflector {
pool: Arc<DescriptorPool>,
response_bytes: HashMap<String, Vec<u8>>,
services_override: Option<Vec<String>>,
}
impl std::fmt::Debug for Reflector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Reflector")
.field("files", &self.pool.files().len())
.field("services", &self.service_names())
.finish_non_exhaustive()
}
}
impl Reflector {
pub fn from_descriptor_set_bytes(bytes: &[u8]) -> Result<Self, ReflectionError> {
let mut reflector = Self {
pool: Arc::new(DescriptorPool::default()),
response_bytes: HashMap::new(),
services_override: None,
};
reflector.add_descriptor_set_bytes(bytes)?;
Ok(reflector)
}
pub fn from_descriptor_pool(pool: Arc<DescriptorPool>) -> Result<Self, ReflectionError> {
let mut response_bytes = HashMap::with_capacity(pool.files().len());
for (index, fd) in pool.files().iter().enumerate() {
let name = fd
.name
.clone()
.ok_or(ReflectionError::UnnamedFile { index })?;
response_bytes
.entry(name)
.or_insert_with(|| fd.encode_to_vec());
}
Ok(Self {
pool,
response_bytes,
services_override: None,
})
}
pub fn add_descriptor_set_bytes(&mut self, bytes: &[u8]) -> Result<(), ReflectionError> {
let raw_files = split_descriptor_set(bytes)?;
let set = FileDescriptorSet::decode_from_slice(bytes)?;
if raw_files.len() != set.file.len() {
return Err(ReflectionError::CountMismatch {
framed: raw_files.len(),
decoded: set.file.len(),
});
}
let mut names = Vec::with_capacity(set.file.len());
for (index, fd) in set.file.iter().enumerate() {
names.push(
fd.name
.clone()
.ok_or(ReflectionError::UnnamedFile { index })?,
);
}
let pool = Arc::get_mut(&mut self.pool).ok_or(ReflectionError::SharedPool)?;
pool.add_file_descriptor_set(set)?;
for (name, raw) in names.into_iter().zip(raw_files) {
self.response_bytes
.entry(name)
.or_insert_with(|| raw.to_vec());
}
Ok(())
}
#[must_use]
pub fn with_services<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.services_override = Some(names.into_iter().map(Into::into).collect());
self
}
#[must_use]
pub fn service_names(&self) -> Vec<String> {
self.services_override.clone().unwrap_or_else(|| {
let mut names: Vec<String> = self
.pool
.services()
.iter()
.map(|svc| svc.full_name().to_owned())
.collect();
for own in self_descriptors().pool.services() {
if !names.iter().any(|name| name == own.full_name()) {
names.push(own.full_name().to_owned());
}
}
names
})
}
#[must_use]
pub fn pool(&self) -> &DescriptorPool {
&self.pool
}
pub(crate) fn file_by_filename(&self, name: &str) -> Answer {
for source in self.sources() {
if let Some(fd) = source.pool.file_by_name(name) {
return Answer::Files(source.closure(fd));
}
}
Answer::NotFound(format!("file {name:?} not found"))
}
pub(crate) fn file_containing_symbol(&self, symbol: &str) -> Answer {
for source in self.sources() {
if let Some(fd) = source.pool.file_containing_symbol(symbol) {
return Answer::Files(source.closure(fd));
}
}
Answer::NotFound(format!("symbol {symbol:?} not found"))
}
pub(crate) fn file_containing_extension(&self, containing_type: &str, number: i32) -> Answer {
let not_found = || {
Answer::NotFound(format!(
"extension {number} of type {containing_type:?} not found"
))
};
let Ok(number) = u32::try_from(number) else {
return not_found();
};
for source in self.sources() {
let Some(extendee) = source.pool.message_index(containing_type) else {
continue;
};
let Some(extension) = source.pool.extension_for(extendee, number) else {
return not_found();
};
return match source.pool.file_containing_symbol(extension.full_name()) {
Some(fd) => Answer::Files(source.closure(fd)),
None => not_found(),
};
}
not_found()
}
pub(crate) fn all_extension_numbers_of_type(&self, name: &str) -> Answer {
let normalized = name.strip_prefix('.').unwrap_or(name);
for source in self.sources() {
let Some(extendee) = source.pool.message_index(normalized) else {
continue;
};
let numbers = source
.pool
.extensions_of(extendee)
.filter_map(|ext| i32::try_from(ext.field().number()).ok())
.collect();
return Answer::ExtensionNumbers {
base_type: normalized.to_owned(),
numbers,
};
}
Answer::NotFound(format!("message {normalized:?} not found"))
}
pub(crate) fn list_services(&self) -> Answer {
Answer::Services(self.service_names())
}
fn sources(&self) -> [DescriptorSource<'_>; 2] {
let own = self_descriptors();
[
DescriptorSource {
pool: &self.pool,
response_bytes: &self.response_bytes,
},
DescriptorSource {
pool: &own.pool,
response_bytes: &own.response_bytes,
},
]
}
}
struct DescriptorSource<'a> {
pool: &'a DescriptorPool,
response_bytes: &'a HashMap<String, Vec<u8>>,
}
impl DescriptorSource<'_> {
fn closure(&self, fd: &FileDescriptorProto) -> Vec<Vec<u8>> {
let mut seen = HashSet::new();
let mut out = Vec::new();
let mut stack = vec![fd];
while let Some(fd) = stack.pop() {
let Some(name) = fd.name.as_deref() else {
continue;
};
if !seen.insert(name) {
continue;
}
if let Some(bytes) = self.response_bytes.get(name) {
out.push(bytes.clone());
}
stack.extend(
fd.dependency
.iter()
.filter_map(|dep| self.pool.file_by_name(dep)),
);
}
out
}
}
struct SelfDescriptors {
pool: DescriptorPool,
response_bytes: HashMap<String, Vec<u8>>,
}
fn self_descriptors() -> &'static SelfDescriptors {
static SELF: std::sync::OnceLock<SelfDescriptors> = std::sync::OnceLock::new();
SELF.get_or_init(|| {
let bytes = crate::FILE_DESCRIPTOR_SET;
let raw_files = split_descriptor_set(bytes).expect("embedded descriptor set is framed");
let set =
FileDescriptorSet::decode_from_slice(bytes).expect("embedded descriptor set decodes");
let response_bytes = set
.file
.iter()
.zip(&raw_files)
.filter_map(|(fd, raw)| Some((fd.name.clone()?, raw.to_vec())))
.collect();
let pool = DescriptorPool::new(set).expect("embedded descriptor set links");
SelfDescriptors {
pool,
response_bytes,
}
})
}
fn split_descriptor_set(bytes: &[u8]) -> Result<Vec<&[u8]>, ReflectionError> {
let mut files = Vec::new();
let mut pos = 0;
while pos < bytes.len() {
let tag_offset = pos;
let tag = read_varint(bytes, &mut pos)
.ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?;
let (field, wire_type) = (tag >> 3, tag & 0x7);
match wire_type {
0 => {
read_varint(bytes, &mut pos)
.ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?;
}
1 => pos += 8,
2 => {
let len = read_varint(bytes, &mut pos)
.ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?
as usize;
let end = pos
.checked_add(len)
.filter(|&end| end <= bytes.len())
.ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?;
if field == 1 {
files.push(&bytes[pos..end]);
}
pos = end;
}
5 => pos += 4,
_ => return Err(ReflectionError::MalformedFraming { offset: tag_offset }),
}
if pos > bytes.len() {
return Err(ReflectionError::MalformedFraming { offset: tag_offset });
}
}
Ok(files)
}
fn read_varint(bytes: &[u8], pos: &mut usize) -> Option<u64> {
let mut value = 0u64;
for shift in (0..64).step_by(7) {
let byte = *bytes.get(*pos)?;
*pos += 1;
value |= u64::from(byte & 0x7f) << shift;
if byte & 0x80 == 0 {
return Some(value);
}
}
None
}
#[cfg(test)]
mod tests {
use buffa_descriptor::generated::descriptor::field_descriptor_proto::{Label, Type};
use buffa_descriptor::generated::descriptor::{
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
MethodDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto,
};
use super::*;
const SELF_V1: &str = "grpc.reflection.v1.ServerReflection";
const SELF_V1ALPHA: &str = "grpc.reflection.v1alpha.ServerReflection";
fn test_set() -> FileDescriptorSet {
let base = FileDescriptorProto {
name: Some("acme/base.proto".into()),
package: Some("acme.base".into()),
message_type: vec![DescriptorProto {
name: Some("Shared".into()),
extension_range: vec![
buffa_descriptor::generated::descriptor::descriptor_proto::ExtensionRange {
start: Some(100),
end: Some(200),
..Default::default()
},
],
..Default::default()
}],
..Default::default()
};
let api = FileDescriptorProto {
name: Some("acme/api.proto".into()),
package: Some("acme.api".into()),
dependency: vec!["acme/base.proto".into()],
message_type: vec![DescriptorProto {
name: Some("Request".into()),
field: vec![FieldDescriptorProto {
name: Some("query".into()),
number: Some(1),
label: Some(Label::LABEL_OPTIONAL),
r#type: Some(Type::TYPE_STRING),
..Default::default()
}],
oneof_decl: vec![OneofDescriptorProto {
name: Some("variant".into()),
..Default::default()
}],
nested_type: vec![DescriptorProto {
name: Some("Inner".into()),
..Default::default()
}],
enum_type: vec![EnumDescriptorProto {
name: Some("Kind".into()),
value: vec![EnumValueDescriptorProto {
name: Some("KIND_UNSPECIFIED".into()),
number: Some(0),
..Default::default()
}],
..Default::default()
}],
..Default::default()
}],
enum_type: vec![EnumDescriptorProto {
name: Some("Code".into()),
value: vec![EnumValueDescriptorProto {
name: Some("CODE_OK".into()),
number: Some(0),
..Default::default()
}],
..Default::default()
}],
service: vec![ServiceDescriptorProto {
name: Some("Search".into()),
method: vec![MethodDescriptorProto {
name: Some("Query".into()),
input_type: Some(".acme.api.Request".into()),
output_type: Some(".acme.api.Request".into()),
..Default::default()
}],
..Default::default()
}],
extension: vec![FieldDescriptorProto {
name: Some("tag".into()),
number: Some(150),
label: Some(Label::LABEL_OPTIONAL),
r#type: Some(Type::TYPE_INT32),
extendee: Some(".acme.base.Shared".into()),
..Default::default()
}],
..Default::default()
};
FileDescriptorSet {
file: vec![base, api],
..Default::default()
}
}
fn test_reflector() -> Reflector {
Reflector::from_descriptor_set_bytes(&test_set().encode_to_vec()).unwrap()
}
fn files(answer: Answer) -> Vec<Vec<u8>> {
match answer {
Answer::Files(files) => files,
_ => panic!("expected Answer::Files"),
}
}
fn assert_not_found(answer: &Answer) {
assert!(matches!(answer, Answer::NotFound(_)));
}
#[test]
fn file_by_filename_returns_raw_bytes_and_closure() {
let set = test_set();
let reflector = test_reflector();
let got = files(reflector.file_by_filename("acme/api.proto"));
assert_eq!(got.len(), 2);
assert_eq!(got[0], set.file[1].encode_to_vec());
assert_eq!(got[1], set.file[0].encode_to_vec());
let got = files(reflector.file_by_filename("acme/base.proto"));
assert_eq!(got.len(), 1);
assert_not_found(&reflector.file_by_filename("nope.proto"));
}
#[test]
fn raw_bytes_survive_unknown_fields() {
let mut file = test_set().file[0].encode_to_vec();
let unknown = [0xc8, 0x83, 0x06, 0x01]; file.extend_from_slice(&unknown);
let mut set_bytes = vec![0x0a, u8::try_from(file.len()).unwrap()];
set_bytes.extend_from_slice(&file);
let reflector = Reflector::from_descriptor_set_bytes(&set_bytes).unwrap();
let got = files(reflector.file_by_filename("acme/base.proto"));
assert_eq!(got, vec![file]);
}
#[test]
fn symbol_lookup_covers_every_kind() {
let reflector = test_reflector();
for symbol in [
"acme.api.Request",
"acme.api.Request.query",
"acme.api.Request.variant",
"acme.api.Request.Inner",
"acme.api.Request.Kind",
"acme.api.Request.KIND_UNSPECIFIED", "acme.api.Code",
"acme.api.CODE_OK",
"acme.api.Search",
"acme.api.Search.Query",
"acme.api.tag",
".acme.api.Request", ] {
let got = files(reflector.file_containing_symbol(symbol));
assert_eq!(got.len(), 2, "symbol {symbol}");
}
assert_not_found(&reflector.file_containing_symbol("acme.api.Code.CODE_OK"));
assert_not_found(&reflector.file_containing_symbol("acme.api"));
assert_not_found(&reflector.file_containing_symbol("acme.api.Missing"));
}
#[test]
fn extension_queries() {
let reflector = test_reflector();
let got = files(reflector.file_containing_extension("acme.base.Shared", 150));
assert_eq!(got.len(), 2);
assert_not_found(&reflector.file_containing_extension("acme.base.Shared", 151));
assert_not_found(&reflector.file_containing_extension("acme.base.Shared", -1));
assert_not_found(&reflector.file_containing_extension("acme.api.Request", 150));
match reflector.all_extension_numbers_of_type("acme.base.Shared") {
Answer::ExtensionNumbers { base_type, numbers } => {
assert_eq!(base_type, "acme.base.Shared");
assert_eq!(numbers, vec![150]);
}
_ => panic!("expected extension numbers"),
}
match reflector.all_extension_numbers_of_type("acme.api.Request") {
Answer::ExtensionNumbers { numbers, .. } => assert!(numbers.is_empty()),
_ => panic!("expected extension numbers"),
}
assert_not_found(&reflector.all_extension_numbers_of_type("acme.Missing"));
assert_not_found(&reflector.all_extension_numbers_of_type("acme.api.Search"));
}
#[test]
fn list_services() {
match test_reflector().list_services() {
Answer::Services(names) => {
assert_eq!(names, vec!["acme.api.Search", SELF_V1, SELF_V1ALPHA]);
}
_ => panic!("expected services"),
}
}
#[test]
fn with_services_overrides_advertised_list_only() {
let reflector = test_reflector().with_services(["acme.api.Curated"]);
assert_eq!(reflector.service_names(), ["acme.api.Curated"]);
match reflector.list_services() {
Answer::Services(names) => assert_eq!(names, vec!["acme.api.Curated"]),
_ => panic!("expected services"),
}
let got = files(reflector.file_containing_symbol("acme.api.Search"));
assert_eq!(got.len(), 2);
}
#[test]
fn merging_sets_skips_duplicate_files() {
let mut reflector = test_reflector();
let second = FileDescriptorSet {
file: vec![
FileDescriptorProto {
name: Some("acme/base.proto".into()),
package: Some("acme.other".into()),
..Default::default()
},
FileDescriptorProto {
name: Some("acme/extra.proto".into()),
package: Some("acme.extra".into()),
service: vec![ServiceDescriptorProto {
name: Some("Extra".into()),
..Default::default()
}],
..Default::default()
},
],
..Default::default()
};
reflector
.add_descriptor_set_bytes(&second.encode_to_vec())
.unwrap();
assert!(matches!(
reflector.file_containing_symbol("acme.base.Shared"),
Answer::Files(_)
));
match reflector.list_services() {
Answer::Services(names) => {
assert_eq!(
names,
vec!["acme.api.Search", "acme.extra.Extra", SELF_V1, SELF_V1ALPHA]
);
}
_ => panic!("expected services"),
}
}
#[test]
fn from_descriptor_pool_serves_reencoded_files() {
let set = test_set();
let pool = Arc::new(DescriptorPool::new(set.clone()).unwrap());
let reflector = Reflector::from_descriptor_pool(Arc::clone(&pool)).unwrap();
let got = files(reflector.file_containing_symbol("acme.api.Search"));
assert_eq!(got.len(), 2);
let decoded = FileDescriptorProto::decode_from_slice(&got[0]).unwrap();
assert_eq!(decoded, set.file[1]);
match reflector.list_services() {
Answer::Services(names) => {
assert_eq!(names, vec!["acme.api.Search", SELF_V1, SELF_V1ALPHA]);
}
_ => panic!("expected services"),
}
let mut reflector = reflector;
let err = reflector
.add_descriptor_set_bytes(&FileDescriptorSet::default().encode_to_vec())
.unwrap_err();
assert!(matches!(err, ReflectionError::SharedPool));
}
#[test]
fn construction_errors() {
let err = Reflector::from_descriptor_set_bytes(&[0x0a, 0xff]).unwrap_err();
assert!(matches!(err, ReflectionError::MalformedFraming { .. }));
let set = FileDescriptorSet {
file: vec![FileDescriptorProto::default()],
..Default::default()
};
let err = Reflector::from_descriptor_set_bytes(&set.encode_to_vec()).unwrap_err();
assert!(matches!(err, ReflectionError::UnnamedFile { index: 0 }));
let reflector = Reflector::from_descriptor_set_bytes(&[]).unwrap();
assert_not_found(&reflector.file_by_filename("x.proto"));
match reflector.list_services() {
Answer::Services(names) => assert_eq!(names, vec![SELF_V1, SELF_V1ALPHA]),
_ => panic!("expected services"),
}
}
}