1use prost_types::{
8 field_descriptor_proto::{Label, Type},
9 FieldDescriptorProto,
10};
11
12use crate::descriptor::{Descriptor, DescriptorSet, MessageDescriptor, Syntax, TypeName, TypePath};
13use crate::escape::{escape_ident, escape_type};
14
15#[derive(Debug, Clone, Copy)]
16pub enum ScalarType {
17 F64,
18 F32,
19 I32,
20 I64,
21 U32,
22 U64,
23 Bool,
24 String,
25 Bytes,
26}
27
28impl ScalarType {
29 pub fn rust_type(&self) -> &'static str {
30 match self {
31 Self::F64 => "f64",
32 Self::F32 => "f32",
33 Self::I32 => "i32",
34 Self::I64 => "i64",
35 Self::U32 => "u32",
36 Self::U64 => "u64",
37 Self::Bool => "bool",
38 Self::String => "String",
39 Self::Bytes => "Vec<u8>",
40 }
41 }
42
43 pub fn is_numeric(&self) -> bool {
44 matches!(
45 self,
46 Self::F64 | Self::F32 | Self::I32 | Self::I64 | Self::U32 | Self::U64
47 )
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum FieldType {
53 Scalar(ScalarType),
54 Enum(TypePath),
55 Message(TypePath),
56 Map(ScalarType, Box<FieldType>),
57}
58
59#[derive(Debug, Clone, Copy)]
60pub enum FieldModifier {
61 Required,
62 Optional,
63 UseDefault,
64 Repeated,
65}
66
67impl FieldModifier {
68 pub fn is_required(&self) -> bool {
69 matches!(self, Self::Required)
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct Field {
75 pub name: String,
76 pub json_name: Option<String>,
77 pub field_modifier: FieldModifier,
78 pub field_type: FieldType,
79}
80
81impl Field {
82 pub fn rust_type_name(&self) -> String {
83 use heck::ToUpperCamelCase;
84 escape_type(self.name.to_upper_camel_case())
85 }
86
87 pub fn rust_field_name(&self) -> String {
88 use heck::ToSnakeCase;
89 escape_ident(self.name.to_snake_case())
90 }
91
92 pub fn json_name(&self) -> String {
93 use heck::ToLowerCamelCase;
94 self.json_name
95 .clone()
96 .unwrap_or_else(|| self.name.to_lower_camel_case())
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct OneOf {
102 pub name: String,
103 pub path: TypePath,
104 pub fields: Vec<Field>,
105}
106
107impl OneOf {
108 pub fn rust_field_name(&self) -> String {
109 use heck::ToSnakeCase;
110 escape_ident(self.name.to_snake_case())
111 }
112}
113
114#[derive(Debug, Clone)]
115pub struct Message {
116 pub path: TypePath,
117 pub fields: Vec<Field>,
118 pub one_ofs: Vec<OneOf>,
119}
120
121impl Message {
122 pub fn all_fields(&self) -> impl Iterator<Item = &Field> + '_ {
123 self.fields
124 .iter()
125 .chain(self.one_ofs.iter().flat_map(|one_of| one_of.fields.iter()))
126 }
127}
128
129pub fn resolve_message(
133 descriptors: &DescriptorSet,
134 message: &MessageDescriptor,
135) -> Option<Message> {
136 if message.is_map() {
137 return None;
138 }
139
140 let mut fields = Vec::new();
141 let mut one_of_fields = vec![Vec::new(); message.one_of.len()];
142
143 for field in &message.fields {
144 let field_type = field_type(descriptors, field);
145 let field_modifier = field_modifier(message, field, &field_type);
146
147 let resolved = Field {
148 name: field.name.clone().expect("expected field to have name"),
149 json_name: field.json_name.clone(),
150 field_type,
151 field_modifier,
152 };
153
154 let proto3_optional = field.proto3_optional.unwrap_or(false);
156 match (field.oneof_index, proto3_optional) {
157 (Some(idx), false) => one_of_fields[idx as usize].push(resolved),
158 _ => fields.push(resolved),
159 }
160 }
161
162 let mut one_ofs = Vec::new();
163
164 for (fields, descriptor) in one_of_fields.into_iter().zip(&message.one_of) {
165 if !fields.is_empty() {
167 let name = descriptor.name.clone().expect("oneof with no name");
168 let path = message.path.child(TypeName::new(&name));
169
170 one_ofs.push(OneOf { name, path, fields })
171 }
172 }
173
174 Some(Message {
175 path: message.path.clone(),
176 fields,
177 one_ofs,
178 })
179}
180
181fn field_modifier(
182 message: &MessageDescriptor,
183 field: &FieldDescriptorProto,
184 field_type: &FieldType,
185) -> FieldModifier {
186 let label = Label::try_from(field.label.expect("expected label")).expect("valid label");
187 if field.proto3_optional.unwrap_or(false) {
188 assert_eq!(label, Label::Optional);
189 return FieldModifier::Optional;
190 }
191
192 if field.oneof_index.is_some() {
193 assert_eq!(label, Label::Optional);
194 return FieldModifier::Optional;
195 }
196
197 if matches!(field_type, FieldType::Map(_, _)) {
198 assert_eq!(label, Label::Repeated);
199 return FieldModifier::Repeated;
200 }
201
202 match label {
203 Label::Optional => match message.syntax {
204 Syntax::Proto2 => FieldModifier::Optional,
205 Syntax::Proto3 => match field_type {
206 FieldType::Message(_) => FieldModifier::Optional,
207 _ => FieldModifier::UseDefault,
208 },
209 },
210 Label::Required => FieldModifier::Required,
211 Label::Repeated => FieldModifier::Repeated,
212 }
213}
214
215fn field_type(descriptors: &DescriptorSet, field: &FieldDescriptorProto) -> FieldType {
216 match field.type_name.as_ref() {
217 Some(type_name) => resolve_type(descriptors, type_name.as_str()),
218 None => {
219 let scalar =
220 match Type::try_from(field.r#type.expect("expected type")).expect("valid type") {
221 Type::Double => ScalarType::F64,
222 Type::Float => ScalarType::F32,
223 Type::Int64 | Type::Sfixed64 | Type::Sint64 => ScalarType::I64,
224 Type::Int32 | Type::Sfixed32 | Type::Sint32 => ScalarType::I32,
225 Type::Uint64 | Type::Fixed64 => ScalarType::U64,
226 Type::Uint32 | Type::Fixed32 => ScalarType::U32,
227 Type::Bool => ScalarType::Bool,
228 Type::String => ScalarType::String,
229 Type::Bytes => ScalarType::Bytes,
230 Type::Message | Type::Enum | Type::Group => panic!("no type name specified"),
231 };
232 FieldType::Scalar(scalar)
233 }
234 }
235}
236
237fn resolve_type(descriptors: &DescriptorSet, type_name: &str) -> FieldType {
238 assert!(
239 type_name.starts_with('.'),
240 "pbjson does not currently support resolving relative types"
241 );
242 let maybe_descriptor = descriptors
243 .iter()
244 .find(|(path, _)| path.prefix_match(type_name).is_some());
245
246 match maybe_descriptor {
247 Some((path, Descriptor::Enum(_))) => FieldType::Enum(path.clone()),
248 Some((path, Descriptor::Message(descriptor))) => match descriptor.is_map() {
249 true => {
250 assert_eq!(descriptor.fields.len(), 2, "expected map to have 2 fields");
251 let key = &descriptor.fields[0];
252 let value = &descriptor.fields[1];
253
254 assert_eq!("key", key.name());
255 assert_eq!("value", value.name());
256
257 let key_type = match field_type(descriptors, key) {
258 FieldType::Scalar(scalar) => scalar,
259 _ => panic!("non scalar map key"),
260 };
261 let value_type = field_type(descriptors, value);
262 FieldType::Map(key_type, Box::new(value_type))
263 }
264 false => FieldType::Message(path.clone()),
267 },
268 None => panic!("failed to resolve type: {}", type_name),
269 }
270}