1use std::collections::{HashMap, HashSet};
25use std::sync::Arc;
26
27use buffa::Message;
28use buffa_descriptor::generated::descriptor::{FileDescriptorProto, FileDescriptorSet};
29use buffa_descriptor::{DescriptorPool, PoolError};
30
31#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum ReflectionError {
35 #[error("failed to decode FileDescriptorSet: {0}")]
37 Decode(#[from] buffa::DecodeError),
38 #[error("invalid descriptor set: {0}")]
41 Pool(#[from] PoolError),
42 #[error("malformed FileDescriptorSet framing at byte {offset}")]
46 MalformedFraming {
47 offset: usize,
49 },
50 #[error("FileDescriptorProto at index {index} has no name")]
53 UnnamedFile {
54 index: usize,
56 },
57 #[error("FileDescriptorSet framing yields {framed} files but decoding yields {decoded}")]
61 CountMismatch {
62 framed: usize,
64 decoded: usize,
66 },
67 #[error("cannot add to a descriptor pool with outstanding references")]
73 SharedPool,
74}
75
76pub(crate) enum Answer {
81 Files(Vec<Vec<u8>>),
84 ExtensionNumbers {
86 base_type: String,
87 numbers: Vec<i32>,
88 },
89 Services(Vec<String>),
91 NotFound(String),
93}
94
95pub struct Reflector {
119 pool: Arc<DescriptorPool>,
120 response_bytes: HashMap<String, Vec<u8>>,
124 services_override: Option<Vec<String>>,
127}
128
129impl std::fmt::Debug for Reflector {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("Reflector")
132 .field("files", &self.pool.files().len())
133 .field("services", &self.service_names())
134 .finish_non_exhaustive()
135 }
136}
137
138impl Reflector {
139 pub fn from_descriptor_set_bytes(bytes: &[u8]) -> Result<Self, ReflectionError> {
152 let mut reflector = Self {
153 pool: Arc::new(DescriptorPool::default()),
154 response_bytes: HashMap::new(),
155 services_override: None,
156 };
157 reflector.add_descriptor_set_bytes(bytes)?;
158 Ok(reflector)
159 }
160
161 pub fn from_descriptor_pool(pool: Arc<DescriptorPool>) -> Result<Self, ReflectionError> {
190 let mut response_bytes = HashMap::with_capacity(pool.files().len());
191 for (index, fd) in pool.files().iter().enumerate() {
192 let name = fd
193 .name
194 .clone()
195 .ok_or(ReflectionError::UnnamedFile { index })?;
196 response_bytes
197 .entry(name)
198 .or_insert_with(|| fd.encode_to_vec());
199 }
200 Ok(Self {
201 pool,
202 response_bytes,
203 services_override: None,
204 })
205 }
206
207 pub fn add_descriptor_set_bytes(&mut self, bytes: &[u8]) -> Result<(), ReflectionError> {
222 let raw_files = split_descriptor_set(bytes)?;
223 let set = FileDescriptorSet::decode_from_slice(bytes)?;
224 if raw_files.len() != set.file.len() {
228 return Err(ReflectionError::CountMismatch {
229 framed: raw_files.len(),
230 decoded: set.file.len(),
231 });
232 }
233 let mut names = Vec::with_capacity(set.file.len());
234 for (index, fd) in set.file.iter().enumerate() {
235 names.push(
236 fd.name
237 .clone()
238 .ok_or(ReflectionError::UnnamedFile { index })?,
239 );
240 }
241
242 let pool = Arc::get_mut(&mut self.pool).ok_or(ReflectionError::SharedPool)?;
243 pool.add_file_descriptor_set(set)?;
244
245 for (name, raw) in names.into_iter().zip(raw_files) {
246 self.response_bytes
247 .entry(name)
248 .or_insert_with(|| raw.to_vec());
249 }
250 Ok(())
251 }
252
253 #[must_use]
266 pub fn with_services<I, S>(mut self, names: I) -> Self
267 where
268 I: IntoIterator<Item = S>,
269 S: Into<String>,
270 {
271 self.services_override = Some(names.into_iter().map(Into::into).collect());
272 self
273 }
274
275 #[must_use]
281 pub fn service_names(&self) -> Vec<String> {
282 self.services_override.clone().unwrap_or_else(|| {
283 let mut names: Vec<String> = self
284 .pool
285 .services()
286 .iter()
287 .map(|svc| svc.full_name().to_owned())
288 .collect();
289 for own in self_descriptors().pool.services() {
290 if !names.iter().any(|name| name == own.full_name()) {
291 names.push(own.full_name().to_owned());
292 }
293 }
294 names
295 })
296 }
297
298 #[must_use]
301 pub fn pool(&self) -> &DescriptorPool {
302 &self.pool
303 }
304
305 pub(crate) fn file_by_filename(&self, name: &str) -> Answer {
315 for source in self.sources() {
316 if let Some(fd) = source.pool.file_by_name(name) {
317 return Answer::Files(source.closure(fd));
318 }
319 }
320 Answer::NotFound(format!("file {name:?} not found"))
321 }
322
323 pub(crate) fn file_containing_symbol(&self, symbol: &str) -> Answer {
324 for source in self.sources() {
325 if let Some(fd) = source.pool.file_containing_symbol(symbol) {
326 return Answer::Files(source.closure(fd));
327 }
328 }
329 Answer::NotFound(format!("symbol {symbol:?} not found"))
330 }
331
332 pub(crate) fn file_containing_extension(&self, containing_type: &str, number: i32) -> Answer {
333 let not_found = || {
334 Answer::NotFound(format!(
335 "extension {number} of type {containing_type:?} not found"
336 ))
337 };
338 let Ok(number) = u32::try_from(number) else {
339 return not_found();
340 };
341 for source in self.sources() {
342 let Some(extendee) = source.pool.message_index(containing_type) else {
343 continue;
344 };
345 let Some(extension) = source.pool.extension_for(extendee, number) else {
346 return not_found();
347 };
348 return match source.pool.file_containing_symbol(extension.full_name()) {
349 Some(fd) => Answer::Files(source.closure(fd)),
350 None => not_found(),
351 };
352 }
353 not_found()
354 }
355
356 pub(crate) fn all_extension_numbers_of_type(&self, name: &str) -> Answer {
357 let normalized = name.strip_prefix('.').unwrap_or(name);
358 for source in self.sources() {
359 let Some(extendee) = source.pool.message_index(normalized) else {
360 continue;
361 };
362 let numbers = source
365 .pool
366 .extensions_of(extendee)
367 .filter_map(|ext| i32::try_from(ext.field().number()).ok())
368 .collect();
369 return Answer::ExtensionNumbers {
370 base_type: normalized.to_owned(),
371 numbers,
372 };
373 }
374 Answer::NotFound(format!("message {normalized:?} not found"))
375 }
376
377 pub(crate) fn list_services(&self) -> Answer {
378 Answer::Services(self.service_names())
379 }
380
381 fn sources(&self) -> [DescriptorSource<'_>; 2] {
384 let own = self_descriptors();
385 [
386 DescriptorSource {
387 pool: &self.pool,
388 response_bytes: &self.response_bytes,
389 },
390 DescriptorSource {
391 pool: &own.pool,
392 response_bytes: &own.response_bytes,
393 },
394 ]
395 }
396}
397
398struct DescriptorSource<'a> {
401 pool: &'a DescriptorPool,
402 response_bytes: &'a HashMap<String, Vec<u8>>,
403}
404
405impl DescriptorSource<'_> {
406 fn closure(&self, fd: &FileDescriptorProto) -> Vec<Vec<u8>> {
411 let mut seen = HashSet::new();
412 let mut out = Vec::new();
413 let mut stack = vec![fd];
414 while let Some(fd) = stack.pop() {
415 let Some(name) = fd.name.as_deref() else {
416 continue;
417 };
418 if !seen.insert(name) {
419 continue;
420 }
421 if let Some(bytes) = self.response_bytes.get(name) {
422 out.push(bytes.clone());
423 }
424 stack.extend(
425 fd.dependency
426 .iter()
427 .filter_map(|dep| self.pool.file_by_name(dep)),
428 );
429 }
430 out
431 }
432}
433
434struct SelfDescriptors {
438 pool: DescriptorPool,
439 response_bytes: HashMap<String, Vec<u8>>,
440}
441
442fn self_descriptors() -> &'static SelfDescriptors {
443 static SELF: std::sync::OnceLock<SelfDescriptors> = std::sync::OnceLock::new();
444 SELF.get_or_init(|| {
445 let bytes = crate::FILE_DESCRIPTOR_SET;
448 let raw_files = split_descriptor_set(bytes).expect("embedded descriptor set is framed");
449 let set =
450 FileDescriptorSet::decode_from_slice(bytes).expect("embedded descriptor set decodes");
451 let response_bytes = set
452 .file
453 .iter()
454 .zip(&raw_files)
455 .filter_map(|(fd, raw)| Some((fd.name.clone()?, raw.to_vec())))
456 .collect();
457 let pool = DescriptorPool::new(set).expect("embedded descriptor set links");
458 SelfDescriptors {
459 pool,
460 response_bytes,
461 }
462 })
463}
464
465fn split_descriptor_set(bytes: &[u8]) -> Result<Vec<&[u8]>, ReflectionError> {
468 let mut files = Vec::new();
469 let mut pos = 0;
470 while pos < bytes.len() {
471 let tag_offset = pos;
472 let tag = read_varint(bytes, &mut pos)
473 .ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?;
474 let (field, wire_type) = (tag >> 3, tag & 0x7);
475 match wire_type {
476 0 => {
477 read_varint(bytes, &mut pos)
478 .ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?;
479 }
480 1 => pos += 8,
481 2 => {
482 let len = read_varint(bytes, &mut pos)
483 .ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?
484 as usize;
485 let end = pos
486 .checked_add(len)
487 .filter(|&end| end <= bytes.len())
488 .ok_or(ReflectionError::MalformedFraming { offset: tag_offset })?;
489 if field == 1 {
490 files.push(&bytes[pos..end]);
491 }
492 pos = end;
493 }
494 5 => pos += 4,
495 _ => return Err(ReflectionError::MalformedFraming { offset: tag_offset }),
496 }
497 if pos > bytes.len() {
498 return Err(ReflectionError::MalformedFraming { offset: tag_offset });
499 }
500 }
501 Ok(files)
502}
503
504fn read_varint(bytes: &[u8], pos: &mut usize) -> Option<u64> {
508 let mut value = 0u64;
509 for shift in (0..64).step_by(7) {
510 let byte = *bytes.get(*pos)?;
511 *pos += 1;
512 value |= u64::from(byte & 0x7f) << shift;
513 if byte & 0x80 == 0 {
514 return Some(value);
515 }
516 }
517 None
518}
519
520#[cfg(test)]
521mod tests {
522 use buffa_descriptor::generated::descriptor::field_descriptor_proto::{Label, Type};
523 use buffa_descriptor::generated::descriptor::{
524 DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
525 MethodDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto,
526 };
527
528 use super::*;
529
530 const SELF_V1: &str = "grpc.reflection.v1.ServerReflection";
531 const SELF_V1ALPHA: &str = "grpc.reflection.v1alpha.ServerReflection";
532
533 fn test_set() -> FileDescriptorSet {
536 let base = FileDescriptorProto {
537 name: Some("acme/base.proto".into()),
538 package: Some("acme.base".into()),
539 message_type: vec![DescriptorProto {
540 name: Some("Shared".into()),
541 extension_range: vec![
542 buffa_descriptor::generated::descriptor::descriptor_proto::ExtensionRange {
543 start: Some(100),
544 end: Some(200),
545 ..Default::default()
546 },
547 ],
548 ..Default::default()
549 }],
550 ..Default::default()
551 };
552 let api = FileDescriptorProto {
553 name: Some("acme/api.proto".into()),
554 package: Some("acme.api".into()),
555 dependency: vec!["acme/base.proto".into()],
556 message_type: vec![DescriptorProto {
557 name: Some("Request".into()),
558 field: vec![FieldDescriptorProto {
559 name: Some("query".into()),
560 number: Some(1),
561 label: Some(Label::LABEL_OPTIONAL),
562 r#type: Some(Type::TYPE_STRING),
563 ..Default::default()
564 }],
565 oneof_decl: vec![OneofDescriptorProto {
566 name: Some("variant".into()),
567 ..Default::default()
568 }],
569 nested_type: vec![DescriptorProto {
570 name: Some("Inner".into()),
571 ..Default::default()
572 }],
573 enum_type: vec![EnumDescriptorProto {
574 name: Some("Kind".into()),
575 value: vec![EnumValueDescriptorProto {
576 name: Some("KIND_UNSPECIFIED".into()),
577 number: Some(0),
578 ..Default::default()
579 }],
580 ..Default::default()
581 }],
582 ..Default::default()
583 }],
584 enum_type: vec![EnumDescriptorProto {
585 name: Some("Code".into()),
586 value: vec![EnumValueDescriptorProto {
587 name: Some("CODE_OK".into()),
588 number: Some(0),
589 ..Default::default()
590 }],
591 ..Default::default()
592 }],
593 service: vec![ServiceDescriptorProto {
594 name: Some("Search".into()),
595 method: vec![MethodDescriptorProto {
596 name: Some("Query".into()),
597 input_type: Some(".acme.api.Request".into()),
598 output_type: Some(".acme.api.Request".into()),
599 ..Default::default()
600 }],
601 ..Default::default()
602 }],
603 extension: vec![FieldDescriptorProto {
604 name: Some("tag".into()),
605 number: Some(150),
606 label: Some(Label::LABEL_OPTIONAL),
607 r#type: Some(Type::TYPE_INT32),
608 extendee: Some(".acme.base.Shared".into()),
609 ..Default::default()
610 }],
611 ..Default::default()
612 };
613 FileDescriptorSet {
614 file: vec![base, api],
615 ..Default::default()
616 }
617 }
618
619 fn test_reflector() -> Reflector {
620 Reflector::from_descriptor_set_bytes(&test_set().encode_to_vec()).unwrap()
621 }
622
623 fn files(answer: Answer) -> Vec<Vec<u8>> {
624 match answer {
625 Answer::Files(files) => files,
626 _ => panic!("expected Answer::Files"),
627 }
628 }
629
630 fn assert_not_found(answer: &Answer) {
631 assert!(matches!(answer, Answer::NotFound(_)));
632 }
633
634 #[test]
635 fn file_by_filename_returns_raw_bytes_and_closure() {
636 let set = test_set();
637 let reflector = test_reflector();
638
639 let got = files(reflector.file_by_filename("acme/api.proto"));
640 assert_eq!(got.len(), 2);
642 assert_eq!(got[0], set.file[1].encode_to_vec());
643 assert_eq!(got[1], set.file[0].encode_to_vec());
644
645 let got = files(reflector.file_by_filename("acme/base.proto"));
647 assert_eq!(got.len(), 1);
648
649 assert_not_found(&reflector.file_by_filename("nope.proto"));
650 }
651
652 #[test]
653 fn raw_bytes_survive_unknown_fields() {
654 let mut file = test_set().file[0].encode_to_vec();
658 let unknown = [0xc8, 0x83, 0x06, 0x01]; file.extend_from_slice(&unknown);
660 let mut set_bytes = vec![0x0a, u8::try_from(file.len()).unwrap()];
661 set_bytes.extend_from_slice(&file);
662
663 let reflector = Reflector::from_descriptor_set_bytes(&set_bytes).unwrap();
664 let got = files(reflector.file_by_filename("acme/base.proto"));
665 assert_eq!(got, vec![file]);
666 }
667
668 #[test]
669 fn symbol_lookup_covers_every_kind() {
670 let reflector = test_reflector();
671 for symbol in [
672 "acme.api.Request",
673 "acme.api.Request.query",
674 "acme.api.Request.variant",
675 "acme.api.Request.Inner",
676 "acme.api.Request.Kind",
677 "acme.api.Request.KIND_UNSPECIFIED", "acme.api.Code",
679 "acme.api.CODE_OK",
680 "acme.api.Search",
681 "acme.api.Search.Query",
682 "acme.api.tag",
683 ".acme.api.Request", ] {
685 let got = files(reflector.file_containing_symbol(symbol));
686 assert_eq!(got.len(), 2, "symbol {symbol}");
687 }
688 assert_not_found(&reflector.file_containing_symbol("acme.api.Code.CODE_OK"));
690 assert_not_found(&reflector.file_containing_symbol("acme.api"));
692 assert_not_found(&reflector.file_containing_symbol("acme.api.Missing"));
693 }
694
695 #[test]
696 fn extension_queries() {
697 let reflector = test_reflector();
698
699 let got = files(reflector.file_containing_extension("acme.base.Shared", 150));
700 assert_eq!(got.len(), 2); assert_not_found(&reflector.file_containing_extension("acme.base.Shared", 151));
703 assert_not_found(&reflector.file_containing_extension("acme.base.Shared", -1));
704 assert_not_found(&reflector.file_containing_extension("acme.api.Request", 150));
705
706 match reflector.all_extension_numbers_of_type("acme.base.Shared") {
707 Answer::ExtensionNumbers { base_type, numbers } => {
708 assert_eq!(base_type, "acme.base.Shared");
709 assert_eq!(numbers, vec![150]);
710 }
711 _ => panic!("expected extension numbers"),
712 }
713 match reflector.all_extension_numbers_of_type("acme.api.Request") {
716 Answer::ExtensionNumbers { numbers, .. } => assert!(numbers.is_empty()),
717 _ => panic!("expected extension numbers"),
718 }
719 assert_not_found(&reflector.all_extension_numbers_of_type("acme.Missing"));
722 assert_not_found(&reflector.all_extension_numbers_of_type("acme.api.Search"));
723 }
724
725 #[test]
726 fn list_services() {
727 match test_reflector().list_services() {
728 Answer::Services(names) => {
729 assert_eq!(names, vec!["acme.api.Search", SELF_V1, SELF_V1ALPHA]);
730 }
731 _ => panic!("expected services"),
732 }
733 }
734
735 #[test]
736 fn with_services_overrides_advertised_list_only() {
737 let reflector = test_reflector().with_services(["acme.api.Curated"]);
738 assert_eq!(reflector.service_names(), ["acme.api.Curated"]);
739 match reflector.list_services() {
740 Answer::Services(names) => assert_eq!(names, vec!["acme.api.Curated"]),
741 _ => panic!("expected services"),
742 }
743 let got = files(reflector.file_containing_symbol("acme.api.Search"));
745 assert_eq!(got.len(), 2);
746 }
747
748 #[test]
749 fn merging_sets_skips_duplicate_files() {
750 let mut reflector = test_reflector();
751 let second = FileDescriptorSet {
754 file: vec![
755 FileDescriptorProto {
756 name: Some("acme/base.proto".into()),
757 package: Some("acme.other".into()),
758 ..Default::default()
759 },
760 FileDescriptorProto {
761 name: Some("acme/extra.proto".into()),
762 package: Some("acme.extra".into()),
763 service: vec![ServiceDescriptorProto {
764 name: Some("Extra".into()),
765 ..Default::default()
766 }],
767 ..Default::default()
768 },
769 ],
770 ..Default::default()
771 };
772 reflector
773 .add_descriptor_set_bytes(&second.encode_to_vec())
774 .unwrap();
775
776 assert!(matches!(
779 reflector.file_containing_symbol("acme.base.Shared"),
780 Answer::Files(_)
781 ));
782 match reflector.list_services() {
783 Answer::Services(names) => {
784 assert_eq!(
785 names,
786 vec!["acme.api.Search", "acme.extra.Extra", SELF_V1, SELF_V1ALPHA]
787 );
788 }
789 _ => panic!("expected services"),
790 }
791 }
792
793 #[test]
794 fn from_descriptor_pool_serves_reencoded_files() {
795 let set = test_set();
796 let pool = Arc::new(DescriptorPool::new(set.clone()).unwrap());
797 let reflector = Reflector::from_descriptor_pool(Arc::clone(&pool)).unwrap();
798
799 let got = files(reflector.file_containing_symbol("acme.api.Search"));
802 assert_eq!(got.len(), 2);
803 let decoded = FileDescriptorProto::decode_from_slice(&got[0]).unwrap();
804 assert_eq!(decoded, set.file[1]);
805
806 match reflector.list_services() {
807 Answer::Services(names) => {
808 assert_eq!(names, vec!["acme.api.Search", SELF_V1, SELF_V1ALPHA]);
809 }
810 _ => panic!("expected services"),
811 }
812
813 let mut reflector = reflector;
816 let err = reflector
817 .add_descriptor_set_bytes(&FileDescriptorSet::default().encode_to_vec())
818 .unwrap_err();
819 assert!(matches!(err, ReflectionError::SharedPool));
820 }
821
822 #[test]
823 fn construction_errors() {
824 let err = Reflector::from_descriptor_set_bytes(&[0x0a, 0xff]).unwrap_err();
826 assert!(matches!(err, ReflectionError::MalformedFraming { .. }));
827
828 let set = FileDescriptorSet {
830 file: vec![FileDescriptorProto::default()],
831 ..Default::default()
832 };
833 let err = Reflector::from_descriptor_set_bytes(&set.encode_to_vec()).unwrap_err();
834 assert!(matches!(err, ReflectionError::UnnamedFile { index: 0 }));
835
836 let reflector = Reflector::from_descriptor_set_bytes(&[]).unwrap();
838 assert_not_found(&reflector.file_by_filename("x.proto"));
839 match reflector.list_services() {
840 Answer::Services(names) => assert_eq!(names, vec![SELF_V1, SELF_V1ALPHA]),
842 _ => panic!("expected services"),
843 }
844 }
845}