1use crate::{DescriptorDatabase, Error, Tranche};
2use crate::{FieldType, MessageView};
3use bytes::{Buf, BufMut, Bytes};
4use looking_glass::{CowValue, Instance, IntoInner, OwnedValue, SmolStr, StructInstance, Typed, ValueTy};
5use prost::encoding::{encode_key, encode_varint, encoded_len_varint, key_len, WireType};
6use std::{
7 any::TypeId,
8 collections::{BTreeMap, HashMap},
9 convert::TryFrom,
10 sync::Arc,
11};
12
13#[derive(Debug, PartialEq, Clone)]
18pub struct DynamicMessage {
19 values: BTreeMap<u32, Field<OwnedValue<'static>>>,
20 descriptor_name: String,
21 descriptor_database: Arc<DescriptorDatabase>,
22}
23
24impl DynamicMessage {
25 pub fn new<T: Tranche>(view: &MessageView<T>) -> Result<DynamicMessage, Error> {
27 let descriptor_database = view.descriptor_database.clone();
28 let descriptor_name = view.descriptor_name.clone();
29 let descriptor = descriptor_database
30 .descriptor(&descriptor_name)
31 .ok_or_else(|| looking_glass::Error::NotFound("descriptor".into()))?;
32 let values = descriptor
33 .fields
34 .iter()
35 .map(|(tag, field)| match view.view_tag(*tag) {
36 Ok(view) => Ok((
37 *tag,
38 Field {
39 field_type: field.field_type.clone(),
40 attr: OwnedValue::try_from(view)?,
41 },
42 )),
43 Err(e) => Err(e),
44 })
45 .collect::<Result<BTreeMap<_, _>, Error>>()?;
46 Ok(DynamicMessage {
47 values,
48 descriptor_name,
49 descriptor_database,
50 })
51 }
52
53 pub fn encoded_len(&self) -> usize {
55 self.values
56 .iter()
57 .map(|(tag, field)| field.encoded_len(*tag))
58 .sum()
59 }
60
61 pub fn encode(&self, buf: &mut impl BufMut) {
63 for (tag, field) in &self.values {
64 field.encode(*tag, buf)
65 }
66 }
67
68 pub fn descriptor_name(&self) -> String {
69 self.descriptor_name.clone()
70 }
71
72 pub fn descriptor_database(&self) -> Arc<DescriptorDatabase> {
73 self.descriptor_database.clone()
74 }
75}
76
77impl Instance<'static> for DynamicMessage {
78 fn name(&self) -> SmolStr {
79 SmolStr::new(&self.descriptor_name)
80 }
81
82 fn as_inst(&self) -> &(dyn Instance<'static> + 'static) {
83 self
84 }
85}
86
87impl StructInstance<'static> for DynamicMessage {
88 fn get_value<'a>(&'a self, field: &str) -> Option<CowValue<'a, 'static>>
89 where
90 'static: 'a,
91 {
92 let descriptor = self.descriptor_database.descriptor(&self.descriptor_name)?;
93 let tag = descriptor.tags_by_name.get(field)?;
94 Some(CowValue::from(self.values.get(tag)?.attr.as_ref()))
95 }
96
97 fn update<'a>(
98 &'a mut self,
99 update: &'a (dyn StructInstance<'static> + 'static),
100 field_mask: Option<&looking_glass::FieldMask>,
101 replace_repeated: bool,
102 ) -> Result<(), looking_glass::Error> {
103 let descriptor = self
104 .descriptor_database
105 .descriptor(&self.descriptor_name)
106 .ok_or_else(|| looking_glass::Error::NotFound("descriptor".into()))?;
107 for (key, update_value) in update.values() {
108 let tag = descriptor
109 .tags_by_name
110 .get(&key)
111 .ok_or_else(|| looking_glass::Error::NotFound("tag".into()))?;
112 let new_mask = field_mask.and_then(|m| m.child(&key));
113 if new_mask.is_some() || field_mask.is_none() {
114 let field = self.values.get_mut(tag);
115 let attr = field.map(|f| &mut f.attr);
116 match attr {
117 Some(OwnedValue::Struct(inst)) => {
118 if let Some(update_inst) = update_value.as_ref().as_reflected_struct() {
119 inst.update(update_inst, new_mask, replace_repeated)?;
120 }
121 }
122 Some(OwnedValue::Vec(ref mut v)) => {
123 if let Some(update_vec) = update_value.as_ref().as_reflected_vec() {
124 v.update(update_vec, replace_repeated)?;
125 }
126 }
127 _ => {
128 let field = descriptor
129 .fields
130 .get(tag)
131 .ok_or_else(|| looking_glass::Error::NotFound("descriptor".into()))?;
132 let update_value = if let Some(view) =
133 update_value.as_ref().borrow::<&MessageView<Bytes>>()
134 {
135 OwnedValue::Option(Box::new(Some(DynamicMessage::new(view).map_err(
136 |_| looking_glass::Error::TypeError {
137 expected: "valid message view".into(),
138 found: "invalid message view".into(),
139 },
140 )?)))
141 } else {
142 update_value.to_owned()
143 };
144 let field = Field {
145 field_type: field.field_type.clone(),
146 attr: update_value,
147 };
148 self.values.insert(*tag, field);
149 }
150 }
151 }
152 }
153 Ok(())
154 }
155
156 fn values<'a>(&'a self) -> HashMap<SmolStr, CowValue<'a, 'static>> {
157 if let Some(descriptor) = self.descriptor_database.descriptor(&self.descriptor_name) {
158 self.values
159 .iter()
160 .filter_map(|(t, v)| {
161 let name = descriptor.fields.get(t)?.name.clone();
162 Some((name, CowValue::Ref(v.attr.as_ref())))
163 })
164 .collect()
165 } else {
166 HashMap::new()
167 }
168 }
169
170 fn boxed_clone(&self) -> Box<dyn StructInstance<'static> + 'static> {
171 Box::new(self.clone())
172 }
173
174 fn into_boxed_instance(self: Box<Self>) -> Box<dyn Instance<'static> + 'static> {
175 self
176 }
177}
178
179impl Typed<'static> for DynamicMessage {
180 fn ty() -> looking_glass::ValueTy {
181 ValueTy::Struct(TypeId::of::<Self>())
182 }
183
184 fn as_value<'a>(&'a self) -> looking_glass::Value<'a, 'static>
185 where
186 'static: 'a,
187 {
188 looking_glass::Value::from_struct(self)
189 }
190}
191
192fn is_default_owned_value(value: &OwnedValue<'static>) -> bool {
193 match value {
194 OwnedValue::U64(u) => *u == 0,
195 OwnedValue::U32(u) => *u == 0,
196 OwnedValue::U16(u) => *u == 0,
197 OwnedValue::U8(u) => *u == 0,
198 OwnedValue::I64(u) => *u == 0,
199 OwnedValue::I32(u) => *u == 0,
200 OwnedValue::I16(u) => *u == 0,
201 OwnedValue::I8(u) => *u == 0,
202 OwnedValue::Bool(b) => !b,
203 OwnedValue::String(s) => s.is_empty(),
204 OwnedValue::Vec(v) => v.is_empty(),
205 OwnedValue::Bytes(b) => b.is_empty(),
206 OwnedValue::Struct(m) => {
207 if let Some(msg) = m.as_inst().downcast_ref::<DynamicMessage>() {
208 for field in msg.values.values() {
209 if !field.is_default() {
210 return false;
211 }
212 }
213 true
214 } else {
215 false
216 }
217 }
218 OwnedValue::Option(option) => {
219 if let Some(msg) = option.as_inst().downcast_ref::<Option<DynamicMessage>>() {
220 if let Some(msg) = msg {
221 for field in msg.values.values() {
222 if !field.is_default() {
223 return false;
224 }
225 }
226 true
227 } else {
228 true
229 }
230 } else {
231 false
232 }
233 }
234 _ => false,
235 }
236}
237
238#[derive(Clone, PartialEq, Debug)]
239pub struct Field<A> {
240 field_type: FieldType,
241 attr: A,
242}
243
244impl Field<OwnedValue<'static>> {
245 pub fn is_default(&self) -> bool {
247 is_default_owned_value(&self.attr)
248 }
249
250 pub fn encode_key<B: BufMut>(&self, tag: u32, buf: &mut B) {
252 if let OwnedValue::Vec(_) = self.attr {
253 return;
254 }
255 match self.field_type {
256 FieldType::Bool
257 | FieldType::Int32
258 | FieldType::Int64
259 | FieldType::UInt32
260 | FieldType::UInt64
261 | FieldType::Enum(_) => encode_key(tag, WireType::Varint, buf),
262 FieldType::Double | FieldType::Fixed64 | FieldType::SInt64 | FieldType::SFixed64 => {
263 encode_key(tag, WireType::SixtyFourBit, buf)
264 }
265 FieldType::Float | FieldType::SInt32 | FieldType::Fixed32 | FieldType::SFixed32 => {
266 encode_key(tag, WireType::ThirtyTwoBit, buf)
267 }
268 FieldType::String | FieldType::Message(_) | FieldType::Bytes => {
269 encode_key(tag, WireType::LengthDelimited, buf)
270 }
271 FieldType::Group => {}
272 };
273 }
274
275 pub fn encode_raw<B: BufMut>(&self, buf: &mut B) {
277 match (&self.field_type, &self.attr) {
278 (FieldType::Bool, OwnedValue::Bool(b)) => {
279 encode_varint(if *b { 1u64 } else { 0u64 }, buf)
280 }
281 (FieldType::Int32, OwnedValue::I32(i)) => encode_varint(*i as u64, buf),
282 (FieldType::Int64, OwnedValue::I64(i)) => encode_varint(*i as u64, buf),
283 (FieldType::UInt32, OwnedValue::U32(i)) => encode_varint(*i as u64, buf),
284 (FieldType::UInt64, OwnedValue::U64(i)) => encode_varint(*i as u64, buf),
285 (FieldType::SInt32, OwnedValue::I32(value)) => {
286 encode_varint(((value << 1) ^ (value >> 31)) as u32 as u64, buf)
287 }
288 (FieldType::SInt64, OwnedValue::I64(value)) => {
289 encode_varint(((value << 1) ^ (value >> 63)) as u64, buf)
290 }
291 (FieldType::Fixed64, OwnedValue::U64(i)) => buf.put_u64_le(*i),
292 (FieldType::Fixed32, OwnedValue::U32(i)) => buf.put_u32_le(*i),
293 (FieldType::String, OwnedValue::String(s)) => {
294 let bytes: &[u8] = s.as_ref();
295 encode_varint(bytes.len() as u64, buf);
296 buf.put_slice(bytes)
297 }
298 (FieldType::Float, OwnedValue::F32(f)) => buf.put_f32_le(*f),
299 (FieldType::Double, OwnedValue::F64(f)) => buf.put_f64_le(*f),
300 (FieldType::Enum(_), OwnedValue::I32(i)) => encode_varint(*i as u64, buf),
301 (FieldType::Message(_), OwnedValue::Option(m)) => {
302 if let Some(v) = m.value() {
303 if let Some(msg) = v.borrow::<&DynamicMessage>() {
304 let n = msg.encoded_len();
305 encode_varint(n as u64, buf);
306 msg.encode(buf);
307 }
308 }
309 }
314 (FieldType::Message(_), OwnedValue::Struct(m)) => {
315 if let Some(msg) = m.as_value().borrow::<&DynamicMessage>() {
316 let n = msg.encoded_len();
317 encode_varint(n as u64, buf);
318 msg.encode(buf);
319 }
320 }
321 (FieldType::Bytes, OwnedValue::Bytes(b)) => {
322 encode_varint(b.remaining() as u64, buf);
323 buf.put_slice(b.chunk());
324 }
325 _ => {
326 }
328 }
329 }
330
331 pub fn encode<B: BufMut>(&self, tag: u32, buf: &mut B) {
333 match &self.attr {
334 OwnedValue::Vec(r) if r.is_empty() => {}
335 OwnedValue::Vec(r) => {
336 let r: Vec<_> = r
337 .values()
338 .iter()
339 .map(|a| Field {
340 field_type: self.field_type.clone(),
341 attr: a.to_owned(),
342 })
343 .collect();
344 match self.field_type {
345 FieldType::Bool
346 | FieldType::Int32
347 | FieldType::Int64
348 | FieldType::SInt32
349 | FieldType::SInt64
350 | FieldType::UInt32
351 | FieldType::UInt64
352 | FieldType::Float
353 | FieldType::Double
354 | FieldType::SFixed32
355 | FieldType::SFixed64
356 | FieldType::Fixed32
357 | FieldType::Fixed64
358 | FieldType::Enum(_) => {
359 encode_key(tag, WireType::LengthDelimited, buf);
360 let len: usize = r.iter().map(|value| value.encoded_len_raw()).sum();
361 encode_varint(len as u64, buf);
362 for value in r {
363 value.encode_raw(buf);
364 }
365 }
366 _ => {
367 for value in r {
368 value.encode_key(tag, buf);
369 value.encode_raw(buf);
370 }
371 }
372 };
373 }
374 _ => {
375 if !self.is_default() {
376 self.encode_key(tag, buf);
377 self.encode_raw(buf);
378 }
379 }
380 }
381 }
382
383 pub fn encoded_len(&self, tag: u32) -> usize {
385 match &self.attr {
386 OwnedValue::Vec(r) if r.is_empty() => 0,
387 OwnedValue::Vec(r) => {
388 let values = r.values();
389 let iter = values.iter().map(|a| Field {
390 field_type: self.field_type.clone(),
391 attr: a.to_owned(),
392 });
393 let len = iter.map(|f| f.encoded_len_raw()).sum::<usize>();
394 let key_len: usize = match self.field_type {
395 FieldType::Bool
396 | FieldType::Int32
397 | FieldType::Int64
398 | FieldType::SInt32
399 | FieldType::SInt64
400 | FieldType::UInt32
401 | FieldType::UInt64
402 | FieldType::Float
403 | FieldType::Double
404 | FieldType::SFixed32
405 | FieldType::SFixed64
406 | FieldType::Fixed32
407 | FieldType::Fixed64
408 | FieldType::Enum(_) => key_len(tag) + encoded_len_varint(len as u64),
409 _ => key_len(tag) * r.len(),
410 };
411 key_len + len
412 }
413 _ => {
414 if !self.is_default() {
415 key_len(tag) + self.encoded_len_raw()
416 } else {
417 0
418 }
419 }
420 }
421 }
422 pub fn encoded_len_raw(&self) -> usize {
424 match (&self.field_type, &self.attr) {
425 (FieldType::Bool, OwnedValue::Bool(b)) => {
426 encoded_len_varint(if *b { 1u64 } else { 0u64 })
427 }
428 (FieldType::Int32, OwnedValue::I32(i)) => encoded_len_varint(*i as u64),
429 (FieldType::Int64, OwnedValue::I64(i)) => encoded_len_varint(*i as u64),
430 (FieldType::UInt32, OwnedValue::U32(i)) => encoded_len_varint(*i as u64),
431 (FieldType::UInt64, OwnedValue::U64(i)) => encoded_len_varint(*i as u64),
432 (FieldType::SInt32, OwnedValue::I32(value)) => {
433 encoded_len_varint(((value << 1) ^ (value >> 31)) as u32 as u64)
434 }
435 (FieldType::SInt64, OwnedValue::I64(value)) => {
436 encoded_len_varint(((value << 1) ^ (value >> 63)) as u64)
437 }
438 (FieldType::Fixed64, OwnedValue::U64(_)) => 8,
439 (FieldType::Fixed32, OwnedValue::U32(_)) => 4,
440 (FieldType::String, OwnedValue::String(s)) => {
441 let bytes: &[u8] = s.as_ref();
442 encoded_len_varint(bytes.len() as u64) + bytes.len()
443 }
444 (FieldType::Float, OwnedValue::F32(_)) => 4,
445 (FieldType::Double, OwnedValue::F64(_)) => 4,
446 (FieldType::Enum(_), OwnedValue::I32(i)) => encoded_len_varint(*i as u64),
447 (FieldType::Message(_), v @ OwnedValue::Struct(_)) => {
448 if let Ok(msg) = IntoInner::<DynamicMessage>::into_inner(v.clone()) {
449 let len = msg.encoded_len();
450 encoded_len_varint(len as u64) + len
451 } else {
452 0
453 }
454 }
455 (FieldType::Message(_), v @ OwnedValue::Option(_)) => {
456 if let Ok(Some(msg)) = IntoInner::<Option<DynamicMessage>>::into_inner(v.clone()) {
457 let len = msg.encoded_len();
458 encoded_len_varint(len as u64) + len
459 } else {
460 0
461 }
462 }
463 (FieldType::Bytes, OwnedValue::Bytes(b)) => {
464 encoded_len_varint(b.remaining() as u64) + b.remaining()
465 }
466 _ => 0,
467 }
468 }
469}