pbjson_build/
message.rs

1//! The raw descriptor format is not very easy to work with, a fact not aided
2//! by prost making almost all members of proto2 syntax message optional
3//!
4//! This module therefore extracts a slightly less obtuse representation of a
5//! message that can be used by the code generation logic
6
7use prost_types::{
8    field_descriptor_proto::{Label, Type},
9    FieldDescriptorProto,
10};
11
12use crate::descriptor::{Descriptor, DescriptorSet, MessageDescriptor, Syntax, TypeName, TypePath};
13use crate::escape::{escape_ident, escape_type};
14
15#[derive(Debug, Clone, Copy)]
16pub enum ScalarType {
17    F64,
18    F32,
19    I32,
20    I64,
21    U32,
22    U64,
23    Bool,
24    String,
25    Bytes,
26}
27
28impl ScalarType {
29    pub fn rust_type(&self) -> &'static str {
30        match self {
31            Self::F64 => "f64",
32            Self::F32 => "f32",
33            Self::I32 => "i32",
34            Self::I64 => "i64",
35            Self::U32 => "u32",
36            Self::U64 => "u64",
37            Self::Bool => "bool",
38            Self::String => "String",
39            Self::Bytes => "Vec<u8>",
40        }
41    }
42
43    pub fn is_numeric(&self) -> bool {
44        matches!(
45            self,
46            Self::F64 | Self::F32 | Self::I32 | Self::I64 | Self::U32 | Self::U64
47        )
48    }
49}
50
51#[derive(Debug, Clone)]
52pub enum FieldType {
53    Scalar(ScalarType),
54    Enum(TypePath),
55    Message(TypePath),
56    Map(ScalarType, Box<FieldType>),
57}
58
59#[derive(Debug, Clone, Copy)]
60pub enum FieldModifier {
61    Required,
62    Optional,
63    UseDefault,
64    Repeated,
65}
66
67impl FieldModifier {
68    pub fn is_required(&self) -> bool {
69        matches!(self, Self::Required)
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct Field {
75    pub name: String,
76    pub json_name: Option<String>,
77    pub field_modifier: FieldModifier,
78    pub field_type: FieldType,
79}
80
81impl Field {
82    pub fn rust_type_name(&self) -> String {
83        use heck::ToUpperCamelCase;
84        escape_type(self.name.to_upper_camel_case())
85    }
86
87    pub fn rust_field_name(&self) -> String {
88        use heck::ToSnakeCase;
89        escape_ident(self.name.to_snake_case())
90    }
91
92    pub fn json_name(&self) -> String {
93        use heck::ToLowerCamelCase;
94        self.json_name
95            .clone()
96            .unwrap_or_else(|| self.name.to_lower_camel_case())
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct OneOf {
102    pub name: String,
103    pub path: TypePath,
104    pub fields: Vec<Field>,
105}
106
107impl OneOf {
108    pub fn rust_field_name(&self) -> String {
109        use heck::ToSnakeCase;
110        escape_ident(self.name.to_snake_case())
111    }
112}
113
114#[derive(Debug, Clone)]
115pub struct Message {
116    pub path: TypePath,
117    pub fields: Vec<Field>,
118    pub one_ofs: Vec<OneOf>,
119}
120
121impl Message {
122    pub fn all_fields(&self) -> impl Iterator<Item = &Field> + '_ {
123        self.fields
124            .iter()
125            .chain(self.one_ofs.iter().flat_map(|one_of| one_of.fields.iter()))
126    }
127}
128
129/// Resolve the provided message descriptor into a slightly less obtuse representation
130///
131/// Returns None if the provided provided message is auto-generated
132pub fn resolve_message(
133    descriptors: &DescriptorSet,
134    message: &MessageDescriptor,
135) -> Option<Message> {
136    if message.is_map() {
137        return None;
138    }
139
140    let mut fields = Vec::new();
141    let mut one_of_fields = vec![Vec::new(); message.one_of.len()];
142
143    for field in &message.fields {
144        let field_type = field_type(descriptors, field);
145        let field_modifier = field_modifier(message, field, &field_type);
146
147        let resolved = Field {
148            name: field.name.clone().expect("expected field to have name"),
149            json_name: field.json_name.clone(),
150            field_type,
151            field_modifier,
152        };
153
154        // Treat synthetic one-of as normal
155        let proto3_optional = field.proto3_optional.unwrap_or(false);
156        match (field.oneof_index, proto3_optional) {
157            (Some(idx), false) => one_of_fields[idx as usize].push(resolved),
158            _ => fields.push(resolved),
159        }
160    }
161
162    let mut one_ofs = Vec::new();
163
164    for (fields, descriptor) in one_of_fields.into_iter().zip(&message.one_of) {
165        // Might be empty in the event of a synthetic one-of
166        if !fields.is_empty() {
167            let name = descriptor.name.clone().expect("oneof with no name");
168            let path = message.path.child(TypeName::new(&name));
169
170            one_ofs.push(OneOf { name, path, fields })
171        }
172    }
173
174    Some(Message {
175        path: message.path.clone(),
176        fields,
177        one_ofs,
178    })
179}
180
181fn field_modifier(
182    message: &MessageDescriptor,
183    field: &FieldDescriptorProto,
184    field_type: &FieldType,
185) -> FieldModifier {
186    let label = Label::try_from(field.label.expect("expected label")).expect("valid label");
187    if field.proto3_optional.unwrap_or(false) {
188        assert_eq!(label, Label::Optional);
189        return FieldModifier::Optional;
190    }
191
192    if field.oneof_index.is_some() {
193        assert_eq!(label, Label::Optional);
194        return FieldModifier::Optional;
195    }
196
197    if matches!(field_type, FieldType::Map(_, _)) {
198        assert_eq!(label, Label::Repeated);
199        return FieldModifier::Repeated;
200    }
201
202    match label {
203        Label::Optional => match message.syntax {
204            Syntax::Proto2 => FieldModifier::Optional,
205            Syntax::Proto3 => match field_type {
206                FieldType::Message(_) => FieldModifier::Optional,
207                _ => FieldModifier::UseDefault,
208            },
209        },
210        Label::Required => FieldModifier::Required,
211        Label::Repeated => FieldModifier::Repeated,
212    }
213}
214
215fn field_type(descriptors: &DescriptorSet, field: &FieldDescriptorProto) -> FieldType {
216    match field.type_name.as_ref() {
217        Some(type_name) => resolve_type(descriptors, type_name.as_str()),
218        None => {
219            let scalar =
220                match Type::try_from(field.r#type.expect("expected type")).expect("valid type") {
221                    Type::Double => ScalarType::F64,
222                    Type::Float => ScalarType::F32,
223                    Type::Int64 | Type::Sfixed64 | Type::Sint64 => ScalarType::I64,
224                    Type::Int32 | Type::Sfixed32 | Type::Sint32 => ScalarType::I32,
225                    Type::Uint64 | Type::Fixed64 => ScalarType::U64,
226                    Type::Uint32 | Type::Fixed32 => ScalarType::U32,
227                    Type::Bool => ScalarType::Bool,
228                    Type::String => ScalarType::String,
229                    Type::Bytes => ScalarType::Bytes,
230                    Type::Message | Type::Enum | Type::Group => panic!("no type name specified"),
231                };
232            FieldType::Scalar(scalar)
233        }
234    }
235}
236
237fn resolve_type(descriptors: &DescriptorSet, type_name: &str) -> FieldType {
238    assert!(
239        type_name.starts_with('.'),
240        "pbjson does not currently support resolving relative types"
241    );
242    let maybe_descriptor = descriptors
243        .iter()
244        .find(|(path, _)| path.prefix_match(type_name).is_some());
245
246    match maybe_descriptor {
247        Some((path, Descriptor::Enum(_))) => FieldType::Enum(path.clone()),
248        Some((path, Descriptor::Message(descriptor))) => match descriptor.is_map() {
249            true => {
250                assert_eq!(descriptor.fields.len(), 2, "expected map to have 2 fields");
251                let key = &descriptor.fields[0];
252                let value = &descriptor.fields[1];
253
254                assert_eq!("key", key.name());
255                assert_eq!("value", value.name());
256
257                let key_type = match field_type(descriptors, key) {
258                    FieldType::Scalar(scalar) => scalar,
259                    _ => panic!("non scalar map key"),
260                };
261                let value_type = field_type(descriptors, value);
262                FieldType::Map(key_type, Box::new(value_type))
263            }
264            // Note: This may actually be a group but it is non-trivial to detect this,
265            // they're deprecated, and pbjson doesn't need to be able to distinguish
266            false => FieldType::Message(path.clone()),
267        },
268        None => panic!("failed to resolve type: {}", type_name),
269    }
270}