Skip to main content

oxiproto_reflect/native/
wire_codec.rs

1//! Protobuf wire-format encode/decode for [`DynamicMessage`].
2//!
3//! All wire primitives are reused from [`oxiproto_core::wire`] — this module
4//! contains no hand-rolled varint logic. It implements the *schema-aware*
5//! layer on top: choosing wire types per [`Kind`], packed vs unpacked repeated
6//! scalars, `map<K, V>` as repeated synthetic entries, proto3 default
7//! omission, and unknown-field preservation.
8//!
9//! Groups (wire types 3/4) are explicitly rejected with
10//! [`ReflectError::Field`].
11
12use oxiproto_core::wire::{
13    zigzag_decode32, zigzag_decode64, zigzag_encode32, zigzag_encode64, DecodeBuffer, EncodeBuffer,
14    Tag, UnknownFields, WireType,
15};
16
17use super::descriptor::{Cardinality, FieldDescriptor, Kind, MessageDescriptor};
18use super::dynamic::{default_scalar_value, is_field_value_default, DynamicMessage};
19use super::value::{MapKey, Value};
20use crate::ReflectError;
21
22impl DynamicMessage {
23    /// Encode this message to a freshly-allocated byte vector.
24    ///
25    /// Fields are written in ascending field-number order; unknown fields are
26    /// appended afterwards. Proto3 singular fields equal to their type default
27    /// are omitted.
28    ///
29    /// # Errors
30    ///
31    /// Returns [`ReflectError::Field`] if the message (or a nested message)
32    /// contains a group-kind field, which is unsupported.
33    pub fn encode_to_vec(&self) -> Result<Vec<u8>, ReflectError> {
34        let mut buf = EncodeBuffer::new();
35        self.encode(&mut buf)?;
36        Ok(buf.into_vec())
37    }
38
39    /// Encode this message into an existing [`EncodeBuffer`].
40    ///
41    /// # Errors
42    ///
43    /// See [`DynamicMessage::encode_to_vec`].
44    pub fn encode(&self, buf: &mut EncodeBuffer) -> Result<(), ReflectError> {
45        for (field, value) in self.iter_fields() {
46            // Skip singular fields whose value equals the default (proto3
47            // omission). Repeated/map empties are also skipped.
48            if is_field_value_default(&field, value) {
49                continue;
50            }
51            encode_field(buf, &field, value)?;
52        }
53        // Re-emit preserved unknown fields so that data added by a newer schema
54        // survives a decode → encode round-trip.
55        self.unknown.encode_to(buf);
56        Ok(())
57    }
58
59    /// Decode a message of the given descriptor from `bytes`.
60    ///
61    /// Repeated scalar fields accept both packed and unpacked encodings.
62    /// Fields whose numbers are not in the descriptor are preserved as unknown
63    /// fields.
64    ///
65    /// # Errors
66    ///
67    /// Returns [`ReflectError::Field`] on malformed wire data (propagated from
68    /// the wire layer) or if a group-kind field is encountered.
69    pub fn decode(desc: MessageDescriptor, bytes: &[u8]) -> Result<Self, ReflectError> {
70        let mut msg = DynamicMessage::new(desc);
71        let mut dec = DecodeBuffer::new(bytes);
72        decode_into(&mut msg, &mut dec)?;
73        Ok(msg)
74    }
75}
76
77/// Decode the contents of `dec` into `msg` until the buffer is exhausted.
78fn decode_into(msg: &mut DynamicMessage, dec: &mut DecodeBuffer<'_>) -> Result<(), ReflectError> {
79    while !dec.is_empty() {
80        let tag = dec.read_tag().map_err(wire_err)?;
81        let desc = msg.descriptor();
82        match desc.get_field(tag.field_number) {
83            Some(field) => decode_known_field(msg, &field, tag, dec)?,
84            None => decode_unknown_field(&mut msg.unknown, tag, dec)?,
85        }
86    }
87    Ok(())
88}
89
90/// Decode a field that is present in the descriptor.
91fn decode_known_field(
92    msg: &mut DynamicMessage,
93    field: &FieldDescriptor,
94    tag: Tag,
95    dec: &mut DecodeBuffer<'_>,
96) -> Result<(), ReflectError> {
97    if field.is_map() {
98        return decode_map_entry(msg, field, tag, dec);
99    }
100
101    match field.cardinality() {
102        Cardinality::Repeated => decode_repeated(msg, field, tag, dec),
103        Cardinality::Optional | Cardinality::Required => {
104            let value = decode_single_value(field, tag, dec)?;
105            // For singular fields, last-one-wins (protobuf merge semantics for
106            // scalars); set_field also enforces oneof exclusivity.
107            msg.set_field(field, value);
108            Ok(())
109        }
110    }
111}
112
113/// Decode a repeated field element, appending to (or creating) its list.
114fn decode_repeated(
115    msg: &mut DynamicMessage,
116    field: &FieldDescriptor,
117    tag: Tag,
118    dec: &mut DecodeBuffer<'_>,
119) -> Result<(), ReflectError> {
120    // Packed encoding: a single length-delimited blob of back-to-back scalars.
121    if tag.wire_type == WireType::Len && field.kind().is_packable() {
122        let payload = dec.read_length_delimited().map_err(wire_err)?;
123        let mut inner = DecodeBuffer::new(payload);
124        let mut decoded = Vec::new();
125        while !inner.is_empty() {
126            decoded.push(decode_scalar_from(field.kind(), &mut inner)?);
127        }
128        append_to_list(msg, field, decoded);
129        return Ok(());
130    }
131
132    // Unpacked encoding: one tag+value per element.
133    let value = decode_single_value(field, tag, dec)?;
134    append_to_list(msg, field, vec![value]);
135    Ok(())
136}
137
138/// Append elements to the field's list value, creating the list if absent.
139fn append_to_list(msg: &mut DynamicMessage, field: &FieldDescriptor, mut elems: Vec<Value>) {
140    let entry = msg
141        .fields
142        .entry(field.number())
143        .or_insert_with(|| Value::List(Vec::new()));
144    match entry {
145        Value::List(list) => list.append(&mut elems),
146        // If a non-list somehow occupies the slot (e.g. a prior singular
147        // decode), replace it with a fresh list.
148        other => {
149            let mut list = Vec::new();
150            list.append(&mut elems);
151            *other = Value::List(list);
152        }
153    }
154}
155
156/// Decode a single scalar/message value for a field given its tag.
157fn decode_single_value(
158    field: &FieldDescriptor,
159    tag: Tag,
160    dec: &mut DecodeBuffer<'_>,
161) -> Result<Value, ReflectError> {
162    match field.kind() {
163        Kind::Group(_) => Err(group_unsupported()),
164        Kind::Message(idx) => {
165            if tag.wire_type != WireType::Len {
166                return Err(ReflectError::Field(format!(
167                    "message field '{}' expected length-delimited wire type, got {}",
168                    field.name(),
169                    tag.wire_type
170                )));
171            }
172            let payload = dec.read_length_delimited().map_err(wire_err)?;
173            let nested_desc = MessageDescriptor {
174                pool: field.pool.clone(),
175                index: idx,
176            };
177            let nested = DynamicMessage::decode(nested_desc, payload)?;
178            Ok(Value::Message(Box::new(nested)))
179        }
180        kind => decode_scalar_with_tag(kind, tag, dec, field),
181    }
182}
183
184/// Decode a scalar value, validating the tag's wire type.
185fn decode_scalar_with_tag(
186    kind: Kind,
187    tag: Tag,
188    dec: &mut DecodeBuffer<'_>,
189    field: &FieldDescriptor,
190) -> Result<Value, ReflectError> {
191    let expected = scalar_wire_type(kind)?;
192    if tag.wire_type != expected {
193        return Err(ReflectError::Field(format!(
194            "field '{}' expected wire type {expected}, got {}",
195            field.name(),
196            tag.wire_type
197        )));
198    }
199    decode_scalar_from(kind, dec)
200}
201
202/// Decode a scalar value of `kind` from the buffer (wire type already known to
203/// match). Used for both single values and packed elements.
204fn decode_scalar_from(kind: Kind, dec: &mut DecodeBuffer<'_>) -> Result<Value, ReflectError> {
205    let value = match kind {
206        Kind::Double => Value::F64(dec.read_double().map_err(wire_err)?),
207        Kind::Float => Value::F32(dec.read_float().map_err(wire_err)?),
208        Kind::Int32 => Value::I32(dec.read_varint().map_err(wire_err)? as i32),
209        Kind::Int64 => Value::I64(dec.read_varint().map_err(wire_err)? as i64),
210        Kind::Uint32 => {
211            let v = dec.read_varint().map_err(wire_err)?;
212            Value::U32(v as u32)
213        }
214        Kind::Uint64 => Value::U64(dec.read_varint().map_err(wire_err)?),
215        Kind::Sint32 => {
216            let raw = dec.read_varint().map_err(wire_err)? as u32;
217            Value::I32(zigzag_decode32(raw))
218        }
219        Kind::Sint64 => {
220            let raw = dec.read_varint().map_err(wire_err)?;
221            Value::I64(zigzag_decode64(raw))
222        }
223        Kind::Fixed32 => Value::U32(dec.read_fixed32().map_err(wire_err)?),
224        Kind::Fixed64 => Value::U64(dec.read_fixed64().map_err(wire_err)?),
225        Kind::Sfixed32 => Value::I32(dec.read_fixed32().map_err(wire_err)? as i32),
226        Kind::Sfixed64 => Value::I64(dec.read_fixed64().map_err(wire_err)? as i64),
227        Kind::Bool => Value::Bool(dec.read_varint().map_err(wire_err)? != 0),
228        Kind::String => Value::String(dec.read_string().map_err(wire_err)?.to_owned()),
229        Kind::Bytes => Value::Bytes(dec.read_length_delimited().map_err(wire_err)?.to_vec()),
230        Kind::Enum(_) => Value::EnumNumber(dec.read_varint().map_err(wire_err)? as i32),
231        Kind::Message(_) | Kind::Group(_) => {
232            return Err(ReflectError::Field(
233                "message/group kind is not a scalar".to_owned(),
234            ))
235        }
236    };
237    Ok(value)
238}
239
240/// Decode one `map<K, V>` synthetic entry message and merge it into the map.
241fn decode_map_entry(
242    msg: &mut DynamicMessage,
243    field: &FieldDescriptor,
244    tag: Tag,
245    dec: &mut DecodeBuffer<'_>,
246) -> Result<(), ReflectError> {
247    if tag.wire_type != WireType::Len {
248        return Err(ReflectError::Field(format!(
249            "map field '{}' expected length-delimited entries, got {}",
250            field.name(),
251            tag.wire_type
252        )));
253    }
254    let payload = dec.read_length_delimited().map_err(wire_err)?;
255
256    let key_field = field
257        .map_entry_key_field()
258        .ok_or_else(|| ReflectError::Field("map field missing entry key field".to_owned()))?;
259    let value_field = field
260        .map_entry_value_field()
261        .ok_or_else(|| ReflectError::Field("map field missing entry value field".to_owned()))?;
262
263    // A map entry omits key/value when they equal the default; supply defaults.
264    let mut key_val = default_scalar_value(key_field.kind());
265    let mut val_val = match value_field.kind() {
266        Kind::Message(idx) => {
267            let nested_desc = MessageDescriptor {
268                pool: value_field.pool.clone(),
269                index: idx,
270            };
271            Value::Message(Box::new(DynamicMessage::new(nested_desc)))
272        }
273        other => default_scalar_value(other),
274    };
275
276    let mut entry_dec = DecodeBuffer::new(payload);
277    while !entry_dec.is_empty() {
278        let entry_tag = entry_dec.read_tag().map_err(wire_err)?;
279        match entry_tag.field_number {
280            1 => key_val = decode_single_value(&key_field, entry_tag, &mut entry_dec)?,
281            2 => val_val = decode_single_value(&value_field, entry_tag, &mut entry_dec)?,
282            _ => entry_dec
283                .skip_field(entry_tag.wire_type)
284                .map_err(wire_err)?,
285        }
286    }
287
288    let map_key = value_to_map_key(&key_val).ok_or_else(|| {
289        ReflectError::Field(format!(
290            "map field '{}' has an unsupported key type",
291            field.name()
292        ))
293    })?;
294
295    let entry = msg
296        .fields
297        .entry(field.number())
298        .or_insert_with(|| Value::Map(std::collections::HashMap::new()));
299    match entry {
300        Value::Map(map) => {
301            map.insert(map_key, val_val);
302        }
303        other => {
304            let mut map = std::collections::HashMap::new();
305            map.insert(map_key, val_val);
306            *other = Value::Map(map);
307        }
308    }
309    Ok(())
310}
311
312/// Decode an unknown field (one whose number is absent from the descriptor),
313/// preserving its raw bytes.
314fn decode_unknown_field(
315    unknown: &mut UnknownFields,
316    tag: Tag,
317    dec: &mut DecodeBuffer<'_>,
318) -> Result<(), ReflectError> {
319    match tag.wire_type {
320        WireType::Varint => {
321            let v = dec.read_varint().map_err(wire_err)?;
322            unknown.push_varint(tag.field_number, v);
323        }
324        WireType::I64 => {
325            let v = dec.read_fixed64().map_err(wire_err)?;
326            unknown.push_fixed64(tag.field_number, v);
327        }
328        WireType::I32 => {
329            let v = dec.read_fixed32().map_err(wire_err)?;
330            unknown.push_fixed32(tag.field_number, v);
331        }
332        WireType::Len => {
333            let payload = dec.read_length_delimited().map_err(wire_err)?;
334            unknown.push_length_delimited(tag.field_number, payload.to_vec());
335        }
336        WireType::SGroup | WireType::EGroup => return Err(group_unsupported()),
337    }
338    Ok(())
339}
340
341// ---------------------------------------------------------------------------
342// Encoding
343// ---------------------------------------------------------------------------
344
345/// Encode a single (already non-default) field.
346fn encode_field(
347    buf: &mut EncodeBuffer,
348    field: &FieldDescriptor,
349    value: &Value,
350) -> Result<(), ReflectError> {
351    if field.is_map() {
352        return encode_map(buf, field, value);
353    }
354    match field.cardinality() {
355        Cardinality::Repeated => encode_repeated(buf, field, value),
356        Cardinality::Optional | Cardinality::Required => {
357            encode_single(buf, field, value, field.number())
358        }
359    }
360}
361
362/// Encode a repeated field (packed for packable scalars when the field's
363/// `packed` flag is set, otherwise unpacked).
364fn encode_repeated(
365    buf: &mut EncodeBuffer,
366    field: &FieldDescriptor,
367    value: &Value,
368) -> Result<(), ReflectError> {
369    let list = match value {
370        Value::List(l) => l,
371        _ => {
372            return Err(ReflectError::Field(format!(
373                "repeated field '{}' holds a non-list value",
374                field.name()
375            )))
376        }
377    };
378    if list.is_empty() {
379        return Ok(());
380    }
381
382    if field.is_packed() && field.kind().is_packable() {
383        // Packed: a single length-delimited payload of back-to-back scalars.
384        let mut payload = EncodeBuffer::new();
385        for elem in list {
386            encode_scalar_payload(&mut payload, field.kind(), elem, field)?;
387        }
388        buf.write_tag(field.number(), WireType::Len)
389            .map_err(wire_err)?;
390        buf.write_length_delimited(payload.as_bytes());
391    } else {
392        for elem in list {
393            encode_single(buf, field, elem, field.number())?;
394        }
395    }
396    Ok(())
397}
398
399/// Encode a `map<K, V>` field as a series of synthetic entry messages.
400fn encode_map(
401    buf: &mut EncodeBuffer,
402    field: &FieldDescriptor,
403    value: &Value,
404) -> Result<(), ReflectError> {
405    let map = match value {
406        Value::Map(m) => m,
407        _ => {
408            return Err(ReflectError::Field(format!(
409                "map field '{}' holds a non-map value",
410                field.name()
411            )))
412        }
413    };
414    let key_field = field
415        .map_entry_key_field()
416        .ok_or_else(|| ReflectError::Field("map field missing entry key field".to_owned()))?;
417    let value_field = field
418        .map_entry_value_field()
419        .ok_or_else(|| ReflectError::Field("map field missing entry value field".to_owned()))?;
420
421    for (k, v) in map {
422        let key_value = k.to_value();
423        let mut entry = EncodeBuffer::new();
424        // Map entries always write key (field 1) and value (field 2), even at
425        // default, to match the canonical encoding produced by protoc/prost.
426        encode_single(&mut entry, &key_field, &key_value, 1)?;
427        encode_single(&mut entry, &value_field, v, 2)?;
428        buf.write_tag(field.number(), WireType::Len)
429            .map_err(wire_err)?;
430        buf.write_length_delimited(entry.as_bytes());
431    }
432    Ok(())
433}
434
435/// Encode a single value (scalar or message) with the given field number.
436fn encode_single(
437    buf: &mut EncodeBuffer,
438    field: &FieldDescriptor,
439    value: &Value,
440    field_number: u32,
441) -> Result<(), ReflectError> {
442    match field.kind() {
443        Kind::Group(_) => Err(group_unsupported()),
444        Kind::Message(_) => {
445            let nested = match value {
446                Value::Message(m) => m,
447                _ => {
448                    return Err(ReflectError::Field(format!(
449                        "message field '{}' holds a non-message value",
450                        field.name()
451                    )))
452                }
453            };
454            let payload = nested.encode_to_vec()?;
455            buf.write_tag(field_number, WireType::Len)
456                .map_err(wire_err)?;
457            buf.write_length_delimited(&payload);
458            Ok(())
459        }
460        kind => {
461            let wt = scalar_wire_type(kind)?;
462            buf.write_tag(field_number, wt).map_err(wire_err)?;
463            encode_scalar_payload(buf, kind, value, field)
464        }
465    }
466}
467
468/// Encode just the payload of a scalar (no tag), used for both singular and
469/// packed-repeated elements.
470fn encode_scalar_payload(
471    buf: &mut EncodeBuffer,
472    kind: Kind,
473    value: &Value,
474    field: &FieldDescriptor,
475) -> Result<(), ReflectError> {
476    match kind {
477        Kind::Double => buf.write_double(expect_f64(value, field)?),
478        Kind::Float => buf.write_float(expect_f32(value, field)?),
479        Kind::Int32 => buf.write_varint_i32(expect_i32(value, field)?),
480        Kind::Int64 => buf.write_varint_i64(expect_i64(value, field)?),
481        Kind::Uint32 => buf.write_varint32(expect_u32(value, field)?),
482        Kind::Uint64 => buf.write_varint(expect_u64(value, field)?),
483        Kind::Sint32 => buf.write_varint32(zigzag_encode32(expect_i32(value, field)?)),
484        Kind::Sint64 => buf.write_varint(zigzag_encode64(expect_i64(value, field)?)),
485        Kind::Fixed32 => buf.write_fixed32(expect_u32(value, field)?),
486        Kind::Fixed64 => buf.write_fixed64(expect_u64(value, field)?),
487        Kind::Sfixed32 => buf.write_fixed32(expect_i32(value, field)? as u32),
488        Kind::Sfixed64 => buf.write_fixed64(expect_i64(value, field)? as u64),
489        Kind::Bool => buf.write_bool(expect_bool(value, field)?),
490        Kind::String => buf.write_string(expect_str(value, field)?),
491        Kind::Bytes => buf.write_length_delimited(expect_bytes(value, field)?),
492        Kind::Enum(_) => buf.write_varint_i32(expect_enum(value, field)?),
493        Kind::Message(_) | Kind::Group(_) => {
494            return Err(ReflectError::Field(
495                "message/group kind has no scalar payload".to_owned(),
496            ))
497        }
498    }
499    Ok(())
500}
501
502// ---------------------------------------------------------------------------
503// Helpers
504// ---------------------------------------------------------------------------
505
506/// The wire type used to encode a scalar of `kind`.
507fn scalar_wire_type(kind: Kind) -> Result<WireType, ReflectError> {
508    let wt = match kind {
509        Kind::Int32
510        | Kind::Int64
511        | Kind::Uint32
512        | Kind::Uint64
513        | Kind::Sint32
514        | Kind::Sint64
515        | Kind::Bool
516        | Kind::Enum(_) => WireType::Varint,
517        Kind::Fixed64 | Kind::Sfixed64 | Kind::Double => WireType::I64,
518        Kind::Fixed32 | Kind::Sfixed32 | Kind::Float => WireType::I32,
519        Kind::String | Kind::Bytes => WireType::Len,
520        Kind::Message(_) | Kind::Group(_) => {
521            return Err(ReflectError::Field(
522                "message/group kind has no scalar wire type".to_owned(),
523            ))
524        }
525    };
526    Ok(wt)
527}
528
529/// Convert a decoded scalar [`Value`] into a [`MapKey`], if it is a valid key
530/// type.
531fn value_to_map_key(value: &Value) -> Option<MapKey> {
532    match value {
533        Value::String(s) => Some(MapKey::String(s.clone())),
534        Value::I32(v) => Some(MapKey::I32(*v)),
535        Value::I64(v) => Some(MapKey::I64(*v)),
536        Value::U32(v) => Some(MapKey::U32(*v)),
537        Value::U64(v) => Some(MapKey::U64(*v)),
538        Value::Bool(v) => Some(MapKey::Bool(*v)),
539        _ => None,
540    }
541}
542
543/// Build the canonical "groups unsupported" error.
544fn group_unsupported() -> ReflectError {
545    ReflectError::Field("protobuf groups (wire types 3/4) are unsupported".to_owned())
546}
547
548/// Map a [`oxiproto_core::wire::WireError`] to a [`ReflectError`].
549fn wire_err(e: oxiproto_core::wire::WireError) -> ReflectError {
550    ReflectError::Field(format!("wire format error: {e}"))
551}
552
553// Typed accessors used during encode, producing a descriptive error on a type
554// mismatch rather than panicking.
555
556fn type_mismatch(field: &FieldDescriptor, expected: &str) -> ReflectError {
557    ReflectError::Field(format!(
558        "field '{}' expected a {expected} value",
559        field.name()
560    ))
561}
562
563fn expect_f64(value: &Value, field: &FieldDescriptor) -> Result<f64, ReflectError> {
564    value.as_f64().ok_or_else(|| type_mismatch(field, "f64"))
565}
566fn expect_f32(value: &Value, field: &FieldDescriptor) -> Result<f32, ReflectError> {
567    value.as_f32().ok_or_else(|| type_mismatch(field, "f32"))
568}
569fn expect_i32(value: &Value, field: &FieldDescriptor) -> Result<i32, ReflectError> {
570    value.as_i32().ok_or_else(|| type_mismatch(field, "i32"))
571}
572fn expect_i64(value: &Value, field: &FieldDescriptor) -> Result<i64, ReflectError> {
573    value.as_i64().ok_or_else(|| type_mismatch(field, "i64"))
574}
575fn expect_u32(value: &Value, field: &FieldDescriptor) -> Result<u32, ReflectError> {
576    value.as_u32().ok_or_else(|| type_mismatch(field, "u32"))
577}
578fn expect_u64(value: &Value, field: &FieldDescriptor) -> Result<u64, ReflectError> {
579    value.as_u64().ok_or_else(|| type_mismatch(field, "u64"))
580}
581fn expect_bool(value: &Value, field: &FieldDescriptor) -> Result<bool, ReflectError> {
582    value.as_bool().ok_or_else(|| type_mismatch(field, "bool"))
583}
584fn expect_str<'a>(value: &'a Value, field: &FieldDescriptor) -> Result<&'a str, ReflectError> {
585    value.as_str().ok_or_else(|| type_mismatch(field, "string"))
586}
587fn expect_bytes<'a>(value: &'a Value, field: &FieldDescriptor) -> Result<&'a [u8], ReflectError> {
588    value
589        .as_bytes()
590        .ok_or_else(|| type_mismatch(field, "bytes"))
591}
592fn expect_enum(value: &Value, field: &FieldDescriptor) -> Result<i32, ReflectError> {
593    value
594        .as_enum_number()
595        .or_else(|| value.as_i32())
596        .ok_or_else(|| type_mismatch(field, "enum number"))
597}