exocore_protos/
registry.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use protobuf::{
7    descriptor::{FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet},
8    reflect::FileDescriptor,
9    Message, MessageFull, UnknownValueRef,
10};
11
12use super::{
13    reflect::{FieldDescriptor, FieldType, ReflectMessageDescriptor},
14    Error,
15};
16
17type MessageDescriptorsMap = HashMap<String, Arc<ReflectMessageDescriptor>>;
18type FileDescriptorsMap = HashMap<String, FileDescriptor>;
19
20pub struct Registry {
21    message_descriptors: RwLock<MessageDescriptorsMap>,
22    file_descriptors: RwLock<FileDescriptorsMap>,
23}
24
25impl Registry {
26    pub fn new() -> Registry {
27        Registry {
28            message_descriptors: RwLock::new(HashMap::new()),
29            file_descriptors: RwLock::new(HashMap::new()),
30        }
31    }
32
33    pub fn new_with_exocore_types() -> Registry {
34        let reg = Registry::new();
35
36        reg.register_well_knowns();
37
38        reg.register_file_descriptor_set_bytes(super::generated::STORE_FDSET)
39            .expect("Couldn't register exocore_store FileDescriptorProto");
40
41        reg.register_file_descriptor_set_bytes(super::generated::TEST_FDSET)
42            .expect("Couldn't register exocore_test FileDescriptorProto");
43
44        reg
45    }
46
47    pub fn register_well_knowns(&self) {
48        let fds = &[
49            protobuf::well_known_types::timestamp::Timestamp::descriptor(),
50            protobuf::well_known_types::any::Any::descriptor(),
51            FileDescriptorProto::descriptor(),
52        ];
53
54        for fd in fds {
55            self.register_file_descriptor(fd.file_descriptor_proto().clone());
56        }
57    }
58
59    pub fn register_file_descriptor_set(&self, fd_set: &FileDescriptorSet) {
60        let fds = protobuf::reflect::FileDescriptor::new_dynamic_fds(
61            fd_set.file.clone(),
62            self.dependencies().as_ref(),
63        )
64        .expect("FIX ME");
65
66        for fd in &fds {
67            {
68                let mut file_descriptors = self.file_descriptors.write().unwrap();
69                file_descriptors.insert(fd.name().to_string(), fd.clone());
70            }
71
72            for msg_descriptor in fd.messages() {
73                let full_name = format!("{}.{}", fd.package(), msg_descriptor.name(),);
74                self.register_message_descriptor(full_name, msg_descriptor);
75            }
76        }
77    }
78
79    fn dependencies(&self) -> Vec<FileDescriptor> {
80        let fds = self.file_descriptors.read().unwrap();
81        fds.values().cloned().collect()
82    }
83
84    pub fn register_file_descriptor_set_bytes<R: std::io::Read>(
85        &self,
86        fd_set_bytes: R,
87    ) -> Result<(), Error> {
88        let mut bytes = fd_set_bytes;
89        let fd_set = FileDescriptorSet::parse_from_reader(&mut bytes)?;
90
91        self.register_file_descriptor_set(&fd_set);
92
93        Ok(())
94    }
95
96    pub fn register_file_descriptor(&self, file_descriptor_proto: FileDescriptorProto) {
97        let fd = protobuf::reflect::FileDescriptor::new_dynamic(
98            file_descriptor_proto,
99            self.dependencies().as_ref(),
100        )
101        .expect("FIX ME");
102
103        {
104            let mut file_descriptors = self.file_descriptors.write().unwrap();
105            file_descriptors.insert(fd.name().to_string(), fd.clone());
106        }
107
108        for msg_descriptor in fd.messages() {
109            let full_name = format!("{}.{}", fd.package(), msg_descriptor.name(),);
110            self.register_message_descriptor(full_name, msg_descriptor);
111        }
112    }
113
114    pub fn register_message_descriptor(
115        &self,
116        full_name: String,
117        msg_descriptor: protobuf::reflect::MessageDescriptor,
118    ) -> Arc<ReflectMessageDescriptor> {
119        for sub_msg in msg_descriptor.nested_messages() {
120            let sub_full_name = format!("{}.{}", full_name, sub_msg.name());
121            self.register_message_descriptor(sub_full_name, sub_msg.clone());
122        }
123
124        let mut fields = HashMap::new();
125        for field in msg_descriptor.fields() {
126            let field_proto = field.proto();
127
128            use protobuf::descriptor::field_descriptor_proto::Type as ProtoFieldType;
129            let mut field_type = match field_proto.type_.map(|e| e.enum_value()) {
130                Some(Ok(ProtoFieldType::TYPE_STRING)) => FieldType::String,
131                Some(Ok(ProtoFieldType::TYPE_INT32)) => FieldType::Int32,
132                Some(Ok(ProtoFieldType::TYPE_UINT32)) => FieldType::Uint32,
133                Some(Ok(ProtoFieldType::TYPE_INT64)) => FieldType::Int64,
134                Some(Ok(ProtoFieldType::TYPE_UINT64)) => FieldType::Uint64,
135                Some(Ok(ProtoFieldType::TYPE_MESSAGE)) => {
136                    let typ = field_proto.type_name().trim_start_matches('.');
137                    match typ {
138                        "google.protobuf.Timestamp" => FieldType::DateTime,
139                        "exocore.store.Reference" => FieldType::Reference,
140                        _ => FieldType::Message(typ.to_string()),
141                    }
142                }
143
144                _ => continue,
145            };
146
147            if field_proto.label()
148                == protobuf::descriptor::field_descriptor_proto::Label::LABEL_REPEATED
149            {
150                field_type = FieldType::Repeated(Box::new(field_type));
151            }
152
153            if let Some(number) = field_proto.number {
154                let id = number as u32;
155                fields.insert(
156                    id,
157                    FieldDescriptor {
158                        id,
159                        descriptor: field.clone(),
160                        name: field.name().to_string(),
161                        field_type,
162
163                        // see exocore/store/options.proto
164                        indexed_flag: Registry::field_has_option(field_proto, 1373),
165                        sorted_flag: Registry::field_has_option(field_proto, 1374),
166                        text_flag: Registry::field_has_option(field_proto, 1375),
167                        groups: Registry::get_field_u32s_option(field_proto, 1376),
168                    },
169                );
170            }
171        }
172
173        let short_names = Registry::get_message_strings_option(&msg_descriptor, 1377);
174        let descriptor = Arc::new(ReflectMessageDescriptor {
175            name: full_name.clone(),
176            fields,
177            message: msg_descriptor,
178
179            // see exocore/store/options.proto
180            short_names,
181        });
182
183        let mut file_descriptors = self.file_descriptors.write().unwrap();
184        let fd = descriptor.message.file_descriptor();
185        file_descriptors.insert(fd.name().to_string(), fd.clone());
186
187        let mut message_descriptors = self.message_descriptors.write().unwrap();
188        message_descriptors.insert(full_name, descriptor.clone());
189
190        descriptor
191    }
192
193    pub fn get_message_descriptor(
194        &self,
195        full_name: &str,
196    ) -> Result<Arc<ReflectMessageDescriptor>, Error> {
197        let message_descriptors = self.message_descriptors.read().unwrap();
198        message_descriptors
199            .get(full_name)
200            .cloned()
201            .ok_or_else(|| Error::NotInRegistry(full_name.to_string()))
202    }
203
204    pub fn message_descriptors(&self) -> Vec<Arc<ReflectMessageDescriptor>> {
205        let message_descriptors = self.message_descriptors.read().unwrap();
206        message_descriptors.values().cloned().collect()
207    }
208
209    fn field_has_option(field: &FieldDescriptorProto, option_field_id: u32) -> bool {
210        if let Some(UnknownValueRef::Varint(v)) =
211            field.options.unknown_fields().get(option_field_id)
212        {
213            v == 1
214        } else {
215            false
216        }
217    }
218
219    fn get_field_u32s_option(field: &FieldDescriptorProto, option_field_id: u32) -> Vec<u32> {
220        let mut ret = Vec::new();
221        for (field_id, value) in field.options.unknown_fields().iter() {
222            // unfortunately, doesn't allow getting multiple values for one field other than
223            // iterating on all options
224            if field_id != option_field_id {
225                continue;
226            }
227
228            match value {
229                UnknownValueRef::Varint(v) => {
230                    ret.push(v as u32);
231                }
232                UnknownValueRef::LengthDelimited(values) => {
233                    for value in values {
234                        ret.push(*value as u32);
235                    }
236                }
237                _ => (),
238            }
239        }
240        ret
241    }
242
243    fn get_message_strings_option(
244        msg_desc: &protobuf::reflect::MessageDescriptor,
245        option_field_id: u32,
246    ) -> Vec<String> {
247        let mut ret = Vec::new();
248        for (field_id, value) in msg_desc.proto().options.unknown_fields().iter() {
249            if let UnknownValueRef::LengthDelimited(bytes) = value {
250                // unfortunately, doesn't allow getting multiple values for one field other than
251                // iterating on all options
252                if field_id == option_field_id {
253                    ret.push(String::from_utf8_lossy(bytes).to_string());
254                }
255            }
256        }
257        ret
258    }
259}
260
261impl Default for Registry {
262    fn default() -> Self {
263        Registry::new()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn with_exocore_types() {
273        let reg = Registry::new_with_exocore_types();
274        let entity = reg.get_message_descriptor("exocore.store.Entity").unwrap();
275        assert_eq!(entity.name, "exocore.store.Entity");
276        assert!(!entity.fields.is_empty());
277
278        let desc = reg.message_descriptors();
279        assert!(desc.len() > 20);
280    }
281
282    #[test]
283    fn field_and_msg_options() -> anyhow::Result<()> {
284        let registry = Registry::new_with_exocore_types();
285
286        let descriptor = registry.get_message_descriptor("exocore.test.TestMessage")?;
287
288        // see `protos/exocore/test/test.proto`
289        assert_eq!(descriptor.short_names, vec!["test".to_string()]);
290
291        assert!(descriptor.fields.get(&1).unwrap().text_flag);
292        assert!(!descriptor.fields.get(&2).unwrap().text_flag);
293
294        assert!(descriptor.fields.get(&8).unwrap().indexed_flag);
295        assert!(!descriptor.fields.get(&9).unwrap().indexed_flag);
296
297        assert!(descriptor.fields.get(&18).unwrap().sorted_flag);
298        assert!(!descriptor.fields.get(&11).unwrap().sorted_flag);
299
300        assert!(descriptor.fields.get(&19).unwrap().groups.is_empty());
301        assert_eq!(descriptor.fields.get(&20).unwrap().groups, vec![1]);
302        assert_eq!(descriptor.fields.get(&21).unwrap().groups, vec![1, 2]);
303
304        Ok(())
305    }
306}