looking_glass_protobuf/
message.rs

1use crate::{DescriptorDatabase, Error, Tranche};
2use crate::{FieldType, MessageView};
3use bytes::{Buf, BufMut, Bytes};
4use looking_glass::{CowValue, Instance, IntoInner, OwnedValue, SmolStr, StructInstance, Typed, ValueTy};
5use prost::encoding::{encode_key, encode_varint, encoded_len_varint, key_len, WireType};
6use std::{
7    any::TypeId,
8    collections::{BTreeMap, HashMap},
9    convert::TryFrom,
10    sync::Arc,
11};
12
13/// A looking_glass representation of a ProtocolBuffer message.
14///
15/// DynamicMessage allows users to create and read ProtocolBuffer messages at runtime. A DynamicMessage is created from a [`MessageView`], but unlike a [`MessageView`] it owns its data.
16/// This means that DynamicMessages can be modified and re-encoded. In the name of efficiency string and byte, types are reference counted from their source.
17#[derive(Debug, PartialEq, Clone)]
18pub struct DynamicMessage {
19    values: BTreeMap<u32, Field<OwnedValue<'static>>>,
20    descriptor_name: String,
21    descriptor_database: Arc<DescriptorDatabase>,
22}
23
24impl DynamicMessage {
25    /// Builds a new `DynamicMessage` from a [`MessageView`].
26    pub fn new<T: Tranche>(view: &MessageView<T>) -> Result<DynamicMessage, Error> {
27        let descriptor_database = view.descriptor_database.clone();
28        let descriptor_name = view.descriptor_name.clone();
29        let descriptor = descriptor_database
30            .descriptor(&descriptor_name)
31            .ok_or_else(|| looking_glass::Error::NotFound("descriptor".into()))?;
32        let values = descriptor
33            .fields
34            .iter()
35            .map(|(tag, field)| match view.view_tag(*tag) {
36                Ok(view) => Ok((
37                    *tag,
38                    Field {
39                        field_type: field.field_type.clone(),
40                        attr: OwnedValue::try_from(view)?,
41                    },
42                )),
43                Err(e) => Err(e),
44            })
45            .collect::<Result<BTreeMap<_, _>, Error>>()?;
46        Ok(DynamicMessage {
47            values,
48            descriptor_name,
49            descriptor_database,
50        })
51    }
52
53    /// The length of the mesage once encoded
54    pub fn encoded_len(&self) -> usize {
55        self.values
56            .iter()
57            .map(|(tag, field)| field.encoded_len(*tag))
58            .sum()
59    }
60
61    /// Encoded a [`DynamicMessage`] to the passed buffer
62    pub fn encode(&self, buf: &mut impl BufMut) {
63        for (tag, field) in &self.values {
64            field.encode(*tag, buf)
65        }
66    }
67
68    pub fn descriptor_name(&self) -> String {
69        self.descriptor_name.clone()
70    }
71
72    pub fn descriptor_database(&self) -> Arc<DescriptorDatabase> {
73        self.descriptor_database.clone()
74    }
75}
76
77impl Instance<'static> for DynamicMessage {
78    fn name(&self) -> SmolStr {
79        SmolStr::new(&self.descriptor_name)
80    }
81
82    fn as_inst(&self) -> &(dyn Instance<'static> + 'static) {
83        self
84    }
85}
86
87impl StructInstance<'static> for DynamicMessage {
88    fn get_value<'a>(&'a self, field: &str) -> Option<CowValue<'a, 'static>>
89    where
90        'static: 'a,
91    {
92        let descriptor = self.descriptor_database.descriptor(&self.descriptor_name)?;
93        let tag = descriptor.tags_by_name.get(field)?;
94        Some(CowValue::from(self.values.get(tag)?.attr.as_ref()))
95    }
96
97    fn update<'a>(
98        &'a mut self,
99        update: &'a (dyn StructInstance<'static> + 'static),
100        field_mask: Option<&looking_glass::FieldMask>,
101        replace_repeated: bool,
102    ) -> Result<(), looking_glass::Error> {
103        let descriptor = self
104            .descriptor_database
105            .descriptor(&self.descriptor_name)
106            .ok_or_else(|| looking_glass::Error::NotFound("descriptor".into()))?;
107        for (key, update_value) in update.values() {
108            let tag = descriptor
109                .tags_by_name
110                .get(&key)
111                .ok_or_else(|| looking_glass::Error::NotFound("tag".into()))?;
112            let new_mask = field_mask.and_then(|m| m.child(&key));
113            if new_mask.is_some() || field_mask.is_none() {
114                let field = self.values.get_mut(tag);
115                let attr = field.map(|f| &mut f.attr);
116                match attr {
117                    Some(OwnedValue::Struct(inst)) => {
118                        if let Some(update_inst) = update_value.as_ref().as_reflected_struct() {
119                            inst.update(update_inst, new_mask, replace_repeated)?;
120                        }
121                    }
122                    Some(OwnedValue::Vec(ref mut v)) => {
123                        if let Some(update_vec) = update_value.as_ref().as_reflected_vec() {
124                            v.update(update_vec, replace_repeated)?;
125                        }
126                    }
127                    _ => {
128                        let field = descriptor
129                            .fields
130                            .get(tag)
131                            .ok_or_else(|| looking_glass::Error::NotFound("descriptor".into()))?;
132                        let update_value = if let Some(view) =
133                            update_value.as_ref().borrow::<&MessageView<Bytes>>()
134                        {
135                            OwnedValue::Option(Box::new(Some(DynamicMessage::new(view).map_err(
136                                |_| looking_glass::Error::TypeError {
137                                    expected: "valid message view".into(),
138                                    found: "invalid message view".into(),
139                                },
140                            )?)))
141                        } else {
142                            update_value.to_owned()
143                        };
144                        let field = Field {
145                            field_type: field.field_type.clone(),
146                            attr: update_value,
147                        };
148                        self.values.insert(*tag, field);
149                    }
150                }
151            }
152        }
153        Ok(())
154    }
155
156    fn values<'a>(&'a self) -> HashMap<SmolStr, CowValue<'a, 'static>> {
157        if let Some(descriptor) = self.descriptor_database.descriptor(&self.descriptor_name) {
158            self.values
159                .iter()
160                .filter_map(|(t, v)| {
161                    let name = descriptor.fields.get(t)?.name.clone();
162                    Some((name, CowValue::Ref(v.attr.as_ref())))
163                })
164                .collect()
165        } else {
166            HashMap::new()
167        }
168    }
169
170    fn boxed_clone(&self) -> Box<dyn StructInstance<'static> + 'static> {
171        Box::new(self.clone())
172    }
173
174    fn into_boxed_instance(self: Box<Self>) -> Box<dyn Instance<'static> + 'static> {
175        self
176    }
177}
178
179impl Typed<'static> for DynamicMessage {
180    fn ty() -> looking_glass::ValueTy {
181        ValueTy::Struct(TypeId::of::<Self>())
182    }
183
184    fn as_value<'a>(&'a self) -> looking_glass::Value<'a, 'static>
185    where
186        'static: 'a,
187    {
188        looking_glass::Value::from_struct(self)
189    }
190}
191
192fn is_default_owned_value(value: &OwnedValue<'static>) -> bool {
193    match value {
194        OwnedValue::U64(u) => *u == 0,
195        OwnedValue::U32(u) => *u == 0,
196        OwnedValue::U16(u) => *u == 0,
197        OwnedValue::U8(u) => *u == 0,
198        OwnedValue::I64(u) => *u == 0,
199        OwnedValue::I32(u) => *u == 0,
200        OwnedValue::I16(u) => *u == 0,
201        OwnedValue::I8(u) => *u == 0,
202        OwnedValue::Bool(b) => !b,
203        OwnedValue::String(s) => s.is_empty(),
204        OwnedValue::Vec(v) => v.is_empty(),
205        OwnedValue::Bytes(b) => b.is_empty(),
206        OwnedValue::Struct(m) => {
207            if let Some(msg) = m.as_inst().downcast_ref::<DynamicMessage>() {
208                for field in msg.values.values() {
209                    if !field.is_default() {
210                        return false;
211                    }
212                }
213                true
214            } else {
215                false
216            }
217        }
218        OwnedValue::Option(option) => {
219            if let Some(msg) = option.as_inst().downcast_ref::<Option<DynamicMessage>>() {
220                if let Some(msg) = msg {
221                    for field in msg.values.values() {
222                        if !field.is_default() {
223                            return false;
224                        }
225                    }
226                    true
227                } else {
228                    true
229                }
230            } else {
231                false
232            }
233        }
234        _ => false,
235    }
236}
237
238#[derive(Clone, PartialEq, Debug)]
239pub struct Field<A> {
240    field_type: FieldType,
241    attr: A,
242}
243
244impl Field<OwnedValue<'static>> {
245    /// Checks if the field is a default value
246    pub fn is_default(&self) -> bool {
247        is_default_owned_value(&self.attr)
248    }
249
250    /// Encodes the tag for a field with the appropriate wire-type
251    pub fn encode_key<B: BufMut>(&self, tag: u32, buf: &mut B) {
252        if let OwnedValue::Vec(_) = self.attr {
253            return;
254        }
255        match self.field_type {
256            FieldType::Bool
257            | FieldType::Int32
258            | FieldType::Int64
259            | FieldType::UInt32
260            | FieldType::UInt64
261            | FieldType::Enum(_) => encode_key(tag, WireType::Varint, buf),
262            FieldType::Double | FieldType::Fixed64 | FieldType::SInt64 | FieldType::SFixed64 => {
263                encode_key(tag, WireType::SixtyFourBit, buf)
264            }
265            FieldType::Float | FieldType::SInt32 | FieldType::Fixed32 | FieldType::SFixed32 => {
266                encode_key(tag, WireType::ThirtyTwoBit, buf)
267            }
268            FieldType::String | FieldType::Message(_) | FieldType::Bytes => {
269                encode_key(tag, WireType::LengthDelimited, buf)
270            }
271            FieldType::Group => {}
272        };
273    }
274
275    /// Encodes just the contents of a field
276    pub fn encode_raw<B: BufMut>(&self, buf: &mut B) {
277        match (&self.field_type, &self.attr) {
278            (FieldType::Bool, OwnedValue::Bool(b)) => {
279                encode_varint(if *b { 1u64 } else { 0u64 }, buf)
280            }
281            (FieldType::Int32, OwnedValue::I32(i)) => encode_varint(*i as u64, buf),
282            (FieldType::Int64, OwnedValue::I64(i)) => encode_varint(*i as u64, buf),
283            (FieldType::UInt32, OwnedValue::U32(i)) => encode_varint(*i as u64, buf),
284            (FieldType::UInt64, OwnedValue::U64(i)) => encode_varint(*i as u64, buf),
285            (FieldType::SInt32, OwnedValue::I32(value)) => {
286                encode_varint(((value << 1) ^ (value >> 31)) as u32 as u64, buf)
287            }
288            (FieldType::SInt64, OwnedValue::I64(value)) => {
289                encode_varint(((value << 1) ^ (value >> 63)) as u64, buf)
290            }
291            (FieldType::Fixed64, OwnedValue::U64(i)) => buf.put_u64_le(*i),
292            (FieldType::Fixed32, OwnedValue::U32(i)) => buf.put_u32_le(*i),
293            (FieldType::String, OwnedValue::String(s)) => {
294                let bytes: &[u8] = s.as_ref();
295                encode_varint(bytes.len() as u64, buf);
296                buf.put_slice(bytes)
297            }
298            (FieldType::Float, OwnedValue::F32(f)) => buf.put_f32_le(*f),
299            (FieldType::Double, OwnedValue::F64(f)) => buf.put_f64_le(*f),
300            (FieldType::Enum(_), OwnedValue::I32(i)) => encode_varint(*i as u64, buf),
301            (FieldType::Message(_), OwnedValue::Option(m)) => {
302                if let Some(v) = m.value() {
303                    if let Some(msg) = v.borrow::<&DynamicMessage>() {
304                        let n = msg.encoded_len();
305                        encode_varint(n as u64, buf);
306                        msg.encode(buf);
307                    }
308                }
309                // For this method in the original implementation I used an extra Vec as the buffer,
310                // and then wrote that buffer. I don't remember why I did that, but if there is bug look here
311                // Also I don't like that I ignore the error here, but I also don't want to propogate errors throughout
312                // -- SPHW
313            }
314            (FieldType::Message(_), OwnedValue::Struct(m)) => {
315                if let Some(msg) = m.as_value().borrow::<&DynamicMessage>() {
316                    let n = msg.encoded_len();
317                    encode_varint(n as u64, buf);
318                    msg.encode(buf);
319                }
320            }
321            (FieldType::Bytes, OwnedValue::Bytes(b)) => {
322                encode_varint(b.remaining() as u64, buf);
323                buf.put_slice(b.chunk());
324            }
325            _ => {
326                //TODO: Handle the type error
327            }
328        }
329    }
330
331    /// Encodes a the field with a given tag
332    pub fn encode<B: BufMut>(&self, tag: u32, buf: &mut B) {
333        match &self.attr {
334            OwnedValue::Vec(r) if r.is_empty() => {}
335            OwnedValue::Vec(r) => {
336                let r: Vec<_> = r
337                    .values()
338                    .iter()
339                    .map(|a| Field {
340                        field_type: self.field_type.clone(),
341                        attr: a.to_owned(),
342                    })
343                    .collect();
344                match self.field_type {
345                    FieldType::Bool
346                    | FieldType::Int32
347                    | FieldType::Int64
348                    | FieldType::SInt32
349                    | FieldType::SInt64
350                    | FieldType::UInt32
351                    | FieldType::UInt64
352                    | FieldType::Float
353                    | FieldType::Double
354                    | FieldType::SFixed32
355                    | FieldType::SFixed64
356                    | FieldType::Fixed32
357                    | FieldType::Fixed64
358                    | FieldType::Enum(_) => {
359                        encode_key(tag, WireType::LengthDelimited, buf);
360                        let len: usize = r.iter().map(|value| value.encoded_len_raw()).sum();
361                        encode_varint(len as u64, buf);
362                        for value in r {
363                            value.encode_raw(buf);
364                        }
365                    }
366                    _ => {
367                        for value in r {
368                            value.encode_key(tag, buf);
369                            value.encode_raw(buf);
370                        }
371                    }
372                };
373            }
374            _ => {
375                if !self.is_default() {
376                    self.encode_key(tag, buf);
377                    self.encode_raw(buf);
378                }
379            }
380        }
381    }
382
383    /// Returns the encoded length of a field
384    pub fn encoded_len(&self, tag: u32) -> usize {
385        match &self.attr {
386            OwnedValue::Vec(r) if r.is_empty() => 0,
387            OwnedValue::Vec(r) => {
388                let values = r.values();
389                let iter = values.iter().map(|a| Field {
390                    field_type: self.field_type.clone(),
391                    attr: a.to_owned(),
392                });
393                let len = iter.map(|f| f.encoded_len_raw()).sum::<usize>();
394                let key_len: usize = match self.field_type {
395                    FieldType::Bool
396                    | FieldType::Int32
397                    | FieldType::Int64
398                    | FieldType::SInt32
399                    | FieldType::SInt64
400                    | FieldType::UInt32
401                    | FieldType::UInt64
402                    | FieldType::Float
403                    | FieldType::Double
404                    | FieldType::SFixed32
405                    | FieldType::SFixed64
406                    | FieldType::Fixed32
407                    | FieldType::Fixed64
408                    | FieldType::Enum(_) => key_len(tag) + encoded_len_varint(len as u64),
409                    _ => key_len(tag) * r.len(),
410                };
411                key_len + len
412            }
413            _ => {
414                if !self.is_default() {
415                    key_len(tag) + self.encoded_len_raw()
416                } else {
417                    0
418                }
419            }
420        }
421    }
422    ///Returns the encoded length of a field, not taking into account compacted fields
423    pub fn encoded_len_raw(&self) -> usize {
424        match (&self.field_type, &self.attr) {
425            (FieldType::Bool, OwnedValue::Bool(b)) => {
426                encoded_len_varint(if *b { 1u64 } else { 0u64 })
427            }
428            (FieldType::Int32, OwnedValue::I32(i)) => encoded_len_varint(*i as u64),
429            (FieldType::Int64, OwnedValue::I64(i)) => encoded_len_varint(*i as u64),
430            (FieldType::UInt32, OwnedValue::U32(i)) => encoded_len_varint(*i as u64),
431            (FieldType::UInt64, OwnedValue::U64(i)) => encoded_len_varint(*i as u64),
432            (FieldType::SInt32, OwnedValue::I32(value)) => {
433                encoded_len_varint(((value << 1) ^ (value >> 31)) as u32 as u64)
434            }
435            (FieldType::SInt64, OwnedValue::I64(value)) => {
436                encoded_len_varint(((value << 1) ^ (value >> 63)) as u64)
437            }
438            (FieldType::Fixed64, OwnedValue::U64(_)) => 8,
439            (FieldType::Fixed32, OwnedValue::U32(_)) => 4,
440            (FieldType::String, OwnedValue::String(s)) => {
441                let bytes: &[u8] = s.as_ref();
442                encoded_len_varint(bytes.len() as u64) + bytes.len()
443            }
444            (FieldType::Float, OwnedValue::F32(_)) => 4,
445            (FieldType::Double, OwnedValue::F64(_)) => 4,
446            (FieldType::Enum(_), OwnedValue::I32(i)) => encoded_len_varint(*i as u64),
447            (FieldType::Message(_), v @ OwnedValue::Struct(_)) => {
448                if let Ok(msg) = IntoInner::<DynamicMessage>::into_inner(v.clone()) {
449                    let len = msg.encoded_len();
450                    encoded_len_varint(len as u64) + len
451                } else {
452                    0
453                }
454            }
455            (FieldType::Message(_), v @ OwnedValue::Option(_)) => {
456                if let Ok(Some(msg)) = IntoInner::<Option<DynamicMessage>>::into_inner(v.clone()) {
457                    let len = msg.encoded_len();
458                    encoded_len_varint(len as u64) + len
459                } else {
460                    0
461                }
462            }
463            (FieldType::Bytes, OwnedValue::Bytes(b)) => {
464                encoded_len_varint(b.remaining() as u64) + b.remaining()
465            }
466            _ => 0,
467        }
468    }
469}