1use std::collections::HashMap;
18use std::sync::Arc;
19
20use prost_types::field_descriptor_proto::{Label, Type};
21use prost_types::{
22 DescriptorProto, EnumDescriptorProto, FileDescriptorSet, ServiceDescriptorProto,
23};
24
25use super::descriptor::{
26 Cardinality, EnumData, EnumDescriptor, EnumValueData, FieldData, FileData, Kind, MessageData,
27 MessageDescriptor, MethodData, OneofData, ServiceData, ServiceDescriptor,
28};
29use crate::ReflectError;
30
31#[derive(Debug)]
36pub struct PoolInner {
37 pub(crate) files: Vec<FileData>,
38 pub(crate) messages: Vec<MessageData>,
39 pub(crate) enums: Vec<EnumData>,
40 pub(crate) services: Vec<ServiceData>,
41 pub(crate) message_by_name: HashMap<String, usize>,
43 pub(crate) enum_by_name: HashMap<String, usize>,
45 pub(crate) service_by_name: HashMap<String, usize>,
47}
48
49#[derive(Clone, Debug)]
55pub struct DescriptorPool {
56 inner: Arc<PoolInner>,
57}
58
59#[derive(Clone, Copy)]
62enum TypeRef {
63 Message(usize),
64 Enum(usize),
65}
66
67impl DescriptorPool {
68 pub fn from_file_descriptor_set(fds: FileDescriptorSet) -> Result<Self, ReflectError> {
76 let mut builder = Builder::default();
77 builder.register(&fds)?;
78 builder.resolve(&fds)?;
79 Ok(Self {
80 inner: Arc::new(builder.into_inner()),
81 })
82 }
83
84 pub fn get_message_by_name(&self, full_name: &str) -> Option<MessageDescriptor> {
86 self.inner
87 .message_by_name
88 .get(full_name)
89 .map(|&index| MessageDescriptor {
90 pool: Arc::clone(&self.inner),
91 index,
92 })
93 }
94
95 pub fn get_enum_by_name(&self, full_name: &str) -> Option<EnumDescriptor> {
97 self.inner
98 .enum_by_name
99 .get(full_name)
100 .map(|&index| EnumDescriptor {
101 pool: Arc::clone(&self.inner),
102 index,
103 })
104 }
105
106 pub fn get_service_by_name(&self, full_name: &str) -> Option<ServiceDescriptor> {
108 self.inner
109 .service_by_name
110 .get(full_name)
111 .map(|&index| ServiceDescriptor {
112 pool: Arc::clone(&self.inner),
113 index,
114 })
115 }
116
117 pub fn all_messages(&self) -> impl ExactSizeIterator<Item = MessageDescriptor> + '_ {
120 let inner = Arc::clone(&self.inner);
121 (0..self.inner.messages.len()).map(move |index| MessageDescriptor {
122 pool: Arc::clone(&inner),
123 index,
124 })
125 }
126
127 pub fn all_enums(&self) -> impl ExactSizeIterator<Item = EnumDescriptor> + '_ {
129 let inner = Arc::clone(&self.inner);
130 (0..self.inner.enums.len()).map(move |index| EnumDescriptor {
131 pool: Arc::clone(&inner),
132 index,
133 })
134 }
135
136 pub fn services(&self) -> impl ExactSizeIterator<Item = ServiceDescriptor> + '_ {
138 let inner = Arc::clone(&self.inner);
139 (0..self.inner.services.len()).map(move |index| ServiceDescriptor {
140 pool: Arc::clone(&inner),
141 index,
142 })
143 }
144}
145
146#[derive(Default)]
148struct Builder {
149 files: Vec<FileData>,
150 messages: Vec<MessageData>,
151 enums: Vec<EnumData>,
152 services: Vec<ServiceData>,
153 message_by_name: HashMap<String, usize>,
154 enum_by_name: HashMap<String, usize>,
155 service_by_name: HashMap<String, usize>,
156 type_by_name: HashMap<String, TypeRef>,
159}
160
161impl Builder {
162 fn into_inner(self) -> PoolInner {
163 PoolInner {
164 files: self.files,
165 messages: self.messages,
166 enums: self.enums,
167 services: self.services,
168 message_by_name: self.message_by_name,
169 enum_by_name: self.enum_by_name,
170 service_by_name: self.service_by_name,
171 }
172 }
173
174 fn register(&mut self, fds: &FileDescriptorSet) -> Result<(), ReflectError> {
177 for file in &fds.file {
178 let package = file.package.clone().unwrap_or_default();
179 let file_index = self.files.len();
180 let (java_pkg, go_pkg, java_outer, deprecated, optimize_for) =
181 if let Some(opts) = &file.options {
182 (
183 opts.java_package.clone(),
184 opts.go_package.clone(),
185 opts.java_outer_classname.clone(),
186 opts.deprecated.unwrap_or(false),
187 opts.optimize_for.unwrap_or(0),
188 )
189 } else {
190 (None, None, None, false, 0)
191 };
192 self.files.push(FileData {
193 name: file.name.clone().unwrap_or_default(),
194 package: package.clone(),
195 syntax: file.syntax.clone().unwrap_or_else(|| "proto2".to_owned()),
196 dependencies: file.dependency.clone(),
197 java_package: java_pkg,
198 go_package: go_pkg,
199 java_outer_classname: java_outer,
200 deprecated,
201 optimize_for,
202 });
203
204 for msg in &file.message_type {
205 self.register_message(msg, &package, file_index)?;
206 }
207 for en in &file.enum_type {
208 self.register_enum(en, &package, file_index)?;
209 }
210 }
211 Ok(())
212 }
213
214 fn register_message(
217 &mut self,
218 msg: &DescriptorProto,
219 scope: &str,
220 file_index: usize,
221 ) -> Result<usize, ReflectError> {
222 let name = msg
223 .name
224 .clone()
225 .ok_or_else(|| ReflectError::Pool("message without a name".to_owned()))?;
226 let full_name = qualify(scope, &name);
227
228 let is_map_entry = msg
229 .options
230 .as_ref()
231 .and_then(|o| o.map_entry)
232 .unwrap_or(false);
233
234 let index = self.messages.len();
237 self.messages.push(MessageData {
238 full_name: full_name.clone(),
239 name,
240 file_index,
241 fields: Vec::new(),
242 field_by_number: HashMap::new(),
243 field_by_name: HashMap::new(),
244 field_by_json_name: HashMap::new(),
245 oneofs: Vec::new(),
246 nested_messages: Vec::new(),
247 nested_enums: Vec::new(),
248 is_map_entry,
249 });
250 if self
251 .message_by_name
252 .insert(full_name.clone(), index)
253 .is_some()
254 {
255 return Err(ReflectError::Pool(format!(
256 "duplicate message name '{full_name}'"
257 )));
258 }
259 self.type_by_name
260 .insert(full_name.clone(), TypeRef::Message(index));
261
262 let mut nested_messages = Vec::with_capacity(msg.nested_type.len());
263 for nested in &msg.nested_type {
264 let child = self.register_message(nested, &full_name, file_index)?;
265 nested_messages.push(child);
266 }
267 let mut nested_enums = Vec::with_capacity(msg.enum_type.len());
268 for nested in &msg.enum_type {
269 let child = self.register_enum(nested, &full_name, file_index)?;
270 nested_enums.push(child);
271 }
272 self.messages[index].nested_messages = nested_messages;
273 self.messages[index].nested_enums = nested_enums;
274
275 Ok(index)
276 }
277
278 fn register_enum(
280 &mut self,
281 en: &EnumDescriptorProto,
282 scope: &str,
283 file_index: usize,
284 ) -> Result<usize, ReflectError> {
285 let name = en
286 .name
287 .clone()
288 .ok_or_else(|| ReflectError::Pool("enum without a name".to_owned()))?;
289 let full_name = qualify(scope, &name);
290
291 let mut values = Vec::with_capacity(en.value.len());
292 let mut value_by_number = HashMap::new();
293 let mut value_by_name = HashMap::new();
294 for value in &en.value {
295 let value_name = value
296 .name
297 .clone()
298 .ok_or_else(|| ReflectError::Pool("enum value without a name".to_owned()))?;
299 let number = value
300 .number
301 .ok_or_else(|| ReflectError::Pool("enum value without a number".to_owned()))?;
302 let value_index = values.len();
303 let value_full_name = qualify(&full_name, &value_name);
307 values.push(EnumValueData {
308 name: value_name.clone(),
309 full_name: value_full_name,
310 number,
311 });
312 value_by_number.entry(number).or_insert(value_index);
315 value_by_name.insert(value_name, value_index);
316 }
317
318 let index = self.enums.len();
319 self.enums.push(EnumData {
320 full_name: full_name.clone(),
321 name,
322 file_index,
323 values,
324 value_by_number,
325 value_by_name,
326 });
327 if self.enum_by_name.insert(full_name.clone(), index).is_some() {
328 return Err(ReflectError::Pool(format!(
329 "duplicate enum name '{full_name}'"
330 )));
331 }
332 self.type_by_name.insert(full_name, TypeRef::Enum(index));
333
334 Ok(index)
335 }
336
337 fn resolve(&mut self, fds: &FileDescriptorSet) -> Result<(), ReflectError> {
339 let mut message_cursor = 0usize;
342 for file in &fds.file {
343 let syntax = file.syntax.as_deref().unwrap_or("proto2");
344 for msg in &file.message_type {
345 self.resolve_message(msg, &mut message_cursor, syntax)?;
346 }
347 }
348
349 for file in &fds.file {
351 let package = file.package.clone().unwrap_or_default();
352 for svc in &file.service {
353 self.resolve_service(svc, &package)?;
354 }
355 }
356
357 Ok(())
358 }
359
360 fn resolve_message(
363 &mut self,
364 msg: &DescriptorProto,
365 cursor: &mut usize,
366 syntax: &str,
367 ) -> Result<(), ReflectError> {
368 let index = *cursor;
369 *cursor += 1;
370
371 let message_full_name = self.messages[index].full_name.clone();
372
373 let mut fields: Vec<FieldData> = Vec::with_capacity(msg.field.len());
375 let mut field_by_number = HashMap::new();
376 let mut field_by_name = HashMap::new();
377 let mut field_by_json_name = HashMap::new();
378
379 for field in &msg.field {
380 let fname = field
381 .name
382 .clone()
383 .ok_or_else(|| ReflectError::Pool("field without a name".to_owned()))?;
384 let number = field
385 .number
386 .ok_or_else(|| ReflectError::Pool(format!("field '{fname}' without a number")))?;
387 let number = u32::try_from(number).map_err(|_| {
388 ReflectError::Pool(format!("field '{fname}' has invalid number {number}"))
389 })?;
390
391 let kind = self.resolve_kind(field, &fname)?;
392
393 let label = field
394 .label
395 .and_then(|l| Label::try_from(l).ok())
396 .unwrap_or(Label::Optional);
397 let cardinality = match label {
398 Label::Optional => Cardinality::Optional,
399 Label::Required => Cardinality::Required,
400 Label::Repeated => Cardinality::Repeated,
401 };
402
403 let proto3_optional = field.proto3_optional.unwrap_or(false);
404
405 let packed = compute_packed(field, kind, cardinality, syntax);
406
407 let oneof_index = field
408 .oneof_index
409 .map(|i| usize::try_from(i).unwrap_or(usize::MAX));
410
411 let json_name = field
412 .json_name
413 .clone()
414 .unwrap_or_else(|| to_json_name(&fname));
415
416 let field_full_name = qualify(&message_full_name, &fname);
417 let pos = fields.len();
418 field_by_number.insert(number, pos);
419 field_by_name.insert(fname.clone(), pos);
420 field_by_json_name.insert(json_name.clone(), pos);
421
422 fields.push(FieldData {
423 name: fname,
424 full_name: field_full_name,
425 json_name,
426 number,
427 kind,
428 cardinality,
429 packed,
430 oneof_index,
431 proto3_optional,
432 });
433 }
434
435 let mut oneofs: Vec<OneofData> = Vec::with_capacity(msg.oneof_decl.len());
437 for decl in &msg.oneof_decl {
438 let oname = decl
439 .name
440 .clone()
441 .ok_or_else(|| ReflectError::Pool("oneof without a name".to_owned()))?;
442 let oneof_full_name = qualify(&message_full_name, &oname);
443 oneofs.push(OneofData {
444 name: oname,
445 full_name: oneof_full_name,
446 field_indices: Vec::new(),
447 is_synthetic: false,
449 });
450 }
451 for (pos, field) in fields.iter().enumerate() {
452 if let Some(oi) = field.oneof_index {
453 if let Some(oneof) = oneofs.get_mut(oi) {
454 oneof.field_indices.push(pos);
455 if field.proto3_optional {
458 oneof.is_synthetic = true;
459 }
460 }
461 }
462 }
463
464 self.messages[index].fields = fields;
465 self.messages[index].field_by_number = field_by_number;
466 self.messages[index].field_by_name = field_by_name;
467 self.messages[index].field_by_json_name = field_by_json_name;
468 self.messages[index].oneofs = oneofs;
469
470 for nested in &msg.nested_type {
473 self.resolve_message(nested, cursor, syntax)?;
474 }
475
476 Ok(())
477 }
478
479 fn resolve_kind(
482 &self,
483 field: &prost_types::FieldDescriptorProto,
484 fname: &str,
485 ) -> Result<Kind, ReflectError> {
486 let ty = field
487 .r#type
488 .and_then(|t| Type::try_from(t).ok())
489 .ok_or_else(|| ReflectError::Pool(format!("field '{fname}' without a type")))?;
490
491 let kind = match ty {
492 Type::Double => Kind::Double,
493 Type::Float => Kind::Float,
494 Type::Int64 => Kind::Int64,
495 Type::Uint64 => Kind::Uint64,
496 Type::Int32 => Kind::Int32,
497 Type::Fixed64 => Kind::Fixed64,
498 Type::Fixed32 => Kind::Fixed32,
499 Type::Bool => Kind::Bool,
500 Type::String => Kind::String,
501 Type::Bytes => Kind::Bytes,
502 Type::Uint32 => Kind::Uint32,
503 Type::Sfixed32 => Kind::Sfixed32,
504 Type::Sfixed64 => Kind::Sfixed64,
505 Type::Sint32 => Kind::Sint32,
506 Type::Sint64 => Kind::Sint64,
507 Type::Group => {
508 let idx = self.resolve_type_name(field, fname, true)?;
509 Kind::Group(idx)
510 }
511 Type::Message => {
512 let idx = self.resolve_type_name(field, fname, true)?;
513 Kind::Message(idx)
514 }
515 Type::Enum => {
516 let idx = self.resolve_type_name(field, fname, false)?;
517 Kind::Enum(idx)
518 }
519 };
520 Ok(kind)
521 }
522
523 fn resolve_type_name(
526 &self,
527 field: &prost_types::FieldDescriptorProto,
528 fname: &str,
529 expect_message: bool,
530 ) -> Result<usize, ReflectError> {
531 let raw = field.type_name.as_deref().ok_or_else(|| {
532 ReflectError::Pool(format!(
533 "field '{fname}' is a message/enum but has no type_name"
534 ))
535 })?;
536 let key = raw.strip_prefix('.').unwrap_or(raw);
537 match self.type_by_name.get(key) {
538 Some(TypeRef::Message(i)) if expect_message => Ok(*i),
539 Some(TypeRef::Enum(i)) if !expect_message => Ok(*i),
540 Some(_) => Err(ReflectError::Pool(format!(
541 "field '{fname}' type '{key}' resolved to the wrong kind"
542 ))),
543 None => Err(ReflectError::Pool(format!(
544 "field '{fname}' references unknown type '{key}'"
545 ))),
546 }
547 }
548
549 fn resolve_service(
551 &mut self,
552 svc: &ServiceDescriptorProto,
553 package: &str,
554 ) -> Result<(), ReflectError> {
555 let name = svc
556 .name
557 .clone()
558 .ok_or_else(|| ReflectError::Pool("service without a name".to_owned()))?;
559 let full_name = qualify(package, &name);
560
561 let mut methods = Vec::with_capacity(svc.method.len());
562 for method in &svc.method {
563 let mname = method
564 .name
565 .clone()
566 .ok_or_else(|| ReflectError::Pool("method without a name".to_owned()))?;
567 let input_index =
568 self.resolve_message_ref(method.input_type.as_deref(), &mname, "input")?;
569 let output_index =
570 self.resolve_message_ref(method.output_type.as_deref(), &mname, "output")?;
571 let method_full_name = qualify(&full_name, &mname);
572 methods.push(MethodData {
573 name: mname,
574 full_name: method_full_name,
575 input_index,
576 output_index,
577 client_streaming: method.client_streaming.unwrap_or(false),
578 server_streaming: method.server_streaming.unwrap_or(false),
579 });
580 }
581
582 let index = self.services.len();
583 self.services.push(ServiceData {
584 full_name: full_name.clone(),
585 name,
586 file_index: self.file_index_for_package(package),
587 methods,
588 });
589 if self
590 .service_by_name
591 .insert(full_name.clone(), index)
592 .is_some()
593 {
594 return Err(ReflectError::Pool(format!(
595 "duplicate service name '{full_name}'"
596 )));
597 }
598 Ok(())
599 }
600
601 fn resolve_message_ref(
603 &self,
604 type_name: Option<&str>,
605 method_name: &str,
606 role: &str,
607 ) -> Result<usize, ReflectError> {
608 let raw = type_name.ok_or_else(|| {
609 ReflectError::Pool(format!("method '{method_name}' has no {role} type"))
610 })?;
611 let key = raw.strip_prefix('.').unwrap_or(raw);
612 match self.type_by_name.get(key) {
613 Some(TypeRef::Message(i)) => Ok(*i),
614 _ => Err(ReflectError::Pool(format!(
615 "method '{method_name}' {role} type '{key}' is not a known message"
616 ))),
617 }
618 }
619
620 fn file_index_for_package(&self, package: &str) -> usize {
624 self.files
625 .iter()
626 .position(|f| f.package == package)
627 .unwrap_or(0)
628 }
629}
630
631fn qualify(scope: &str, name: &str) -> String {
634 if scope.is_empty() {
635 name.to_owned()
636 } else {
637 format!("{scope}.{name}")
638 }
639}
640
641fn compute_packed(
646 field: &prost_types::FieldDescriptorProto,
647 kind: Kind,
648 cardinality: Cardinality,
649 syntax: &str,
650) -> bool {
651 if !matches!(cardinality, Cardinality::Repeated) || !kind.is_packable() {
652 return false;
653 }
654 if let Some(opts) = field.options.as_ref() {
655 if let Some(packed) = opts.packed {
656 return packed;
657 }
658 }
659 syntax == "proto3"
660}
661
662fn to_json_name(name: &str) -> String {
665 let mut out = String::with_capacity(name.len());
666 let mut upper_next = false;
667 for ch in name.chars() {
668 if ch == '_' {
669 upper_next = true;
670 } else if upper_next {
671 out.extend(ch.to_uppercase());
672 upper_next = false;
673 } else {
674 out.push(ch);
675 }
676 }
677 out
678}