1use oxiproto_core::wire::{
13 zigzag_decode32, zigzag_decode64, zigzag_encode32, zigzag_encode64, DecodeBuffer, EncodeBuffer,
14 Tag, UnknownFields, WireType,
15};
16
17use super::descriptor::{Cardinality, FieldDescriptor, Kind, MessageDescriptor};
18use super::dynamic::{default_scalar_value, is_field_value_default, DynamicMessage};
19use super::value::{MapKey, Value};
20use crate::ReflectError;
21
22impl DynamicMessage {
23 pub fn encode_to_vec(&self) -> Result<Vec<u8>, ReflectError> {
34 let mut buf = EncodeBuffer::new();
35 self.encode(&mut buf)?;
36 Ok(buf.into_vec())
37 }
38
39 pub fn encode(&self, buf: &mut EncodeBuffer) -> Result<(), ReflectError> {
45 for (field, value) in self.iter_fields() {
46 if is_field_value_default(&field, value) {
49 continue;
50 }
51 encode_field(buf, &field, value)?;
52 }
53 self.unknown.encode_to(buf);
56 Ok(())
57 }
58
59 pub fn decode(desc: MessageDescriptor, bytes: &[u8]) -> Result<Self, ReflectError> {
70 let mut msg = DynamicMessage::new(desc);
71 let mut dec = DecodeBuffer::new(bytes);
72 decode_into(&mut msg, &mut dec)?;
73 Ok(msg)
74 }
75}
76
77fn decode_into(msg: &mut DynamicMessage, dec: &mut DecodeBuffer<'_>) -> Result<(), ReflectError> {
79 while !dec.is_empty() {
80 let tag = dec.read_tag().map_err(wire_err)?;
81 let desc = msg.descriptor();
82 match desc.get_field(tag.field_number) {
83 Some(field) => decode_known_field(msg, &field, tag, dec)?,
84 None => decode_unknown_field(&mut msg.unknown, tag, dec)?,
85 }
86 }
87 Ok(())
88}
89
90fn decode_known_field(
92 msg: &mut DynamicMessage,
93 field: &FieldDescriptor,
94 tag: Tag,
95 dec: &mut DecodeBuffer<'_>,
96) -> Result<(), ReflectError> {
97 if field.is_map() {
98 return decode_map_entry(msg, field, tag, dec);
99 }
100
101 match field.cardinality() {
102 Cardinality::Repeated => decode_repeated(msg, field, tag, dec),
103 Cardinality::Optional | Cardinality::Required => {
104 let value = decode_single_value(field, tag, dec)?;
105 msg.set_field(field, value);
108 Ok(())
109 }
110 }
111}
112
113fn decode_repeated(
115 msg: &mut DynamicMessage,
116 field: &FieldDescriptor,
117 tag: Tag,
118 dec: &mut DecodeBuffer<'_>,
119) -> Result<(), ReflectError> {
120 if tag.wire_type == WireType::Len && field.kind().is_packable() {
122 let payload = dec.read_length_delimited().map_err(wire_err)?;
123 let mut inner = DecodeBuffer::new(payload);
124 let mut decoded = Vec::new();
125 while !inner.is_empty() {
126 decoded.push(decode_scalar_from(field.kind(), &mut inner)?);
127 }
128 append_to_list(msg, field, decoded);
129 return Ok(());
130 }
131
132 let value = decode_single_value(field, tag, dec)?;
134 append_to_list(msg, field, vec![value]);
135 Ok(())
136}
137
138fn append_to_list(msg: &mut DynamicMessage, field: &FieldDescriptor, mut elems: Vec<Value>) {
140 let entry = msg
141 .fields
142 .entry(field.number())
143 .or_insert_with(|| Value::List(Vec::new()));
144 match entry {
145 Value::List(list) => list.append(&mut elems),
146 other => {
149 let mut list = Vec::new();
150 list.append(&mut elems);
151 *other = Value::List(list);
152 }
153 }
154}
155
156fn decode_single_value(
158 field: &FieldDescriptor,
159 tag: Tag,
160 dec: &mut DecodeBuffer<'_>,
161) -> Result<Value, ReflectError> {
162 match field.kind() {
163 Kind::Group(_) => Err(group_unsupported()),
164 Kind::Message(idx) => {
165 if tag.wire_type != WireType::Len {
166 return Err(ReflectError::Field(format!(
167 "message field '{}' expected length-delimited wire type, got {}",
168 field.name(),
169 tag.wire_type
170 )));
171 }
172 let payload = dec.read_length_delimited().map_err(wire_err)?;
173 let nested_desc = MessageDescriptor {
174 pool: field.pool.clone(),
175 index: idx,
176 };
177 let nested = DynamicMessage::decode(nested_desc, payload)?;
178 Ok(Value::Message(Box::new(nested)))
179 }
180 kind => decode_scalar_with_tag(kind, tag, dec, field),
181 }
182}
183
184fn decode_scalar_with_tag(
186 kind: Kind,
187 tag: Tag,
188 dec: &mut DecodeBuffer<'_>,
189 field: &FieldDescriptor,
190) -> Result<Value, ReflectError> {
191 let expected = scalar_wire_type(kind)?;
192 if tag.wire_type != expected {
193 return Err(ReflectError::Field(format!(
194 "field '{}' expected wire type {expected}, got {}",
195 field.name(),
196 tag.wire_type
197 )));
198 }
199 decode_scalar_from(kind, dec)
200}
201
202fn decode_scalar_from(kind: Kind, dec: &mut DecodeBuffer<'_>) -> Result<Value, ReflectError> {
205 let value = match kind {
206 Kind::Double => Value::F64(dec.read_double().map_err(wire_err)?),
207 Kind::Float => Value::F32(dec.read_float().map_err(wire_err)?),
208 Kind::Int32 => Value::I32(dec.read_varint().map_err(wire_err)? as i32),
209 Kind::Int64 => Value::I64(dec.read_varint().map_err(wire_err)? as i64),
210 Kind::Uint32 => {
211 let v = dec.read_varint().map_err(wire_err)?;
212 Value::U32(v as u32)
213 }
214 Kind::Uint64 => Value::U64(dec.read_varint().map_err(wire_err)?),
215 Kind::Sint32 => {
216 let raw = dec.read_varint().map_err(wire_err)? as u32;
217 Value::I32(zigzag_decode32(raw))
218 }
219 Kind::Sint64 => {
220 let raw = dec.read_varint().map_err(wire_err)?;
221 Value::I64(zigzag_decode64(raw))
222 }
223 Kind::Fixed32 => Value::U32(dec.read_fixed32().map_err(wire_err)?),
224 Kind::Fixed64 => Value::U64(dec.read_fixed64().map_err(wire_err)?),
225 Kind::Sfixed32 => Value::I32(dec.read_fixed32().map_err(wire_err)? as i32),
226 Kind::Sfixed64 => Value::I64(dec.read_fixed64().map_err(wire_err)? as i64),
227 Kind::Bool => Value::Bool(dec.read_varint().map_err(wire_err)? != 0),
228 Kind::String => Value::String(dec.read_string().map_err(wire_err)?.to_owned()),
229 Kind::Bytes => Value::Bytes(dec.read_length_delimited().map_err(wire_err)?.to_vec()),
230 Kind::Enum(_) => Value::EnumNumber(dec.read_varint().map_err(wire_err)? as i32),
231 Kind::Message(_) | Kind::Group(_) => {
232 return Err(ReflectError::Field(
233 "message/group kind is not a scalar".to_owned(),
234 ))
235 }
236 };
237 Ok(value)
238}
239
240fn decode_map_entry(
242 msg: &mut DynamicMessage,
243 field: &FieldDescriptor,
244 tag: Tag,
245 dec: &mut DecodeBuffer<'_>,
246) -> Result<(), ReflectError> {
247 if tag.wire_type != WireType::Len {
248 return Err(ReflectError::Field(format!(
249 "map field '{}' expected length-delimited entries, got {}",
250 field.name(),
251 tag.wire_type
252 )));
253 }
254 let payload = dec.read_length_delimited().map_err(wire_err)?;
255
256 let key_field = field
257 .map_entry_key_field()
258 .ok_or_else(|| ReflectError::Field("map field missing entry key field".to_owned()))?;
259 let value_field = field
260 .map_entry_value_field()
261 .ok_or_else(|| ReflectError::Field("map field missing entry value field".to_owned()))?;
262
263 let mut key_val = default_scalar_value(key_field.kind());
265 let mut val_val = match value_field.kind() {
266 Kind::Message(idx) => {
267 let nested_desc = MessageDescriptor {
268 pool: value_field.pool.clone(),
269 index: idx,
270 };
271 Value::Message(Box::new(DynamicMessage::new(nested_desc)))
272 }
273 other => default_scalar_value(other),
274 };
275
276 let mut entry_dec = DecodeBuffer::new(payload);
277 while !entry_dec.is_empty() {
278 let entry_tag = entry_dec.read_tag().map_err(wire_err)?;
279 match entry_tag.field_number {
280 1 => key_val = decode_single_value(&key_field, entry_tag, &mut entry_dec)?,
281 2 => val_val = decode_single_value(&value_field, entry_tag, &mut entry_dec)?,
282 _ => entry_dec
283 .skip_field(entry_tag.wire_type)
284 .map_err(wire_err)?,
285 }
286 }
287
288 let map_key = value_to_map_key(&key_val).ok_or_else(|| {
289 ReflectError::Field(format!(
290 "map field '{}' has an unsupported key type",
291 field.name()
292 ))
293 })?;
294
295 let entry = msg
296 .fields
297 .entry(field.number())
298 .or_insert_with(|| Value::Map(std::collections::HashMap::new()));
299 match entry {
300 Value::Map(map) => {
301 map.insert(map_key, val_val);
302 }
303 other => {
304 let mut map = std::collections::HashMap::new();
305 map.insert(map_key, val_val);
306 *other = Value::Map(map);
307 }
308 }
309 Ok(())
310}
311
312fn decode_unknown_field(
315 unknown: &mut UnknownFields,
316 tag: Tag,
317 dec: &mut DecodeBuffer<'_>,
318) -> Result<(), ReflectError> {
319 match tag.wire_type {
320 WireType::Varint => {
321 let v = dec.read_varint().map_err(wire_err)?;
322 unknown.push_varint(tag.field_number, v);
323 }
324 WireType::I64 => {
325 let v = dec.read_fixed64().map_err(wire_err)?;
326 unknown.push_fixed64(tag.field_number, v);
327 }
328 WireType::I32 => {
329 let v = dec.read_fixed32().map_err(wire_err)?;
330 unknown.push_fixed32(tag.field_number, v);
331 }
332 WireType::Len => {
333 let payload = dec.read_length_delimited().map_err(wire_err)?;
334 unknown.push_length_delimited(tag.field_number, payload.to_vec());
335 }
336 WireType::SGroup | WireType::EGroup => return Err(group_unsupported()),
337 }
338 Ok(())
339}
340
341fn encode_field(
347 buf: &mut EncodeBuffer,
348 field: &FieldDescriptor,
349 value: &Value,
350) -> Result<(), ReflectError> {
351 if field.is_map() {
352 return encode_map(buf, field, value);
353 }
354 match field.cardinality() {
355 Cardinality::Repeated => encode_repeated(buf, field, value),
356 Cardinality::Optional | Cardinality::Required => {
357 encode_single(buf, field, value, field.number())
358 }
359 }
360}
361
362fn encode_repeated(
365 buf: &mut EncodeBuffer,
366 field: &FieldDescriptor,
367 value: &Value,
368) -> Result<(), ReflectError> {
369 let list = match value {
370 Value::List(l) => l,
371 _ => {
372 return Err(ReflectError::Field(format!(
373 "repeated field '{}' holds a non-list value",
374 field.name()
375 )))
376 }
377 };
378 if list.is_empty() {
379 return Ok(());
380 }
381
382 if field.is_packed() && field.kind().is_packable() {
383 let mut payload = EncodeBuffer::new();
385 for elem in list {
386 encode_scalar_payload(&mut payload, field.kind(), elem, field)?;
387 }
388 buf.write_tag(field.number(), WireType::Len)
389 .map_err(wire_err)?;
390 buf.write_length_delimited(payload.as_bytes());
391 } else {
392 for elem in list {
393 encode_single(buf, field, elem, field.number())?;
394 }
395 }
396 Ok(())
397}
398
399fn encode_map(
401 buf: &mut EncodeBuffer,
402 field: &FieldDescriptor,
403 value: &Value,
404) -> Result<(), ReflectError> {
405 let map = match value {
406 Value::Map(m) => m,
407 _ => {
408 return Err(ReflectError::Field(format!(
409 "map field '{}' holds a non-map value",
410 field.name()
411 )))
412 }
413 };
414 let key_field = field
415 .map_entry_key_field()
416 .ok_or_else(|| ReflectError::Field("map field missing entry key field".to_owned()))?;
417 let value_field = field
418 .map_entry_value_field()
419 .ok_or_else(|| ReflectError::Field("map field missing entry value field".to_owned()))?;
420
421 for (k, v) in map {
422 let key_value = k.to_value();
423 let mut entry = EncodeBuffer::new();
424 encode_single(&mut entry, &key_field, &key_value, 1)?;
427 encode_single(&mut entry, &value_field, v, 2)?;
428 buf.write_tag(field.number(), WireType::Len)
429 .map_err(wire_err)?;
430 buf.write_length_delimited(entry.as_bytes());
431 }
432 Ok(())
433}
434
435fn encode_single(
437 buf: &mut EncodeBuffer,
438 field: &FieldDescriptor,
439 value: &Value,
440 field_number: u32,
441) -> Result<(), ReflectError> {
442 match field.kind() {
443 Kind::Group(_) => Err(group_unsupported()),
444 Kind::Message(_) => {
445 let nested = match value {
446 Value::Message(m) => m,
447 _ => {
448 return Err(ReflectError::Field(format!(
449 "message field '{}' holds a non-message value",
450 field.name()
451 )))
452 }
453 };
454 let payload = nested.encode_to_vec()?;
455 buf.write_tag(field_number, WireType::Len)
456 .map_err(wire_err)?;
457 buf.write_length_delimited(&payload);
458 Ok(())
459 }
460 kind => {
461 let wt = scalar_wire_type(kind)?;
462 buf.write_tag(field_number, wt).map_err(wire_err)?;
463 encode_scalar_payload(buf, kind, value, field)
464 }
465 }
466}
467
468fn encode_scalar_payload(
471 buf: &mut EncodeBuffer,
472 kind: Kind,
473 value: &Value,
474 field: &FieldDescriptor,
475) -> Result<(), ReflectError> {
476 match kind {
477 Kind::Double => buf.write_double(expect_f64(value, field)?),
478 Kind::Float => buf.write_float(expect_f32(value, field)?),
479 Kind::Int32 => buf.write_varint_i32(expect_i32(value, field)?),
480 Kind::Int64 => buf.write_varint_i64(expect_i64(value, field)?),
481 Kind::Uint32 => buf.write_varint32(expect_u32(value, field)?),
482 Kind::Uint64 => buf.write_varint(expect_u64(value, field)?),
483 Kind::Sint32 => buf.write_varint32(zigzag_encode32(expect_i32(value, field)?)),
484 Kind::Sint64 => buf.write_varint(zigzag_encode64(expect_i64(value, field)?)),
485 Kind::Fixed32 => buf.write_fixed32(expect_u32(value, field)?),
486 Kind::Fixed64 => buf.write_fixed64(expect_u64(value, field)?),
487 Kind::Sfixed32 => buf.write_fixed32(expect_i32(value, field)? as u32),
488 Kind::Sfixed64 => buf.write_fixed64(expect_i64(value, field)? as u64),
489 Kind::Bool => buf.write_bool(expect_bool(value, field)?),
490 Kind::String => buf.write_string(expect_str(value, field)?),
491 Kind::Bytes => buf.write_length_delimited(expect_bytes(value, field)?),
492 Kind::Enum(_) => buf.write_varint_i32(expect_enum(value, field)?),
493 Kind::Message(_) | Kind::Group(_) => {
494 return Err(ReflectError::Field(
495 "message/group kind has no scalar payload".to_owned(),
496 ))
497 }
498 }
499 Ok(())
500}
501
502fn scalar_wire_type(kind: Kind) -> Result<WireType, ReflectError> {
508 let wt = match kind {
509 Kind::Int32
510 | Kind::Int64
511 | Kind::Uint32
512 | Kind::Uint64
513 | Kind::Sint32
514 | Kind::Sint64
515 | Kind::Bool
516 | Kind::Enum(_) => WireType::Varint,
517 Kind::Fixed64 | Kind::Sfixed64 | Kind::Double => WireType::I64,
518 Kind::Fixed32 | Kind::Sfixed32 | Kind::Float => WireType::I32,
519 Kind::String | Kind::Bytes => WireType::Len,
520 Kind::Message(_) | Kind::Group(_) => {
521 return Err(ReflectError::Field(
522 "message/group kind has no scalar wire type".to_owned(),
523 ))
524 }
525 };
526 Ok(wt)
527}
528
529fn value_to_map_key(value: &Value) -> Option<MapKey> {
532 match value {
533 Value::String(s) => Some(MapKey::String(s.clone())),
534 Value::I32(v) => Some(MapKey::I32(*v)),
535 Value::I64(v) => Some(MapKey::I64(*v)),
536 Value::U32(v) => Some(MapKey::U32(*v)),
537 Value::U64(v) => Some(MapKey::U64(*v)),
538 Value::Bool(v) => Some(MapKey::Bool(*v)),
539 _ => None,
540 }
541}
542
543fn group_unsupported() -> ReflectError {
545 ReflectError::Field("protobuf groups (wire types 3/4) are unsupported".to_owned())
546}
547
548fn wire_err(e: oxiproto_core::wire::WireError) -> ReflectError {
550 ReflectError::Field(format!("wire format error: {e}"))
551}
552
553fn type_mismatch(field: &FieldDescriptor, expected: &str) -> ReflectError {
557 ReflectError::Field(format!(
558 "field '{}' expected a {expected} value",
559 field.name()
560 ))
561}
562
563fn expect_f64(value: &Value, field: &FieldDescriptor) -> Result<f64, ReflectError> {
564 value.as_f64().ok_or_else(|| type_mismatch(field, "f64"))
565}
566fn expect_f32(value: &Value, field: &FieldDescriptor) -> Result<f32, ReflectError> {
567 value.as_f32().ok_or_else(|| type_mismatch(field, "f32"))
568}
569fn expect_i32(value: &Value, field: &FieldDescriptor) -> Result<i32, ReflectError> {
570 value.as_i32().ok_or_else(|| type_mismatch(field, "i32"))
571}
572fn expect_i64(value: &Value, field: &FieldDescriptor) -> Result<i64, ReflectError> {
573 value.as_i64().ok_or_else(|| type_mismatch(field, "i64"))
574}
575fn expect_u32(value: &Value, field: &FieldDescriptor) -> Result<u32, ReflectError> {
576 value.as_u32().ok_or_else(|| type_mismatch(field, "u32"))
577}
578fn expect_u64(value: &Value, field: &FieldDescriptor) -> Result<u64, ReflectError> {
579 value.as_u64().ok_or_else(|| type_mismatch(field, "u64"))
580}
581fn expect_bool(value: &Value, field: &FieldDescriptor) -> Result<bool, ReflectError> {
582 value.as_bool().ok_or_else(|| type_mismatch(field, "bool"))
583}
584fn expect_str<'a>(value: &'a Value, field: &FieldDescriptor) -> Result<&'a str, ReflectError> {
585 value.as_str().ok_or_else(|| type_mismatch(field, "string"))
586}
587fn expect_bytes<'a>(value: &'a Value, field: &FieldDescriptor) -> Result<&'a [u8], ReflectError> {
588 value
589 .as_bytes()
590 .ok_or_else(|| type_mismatch(field, "bytes"))
591}
592fn expect_enum(value: &Value, field: &FieldDescriptor) -> Result<i32, ReflectError> {
593 value
594 .as_enum_number()
595 .or_else(|| value.as_i32())
596 .ok_or_else(|| type_mismatch(field, "enum number"))
597}