1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use protobuf::{
7 descriptor::{FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet},
8 reflect::FileDescriptor,
9 Message, MessageFull, UnknownValueRef,
10};
11
12use super::{
13 reflect::{FieldDescriptor, FieldType, ReflectMessageDescriptor},
14 Error,
15};
16
17type MessageDescriptorsMap = HashMap<String, Arc<ReflectMessageDescriptor>>;
18type FileDescriptorsMap = HashMap<String, FileDescriptor>;
19
20pub struct Registry {
21 message_descriptors: RwLock<MessageDescriptorsMap>,
22 file_descriptors: RwLock<FileDescriptorsMap>,
23}
24
25impl Registry {
26 pub fn new() -> Registry {
27 Registry {
28 message_descriptors: RwLock::new(HashMap::new()),
29 file_descriptors: RwLock::new(HashMap::new()),
30 }
31 }
32
33 pub fn new_with_exocore_types() -> Registry {
34 let reg = Registry::new();
35
36 reg.register_well_knowns();
37
38 reg.register_file_descriptor_set_bytes(super::generated::STORE_FDSET)
39 .expect("Couldn't register exocore_store FileDescriptorProto");
40
41 reg.register_file_descriptor_set_bytes(super::generated::TEST_FDSET)
42 .expect("Couldn't register exocore_test FileDescriptorProto");
43
44 reg
45 }
46
47 pub fn register_well_knowns(&self) {
48 let fds = &[
49 protobuf::well_known_types::timestamp::Timestamp::descriptor(),
50 protobuf::well_known_types::any::Any::descriptor(),
51 FileDescriptorProto::descriptor(),
52 ];
53
54 for fd in fds {
55 self.register_file_descriptor(fd.file_descriptor_proto().clone());
56 }
57 }
58
59 pub fn register_file_descriptor_set(&self, fd_set: &FileDescriptorSet) {
60 let fds = protobuf::reflect::FileDescriptor::new_dynamic_fds(
61 fd_set.file.clone(),
62 self.dependencies().as_ref(),
63 )
64 .expect("FIX ME");
65
66 for fd in &fds {
67 {
68 let mut file_descriptors = self.file_descriptors.write().unwrap();
69 file_descriptors.insert(fd.name().to_string(), fd.clone());
70 }
71
72 for msg_descriptor in fd.messages() {
73 let full_name = format!("{}.{}", fd.package(), msg_descriptor.name(),);
74 self.register_message_descriptor(full_name, msg_descriptor);
75 }
76 }
77 }
78
79 fn dependencies(&self) -> Vec<FileDescriptor> {
80 let fds = self.file_descriptors.read().unwrap();
81 fds.values().cloned().collect()
82 }
83
84 pub fn register_file_descriptor_set_bytes<R: std::io::Read>(
85 &self,
86 fd_set_bytes: R,
87 ) -> Result<(), Error> {
88 let mut bytes = fd_set_bytes;
89 let fd_set = FileDescriptorSet::parse_from_reader(&mut bytes)?;
90
91 self.register_file_descriptor_set(&fd_set);
92
93 Ok(())
94 }
95
96 pub fn register_file_descriptor(&self, file_descriptor_proto: FileDescriptorProto) {
97 let fd = protobuf::reflect::FileDescriptor::new_dynamic(
98 file_descriptor_proto,
99 self.dependencies().as_ref(),
100 )
101 .expect("FIX ME");
102
103 {
104 let mut file_descriptors = self.file_descriptors.write().unwrap();
105 file_descriptors.insert(fd.name().to_string(), fd.clone());
106 }
107
108 for msg_descriptor in fd.messages() {
109 let full_name = format!("{}.{}", fd.package(), msg_descriptor.name(),);
110 self.register_message_descriptor(full_name, msg_descriptor);
111 }
112 }
113
114 pub fn register_message_descriptor(
115 &self,
116 full_name: String,
117 msg_descriptor: protobuf::reflect::MessageDescriptor,
118 ) -> Arc<ReflectMessageDescriptor> {
119 for sub_msg in msg_descriptor.nested_messages() {
120 let sub_full_name = format!("{}.{}", full_name, sub_msg.name());
121 self.register_message_descriptor(sub_full_name, sub_msg.clone());
122 }
123
124 let mut fields = HashMap::new();
125 for field in msg_descriptor.fields() {
126 let field_proto = field.proto();
127
128 use protobuf::descriptor::field_descriptor_proto::Type as ProtoFieldType;
129 let mut field_type = match field_proto.type_.map(|e| e.enum_value()) {
130 Some(Ok(ProtoFieldType::TYPE_STRING)) => FieldType::String,
131 Some(Ok(ProtoFieldType::TYPE_INT32)) => FieldType::Int32,
132 Some(Ok(ProtoFieldType::TYPE_UINT32)) => FieldType::Uint32,
133 Some(Ok(ProtoFieldType::TYPE_INT64)) => FieldType::Int64,
134 Some(Ok(ProtoFieldType::TYPE_UINT64)) => FieldType::Uint64,
135 Some(Ok(ProtoFieldType::TYPE_MESSAGE)) => {
136 let typ = field_proto.type_name().trim_start_matches('.');
137 match typ {
138 "google.protobuf.Timestamp" => FieldType::DateTime,
139 "exocore.store.Reference" => FieldType::Reference,
140 _ => FieldType::Message(typ.to_string()),
141 }
142 }
143
144 _ => continue,
145 };
146
147 if field_proto.label()
148 == protobuf::descriptor::field_descriptor_proto::Label::LABEL_REPEATED
149 {
150 field_type = FieldType::Repeated(Box::new(field_type));
151 }
152
153 if let Some(number) = field_proto.number {
154 let id = number as u32;
155 fields.insert(
156 id,
157 FieldDescriptor {
158 id,
159 descriptor: field.clone(),
160 name: field.name().to_string(),
161 field_type,
162
163 indexed_flag: Registry::field_has_option(field_proto, 1373),
165 sorted_flag: Registry::field_has_option(field_proto, 1374),
166 text_flag: Registry::field_has_option(field_proto, 1375),
167 groups: Registry::get_field_u32s_option(field_proto, 1376),
168 },
169 );
170 }
171 }
172
173 let short_names = Registry::get_message_strings_option(&msg_descriptor, 1377);
174 let descriptor = Arc::new(ReflectMessageDescriptor {
175 name: full_name.clone(),
176 fields,
177 message: msg_descriptor,
178
179 short_names,
181 });
182
183 let mut file_descriptors = self.file_descriptors.write().unwrap();
184 let fd = descriptor.message.file_descriptor();
185 file_descriptors.insert(fd.name().to_string(), fd.clone());
186
187 let mut message_descriptors = self.message_descriptors.write().unwrap();
188 message_descriptors.insert(full_name, descriptor.clone());
189
190 descriptor
191 }
192
193 pub fn get_message_descriptor(
194 &self,
195 full_name: &str,
196 ) -> Result<Arc<ReflectMessageDescriptor>, Error> {
197 let message_descriptors = self.message_descriptors.read().unwrap();
198 message_descriptors
199 .get(full_name)
200 .cloned()
201 .ok_or_else(|| Error::NotInRegistry(full_name.to_string()))
202 }
203
204 pub fn message_descriptors(&self) -> Vec<Arc<ReflectMessageDescriptor>> {
205 let message_descriptors = self.message_descriptors.read().unwrap();
206 message_descriptors.values().cloned().collect()
207 }
208
209 fn field_has_option(field: &FieldDescriptorProto, option_field_id: u32) -> bool {
210 if let Some(UnknownValueRef::Varint(v)) =
211 field.options.unknown_fields().get(option_field_id)
212 {
213 v == 1
214 } else {
215 false
216 }
217 }
218
219 fn get_field_u32s_option(field: &FieldDescriptorProto, option_field_id: u32) -> Vec<u32> {
220 let mut ret = Vec::new();
221 for (field_id, value) in field.options.unknown_fields().iter() {
222 if field_id != option_field_id {
225 continue;
226 }
227
228 match value {
229 UnknownValueRef::Varint(v) => {
230 ret.push(v as u32);
231 }
232 UnknownValueRef::LengthDelimited(values) => {
233 for value in values {
234 ret.push(*value as u32);
235 }
236 }
237 _ => (),
238 }
239 }
240 ret
241 }
242
243 fn get_message_strings_option(
244 msg_desc: &protobuf::reflect::MessageDescriptor,
245 option_field_id: u32,
246 ) -> Vec<String> {
247 let mut ret = Vec::new();
248 for (field_id, value) in msg_desc.proto().options.unknown_fields().iter() {
249 if let UnknownValueRef::LengthDelimited(bytes) = value {
250 if field_id == option_field_id {
253 ret.push(String::from_utf8_lossy(bytes).to_string());
254 }
255 }
256 }
257 ret
258 }
259}
260
261impl Default for Registry {
262 fn default() -> Self {
263 Registry::new()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn with_exocore_types() {
273 let reg = Registry::new_with_exocore_types();
274 let entity = reg.get_message_descriptor("exocore.store.Entity").unwrap();
275 assert_eq!(entity.name, "exocore.store.Entity");
276 assert!(!entity.fields.is_empty());
277
278 let desc = reg.message_descriptors();
279 assert!(desc.len() > 20);
280 }
281
282 #[test]
283 fn field_and_msg_options() -> anyhow::Result<()> {
284 let registry = Registry::new_with_exocore_types();
285
286 let descriptor = registry.get_message_descriptor("exocore.test.TestMessage")?;
287
288 assert_eq!(descriptor.short_names, vec!["test".to_string()]);
290
291 assert!(descriptor.fields.get(&1).unwrap().text_flag);
292 assert!(!descriptor.fields.get(&2).unwrap().text_flag);
293
294 assert!(descriptor.fields.get(&8).unwrap().indexed_flag);
295 assert!(!descriptor.fields.get(&9).unwrap().indexed_flag);
296
297 assert!(descriptor.fields.get(&18).unwrap().sorted_flag);
298 assert!(!descriptor.fields.get(&11).unwrap().sorted_flag);
299
300 assert!(descriptor.fields.get(&19).unwrap().groups.is_empty());
301 assert_eq!(descriptor.fields.get(&20).unwrap().groups, vec![1]);
302 assert_eq!(descriptor.fields.get(&21).unwrap().groups, vec![1, 2]);
303
304 Ok(())
305 }
306}