Skip to main content

prost_protovalidate/validator/
editions.rs

1//! Normalize protobuf Edition 2023 descriptors to proto3 format.
2//!
3//! `prost-reflect` 0.16 does not support `syntax = "editions"` and panics
4//! when decoding such descriptors. This module rewrites Edition 2023
5//! descriptors at the wire level so that they are valid proto3, preserving
6//! all extension and option bytes.
7
8use std::collections::HashMap;
9
10use prost::encoding::{WireType, decode_key, decode_varint, encode_key, encode_varint};
11
12/// `FeatureSet.field_presence` values.
13const FIELD_PRESENCE_EXPLICIT: i32 = 1;
14const FIELD_PRESENCE_LEGACY_REQUIRED: i32 = 3;
15
16/// `FeatureSet.message_encoding` values.
17const MESSAGE_ENCODING_DELIMITED: i32 = 2;
18
19/// FieldDescriptorProto.Label values.
20const LABEL_OPTIONAL: i32 = 1;
21const LABEL_REQUIRED: i32 = 2;
22const LABEL_REPEATED: i32 = 3;
23
24/// FieldDescriptorProto.Type values.
25const TYPE_MESSAGE: i32 = 11;
26const TYPE_GROUP: i32 = 10;
27
28// Wire format tag numbers for FileDescriptorProto.
29mod file_tags {
30    pub const MESSAGE_TYPE: u32 = 4;
31    pub const EXTENSION: u32 = 7;
32    pub const OPTIONS: u32 = 8;
33    pub const SYNTAX: u32 = 12;
34    pub const EDITION: u32 = 14;
35}
36
37// Wire format tag numbers for DescriptorProto.
38mod message_tags {
39    pub const FIELD: u32 = 2;
40    pub const NESTED_TYPE: u32 = 3;
41    pub const EXTENSION: u32 = 6;
42    pub const ONEOF_DECL: u32 = 8;
43}
44
45// Wire format tag numbers for FieldDescriptorProto.
46mod field_tags {
47    pub const NAME: u32 = 1;
48    pub const LABEL: u32 = 4;
49    pub const TYPE: u32 = 5;
50    pub const OPTIONS: u32 = 8;
51    pub const ONEOF_INDEX: u32 = 9;
52    pub const PROTO3_OPTIONAL: u32 = 17;
53}
54
55// Wire format tag numbers for FieldOptions.
56mod field_option_tags {
57    pub const FEATURES: u32 = 21;
58}
59
60// Wire format tag numbers for FeatureSet.
61mod feature_tags {
62    pub const FIELD_PRESENCE: u32 = 1;
63    pub const MESSAGE_ENCODING: u32 = 5;
64}
65
66// Wire format tag numbers for FileOptions / MessageOptions.
67mod option_tags {
68    pub const FEATURES: u32 = 50;
69}
70
71/// Normalize a `FileDescriptorSet` so that any Edition 2023 files are
72/// rewritten as `proto3`. Returns the original bytes unchanged if no
73/// edition files are detected.
74#[must_use]
75pub fn normalize_edition_descriptor_set(bytes: &[u8]) -> Vec<u8> {
76    let mut cursor = bytes;
77    let mut has_editions = false;
78
79    // Quick scan: do any entries use editions?
80    while !cursor.is_empty() {
81        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
82            return bytes.to_vec();
83        };
84        match (tag, wire_type) {
85            (1, WireType::LengthDelimited) => {
86                let Ok(len) = decode_len(&mut cursor) else {
87                    return bytes.to_vec();
88                };
89                if cursor.len() < len {
90                    return bytes.to_vec();
91                }
92                if file_has_editions_syntax(&cursor[..len]) {
93                    has_editions = true;
94                    break;
95                }
96                cursor = &cursor[len..];
97            }
98            _ => {
99                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
100                    return bytes.to_vec();
101                }
102            }
103        }
104    }
105
106    if !has_editions {
107        return bytes.to_vec();
108    }
109
110    // Full rewrite pass.
111    let mut cursor = bytes;
112    let mut out = Vec::with_capacity(bytes.len());
113
114    while !cursor.is_empty() {
115        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
116            return bytes.to_vec();
117        };
118        if let (1, WireType::LengthDelimited) = (tag, wire_type) {
119            let Ok(len) = decode_len(&mut cursor) else {
120                return bytes.to_vec();
121            };
122            if cursor.len() < len {
123                return bytes.to_vec();
124            }
125            let file_bytes = &cursor[..len];
126            cursor = &cursor[len..];
127
128            let normalized = normalize_file_descriptor(file_bytes);
129            encode_key(1, WireType::LengthDelimited, &mut out);
130            encode_varint(normalized.len() as u64, &mut out);
131            out.extend_from_slice(&normalized);
132        } else {
133            let start = cursor;
134            if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
135                return bytes.to_vec();
136            }
137            // Re-encode the tag + data.
138            encode_key(tag, wire_type, &mut out);
139            out.extend_from_slice(&start[..start.len() - cursor.len()]);
140        }
141    }
142
143    out
144}
145
146/// Check whether a `FileDescriptorProto` has `syntax = "editions"`.
147fn file_has_editions_syntax(bytes: &[u8]) -> bool {
148    let mut cursor = bytes;
149    while !cursor.is_empty() {
150        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
151            return false;
152        };
153        match (tag, wire_type) {
154            (file_tags::SYNTAX, WireType::LengthDelimited) => {
155                let Ok(len) = decode_len(&mut cursor) else {
156                    return false;
157                };
158                if cursor.len() < len {
159                    return false;
160                }
161                return &cursor[..len] == b"editions";
162            }
163            _ => {
164                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
165                    return false;
166                }
167            }
168        }
169    }
170    false
171}
172
173/// Extract a varint field from a `FeatureSet` message by tag number.
174#[allow(clippy::cast_possible_truncation)] // Protobuf enum values fit in i32.
175fn extract_feature_set_varint(bytes: &[u8], field_tag: u32) -> i32 {
176    let mut cursor = bytes;
177    while !cursor.is_empty() {
178        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
179            break;
180        };
181        match (tag, wire_type) {
182            (t, WireType::Varint) if t == field_tag => {
183                let Ok(v) = decode_varint(&mut cursor) else {
184                    break;
185                };
186                return v as i32;
187            }
188            _ => {
189                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
190                    break;
191                }
192            }
193        }
194    }
195    0
196}
197
198/// Extract a feature varint from an options message (`FileOptions` / `MessageOptions` / `FieldOptions`).
199///
200/// Scans for the `features` submessage at `features_tag`, then reads
201/// the specified `field_tag` varint from the `FeatureSet` inside.
202fn extract_feature_varint(options_bytes: &[u8], features_tag: u32, field_tag: u32) -> i32 {
203    let mut cursor = options_bytes;
204    while !cursor.is_empty() {
205        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
206            break;
207        };
208        match (tag, wire_type) {
209            (t, WireType::LengthDelimited) if t == features_tag => {
210                let Ok(len) = decode_len(&mut cursor) else {
211                    break;
212                };
213                if cursor.len() < len {
214                    break;
215                }
216                let feature_set = &cursor[..len];
217                let val = extract_feature_set_varint(feature_set, field_tag);
218                if val != 0 {
219                    return val;
220                }
221                cursor = &cursor[len..];
222            }
223            _ => {
224                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
225                    break;
226                }
227            }
228        }
229    }
230    0
231}
232
233/// Extract a `FeatureSet` field value from file-level options.
234///
235/// Scans `FileDescriptorProto` bytes for the options submessage (tag 8),
236/// then reads the specified feature field. Returns `0` if not found.
237fn extract_file_level_feature(bytes: &[u8], feature_field_tag: u32) -> i32 {
238    let mut cursor = bytes;
239    while !cursor.is_empty() {
240        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
241            break;
242        };
243        match (tag, wire_type) {
244            (file_tags::OPTIONS, WireType::LengthDelimited) => {
245                let Ok(len) = decode_len(&mut cursor) else {
246                    break;
247                };
248                if cursor.len() < len {
249                    break;
250                }
251                let options_bytes = &cursor[..len];
252                cursor = &cursor[len..];
253                let val =
254                    extract_feature_varint(options_bytes, option_tags::FEATURES, feature_field_tag);
255                if val != 0 {
256                    return val;
257                }
258            }
259            _ => {
260                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
261                    break;
262                }
263            }
264        }
265    }
266    0
267}
268
269/// Normalize a single `FileDescriptorProto`.
270/// If `syntax != "editions"`, returns the bytes unchanged.
271fn normalize_file_descriptor(bytes: &[u8]) -> Vec<u8> {
272    if !file_has_editions_syntax(bytes) {
273        return bytes.to_vec();
274    }
275
276    let presence = extract_file_level_feature(bytes, feature_tags::FIELD_PRESENCE);
277    let file_default_presence = if presence != 0 {
278        presence
279    } else {
280        FIELD_PRESENCE_EXPLICIT
281    };
282    let file_default_encoding = extract_file_level_feature(bytes, feature_tags::MESSAGE_ENCODING);
283
284    let mut cursor = bytes;
285    let mut out = Vec::with_capacity(bytes.len());
286
287    while !cursor.is_empty() {
288        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
289            return bytes.to_vec();
290        };
291
292        match (tag, wire_type) {
293            // Rewrite syntax.
294            (file_tags::SYNTAX, WireType::LengthDelimited) => {
295                let Ok(len) = decode_len(&mut cursor) else {
296                    return bytes.to_vec();
297                };
298                if cursor.len() < len {
299                    return bytes.to_vec();
300                }
301                cursor = &cursor[len..];
302                // Write "proto3" instead.
303                encode_key(file_tags::SYNTAX, WireType::LengthDelimited, &mut out);
304                encode_varint(6, &mut out); // len("proto3")
305                out.extend_from_slice(b"proto3");
306            }
307            // Strip edition field (tag 14).
308            (file_tags::EDITION, WireType::Varint) => {
309                let Ok(_) = decode_varint(&mut cursor) else {
310                    return bytes.to_vec();
311                };
312                // Drop this field.
313            }
314            // Normalize message_type.
315            (file_tags::MESSAGE_TYPE, WireType::LengthDelimited) => {
316                let Ok(len) = decode_len(&mut cursor) else {
317                    return bytes.to_vec();
318                };
319                if cursor.len() < len {
320                    return bytes.to_vec();
321                }
322                let msg_bytes = &cursor[..len];
323                cursor = &cursor[len..];
324                let normalized = normalize_message_descriptor(
325                    msg_bytes,
326                    file_default_presence,
327                    file_default_encoding,
328                );
329                encode_key(file_tags::MESSAGE_TYPE, WireType::LengthDelimited, &mut out);
330                encode_varint(normalized.len() as u64, &mut out);
331                out.extend_from_slice(&normalized);
332            }
333            // Normalize top-level extension fields.
334            (file_tags::EXTENSION, WireType::LengthDelimited) => {
335                let Ok(len) = decode_len(&mut cursor) else {
336                    return bytes.to_vec();
337                };
338                if cursor.len() < len {
339                    return bytes.to_vec();
340                }
341                let field_bytes = &cursor[..len];
342                cursor = &cursor[len..];
343                let normalized = normalize_field_descriptor(
344                    field_bytes,
345                    file_default_presence,
346                    file_default_encoding,
347                );
348                encode_key(file_tags::EXTENSION, WireType::LengthDelimited, &mut out);
349                encode_varint(normalized.len() as u64, &mut out);
350                out.extend_from_slice(&normalized);
351            }
352            // Pass through all other fields unchanged.
353            _ => {
354                let pre = cursor;
355                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
356                    return bytes.to_vec();
357                }
358                encode_key(tag, wire_type, &mut out);
359                out.extend_from_slice(&pre[..pre.len() - cursor.len()]);
360            }
361        }
362    }
363
364    out
365}
366
367/// Normalize a `DescriptorProto` (message type).
368#[allow(clippy::too_many_lines)] // Wire-level rewriting requires sequential field processing.
369fn normalize_message_descriptor(
370    bytes: &[u8],
371    parent_presence: i32,
372    parent_encoding: i32,
373) -> Vec<u8> {
374    // Extract message-level feature overrides.
375    let msg_presence = extract_message_level_feature(bytes, feature_tags::FIELD_PRESENCE)
376        .unwrap_or(parent_presence);
377    let msg_encoding = extract_message_level_feature(bytes, feature_tags::MESSAGE_ENCODING)
378        .unwrap_or(parent_encoding);
379
380    let mut cursor = bytes;
381    let mut out = Vec::with_capacity(bytes.len());
382    let mut oneof_count = 0u32;
383
384    // First pass: count existing oneofs.
385    {
386        let mut scan = bytes;
387        while !scan.is_empty() {
388            let Ok((tag, wire_type)) = decode_key(&mut scan) else {
389                break;
390            };
391            if tag == message_tags::ONEOF_DECL && wire_type == WireType::LengthDelimited {
392                oneof_count += 1;
393            }
394            if skip_wire_value_simple(&mut scan, wire_type).is_err() {
395                break;
396            }
397        }
398    }
399
400    // We need to collect fields that need synthetic oneofs.
401    let mut fields_needing_synthetic_oneof = Vec::new();
402    let mut field_index = 0u32;
403
404    // Collect field info in a first pass.
405    {
406        let mut scan = bytes;
407        while !scan.is_empty() {
408            let Ok((tag, wire_type)) = decode_key(&mut scan) else {
409                break;
410            };
411            if tag == message_tags::FIELD && wire_type == WireType::LengthDelimited {
412                let Ok(len) = decode_len(&mut scan) else {
413                    break;
414                };
415                if scan.len() < len {
416                    break;
417                }
418                let field_bytes = &scan[..len];
419                scan = &scan[len..];
420                let info = analyze_field(field_bytes, msg_presence, msg_encoding);
421                if info.needs_proto3_optional && !info.has_oneof_index {
422                    fields_needing_synthetic_oneof.push((field_index, info.name.clone()));
423                }
424                field_index += 1;
425            } else if skip_wire_value_simple(&mut scan, wire_type).is_err() {
426                break;
427            }
428        }
429    }
430
431    // Build a map of field_index → synthetic oneof index.
432    let mut synthetic_oneof_map: HashMap<u32, u32> = HashMap::new();
433    for (i, (fi, _)) in fields_needing_synthetic_oneof.iter().enumerate() {
434        #[allow(clippy::cast_possible_truncation)] // Oneof index fits in u32.
435        let idx = i as u32;
436        synthetic_oneof_map.insert(*fi, oneof_count + idx);
437    }
438
439    // Second pass: rewrite.
440    field_index = 0;
441    while !cursor.is_empty() {
442        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
443            return bytes.to_vec();
444        };
445
446        match (tag, wire_type) {
447            (message_tags::FIELD, WireType::LengthDelimited) => {
448                let Ok(len) = decode_len(&mut cursor) else {
449                    return bytes.to_vec();
450                };
451                if cursor.len() < len {
452                    return bytes.to_vec();
453                }
454                let field_bytes = &cursor[..len];
455                cursor = &cursor[len..];
456                let synthetic_oneof = synthetic_oneof_map.get(&field_index).copied();
457                let normalized = normalize_field_descriptor_with_oneof(
458                    field_bytes,
459                    msg_presence,
460                    msg_encoding,
461                    synthetic_oneof,
462                );
463                encode_key(message_tags::FIELD, WireType::LengthDelimited, &mut out);
464                encode_varint(normalized.len() as u64, &mut out);
465                out.extend_from_slice(&normalized);
466                field_index += 1;
467            }
468            (message_tags::NESTED_TYPE, WireType::LengthDelimited) => {
469                let Ok(len) = decode_len(&mut cursor) else {
470                    return bytes.to_vec();
471                };
472                if cursor.len() < len {
473                    return bytes.to_vec();
474                }
475                let nested_bytes = &cursor[..len];
476                cursor = &cursor[len..];
477                let normalized =
478                    normalize_message_descriptor(nested_bytes, msg_presence, msg_encoding);
479                encode_key(
480                    message_tags::NESTED_TYPE,
481                    WireType::LengthDelimited,
482                    &mut out,
483                );
484                encode_varint(normalized.len() as u64, &mut out);
485                out.extend_from_slice(&normalized);
486            }
487            (message_tags::EXTENSION, WireType::LengthDelimited) => {
488                let Ok(len) = decode_len(&mut cursor) else {
489                    return bytes.to_vec();
490                };
491                if cursor.len() < len {
492                    return bytes.to_vec();
493                }
494                let ext_bytes = &cursor[..len];
495                cursor = &cursor[len..];
496                let normalized = normalize_field_descriptor(ext_bytes, msg_presence, msg_encoding);
497                encode_key(message_tags::EXTENSION, WireType::LengthDelimited, &mut out);
498                encode_varint(normalized.len() as u64, &mut out);
499                out.extend_from_slice(&normalized);
500            }
501            _ => {
502                let pre = cursor;
503                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
504                    return bytes.to_vec();
505                }
506                encode_key(tag, wire_type, &mut out);
507                out.extend_from_slice(&pre[..pre.len() - cursor.len()]);
508            }
509        }
510    }
511
512    // Append synthetic OneofDescriptorProto entries for proto3_optional fields.
513    for (_, name) in &fields_needing_synthetic_oneof {
514        let oneof_name = format!("_{name}");
515        let mut oneof_bytes = Vec::new();
516        // OneofDescriptorProto.name (tag 1)
517        encode_key(1, WireType::LengthDelimited, &mut oneof_bytes);
518        encode_varint(oneof_name.len() as u64, &mut oneof_bytes);
519        oneof_bytes.extend_from_slice(oneof_name.as_bytes());
520
521        encode_key(
522            message_tags::ONEOF_DECL,
523            WireType::LengthDelimited,
524            &mut out,
525        );
526        encode_varint(oneof_bytes.len() as u64, &mut out);
527        out.extend_from_slice(&oneof_bytes);
528    }
529
530    out
531}
532
533/// Extract a `FeatureSet` field value from message-level options.
534///
535/// Scans `DescriptorProto` bytes for the `MessageOptions` submessage (tag 7),
536/// then reads the specified feature field. Returns `None` if not found.
537fn extract_message_level_feature(bytes: &[u8], feature_field_tag: u32) -> Option<i32> {
538    let mut cursor = bytes;
539    while !cursor.is_empty() {
540        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
541            break;
542        };
543        match (tag, wire_type) {
544            // MessageOptions is tag 7 in DescriptorProto.
545            (7, WireType::LengthDelimited) => {
546                let Ok(len) = decode_len(&mut cursor) else {
547                    break;
548                };
549                if cursor.len() < len {
550                    break;
551                }
552                let options_bytes = &cursor[..len];
553                cursor = &cursor[len..];
554                let val =
555                    extract_feature_varint(options_bytes, option_tags::FEATURES, feature_field_tag);
556                if val != 0 {
557                    return Some(val);
558                }
559            }
560            _ => {
561                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
562                    break;
563                }
564            }
565        }
566    }
567    None
568}
569
570#[allow(clippy::struct_excessive_bools)] // Wire-format analysis produces independent boolean flags.
571struct FieldInfo {
572    name: String,
573    needs_proto3_optional: bool,
574    has_oneof_index: bool,
575    is_delimited: bool,
576    is_legacy_required: bool,
577}
578
579#[allow(clippy::too_many_lines, clippy::cast_possible_truncation)]
580// Protobuf field metadata values fit in i32.
581fn analyze_field(bytes: &[u8], parent_presence: i32, parent_encoding: i32) -> FieldInfo {
582    let mut cursor = bytes;
583    let mut name = String::new();
584    let mut label = 0i32;
585    let mut field_type = 0i32;
586    let mut has_oneof_index = false;
587    let mut field_presence = 0i32;
588    let mut field_encoding = 0i32;
589    let mut has_proto3_optional = false;
590
591    while !cursor.is_empty() {
592        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
593            break;
594        };
595        match (tag, wire_type) {
596            (field_tags::NAME, WireType::LengthDelimited) => {
597                let Ok(len) = decode_len(&mut cursor) else {
598                    break;
599                };
600                if cursor.len() < len {
601                    break;
602                }
603                name = String::from_utf8_lossy(&cursor[..len]).to_string();
604                cursor = &cursor[len..];
605            }
606            (field_tags::LABEL, WireType::Varint) => {
607                let Ok(v) = decode_varint(&mut cursor) else {
608                    break;
609                };
610                label = v as i32;
611            }
612            (field_tags::TYPE, WireType::Varint) => {
613                let Ok(v) = decode_varint(&mut cursor) else {
614                    break;
615                };
616                field_type = v as i32;
617            }
618            (field_tags::ONEOF_INDEX, WireType::Varint) => {
619                let Ok(_) = decode_varint(&mut cursor) else {
620                    break;
621                };
622                has_oneof_index = true;
623            }
624            (field_tags::PROTO3_OPTIONAL, WireType::Varint) => {
625                let Ok(v) = decode_varint(&mut cursor) else {
626                    break;
627                };
628                has_proto3_optional = v != 0;
629            }
630            (field_tags::OPTIONS, WireType::LengthDelimited) => {
631                let Ok(len) = decode_len(&mut cursor) else {
632                    break;
633                };
634                if cursor.len() < len {
635                    break;
636                }
637                let options = &cursor[..len];
638                field_presence = extract_feature_varint(
639                    options,
640                    field_option_tags::FEATURES,
641                    feature_tags::FIELD_PRESENCE,
642                );
643                field_encoding = extract_feature_varint(
644                    options,
645                    field_option_tags::FEATURES,
646                    feature_tags::MESSAGE_ENCODING,
647                );
648                cursor = &cursor[len..];
649            }
650            _ => {
651                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
652                    break;
653                }
654            }
655        }
656    }
657
658    let effective_presence = if field_presence != 0 {
659        field_presence
660    } else {
661        parent_presence
662    };
663
664    // Determine if this field needs proto3_optional.
665    let is_repeated = label == LABEL_REPEATED;
666    let is_message = field_type == TYPE_MESSAGE || field_type == TYPE_GROUP;
667    let needs_proto3_optional = !has_proto3_optional
668        && !is_repeated
669        && !has_oneof_index
670        && effective_presence == FIELD_PRESENCE_EXPLICIT
671        && !is_message;
672
673    // Determine if this message field uses DELIMITED (group) encoding.
674    let effective_encoding = if field_encoding != 0 {
675        field_encoding
676    } else {
677        parent_encoding
678    };
679    let is_delimited =
680        field_type == TYPE_MESSAGE && effective_encoding == MESSAGE_ENCODING_DELIMITED;
681    let is_legacy_required = effective_presence == FIELD_PRESENCE_LEGACY_REQUIRED;
682
683    FieldInfo {
684        name,
685        needs_proto3_optional,
686        has_oneof_index,
687        is_delimited,
688        is_legacy_required,
689    }
690}
691
692/// Normalize a `FieldDescriptorProto` (simple version, no synthetic oneof).
693fn normalize_field_descriptor(bytes: &[u8], parent_presence: i32, parent_encoding: i32) -> Vec<u8> {
694    normalize_field_descriptor_with_oneof(bytes, parent_presence, parent_encoding, None)
695}
696
697#[allow(clippy::cast_possible_truncation)] // Protobuf field metadata values fit in i32.
698fn normalize_field_descriptor_with_oneof(
699    bytes: &[u8],
700    parent_presence: i32,
701    parent_encoding: i32,
702    synthetic_oneof_index: Option<u32>,
703) -> Vec<u8> {
704    let info = analyze_field(bytes, parent_presence, parent_encoding);
705
706    let mut cursor = bytes;
707    let mut out = Vec::with_capacity(bytes.len() + 8);
708
709    while !cursor.is_empty() {
710        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
711            return bytes.to_vec();
712        };
713
714        match (tag, wire_type) {
715            // Rewrite label for LEGACY_REQUIRED.
716            (field_tags::LABEL, WireType::Varint) => {
717                let Ok(v) = decode_varint(&mut cursor) else {
718                    return bytes.to_vec();
719                };
720                encode_key(field_tags::LABEL, WireType::Varint, &mut out);
721                if info.is_legacy_required && v as i32 == LABEL_OPTIONAL {
722                    encode_varint(LABEL_REQUIRED as u64, &mut out);
723                } else {
724                    encode_varint(v, &mut out);
725                }
726            }
727            // Rewrite TYPE_MESSAGE → TYPE_GROUP for DELIMITED encoding.
728            (field_tags::TYPE, WireType::Varint) => {
729                let Ok(v) = decode_varint(&mut cursor) else {
730                    return bytes.to_vec();
731                };
732                encode_key(field_tags::TYPE, WireType::Varint, &mut out);
733                if info.is_delimited && v as i32 == TYPE_MESSAGE {
734                    encode_varint(TYPE_GROUP as u64, &mut out);
735                } else {
736                    encode_varint(v, &mut out);
737                }
738            }
739            // Pass through other fields.
740            _ => {
741                let pre = cursor;
742                if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
743                    return bytes.to_vec();
744                }
745                encode_key(tag, wire_type, &mut out);
746                out.extend_from_slice(&pre[..pre.len() - cursor.len()]);
747            }
748        }
749    }
750
751    // Add proto3_optional if needed.
752    if (info.needs_proto3_optional || synthetic_oneof_index.is_some())
753        && !has_field_tag(bytes, field_tags::PROTO3_OPTIONAL)
754    {
755        encode_key(field_tags::PROTO3_OPTIONAL, WireType::Varint, &mut out);
756        encode_varint(1, &mut out);
757    }
758
759    // Add synthetic oneof_index if needed.
760    if let Some(idx) = synthetic_oneof_index {
761        if !info.has_oneof_index {
762            encode_key(field_tags::ONEOF_INDEX, WireType::Varint, &mut out);
763            encode_varint(u64::from(idx), &mut out);
764        }
765    }
766
767    out
768}
769
770/// Check if a message has a specific tag.
771fn has_field_tag(bytes: &[u8], target_tag: u32) -> bool {
772    let mut cursor = bytes;
773    while !cursor.is_empty() {
774        let Ok((tag, wire_type)) = decode_key(&mut cursor) else {
775            return false;
776        };
777        if tag == target_tag {
778            return true;
779        }
780        if skip_wire_value_simple(&mut cursor, wire_type).is_err() {
781            return false;
782        }
783    }
784    false
785}
786
787fn decode_len(cursor: &mut &[u8]) -> Result<usize, ()> {
788    let v = decode_varint(cursor).map_err(|_| ())?;
789    usize::try_from(v).map_err(|_| ())
790}
791
792fn skip_wire_value_simple(cursor: &mut &[u8], wire_type: WireType) -> Result<(), ()> {
793    match wire_type {
794        WireType::Varint => {
795            decode_varint(cursor).map_err(|_| ())?;
796            Ok(())
797        }
798        WireType::LengthDelimited => {
799            let len = decode_len(cursor)?;
800            if cursor.len() < len {
801                return Err(());
802            }
803            *cursor = &cursor[len..];
804            Ok(())
805        }
806        WireType::ThirtyTwoBit => {
807            if cursor.len() < 4 {
808                return Err(());
809            }
810            *cursor = &cursor[4..];
811            Ok(())
812        }
813        WireType::SixtyFourBit => {
814            if cursor.len() < 8 {
815                return Err(());
816            }
817            *cursor = &cursor[8..];
818            Ok(())
819        }
820        WireType::StartGroup => {
821            // Skip group contents until EndGroup.
822            loop {
823                let (inner_tag, inner_wt) = decode_key(cursor).map_err(|_| ())?;
824                if inner_wt == WireType::EndGroup {
825                    let _ = inner_tag;
826                    break;
827                }
828                skip_wire_value_simple(cursor, inner_wt)?;
829            }
830            Ok(())
831        }
832        WireType::EndGroup => Ok(()),
833    }
834}
835
836#[cfg(test)]
837mod tests {
838    use super::normalize_edition_descriptor_set;
839    use proptest::collection::vec;
840    use proptest::prelude::*;
841
842    proptest! {
843        #[test]
844        fn normalization_is_idempotent_for_arbitrary_bytes(input in vec(any::<u8>(), 0..2048)) {
845            let once = normalize_edition_descriptor_set(&input);
846            let twice = normalize_edition_descriptor_set(&once);
847            prop_assert_eq!(twice, once);
848        }
849    }
850}