Skip to main content

oxiproto_reflect/native/
pool.rs

1//! Native [`DescriptorPool`] built from a [`prost_types::FileDescriptorSet`].
2//!
3//! The pool builds an in-memory, index-based descriptor model
4//! ([`PoolInner`]) in two passes:
5//!
6//! 1. **Registration** — walk every file and recursively register every
7//!    message (and its nested messages) and enum, assigning each a stable
8//!    index and recording its fully-qualified name.
9//! 2. **Resolution** — walk every message field and service method, resolving
10//!    type-name references (e.g. `.my.pkg.Other`) against the name table to
11//!    concrete indices, producing [`Kind`] values and method input/output
12//!    indices.
13//!
14//! Because descriptors are index handles over a shared [`Arc<PoolInner>`],
15//! circular references between messages are represented naturally.
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use prost_types::field_descriptor_proto::{Label, Type};
21use prost_types::{
22    DescriptorProto, EnumDescriptorProto, FileDescriptorSet, ServiceDescriptorProto,
23};
24
25use super::descriptor::{
26    Cardinality, EnumData, EnumDescriptor, EnumValueData, FieldData, FileData, Kind, MessageData,
27    MessageDescriptor, MethodData, OneofData, ServiceData, ServiceDescriptor,
28};
29use crate::ReflectError;
30
31/// The shared, immutable backing store for all native descriptors in a pool.
32///
33/// All public descriptor handles hold an [`Arc`] to one of these and an index
34/// into the relevant vector.
35#[derive(Debug)]
36pub struct PoolInner {
37    pub(crate) files: Vec<FileData>,
38    pub(crate) messages: Vec<MessageData>,
39    pub(crate) enums: Vec<EnumData>,
40    pub(crate) services: Vec<ServiceData>,
41    /// Fully-qualified message name → index into `messages`.
42    pub(crate) message_by_name: HashMap<String, usize>,
43    /// Fully-qualified enum name → index into `enums`.
44    pub(crate) enum_by_name: HashMap<String, usize>,
45    /// Fully-qualified service name → index into `services`.
46    pub(crate) service_by_name: HashMap<String, usize>,
47}
48
49/// A native protobuf descriptor pool.
50///
51/// Built from a [`prost_types::FileDescriptorSet`] via
52/// [`DescriptorPool::from_file_descriptor_set`]. Cheaply cloneable (it wraps an
53/// [`Arc`]).
54#[derive(Clone, Debug)]
55pub struct DescriptorPool {
56    inner: Arc<PoolInner>,
57}
58
59/// During registration, a fully-qualified type name resolves to either a
60/// message or an enum index.
61#[derive(Clone, Copy)]
62enum TypeRef {
63    Message(usize),
64    Enum(usize),
65}
66
67impl DescriptorPool {
68    /// Build a pool from a decoded [`FileDescriptorSet`].
69    ///
70    /// # Errors
71    ///
72    /// Returns [`ReflectError::Pool`] if a descriptor is malformed (missing
73    /// required name/number fields) or if a field, method input, or method
74    /// output references a type name that is not present in the set.
75    pub fn from_file_descriptor_set(fds: FileDescriptorSet) -> Result<Self, ReflectError> {
76        let mut builder = Builder::default();
77        builder.register(&fds)?;
78        builder.resolve(&fds)?;
79        Ok(Self {
80            inner: Arc::new(builder.into_inner()),
81        })
82    }
83
84    /// Look up a message by its fully-qualified name (no leading dot).
85    pub fn get_message_by_name(&self, full_name: &str) -> Option<MessageDescriptor> {
86        self.inner
87            .message_by_name
88            .get(full_name)
89            .map(|&index| MessageDescriptor {
90                pool: Arc::clone(&self.inner),
91                index,
92            })
93    }
94
95    /// Look up an enum by its fully-qualified name (no leading dot).
96    pub fn get_enum_by_name(&self, full_name: &str) -> Option<EnumDescriptor> {
97        self.inner
98            .enum_by_name
99            .get(full_name)
100            .map(|&index| EnumDescriptor {
101                pool: Arc::clone(&self.inner),
102                index,
103            })
104    }
105
106    /// Look up a service by its fully-qualified name (no leading dot).
107    pub fn get_service_by_name(&self, full_name: &str) -> Option<ServiceDescriptor> {
108        self.inner
109            .service_by_name
110            .get(full_name)
111            .map(|&index| ServiceDescriptor {
112                pool: Arc::clone(&self.inner),
113                index,
114            })
115    }
116
117    /// Iterate over every message in the pool (including nested messages and
118    /// synthetic map-entry types), in registration order.
119    pub fn all_messages(&self) -> impl ExactSizeIterator<Item = MessageDescriptor> + '_ {
120        let inner = Arc::clone(&self.inner);
121        (0..self.inner.messages.len()).map(move |index| MessageDescriptor {
122            pool: Arc::clone(&inner),
123            index,
124        })
125    }
126
127    /// Iterate over every enum in the pool, in registration order.
128    pub fn all_enums(&self) -> impl ExactSizeIterator<Item = EnumDescriptor> + '_ {
129        let inner = Arc::clone(&self.inner);
130        (0..self.inner.enums.len()).map(move |index| EnumDescriptor {
131            pool: Arc::clone(&inner),
132            index,
133        })
134    }
135
136    /// Iterate over every service in the pool, in registration order.
137    pub fn services(&self) -> impl ExactSizeIterator<Item = ServiceDescriptor> + '_ {
138        let inner = Arc::clone(&self.inner);
139        (0..self.inner.services.len()).map(move |index| ServiceDescriptor {
140            pool: Arc::clone(&inner),
141            index,
142        })
143    }
144}
145
146/// Mutable accumulator used while building a [`PoolInner`].
147#[derive(Default)]
148struct Builder {
149    files: Vec<FileData>,
150    messages: Vec<MessageData>,
151    enums: Vec<EnumData>,
152    services: Vec<ServiceData>,
153    message_by_name: HashMap<String, usize>,
154    enum_by_name: HashMap<String, usize>,
155    service_by_name: HashMap<String, usize>,
156    /// Combined type table (messages + enums) keyed by fully-qualified name,
157    /// used during resolution.
158    type_by_name: HashMap<String, TypeRef>,
159}
160
161impl Builder {
162    fn into_inner(self) -> PoolInner {
163        PoolInner {
164            files: self.files,
165            messages: self.messages,
166            enums: self.enums,
167            services: self.services,
168            message_by_name: self.message_by_name,
169            enum_by_name: self.enum_by_name,
170            service_by_name: self.service_by_name,
171        }
172    }
173
174    /// First pass: register every file, message, nested message, and enum,
175    /// assigning indices and building the name tables.
176    fn register(&mut self, fds: &FileDescriptorSet) -> Result<(), ReflectError> {
177        for file in &fds.file {
178            let package = file.package.clone().unwrap_or_default();
179            let file_index = self.files.len();
180            let (java_pkg, go_pkg, java_outer, deprecated, optimize_for) =
181                if let Some(opts) = &file.options {
182                    (
183                        opts.java_package.clone(),
184                        opts.go_package.clone(),
185                        opts.java_outer_classname.clone(),
186                        opts.deprecated.unwrap_or(false),
187                        opts.optimize_for.unwrap_or(0),
188                    )
189                } else {
190                    (None, None, None, false, 0)
191                };
192            self.files.push(FileData {
193                name: file.name.clone().unwrap_or_default(),
194                package: package.clone(),
195                syntax: file.syntax.clone().unwrap_or_else(|| "proto2".to_owned()),
196                dependencies: file.dependency.clone(),
197                java_package: java_pkg,
198                go_package: go_pkg,
199                java_outer_classname: java_outer,
200                deprecated,
201                optimize_for,
202            });
203
204            for msg in &file.message_type {
205                self.register_message(msg, &package, file_index)?;
206            }
207            for en in &file.enum_type {
208                self.register_enum(en, &package, file_index)?;
209            }
210        }
211        Ok(())
212    }
213
214    /// Register a message (and recursively its nested messages and enums).
215    /// Returns the assigned message index.
216    fn register_message(
217        &mut self,
218        msg: &DescriptorProto,
219        scope: &str,
220        file_index: usize,
221    ) -> Result<usize, ReflectError> {
222        let name = msg
223            .name
224            .clone()
225            .ok_or_else(|| ReflectError::Pool("message without a name".to_owned()))?;
226        let full_name = qualify(scope, &name);
227
228        let is_map_entry = msg
229            .options
230            .as_ref()
231            .and_then(|o| o.map_entry)
232            .unwrap_or(false);
233
234        // Reserve this message's slot before recursing so nested types get
235        // larger indices and the parent index is stable.
236        let index = self.messages.len();
237        self.messages.push(MessageData {
238            full_name: full_name.clone(),
239            name,
240            file_index,
241            fields: Vec::new(),
242            field_by_number: HashMap::new(),
243            field_by_name: HashMap::new(),
244            field_by_json_name: HashMap::new(),
245            oneofs: Vec::new(),
246            nested_messages: Vec::new(),
247            nested_enums: Vec::new(),
248            is_map_entry,
249        });
250        if self
251            .message_by_name
252            .insert(full_name.clone(), index)
253            .is_some()
254        {
255            return Err(ReflectError::Pool(format!(
256                "duplicate message name '{full_name}'"
257            )));
258        }
259        self.type_by_name
260            .insert(full_name.clone(), TypeRef::Message(index));
261
262        let mut nested_messages = Vec::with_capacity(msg.nested_type.len());
263        for nested in &msg.nested_type {
264            let child = self.register_message(nested, &full_name, file_index)?;
265            nested_messages.push(child);
266        }
267        let mut nested_enums = Vec::with_capacity(msg.enum_type.len());
268        for nested in &msg.enum_type {
269            let child = self.register_enum(nested, &full_name, file_index)?;
270            nested_enums.push(child);
271        }
272        self.messages[index].nested_messages = nested_messages;
273        self.messages[index].nested_enums = nested_enums;
274
275        Ok(index)
276    }
277
278    /// Register an enum. Returns the assigned enum index.
279    fn register_enum(
280        &mut self,
281        en: &EnumDescriptorProto,
282        scope: &str,
283        file_index: usize,
284    ) -> Result<usize, ReflectError> {
285        let name = en
286            .name
287            .clone()
288            .ok_or_else(|| ReflectError::Pool("enum without a name".to_owned()))?;
289        let full_name = qualify(scope, &name);
290
291        let mut values = Vec::with_capacity(en.value.len());
292        let mut value_by_number = HashMap::new();
293        let mut value_by_name = HashMap::new();
294        for value in &en.value {
295            let value_name = value
296                .name
297                .clone()
298                .ok_or_else(|| ReflectError::Pool("enum value without a name".to_owned()))?;
299            let number = value
300                .number
301                .ok_or_else(|| ReflectError::Pool("enum value without a number".to_owned()))?;
302            let value_index = values.len();
303            // Enum value names are scoped to the *enclosing* scope of the enum,
304            // not the enum itself (C++ scoping rules), but for lookup we record
305            // the qualified-under-enum name which is what most tooling expects.
306            let value_full_name = qualify(&full_name, &value_name);
307            values.push(EnumValueData {
308                name: value_name.clone(),
309                full_name: value_full_name,
310                number,
311            });
312            // First occurrence of a number wins for the by-number map (protobuf
313            // allows aliases when `allow_alias` is set).
314            value_by_number.entry(number).or_insert(value_index);
315            value_by_name.insert(value_name, value_index);
316        }
317
318        let index = self.enums.len();
319        self.enums.push(EnumData {
320            full_name: full_name.clone(),
321            name,
322            file_index,
323            values,
324            value_by_number,
325            value_by_name,
326        });
327        if self.enum_by_name.insert(full_name.clone(), index).is_some() {
328            return Err(ReflectError::Pool(format!(
329                "duplicate enum name '{full_name}'"
330            )));
331        }
332        self.type_by_name.insert(full_name, TypeRef::Enum(index));
333
334        Ok(index)
335    }
336
337    /// Second pass: resolve all field type references and service methods.
338    fn resolve(&mut self, fds: &FileDescriptorSet) -> Result<(), ReflectError> {
339        // Resolve message fields. We re-walk the FDS in the same order as
340        // registration so message indices line up.
341        let mut message_cursor = 0usize;
342        for file in &fds.file {
343            let syntax = file.syntax.as_deref().unwrap_or("proto2");
344            for msg in &file.message_type {
345                self.resolve_message(msg, &mut message_cursor, syntax)?;
346            }
347        }
348
349        // Resolve services.
350        for file in &fds.file {
351            let package = file.package.clone().unwrap_or_default();
352            for svc in &file.service {
353                self.resolve_service(svc, &package)?;
354            }
355        }
356
357        Ok(())
358    }
359
360    /// Resolve a single message's fields, advancing `cursor` over this message
361    /// and all of its nested messages (matching registration order).
362    fn resolve_message(
363        &mut self,
364        msg: &DescriptorProto,
365        cursor: &mut usize,
366        syntax: &str,
367    ) -> Result<(), ReflectError> {
368        let index = *cursor;
369        *cursor += 1;
370
371        let message_full_name = self.messages[index].full_name.clone();
372
373        // Build field data.
374        let mut fields: Vec<FieldData> = Vec::with_capacity(msg.field.len());
375        let mut field_by_number = HashMap::new();
376        let mut field_by_name = HashMap::new();
377        let mut field_by_json_name = HashMap::new();
378
379        for field in &msg.field {
380            let fname = field
381                .name
382                .clone()
383                .ok_or_else(|| ReflectError::Pool("field without a name".to_owned()))?;
384            let number = field
385                .number
386                .ok_or_else(|| ReflectError::Pool(format!("field '{fname}' without a number")))?;
387            let number = u32::try_from(number).map_err(|_| {
388                ReflectError::Pool(format!("field '{fname}' has invalid number {number}"))
389            })?;
390
391            let kind = self.resolve_kind(field, &fname)?;
392
393            let label = field
394                .label
395                .and_then(|l| Label::try_from(l).ok())
396                .unwrap_or(Label::Optional);
397            let cardinality = match label {
398                Label::Optional => Cardinality::Optional,
399                Label::Required => Cardinality::Required,
400                Label::Repeated => Cardinality::Repeated,
401            };
402
403            let proto3_optional = field.proto3_optional.unwrap_or(false);
404
405            let packed = compute_packed(field, kind, cardinality, syntax);
406
407            let oneof_index = field
408                .oneof_index
409                .map(|i| usize::try_from(i).unwrap_or(usize::MAX));
410
411            let json_name = field
412                .json_name
413                .clone()
414                .unwrap_or_else(|| to_json_name(&fname));
415
416            let field_full_name = qualify(&message_full_name, &fname);
417            let pos = fields.len();
418            field_by_number.insert(number, pos);
419            field_by_name.insert(fname.clone(), pos);
420            field_by_json_name.insert(json_name.clone(), pos);
421
422            fields.push(FieldData {
423                name: fname,
424                full_name: field_full_name,
425                json_name,
426                number,
427                kind,
428                cardinality,
429                packed,
430                oneof_index,
431                proto3_optional,
432            });
433        }
434
435        // Build oneof data, then attach field indices.
436        let mut oneofs: Vec<OneofData> = Vec::with_capacity(msg.oneof_decl.len());
437        for decl in &msg.oneof_decl {
438            let oname = decl
439                .name
440                .clone()
441                .ok_or_else(|| ReflectError::Pool("oneof without a name".to_owned()))?;
442            let oneof_full_name = qualify(&message_full_name, &oname);
443            oneofs.push(OneofData {
444                name: oname,
445                full_name: oneof_full_name,
446                field_indices: Vec::new(),
447                // Provisionally non-synthetic; refined below.
448                is_synthetic: false,
449            });
450        }
451        for (pos, field) in fields.iter().enumerate() {
452            if let Some(oi) = field.oneof_index {
453                if let Some(oneof) = oneofs.get_mut(oi) {
454                    oneof.field_indices.push(pos);
455                    // A proto3 `optional` field is implemented as a synthetic
456                    // single-field oneof.
457                    if field.proto3_optional {
458                        oneof.is_synthetic = true;
459                    }
460                }
461            }
462        }
463
464        self.messages[index].fields = fields;
465        self.messages[index].field_by_number = field_by_number;
466        self.messages[index].field_by_name = field_by_name;
467        self.messages[index].field_by_json_name = field_by_json_name;
468        self.messages[index].oneofs = oneofs;
469
470        // Recurse into nested messages, keeping the cursor in registration
471        // order.
472        for nested in &msg.nested_type {
473            self.resolve_message(nested, cursor, syntax)?;
474        }
475
476        Ok(())
477    }
478
479    /// Resolve a field's [`Kind`] from its protobuf type and (for
480    /// message/enum) its `type_name`.
481    fn resolve_kind(
482        &self,
483        field: &prost_types::FieldDescriptorProto,
484        fname: &str,
485    ) -> Result<Kind, ReflectError> {
486        let ty = field
487            .r#type
488            .and_then(|t| Type::try_from(t).ok())
489            .ok_or_else(|| ReflectError::Pool(format!("field '{fname}' without a type")))?;
490
491        let kind = match ty {
492            Type::Double => Kind::Double,
493            Type::Float => Kind::Float,
494            Type::Int64 => Kind::Int64,
495            Type::Uint64 => Kind::Uint64,
496            Type::Int32 => Kind::Int32,
497            Type::Fixed64 => Kind::Fixed64,
498            Type::Fixed32 => Kind::Fixed32,
499            Type::Bool => Kind::Bool,
500            Type::String => Kind::String,
501            Type::Bytes => Kind::Bytes,
502            Type::Uint32 => Kind::Uint32,
503            Type::Sfixed32 => Kind::Sfixed32,
504            Type::Sfixed64 => Kind::Sfixed64,
505            Type::Sint32 => Kind::Sint32,
506            Type::Sint64 => Kind::Sint64,
507            Type::Group => {
508                let idx = self.resolve_type_name(field, fname, true)?;
509                Kind::Group(idx)
510            }
511            Type::Message => {
512                let idx = self.resolve_type_name(field, fname, true)?;
513                Kind::Message(idx)
514            }
515            Type::Enum => {
516                let idx = self.resolve_type_name(field, fname, false)?;
517                Kind::Enum(idx)
518            }
519        };
520        Ok(kind)
521    }
522
523    /// Resolve a `type_name` reference (e.g. `.pkg.Msg` or `pkg.Msg`) to a
524    /// message or enum index.
525    fn resolve_type_name(
526        &self,
527        field: &prost_types::FieldDescriptorProto,
528        fname: &str,
529        expect_message: bool,
530    ) -> Result<usize, ReflectError> {
531        let raw = field.type_name.as_deref().ok_or_else(|| {
532            ReflectError::Pool(format!(
533                "field '{fname}' is a message/enum but has no type_name"
534            ))
535        })?;
536        let key = raw.strip_prefix('.').unwrap_or(raw);
537        match self.type_by_name.get(key) {
538            Some(TypeRef::Message(i)) if expect_message => Ok(*i),
539            Some(TypeRef::Enum(i)) if !expect_message => Ok(*i),
540            Some(_) => Err(ReflectError::Pool(format!(
541                "field '{fname}' type '{key}' resolved to the wrong kind"
542            ))),
543            None => Err(ReflectError::Pool(format!(
544                "field '{fname}' references unknown type '{key}'"
545            ))),
546        }
547    }
548
549    /// Resolve a service and its methods.
550    fn resolve_service(
551        &mut self,
552        svc: &ServiceDescriptorProto,
553        package: &str,
554    ) -> Result<(), ReflectError> {
555        let name = svc
556            .name
557            .clone()
558            .ok_or_else(|| ReflectError::Pool("service without a name".to_owned()))?;
559        let full_name = qualify(package, &name);
560
561        let mut methods = Vec::with_capacity(svc.method.len());
562        for method in &svc.method {
563            let mname = method
564                .name
565                .clone()
566                .ok_or_else(|| ReflectError::Pool("method without a name".to_owned()))?;
567            let input_index =
568                self.resolve_message_ref(method.input_type.as_deref(), &mname, "input")?;
569            let output_index =
570                self.resolve_message_ref(method.output_type.as_deref(), &mname, "output")?;
571            let method_full_name = qualify(&full_name, &mname);
572            methods.push(MethodData {
573                name: mname,
574                full_name: method_full_name,
575                input_index,
576                output_index,
577                client_streaming: method.client_streaming.unwrap_or(false),
578                server_streaming: method.server_streaming.unwrap_or(false),
579            });
580        }
581
582        let index = self.services.len();
583        self.services.push(ServiceData {
584            full_name: full_name.clone(),
585            name,
586            file_index: self.file_index_for_package(package),
587            methods,
588        });
589        if self
590            .service_by_name
591            .insert(full_name.clone(), index)
592            .is_some()
593        {
594            return Err(ReflectError::Pool(format!(
595                "duplicate service name '{full_name}'"
596            )));
597        }
598        Ok(())
599    }
600
601    /// Resolve a method input/output message type name to a message index.
602    fn resolve_message_ref(
603        &self,
604        type_name: Option<&str>,
605        method_name: &str,
606        role: &str,
607    ) -> Result<usize, ReflectError> {
608        let raw = type_name.ok_or_else(|| {
609            ReflectError::Pool(format!("method '{method_name}' has no {role} type"))
610        })?;
611        let key = raw.strip_prefix('.').unwrap_or(raw);
612        match self.type_by_name.get(key) {
613            Some(TypeRef::Message(i)) => Ok(*i),
614            _ => Err(ReflectError::Pool(format!(
615                "method '{method_name}' {role} type '{key}' is not a known message"
616            ))),
617        }
618    }
619
620    /// Best-effort lookup of a file index for a package, used to set a
621    /// service's parent file. Falls back to the first file (index 0) if no
622    /// match is found and at least one file exists.
623    fn file_index_for_package(&self, package: &str) -> usize {
624        self.files
625            .iter()
626            .position(|f| f.package == package)
627            .unwrap_or(0)
628    }
629}
630
631/// Join a scope and a name with a `.` separator, omitting the separator when
632/// the scope is empty.
633fn qualify(scope: &str, name: &str) -> String {
634    if scope.is_empty() {
635        name.to_owned()
636    } else {
637        format!("{scope}.{name}")
638    }
639}
640
641/// Compute the effective `packed` flag for a field.
642///
643/// Only repeated packable scalars can be packed. proto3 packs by default;
644/// proto2 does not. An explicit `options.packed` overrides the default.
645fn compute_packed(
646    field: &prost_types::FieldDescriptorProto,
647    kind: Kind,
648    cardinality: Cardinality,
649    syntax: &str,
650) -> bool {
651    if !matches!(cardinality, Cardinality::Repeated) || !kind.is_packable() {
652        return false;
653    }
654    if let Some(opts) = field.options.as_ref() {
655        if let Some(packed) = opts.packed {
656            return packed;
657        }
658    }
659    syntax == "proto3"
660}
661
662/// Derive the default JSON name (lowerCamelCase) from a snake_case field name,
663/// matching protobuf's algorithm.
664fn to_json_name(name: &str) -> String {
665    let mut out = String::with_capacity(name.len());
666    let mut upper_next = false;
667    for ch in name.chars() {
668        if ch == '_' {
669            upper_next = true;
670        } else if upper_next {
671            out.extend(ch.to_uppercase());
672            upper_next = false;
673        } else {
674            out.push(ch);
675        }
676    }
677    out
678}