exocore_protos/
reflect.rs

1use std::{collections::HashMap, convert::TryFrom, fmt::Debug, sync::Arc};
2
3pub use protobuf::{descriptor::FileDescriptorSet, Message};
4use protobuf::{
5    reflect::{FieldDescriptor as FieldDescriptorProto, ReflectFieldRef, ReflectValueRef},
6    well_known_types::any::Any,
7    MessageDyn,
8};
9
10use super::{registry::Registry, Error};
11use crate::generated::exocore_store::Reference;
12
13pub trait ReflectMessage: Debug + Sized {
14    fn descriptor(&self) -> &ReflectMessageDescriptor;
15
16    fn full_name(&self) -> &str {
17        &self.descriptor().name
18    }
19
20    fn fields(&self) -> &HashMap<FieldId, FieldDescriptor> {
21        &self.descriptor().fields
22    }
23
24    fn get_field(&self, id: FieldId) -> Option<&FieldDescriptor> {
25        self.descriptor().fields.get(&id)
26    }
27
28    fn get_field_value(&self, field_id: FieldId) -> Result<FieldValue, Error>;
29
30    fn encode(&self) -> Result<Vec<u8>, Error>;
31
32    fn encode_to_prost_any(&self) -> Result<prost_types::Any, Error> {
33        let bytes = self.encode()?;
34        Ok(prost_types::Any {
35            type_url: format!("type.googleapis.com/{}", self.descriptor().name),
36            value: bytes,
37        })
38    }
39
40    fn encode_json(&self, registry: &Registry) -> Result<serde_json::Value, Error> {
41        message_to_json(self, registry)
42    }
43}
44
45fn message_to_json<M: ReflectMessage>(
46    msg: &M,
47    registry: &Registry,
48) -> Result<serde_json::Value, Error> {
49    use serde_json::Value;
50
51    let mut values = serde_json::Map::<String, serde_json::Value>::new();
52    for (id, desc) in msg.fields() {
53        let name = desc.name.clone();
54        match msg.get_field_value(*id) {
55            Ok(value) => {
56                let json_value = field_value_to_json(value, registry)?;
57                values.insert(name, json_value);
58            }
59            Err(Error::NoSuchField(_)) => {
60                // field is not set
61            }
62            Err(other) => return Err(other),
63        }
64    }
65
66    let mut obj = serde_json::Map::new();
67    obj.insert(
68        "type".to_string(),
69        Value::String(msg.full_name().to_string()),
70    );
71    obj.insert("value".to_string(), Value::Object(values));
72
73    Ok(serde_json::Value::Object(obj))
74}
75
76fn field_value_to_json(value: FieldValue, registry: &Registry) -> Result<serde_json::Value, Error> {
77    use serde_json::Value;
78    Ok(match value {
79        FieldValue::String(v) => Value::String(v),
80        FieldValue::Int32(v) => Value::Number(v.into()),
81        FieldValue::Uint32(v) => Value::Number(v.into()),
82        FieldValue::Int64(v) => Value::Number(v.into()),
83        FieldValue::Uint64(v) => Value::Number(v.into()),
84        FieldValue::Reference(reference) => {
85            let mut obj = serde_json::Map::new();
86            obj.insert("entity_id".to_string(), Value::String(reference.entity_id));
87            obj.insert("trait_id".to_string(), Value::String(reference.trait_id));
88            Value::Object(obj)
89        }
90        FieldValue::DateTime(v) => Value::String(v.to_rfc3339()),
91        FieldValue::Message(typ, msg) => {
92            let msg = FieldValue::Message(typ, msg).into_message(registry)?;
93            msg.encode_json(registry)?
94        }
95        FieldValue::Repeated(values) => {
96            let arr = values
97                .into_iter()
98                .map(|v| field_value_to_json(v, registry))
99                .collect::<Result<Vec<_>, Error>>()?;
100            Value::Array(arr)
101        }
102    })
103}
104
105pub trait MutableReflectMessage: ReflectMessage {
106    fn clear_field_value(&mut self, field_id: FieldId) -> Result<(), Error>;
107}
108
109pub struct DynamicMessage {
110    message: Box<dyn MessageDyn>,
111    descriptor: Arc<ReflectMessageDescriptor>,
112}
113
114impl ReflectMessage for DynamicMessage {
115    fn descriptor(&self) -> &ReflectMessageDescriptor {
116        self.descriptor.as_ref()
117    }
118
119    fn get_field_value(&self, field_id: FieldId) -> Result<FieldValue, Error> {
120        let field = self
121            .get_field(field_id)
122            .ok_or(Error::NoSuchField(field_id))?;
123
124        let reflect_field = field.descriptor.get_reflect(self.message.as_ref());
125        convert_field_ref(field_id, &field.field_type, reflect_field)
126    }
127
128    fn encode(&self) -> Result<Vec<u8>, Error> {
129        let bytes = self.message.write_to_bytes_dyn()?;
130        Ok(bytes)
131    }
132}
133
134fn convert_field_ref(
135    field_id: FieldId,
136    field_type: &FieldType,
137    field_ref: ReflectFieldRef,
138) -> Result<FieldValue, Error> {
139    match field_ref {
140        ReflectFieldRef::Optional(v) => match v.value() {
141            Some(v) => convert_field_value(field_type, v),
142            None => Err(Error::NoSuchField(field_id)),
143        },
144        ReflectFieldRef::Repeated(r) => {
145            let FieldType::Repeated(inner_field_type) = field_type else {
146                return Err(Error::Other(anyhow!(
147                    "expected repeated field type, got {field_type:?} at field {field_id:?}"
148                )));
149            };
150
151            let mut values = Vec::new();
152            for i in 0..r.len() {
153                values.push(convert_field_value(inner_field_type, r.get(i))?);
154            }
155            Ok(FieldValue::Repeated(values))
156        }
157        ReflectFieldRef::Map(_) => {
158            // TODO: Implement me
159            Err(Error::NoSuchField(field_id))
160        }
161    }
162}
163
164fn convert_field_value(
165    field_type: &FieldType,
166    value: ReflectValueRef,
167) -> Result<FieldValue, Error> {
168    match field_type {
169        FieldType::String => match value {
170            ReflectValueRef::String(v) => Ok(FieldValue::String(v.to_string())),
171            v => Err(Error::Other(anyhow!("expected string field, got: {v:?}"))),
172        },
173        FieldType::Int32 => match value {
174            ReflectValueRef::I32(v) => Ok(FieldValue::Int32(v)),
175            v => Err(Error::Other(anyhow!("expected int32 field, got: {v:?}"))),
176        },
177        FieldType::Uint32 => match value {
178            ReflectValueRef::U32(v) => Ok(FieldValue::Uint32(v)),
179            v => Err(Error::Other(anyhow!("expected uint32 field, got: {v:?}"))),
180        },
181        FieldType::Int64 => match value {
182            ReflectValueRef::I64(v) => Ok(FieldValue::Int64(v)),
183            v => Err(Error::Other(anyhow!("expected int64 field, got: {v:?}"))),
184        },
185        FieldType::Uint64 => match value {
186            ReflectValueRef::U64(v) => Ok(FieldValue::Uint64(v)),
187            v => Err(Error::Other(anyhow!("expected uint64 field, got: {v:?}"))),
188        },
189        FieldType::DateTime => match value {
190            ReflectValueRef::Message(msg) => {
191                let msg_desc = msg.descriptor_dyn();
192                let secs_desc = msg_desc.field_by_number(1).unwrap();
193                let secs = secs_desc
194                    .get_singular(&*msg)
195                    .and_then(|v| v.to_i64())
196                    .unwrap_or_default();
197
198                let nanos_desc = msg_desc.field_by_number(2).unwrap();
199                let nanos = nanos_desc
200                    .get_singular(&*msg)
201                    .and_then(|v| v.to_i32())
202                    .unwrap_or_default();
203
204                Ok(FieldValue::DateTime(
205                    crate::time::timestamp_parts_to_datetime(secs, nanos),
206                ))
207            }
208            v => Err(Error::Other(anyhow!(
209                "expected message as timestamp field, got: {v:?}"
210            ))),
211        },
212        FieldType::Reference => match value {
213            ReflectValueRef::Message(msg) => {
214                let msg_desc = msg.descriptor_dyn();
215                let et_desc = msg_desc.field_by_number(1).unwrap();
216                let entity_id = et_desc
217                    .get_singular(&*msg)
218                    .and_then(|v| v.to_str().map(|v| v.to_string()))
219                    .unwrap_or_default();
220
221                let trt_desc = msg_desc.field_by_number(2).unwrap();
222                let trait_id = trt_desc
223                    .get_singular(&*msg)
224                    .and_then(|v| v.to_str().map(|v| v.to_string()))
225                    .unwrap_or_default();
226
227                Ok(FieldValue::Reference(Reference {
228                    entity_id,
229                    trait_id,
230                }))
231            }
232            v => Err(Error::Other(anyhow!(
233                "expected message as reference field, got: {v:?}"
234            ))),
235        },
236        FieldType::Message(msg_type) => match value {
237            ReflectValueRef::Message(msg) => {
238                let dyn_msg = msg.clone_box();
239                Ok(FieldValue::Message(msg_type.clone(), dyn_msg))
240            }
241            v => Err(Error::Other(anyhow!(
242                "expected field to be a message, got: {v:?}"
243            ))),
244        },
245        FieldType::Repeated(_) => {
246            unreachable!("repeated fields should have been handled in convert_field_ref");
247        }
248    }
249}
250
251impl MutableReflectMessage for DynamicMessage {
252    fn clear_field_value(&mut self, field_id: FieldId) -> Result<(), Error> {
253        let field = self
254            .descriptor
255            .fields
256            .get(&field_id)
257            .ok_or(Error::NoSuchField(field_id))?;
258
259        if !field.descriptor.has_field(self.message.as_ref()) {
260            return Ok(());
261        }
262
263        if field.descriptor.is_repeated() {
264            let mut repeated = field.descriptor.mut_repeated(self.message.as_mut());
265            repeated.clear();
266        } else if field.descriptor.is_map() {
267            let mut map = field.descriptor.mut_map(self.message.as_mut());
268            map.clear();
269        } else {
270            field.descriptor.clear_field(self.message.as_mut());
271        }
272
273        Ok(())
274    }
275}
276
277impl Debug for DynamicMessage {
278    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
279        f.debug_struct("DynamicMessage")
280            .field("full_name", &self.descriptor.name)
281            .finish()
282    }
283}
284
285pub type FieldId = u32;
286
287pub type FieldGroupId = u32;
288
289pub struct ReflectMessageDescriptor {
290    pub name: String, // full name of the message
291    pub fields: HashMap<FieldId, FieldDescriptor>,
292    pub message: protobuf::reflect::MessageDescriptor,
293
294    // see exocore/store/options.proto
295    pub short_names: Vec<String>,
296}
297
298pub struct FieldDescriptor {
299    pub id: FieldId,
300    pub descriptor: FieldDescriptorProto,
301    pub name: String,
302    pub field_type: FieldType,
303
304    // see exocore/store/options.proto
305    pub indexed_flag: bool,
306    pub sorted_flag: bool,
307    pub text_flag: bool,
308    pub groups: Vec<FieldGroupId>,
309}
310
311#[derive(Debug, Clone, PartialEq)]
312pub enum FieldType {
313    String,
314    Int32,
315    Uint32,
316    Int64,
317    Uint64,
318    DateTime,
319    Reference,
320    Message(String),
321    Repeated(Box<FieldType>),
322}
323
324#[derive(Debug)]
325pub enum FieldValue {
326    String(String),
327    Int32(i32),
328    Uint32(u32),
329    Int64(i64),
330    Uint64(u64),
331    Reference(Reference),
332    DateTime(chrono::DateTime<chrono::Utc>),
333    Message(String, Box<dyn MessageDyn>),
334    Repeated(Vec<FieldValue>),
335}
336
337impl FieldValue {
338    pub fn as_str(&self) -> Result<&str, Error> {
339        if let FieldValue::String(value) = self {
340            Ok(value.as_ref())
341        } else {
342            Err(Error::InvalidFieldType)
343        }
344    }
345
346    pub fn as_datetime(&self) -> Result<&chrono::DateTime<chrono::Utc>, Error> {
347        if let FieldValue::DateTime(value) = self {
348            Ok(value)
349        } else {
350            Err(Error::InvalidFieldType)
351        }
352    }
353
354    pub fn as_reference(&self) -> Result<&Reference, Error> {
355        if let FieldValue::Reference(value) = self {
356            Ok(value)
357        } else {
358            Err(Error::InvalidFieldType)
359        }
360    }
361
362    pub fn into_message(self, registry: &Registry) -> Result<DynamicMessage, Error> {
363        if let FieldValue::Message(typ, message) = self {
364            let descriptor = registry.get_message_descriptor(&typ)?;
365            Ok(DynamicMessage {
366                message,
367                descriptor,
368            })
369        } else {
370            Err(Error::InvalidFieldType)
371        }
372    }
373}
374
375impl<'s> TryFrom<&'s FieldValue> for &'s str {
376    type Error = Error;
377
378    fn try_from(value: &'s FieldValue) -> Result<Self, Error> {
379        match value {
380            FieldValue::String(value) => Ok(value),
381            _ => Err(Error::InvalidFieldType),
382        }
383    }
384}
385
386pub fn from_stepan_any(registry: &Registry, any: &Any) -> Result<DynamicMessage, Error> {
387    from_any_url_and_data(registry, &any.type_url, &any.value)
388}
389
390pub fn from_prost_any(
391    registry: &Registry,
392    any: &prost_types::Any,
393) -> Result<DynamicMessage, Error> {
394    from_any_url_and_data(registry, &any.type_url, &any.value)
395}
396
397pub fn from_any_url_and_data(
398    registry: &Registry,
399    url: &str,
400    data: &[u8],
401) -> Result<DynamicMessage, Error> {
402    let full_name = any_url_to_full_name(url);
403
404    let descriptor = registry.get_message_descriptor(&full_name)?;
405    let message = descriptor.message.parse_from_bytes(data)?;
406
407    Ok(DynamicMessage {
408        message,
409        descriptor,
410    })
411}
412
413pub fn any_url_to_full_name(url: &str) -> String {
414    url.replace("type.googleapis.com/", "")
415}
416
417#[cfg(test)]
418mod tests {
419    use chrono::Utc;
420
421    use super::*;
422    use crate::{
423        generated::exocore_test::TestMessage,
424        prost::{ProstAnyPackMessageExt, ProstDateTimeExt},
425        test::TestStruct,
426    };
427
428    #[test]
429    fn reflect_dyn_message() -> anyhow::Result<()> {
430        let registry = Registry::new_with_exocore_types();
431
432        let mut map1 = HashMap::new();
433        map1.insert("key1".to_string(), "value1".to_string());
434        map1.insert("key2".to_string(), "value2".to_string());
435
436        let now = Utc::now();
437        let msg = TestMessage {
438            string1: "val1".to_string(),
439            date1: Some(now.to_proto_timestamp()),
440            ref1: Some(Reference {
441                entity_id: "et1".to_string(),
442                trait_id: "trt1".to_string(),
443            }),
444            ref2: Some(Reference {
445                entity_id: "et2".to_string(),
446                trait_id: String::new(),
447            }),
448            struct1: Some(TestStruct {
449                string1: "str1".to_string(),
450            }),
451            map1,
452            ..Default::default()
453        };
454
455        let msg_any = msg.pack_to_stepan_any()?;
456        let dyn_msg = from_stepan_any(&registry, &msg_any)?;
457
458        assert_eq!("exocore.test.TestMessage", dyn_msg.full_name());
459        assert!(dyn_msg.fields().len() > 10);
460
461        let field1 = dyn_msg.get_field(1).unwrap();
462        assert!(field1.text_flag);
463        assert_eq!(dyn_msg.get_field_value(1)?.as_str()?, "val1");
464
465        let field2 = dyn_msg.get_field(2).unwrap();
466        assert!(!field2.text_flag);
467
468        let field8 = dyn_msg.get_field(8).unwrap();
469        assert_eq!(dyn_msg.get_field_value(8)?.as_datetime()?, &now);
470        assert!(field8.indexed_flag);
471
472        let field_value = dyn_msg.get_field_value(13)?;
473        let value_ref = field_value.as_reference()?;
474        assert_eq!(value_ref.entity_id, "et1");
475        assert_eq!(value_ref.trait_id, "trt1");
476
477        let field_value = dyn_msg.get_field_value(14)?;
478        let value_ref = field_value.as_reference()?;
479        assert_eq!(value_ref.entity_id, "et2");
480        assert_eq!(value_ref.trait_id, "");
481
482        let field3 = dyn_msg.get_field(3).unwrap();
483        assert_eq!(
484            field3.field_type,
485            FieldType::Message("exocore.test.TestStruct".to_string())
486        );
487        let dyn_struct = dyn_msg.get_field_value(3)?.into_message(&registry)?;
488        assert_eq!(dyn_struct.get_field_value(1)?.as_str()?, "str1");
489
490        // TODO: Maps not supported yet
491        // let field22 = dyn_msg.get_field(22).unwrap();
492        // assert_eq!(
493        //     field22.field_type,
494        //     FieldType::Repeated(Box::new(FieldType::Message(
495        //         "exocore.test.TestMessage.Map1Entry".to_string()
496        //     )))
497        // );
498        // let _field_value = dyn_msg.get_field_value(22)?;
499
500        Ok(())
501    }
502
503    #[test]
504    fn clear_value_dyn_message() -> anyhow::Result<()> {
505        let registry = Registry::new_with_exocore_types();
506
507        let msg = TestMessage {
508            string1: "val1".to_string(),
509            ..Default::default()
510        };
511
512        let msg_any = msg.pack_to_stepan_any()?;
513        let mut dyn_msg = from_stepan_any(&registry, &msg_any)?;
514
515        assert!(dyn_msg.get_field_value(1).is_ok());
516
517        dyn_msg.clear_field_value(1).unwrap();
518
519        assert!(dyn_msg.get_field_value(1).is_err());
520
521        Ok(())
522    }
523
524    #[test]
525    fn dyn_message_encode() -> anyhow::Result<()> {
526        let registry = Registry::new_with_exocore_types();
527
528        let msg = TestMessage {
529            string1: "val1".to_string(),
530            ..Default::default()
531        };
532
533        let msg_any = msg.pack_to_stepan_any()?;
534        let dyn_msg = from_stepan_any(&registry, &msg_any)?;
535
536        let bytes = dyn_msg.encode()?;
537        assert_eq!(bytes, msg_any.value);
538
539        let prost_any = dyn_msg.encode_to_prost_any()?;
540        assert_eq!(bytes, prost_any.value);
541
542        Ok(())
543    }
544
545    #[test]
546    fn dyn_message_encode_json() -> anyhow::Result<()> {
547        let registry = Registry::new_with_exocore_types();
548
549        let date = "2022-02-25T02:11:27.793936+00:00";
550        let date = chrono::DateTime::parse_from_rfc3339(date)?;
551        let msg = TestMessage {
552            string1: "val1".to_string(),
553            int1: 1,
554            date1: Some(date.to_proto_timestamp()),
555            ref1: Some(Reference {
556                entity_id: "et1".to_string(),
557                trait_id: "trt1".to_string(),
558            }),
559            struct1: Some(TestStruct {
560                string1: "str1".to_string(),
561            }),
562            ..Default::default()
563        };
564
565        let msg_any = msg.pack_to_stepan_any()?;
566        let dyn_msg = from_stepan_any(&registry, &msg_any)?;
567
568        let value = dyn_msg.encode_json(&registry)?;
569        let expected = serde_json::json!({
570            "type": "exocore.test.TestMessage",
571            "value": {
572                "string1": "val1",
573                "int1": 1,
574                "date1": "2022-02-25T02:11:27.793936+00:00",
575                "ref1": {
576                    "entity_id": "et1",
577                    "trait_id": "trt1"
578                },
579                "struct1": {
580                    "type": "exocore.test.TestStruct",
581                    "value": {
582                        "string1": "str1"
583                    },
584                },
585            }
586        });
587
588        assert_eq!(value, expected);
589
590        Ok(())
591    }
592}