1use std::{collections::HashMap, convert::TryFrom, fmt::Debug, sync::Arc};
2
3pub use protobuf::{descriptor::FileDescriptorSet, Message};
4use protobuf::{
5 reflect::{FieldDescriptor as FieldDescriptorProto, ReflectFieldRef, ReflectValueRef},
6 well_known_types::any::Any,
7 MessageDyn,
8};
9
10use super::{registry::Registry, Error};
11use crate::generated::exocore_store::Reference;
12
13pub trait ReflectMessage: Debug + Sized {
14 fn descriptor(&self) -> &ReflectMessageDescriptor;
15
16 fn full_name(&self) -> &str {
17 &self.descriptor().name
18 }
19
20 fn fields(&self) -> &HashMap<FieldId, FieldDescriptor> {
21 &self.descriptor().fields
22 }
23
24 fn get_field(&self, id: FieldId) -> Option<&FieldDescriptor> {
25 self.descriptor().fields.get(&id)
26 }
27
28 fn get_field_value(&self, field_id: FieldId) -> Result<FieldValue, Error>;
29
30 fn encode(&self) -> Result<Vec<u8>, Error>;
31
32 fn encode_to_prost_any(&self) -> Result<prost_types::Any, Error> {
33 let bytes = self.encode()?;
34 Ok(prost_types::Any {
35 type_url: format!("type.googleapis.com/{}", self.descriptor().name),
36 value: bytes,
37 })
38 }
39
40 fn encode_json(&self, registry: &Registry) -> Result<serde_json::Value, Error> {
41 message_to_json(self, registry)
42 }
43}
44
45fn message_to_json<M: ReflectMessage>(
46 msg: &M,
47 registry: &Registry,
48) -> Result<serde_json::Value, Error> {
49 use serde_json::Value;
50
51 let mut values = serde_json::Map::<String, serde_json::Value>::new();
52 for (id, desc) in msg.fields() {
53 let name = desc.name.clone();
54 match msg.get_field_value(*id) {
55 Ok(value) => {
56 let json_value = field_value_to_json(value, registry)?;
57 values.insert(name, json_value);
58 }
59 Err(Error::NoSuchField(_)) => {
60 }
62 Err(other) => return Err(other),
63 }
64 }
65
66 let mut obj = serde_json::Map::new();
67 obj.insert(
68 "type".to_string(),
69 Value::String(msg.full_name().to_string()),
70 );
71 obj.insert("value".to_string(), Value::Object(values));
72
73 Ok(serde_json::Value::Object(obj))
74}
75
76fn field_value_to_json(value: FieldValue, registry: &Registry) -> Result<serde_json::Value, Error> {
77 use serde_json::Value;
78 Ok(match value {
79 FieldValue::String(v) => Value::String(v),
80 FieldValue::Int32(v) => Value::Number(v.into()),
81 FieldValue::Uint32(v) => Value::Number(v.into()),
82 FieldValue::Int64(v) => Value::Number(v.into()),
83 FieldValue::Uint64(v) => Value::Number(v.into()),
84 FieldValue::Reference(reference) => {
85 let mut obj = serde_json::Map::new();
86 obj.insert("entity_id".to_string(), Value::String(reference.entity_id));
87 obj.insert("trait_id".to_string(), Value::String(reference.trait_id));
88 Value::Object(obj)
89 }
90 FieldValue::DateTime(v) => Value::String(v.to_rfc3339()),
91 FieldValue::Message(typ, msg) => {
92 let msg = FieldValue::Message(typ, msg).into_message(registry)?;
93 msg.encode_json(registry)?
94 }
95 FieldValue::Repeated(values) => {
96 let arr = values
97 .into_iter()
98 .map(|v| field_value_to_json(v, registry))
99 .collect::<Result<Vec<_>, Error>>()?;
100 Value::Array(arr)
101 }
102 })
103}
104
105pub trait MutableReflectMessage: ReflectMessage {
106 fn clear_field_value(&mut self, field_id: FieldId) -> Result<(), Error>;
107}
108
109pub struct DynamicMessage {
110 message: Box<dyn MessageDyn>,
111 descriptor: Arc<ReflectMessageDescriptor>,
112}
113
114impl ReflectMessage for DynamicMessage {
115 fn descriptor(&self) -> &ReflectMessageDescriptor {
116 self.descriptor.as_ref()
117 }
118
119 fn get_field_value(&self, field_id: FieldId) -> Result<FieldValue, Error> {
120 let field = self
121 .get_field(field_id)
122 .ok_or(Error::NoSuchField(field_id))?;
123
124 let reflect_field = field.descriptor.get_reflect(self.message.as_ref());
125 convert_field_ref(field_id, &field.field_type, reflect_field)
126 }
127
128 fn encode(&self) -> Result<Vec<u8>, Error> {
129 let bytes = self.message.write_to_bytes_dyn()?;
130 Ok(bytes)
131 }
132}
133
134fn convert_field_ref(
135 field_id: FieldId,
136 field_type: &FieldType,
137 field_ref: ReflectFieldRef,
138) -> Result<FieldValue, Error> {
139 match field_ref {
140 ReflectFieldRef::Optional(v) => match v.value() {
141 Some(v) => convert_field_value(field_type, v),
142 None => Err(Error::NoSuchField(field_id)),
143 },
144 ReflectFieldRef::Repeated(r) => {
145 let FieldType::Repeated(inner_field_type) = field_type else {
146 return Err(Error::Other(anyhow!(
147 "expected repeated field type, got {field_type:?} at field {field_id:?}"
148 )));
149 };
150
151 let mut values = Vec::new();
152 for i in 0..r.len() {
153 values.push(convert_field_value(inner_field_type, r.get(i))?);
154 }
155 Ok(FieldValue::Repeated(values))
156 }
157 ReflectFieldRef::Map(_) => {
158 Err(Error::NoSuchField(field_id))
160 }
161 }
162}
163
164fn convert_field_value(
165 field_type: &FieldType,
166 value: ReflectValueRef,
167) -> Result<FieldValue, Error> {
168 match field_type {
169 FieldType::String => match value {
170 ReflectValueRef::String(v) => Ok(FieldValue::String(v.to_string())),
171 v => Err(Error::Other(anyhow!("expected string field, got: {v:?}"))),
172 },
173 FieldType::Int32 => match value {
174 ReflectValueRef::I32(v) => Ok(FieldValue::Int32(v)),
175 v => Err(Error::Other(anyhow!("expected int32 field, got: {v:?}"))),
176 },
177 FieldType::Uint32 => match value {
178 ReflectValueRef::U32(v) => Ok(FieldValue::Uint32(v)),
179 v => Err(Error::Other(anyhow!("expected uint32 field, got: {v:?}"))),
180 },
181 FieldType::Int64 => match value {
182 ReflectValueRef::I64(v) => Ok(FieldValue::Int64(v)),
183 v => Err(Error::Other(anyhow!("expected int64 field, got: {v:?}"))),
184 },
185 FieldType::Uint64 => match value {
186 ReflectValueRef::U64(v) => Ok(FieldValue::Uint64(v)),
187 v => Err(Error::Other(anyhow!("expected uint64 field, got: {v:?}"))),
188 },
189 FieldType::DateTime => match value {
190 ReflectValueRef::Message(msg) => {
191 let msg_desc = msg.descriptor_dyn();
192 let secs_desc = msg_desc.field_by_number(1).unwrap();
193 let secs = secs_desc
194 .get_singular(&*msg)
195 .and_then(|v| v.to_i64())
196 .unwrap_or_default();
197
198 let nanos_desc = msg_desc.field_by_number(2).unwrap();
199 let nanos = nanos_desc
200 .get_singular(&*msg)
201 .and_then(|v| v.to_i32())
202 .unwrap_or_default();
203
204 Ok(FieldValue::DateTime(
205 crate::time::timestamp_parts_to_datetime(secs, nanos),
206 ))
207 }
208 v => Err(Error::Other(anyhow!(
209 "expected message as timestamp field, got: {v:?}"
210 ))),
211 },
212 FieldType::Reference => match value {
213 ReflectValueRef::Message(msg) => {
214 let msg_desc = msg.descriptor_dyn();
215 let et_desc = msg_desc.field_by_number(1).unwrap();
216 let entity_id = et_desc
217 .get_singular(&*msg)
218 .and_then(|v| v.to_str().map(|v| v.to_string()))
219 .unwrap_or_default();
220
221 let trt_desc = msg_desc.field_by_number(2).unwrap();
222 let trait_id = trt_desc
223 .get_singular(&*msg)
224 .and_then(|v| v.to_str().map(|v| v.to_string()))
225 .unwrap_or_default();
226
227 Ok(FieldValue::Reference(Reference {
228 entity_id,
229 trait_id,
230 }))
231 }
232 v => Err(Error::Other(anyhow!(
233 "expected message as reference field, got: {v:?}"
234 ))),
235 },
236 FieldType::Message(msg_type) => match value {
237 ReflectValueRef::Message(msg) => {
238 let dyn_msg = msg.clone_box();
239 Ok(FieldValue::Message(msg_type.clone(), dyn_msg))
240 }
241 v => Err(Error::Other(anyhow!(
242 "expected field to be a message, got: {v:?}"
243 ))),
244 },
245 FieldType::Repeated(_) => {
246 unreachable!("repeated fields should have been handled in convert_field_ref");
247 }
248 }
249}
250
251impl MutableReflectMessage for DynamicMessage {
252 fn clear_field_value(&mut self, field_id: FieldId) -> Result<(), Error> {
253 let field = self
254 .descriptor
255 .fields
256 .get(&field_id)
257 .ok_or(Error::NoSuchField(field_id))?;
258
259 if !field.descriptor.has_field(self.message.as_ref()) {
260 return Ok(());
261 }
262
263 if field.descriptor.is_repeated() {
264 let mut repeated = field.descriptor.mut_repeated(self.message.as_mut());
265 repeated.clear();
266 } else if field.descriptor.is_map() {
267 let mut map = field.descriptor.mut_map(self.message.as_mut());
268 map.clear();
269 } else {
270 field.descriptor.clear_field(self.message.as_mut());
271 }
272
273 Ok(())
274 }
275}
276
277impl Debug for DynamicMessage {
278 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
279 f.debug_struct("DynamicMessage")
280 .field("full_name", &self.descriptor.name)
281 .finish()
282 }
283}
284
285pub type FieldId = u32;
286
287pub type FieldGroupId = u32;
288
289pub struct ReflectMessageDescriptor {
290 pub name: String, pub fields: HashMap<FieldId, FieldDescriptor>,
292 pub message: protobuf::reflect::MessageDescriptor,
293
294 pub short_names: Vec<String>,
296}
297
298pub struct FieldDescriptor {
299 pub id: FieldId,
300 pub descriptor: FieldDescriptorProto,
301 pub name: String,
302 pub field_type: FieldType,
303
304 pub indexed_flag: bool,
306 pub sorted_flag: bool,
307 pub text_flag: bool,
308 pub groups: Vec<FieldGroupId>,
309}
310
311#[derive(Debug, Clone, PartialEq)]
312pub enum FieldType {
313 String,
314 Int32,
315 Uint32,
316 Int64,
317 Uint64,
318 DateTime,
319 Reference,
320 Message(String),
321 Repeated(Box<FieldType>),
322}
323
324#[derive(Debug)]
325pub enum FieldValue {
326 String(String),
327 Int32(i32),
328 Uint32(u32),
329 Int64(i64),
330 Uint64(u64),
331 Reference(Reference),
332 DateTime(chrono::DateTime<chrono::Utc>),
333 Message(String, Box<dyn MessageDyn>),
334 Repeated(Vec<FieldValue>),
335}
336
337impl FieldValue {
338 pub fn as_str(&self) -> Result<&str, Error> {
339 if let FieldValue::String(value) = self {
340 Ok(value.as_ref())
341 } else {
342 Err(Error::InvalidFieldType)
343 }
344 }
345
346 pub fn as_datetime(&self) -> Result<&chrono::DateTime<chrono::Utc>, Error> {
347 if let FieldValue::DateTime(value) = self {
348 Ok(value)
349 } else {
350 Err(Error::InvalidFieldType)
351 }
352 }
353
354 pub fn as_reference(&self) -> Result<&Reference, Error> {
355 if let FieldValue::Reference(value) = self {
356 Ok(value)
357 } else {
358 Err(Error::InvalidFieldType)
359 }
360 }
361
362 pub fn into_message(self, registry: &Registry) -> Result<DynamicMessage, Error> {
363 if let FieldValue::Message(typ, message) = self {
364 let descriptor = registry.get_message_descriptor(&typ)?;
365 Ok(DynamicMessage {
366 message,
367 descriptor,
368 })
369 } else {
370 Err(Error::InvalidFieldType)
371 }
372 }
373}
374
375impl<'s> TryFrom<&'s FieldValue> for &'s str {
376 type Error = Error;
377
378 fn try_from(value: &'s FieldValue) -> Result<Self, Error> {
379 match value {
380 FieldValue::String(value) => Ok(value),
381 _ => Err(Error::InvalidFieldType),
382 }
383 }
384}
385
386pub fn from_stepan_any(registry: &Registry, any: &Any) -> Result<DynamicMessage, Error> {
387 from_any_url_and_data(registry, &any.type_url, &any.value)
388}
389
390pub fn from_prost_any(
391 registry: &Registry,
392 any: &prost_types::Any,
393) -> Result<DynamicMessage, Error> {
394 from_any_url_and_data(registry, &any.type_url, &any.value)
395}
396
397pub fn from_any_url_and_data(
398 registry: &Registry,
399 url: &str,
400 data: &[u8],
401) -> Result<DynamicMessage, Error> {
402 let full_name = any_url_to_full_name(url);
403
404 let descriptor = registry.get_message_descriptor(&full_name)?;
405 let message = descriptor.message.parse_from_bytes(data)?;
406
407 Ok(DynamicMessage {
408 message,
409 descriptor,
410 })
411}
412
413pub fn any_url_to_full_name(url: &str) -> String {
414 url.replace("type.googleapis.com/", "")
415}
416
417#[cfg(test)]
418mod tests {
419 use chrono::Utc;
420
421 use super::*;
422 use crate::{
423 generated::exocore_test::TestMessage,
424 prost::{ProstAnyPackMessageExt, ProstDateTimeExt},
425 test::TestStruct,
426 };
427
428 #[test]
429 fn reflect_dyn_message() -> anyhow::Result<()> {
430 let registry = Registry::new_with_exocore_types();
431
432 let mut map1 = HashMap::new();
433 map1.insert("key1".to_string(), "value1".to_string());
434 map1.insert("key2".to_string(), "value2".to_string());
435
436 let now = Utc::now();
437 let msg = TestMessage {
438 string1: "val1".to_string(),
439 date1: Some(now.to_proto_timestamp()),
440 ref1: Some(Reference {
441 entity_id: "et1".to_string(),
442 trait_id: "trt1".to_string(),
443 }),
444 ref2: Some(Reference {
445 entity_id: "et2".to_string(),
446 trait_id: String::new(),
447 }),
448 struct1: Some(TestStruct {
449 string1: "str1".to_string(),
450 }),
451 map1,
452 ..Default::default()
453 };
454
455 let msg_any = msg.pack_to_stepan_any()?;
456 let dyn_msg = from_stepan_any(®istry, &msg_any)?;
457
458 assert_eq!("exocore.test.TestMessage", dyn_msg.full_name());
459 assert!(dyn_msg.fields().len() > 10);
460
461 let field1 = dyn_msg.get_field(1).unwrap();
462 assert!(field1.text_flag);
463 assert_eq!(dyn_msg.get_field_value(1)?.as_str()?, "val1");
464
465 let field2 = dyn_msg.get_field(2).unwrap();
466 assert!(!field2.text_flag);
467
468 let field8 = dyn_msg.get_field(8).unwrap();
469 assert_eq!(dyn_msg.get_field_value(8)?.as_datetime()?, &now);
470 assert!(field8.indexed_flag);
471
472 let field_value = dyn_msg.get_field_value(13)?;
473 let value_ref = field_value.as_reference()?;
474 assert_eq!(value_ref.entity_id, "et1");
475 assert_eq!(value_ref.trait_id, "trt1");
476
477 let field_value = dyn_msg.get_field_value(14)?;
478 let value_ref = field_value.as_reference()?;
479 assert_eq!(value_ref.entity_id, "et2");
480 assert_eq!(value_ref.trait_id, "");
481
482 let field3 = dyn_msg.get_field(3).unwrap();
483 assert_eq!(
484 field3.field_type,
485 FieldType::Message("exocore.test.TestStruct".to_string())
486 );
487 let dyn_struct = dyn_msg.get_field_value(3)?.into_message(®istry)?;
488 assert_eq!(dyn_struct.get_field_value(1)?.as_str()?, "str1");
489
490 Ok(())
501 }
502
503 #[test]
504 fn clear_value_dyn_message() -> anyhow::Result<()> {
505 let registry = Registry::new_with_exocore_types();
506
507 let msg = TestMessage {
508 string1: "val1".to_string(),
509 ..Default::default()
510 };
511
512 let msg_any = msg.pack_to_stepan_any()?;
513 let mut dyn_msg = from_stepan_any(®istry, &msg_any)?;
514
515 assert!(dyn_msg.get_field_value(1).is_ok());
516
517 dyn_msg.clear_field_value(1).unwrap();
518
519 assert!(dyn_msg.get_field_value(1).is_err());
520
521 Ok(())
522 }
523
524 #[test]
525 fn dyn_message_encode() -> anyhow::Result<()> {
526 let registry = Registry::new_with_exocore_types();
527
528 let msg = TestMessage {
529 string1: "val1".to_string(),
530 ..Default::default()
531 };
532
533 let msg_any = msg.pack_to_stepan_any()?;
534 let dyn_msg = from_stepan_any(®istry, &msg_any)?;
535
536 let bytes = dyn_msg.encode()?;
537 assert_eq!(bytes, msg_any.value);
538
539 let prost_any = dyn_msg.encode_to_prost_any()?;
540 assert_eq!(bytes, prost_any.value);
541
542 Ok(())
543 }
544
545 #[test]
546 fn dyn_message_encode_json() -> anyhow::Result<()> {
547 let registry = Registry::new_with_exocore_types();
548
549 let date = "2022-02-25T02:11:27.793936+00:00";
550 let date = chrono::DateTime::parse_from_rfc3339(date)?;
551 let msg = TestMessage {
552 string1: "val1".to_string(),
553 int1: 1,
554 date1: Some(date.to_proto_timestamp()),
555 ref1: Some(Reference {
556 entity_id: "et1".to_string(),
557 trait_id: "trt1".to_string(),
558 }),
559 struct1: Some(TestStruct {
560 string1: "str1".to_string(),
561 }),
562 ..Default::default()
563 };
564
565 let msg_any = msg.pack_to_stepan_any()?;
566 let dyn_msg = from_stepan_any(®istry, &msg_any)?;
567
568 let value = dyn_msg.encode_json(®istry)?;
569 let expected = serde_json::json!({
570 "type": "exocore.test.TestMessage",
571 "value": {
572 "string1": "val1",
573 "int1": 1,
574 "date1": "2022-02-25T02:11:27.793936+00:00",
575 "ref1": {
576 "entity_id": "et1",
577 "trait_id": "trt1"
578 },
579 "struct1": {
580 "type": "exocore.test.TestStruct",
581 "value": {
582 "string1": "str1"
583 },
584 },
585 }
586 });
587
588 assert_eq!(value, expected);
589
590 Ok(())
591 }
592}