Skip to main content

prototext_core/serialize/encode_text/
mod.rs

1// SPDX-FileCopyrightText: 2025-2026 Frederic Ruget <fred@atlant.is> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025-2026 THALES CLOUD SECURISE SAS
3//
4// SPDX-License-Identifier: MIT
5
6use crate::helpers::{write_varint_ohb, WT_END_GROUP, WT_LEN, WT_START_GROUP};
7use memchr::memrchr;
8
9mod encode_annotation;
10mod fields;
11mod frame;
12mod placeholder;
13
14#[cfg(test)]
15use encode_annotation::parse_field_decl_into;
16use encode_annotation::{parse_annotation, Ann};
17use fields::{encode_packed_elem, encode_scalar_line, write_tag_ohb_local};
18use frame::Frame;
19use placeholder::{compact, fill_placeholder, write_placeholder};
20
21// ── Helpers: field number and line classification ─────────────────────────────
22
23/// Extract the field number from the LHS of a line and/or annotation.
24///
25/// Precedence: annotation's field_decl (`= N`) > numeric LHS.
26#[inline]
27fn extract_field_number(lhs: &str, ann: &Ann<'_>) -> u64 {
28    if let Some(fn_) = ann.field_number {
29        return fn_;
30    }
31    lhs.trim().parse::<u64>().unwrap_or(0)
32}
33
34/// Split a line into `(value_part, annotation_str)`.
35///
36/// The separator is `  #@ ` (2 spaces + `#` + `@` + space).  We scan right-to-left
37/// so that quoted string values containing `  #@ ` don't confuse the split.
38#[inline]
39fn split_at_annotation(line: &str) -> (&str, &str) {
40    // Find the rightmost "  #@ " separator using SIMD-accelerated memrchr for '#',
41    // then verify the surrounding bytes.  Falls back leftward on false positives
42    // (a bare '#' inside a string value).
43    let b = line.as_bytes();
44    let mut end = b.len();
45    while let Some(p) = memrchr(b'#', &b[..end]) {
46        if p >= 2
47            && b[p - 1] == b' '
48            && b[p - 2] == b' '
49            && p + 2 < b.len()
50            && b[p + 1] == b'@'
51            && b[p + 2] == b' '
52        {
53            // "  #@ " confirmed: field part ends at p-2, annotation starts at p+3
54            return (&line[..p - 2], &line[p + 3..]);
55        }
56        // Also recognize a line whose non-whitespace content starts with "#@ "
57        // (comment-only annotation line, no value token before it).
58        if b[..p].iter().all(|c| *c == b' ' || *c == b'\t')
59            && p + 2 < b.len()
60            && b[p + 1] == b'@'
61            && b[p + 2] == b' '
62        {
63            return ("", &line[p + 3..]);
64        }
65        end = p; // keep searching leftward
66    }
67    (line, "")
68}
69
70// ── Public entry point ────────────────────────────────────────────────────────
71
72/// Decode a textual prototext byte string directly to binary wire bytes.
73///
74/// Input must start with `b"#@ prototext:"`.
75/// The line-by-line format must have been produced with `include_annotations=true`
76/// (the annotation comment on each line is required to reconstruct field numbers
77/// and types when field names are used on the LHS).
78///
79/// Implements Proposal F — Strategy C2 for MESSAGE frames.
80pub fn encode_text_to_binary(text: &[u8]) -> Vec<u8> {
81    let capacity = (text.len() / 6).max(64);
82    let mut out = Vec::with_capacity(capacity);
83
84    let mut stack: Vec<Frame> = Vec::new();
85    let mut first_placeholder: Option<usize> = None;
86    let mut last_placeholder: Option<usize> = None;
87
88    // ── Per-line packed state ─────────────────────────────────────────────────
89    // When non-None, we are buffering elements for a per-line packed record.
90    // `packed_field_number`: the field number of the active record.
91    // `packed_tag_ohb`: tag overhang for the record's wire tag.
92    // `packed_len_ohb`: length overhang for the record's LEN prefix.
93    // `packed_remaining`: how many more element lines to consume.
94    // `packed_payload`: accumulated payload bytes.
95    let mut packed_field_number: u64 = 0;
96    let mut packed_tag_ohb: Option<u64> = None;
97    let mut packed_len_ohb: Option<u64> = None;
98    let mut packed_remaining: usize = 0;
99    let mut packed_payload: Vec<u8> = Vec::new();
100
101    // The text is always valid ASCII (a subset of UTF-8).
102    let text_str = match std::str::from_utf8(text) {
103        Ok(s) => s,
104        Err(_) => return out,
105    };
106
107    let mut lines = text_str.lines();
108
109    // Skip the first line: "#@ prototext: protoc"
110    lines.next();
111
112    for line in lines {
113        let line = line.trim_end(); // strip trailing CR/spaces
114
115        if line.is_empty() {
116            continue;
117        }
118
119        // ── Close brace ───────────────────────────────────────────────────────
120        //
121        // Brace-folding may place multiple `}` on one line, separated by spaces
122        // (e.g. `}}` for indent_size=1, `} } }` for indent_size=2).  A close-
123        // brace line consists solely of `}` and space characters after the
124        // leading indentation.  Walk the trimmed line byte-by-byte and pop the
125        // stack once per `}` found.
126
127        let trimmed = line.trim_start();
128        if !trimmed.is_empty() && trimmed.bytes().all(|b| b == b'}' || b == b' ') {
129            for b in trimmed.bytes() {
130                if b == b'}' {
131                    match stack.pop() {
132                        Some(Frame::Message {
133                            placeholder_start,
134                            ohb,
135                            content_start,
136                            acw,
137                        }) => {
138                            let total_waste = fill_placeholder(
139                                &mut out,
140                                placeholder_start,
141                                ohb,
142                                content_start,
143                                acw,
144                            );
145                            // Propagate waste to parent frame.
146                            if let Some(parent) = stack.last_mut() {
147                                *parent.acw_mut() += total_waste;
148                            }
149                        }
150                        Some(Frame::Group {
151                            field_number,
152                            open_ended,
153                            mismatched_end,
154                            end_tag_ohb,
155                            acw,
156                        }) => {
157                            if !open_ended {
158                                let end_fn = mismatched_end.unwrap_or(field_number);
159                                write_tag_ohb_local(end_fn, WT_END_GROUP, end_tag_ohb, &mut out);
160                            }
161                            // Propagate accumulated waste from inner MESSAGE placeholders.
162                            if acw > 0 {
163                                if let Some(parent) = stack.last_mut() {
164                                    *parent.acw_mut() += acw;
165                                }
166                            }
167                        }
168                        None => { /* unmatched `}` — ignore */ }
169                    }
170                }
171            }
172            continue;
173        }
174
175        // Skip plain comments (`# ...`) that carry no wire semantics.
176        // Only `#@ ...` lines (handled via split_at_annotation) have meaning.
177        if trimmed.starts_with('#') && !trimmed.starts_with("#@") {
178            continue;
179        }
180
181        // Split value part from annotation.
182        let (value_part, ann_str) = split_at_annotation(line);
183
184        // ── Open brace ────────────────────────────────────────────────────────
185
186        // Detect `name {` (possibly indented, before the annotation).
187        let vp_trimmed = value_part.trim_end();
188        let is_open_brace = vp_trimmed.ends_with(" {") || vp_trimmed == "{";
189
190        if is_open_brace {
191            let ann = parse_annotation(ann_str);
192
193            // Extract the field name (LHS of `name {`).
194            let lhs = vp_trimmed.trim_start().trim_end_matches('{').trim_end();
195
196            let field_number = extract_field_number(lhs, &ann);
197            let tag_ohb = ann.tag_overhang_count;
198
199            if ann.wire_type == "group" {
200                write_tag_ohb_local(field_number, WT_START_GROUP, tag_ohb, &mut out);
201                stack.push(Frame::Group {
202                    field_number,
203                    open_ended: ann.open_ended_group,
204                    mismatched_end: ann.mismatched_group_end,
205                    end_tag_ohb: ann.end_tag_overhang_count,
206                    acw: 0,
207                });
208            } else {
209                // MESSAGE (wire type BYTES or unspecified).
210                write_tag_ohb_local(field_number, WT_LEN, tag_ohb, &mut out);
211                let ohb = ann.length_overhang_count.unwrap_or(0) as usize;
212                let (ph_start, content_start) =
213                    write_placeholder(&mut out, ohb, &mut first_placeholder, &mut last_placeholder);
214                stack.push(Frame::Message {
215                    placeholder_start: ph_start,
216                    ohb,
217                    content_start,
218                    acw: 0,
219                });
220            }
221            continue;
222        }
223
224        // ── Scalar field line ─────────────────────────────────────────────────
225
226        // Detect a comment-only annotation line (no LHS colon, starts with `#@ `).
227        // This is used for empty packed records: `pack_size: 0`.
228        let trimmed_vp = value_part.trim();
229        if trimmed_vp.is_empty() && !ann_str.is_empty() {
230            // Comment-only line — parse annotation to handle pack_size: 0.
231            let ann = parse_annotation(ann_str);
232            if let Some(0) = ann.pack_size {
233                // Empty packed record: emit tag + len=0.
234                write_tag_ohb_local(
235                    ann.field_number.unwrap_or(0),
236                    WT_LEN,
237                    ann.tag_overhang_count,
238                    &mut out,
239                );
240                write_varint_ohb(0, ann.length_overhang_count, &mut out);
241            }
242            continue;
243        }
244
245        // Find the colon separating LHS from value.
246        let Some(colon_pos) = value_part.find(':') else {
247            continue;
248        };
249        let lhs = value_part[..colon_pos].trim_start(); // may be indented
250        let value_str = value_part[colon_pos + 1..].trim();
251
252        let ann = parse_annotation(ann_str);
253        let field_number = extract_field_number(lhs, &ann);
254
255        // ── Per-line packed: continuation element ─────────────────────────────
256        if packed_remaining > 0 {
257            encode_packed_elem(value_str, &ann, &mut packed_payload);
258            packed_remaining -= 1;
259            if packed_remaining == 0 {
260                // Flush the completed wire record.
261                write_tag_ohb_local(packed_field_number, WT_LEN, packed_tag_ohb, &mut out);
262                write_varint_ohb(packed_payload.len() as u64, packed_len_ohb, &mut out);
263                out.extend_from_slice(&packed_payload);
264                packed_payload.clear();
265            }
266            continue;
267        }
268
269        // ── Per-line packed: first element (pack_size: N) ─────────────────────
270        if ann.is_packed {
271            if let Some(n) = ann.pack_size {
272                if n == 0 {
273                    // Empty record — emit immediately.
274                    write_tag_ohb_local(field_number, WT_LEN, ann.tag_overhang_count, &mut out);
275                    write_varint_ohb(0, ann.length_overhang_count, &mut out);
276                } else {
277                    // Start buffering.
278                    packed_field_number = field_number;
279                    packed_tag_ohb = ann.tag_overhang_count;
280                    packed_len_ohb = ann.length_overhang_count;
281                    packed_remaining = n - 1; // this line is element 0
282                    packed_payload.clear();
283                    encode_packed_elem(value_str, &ann, &mut packed_payload);
284                    if packed_remaining == 0 {
285                        // Single-element record — flush immediately.
286                        write_tag_ohb_local(packed_field_number, WT_LEN, packed_tag_ohb, &mut out);
287                        write_varint_ohb(packed_payload.len() as u64, packed_len_ohb, &mut out);
288                        out.extend_from_slice(&packed_payload);
289                        packed_payload.clear();
290                    }
291                }
292                continue;
293            }
294        }
295
296        encode_scalar_line(field_number, value_str, &ann, &mut out);
297    }
298
299    // ── Forward compaction pass ───────────────────────────────────────────────
300
301    if let Some(first_ph) = first_placeholder {
302        compact(&mut out, first_ph);
303    }
304
305    // Development instrumentation — size ratio
306    #[cfg(debug_assertions)]
307    {
308        let ratio = out.len() as f64 / text.len().max(1) as f64;
309        eprintln!(
310            "[encode_text] input_len={} output_len={} ratio={:.2}",
311            text.len(),
312            out.len(),
313            ratio
314        );
315    }
316
317    out
318}
319
320// ── Unit tests ────────────────────────────────────────────────────────────────
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    // ── split_at_annotation ───────────────────────────────────────────────────
327
328    #[test]
329    fn split_bare() {
330        let (field, ann) = split_at_annotation("name: 42");
331        assert_eq!(field, "name: 42");
332        assert_eq!(ann, "");
333    }
334
335    #[test]
336    fn split_hash_at_space() {
337        let (field, ann) = split_at_annotation("name: 42  #@ varint = 1");
338        assert_eq!(field, "name: 42");
339        assert_eq!(ann, "varint = 1");
340    }
341
342    #[test]
343    fn split_hash_only() {
344        // Bare '#' without '@': not a separator.
345        let (field, ann) = split_at_annotation("name: 42  #");
346        assert_eq!(field, "name: 42  #");
347        assert_eq!(ann, "");
348    }
349
350    #[test]
351    fn split_hash_at_end() {
352        // "#@" at end with no space after '@': not a separator.
353        let (field, ann) = split_at_annotation("name: 42  #@");
354        assert_eq!(field, "name: 42  #@");
355        assert_eq!(ann, "");
356    }
357
358    #[test]
359    fn split_hash_at_no_space() {
360        // "#@x" — '@' not followed by space: not a separator.
361        let (field, ann) = split_at_annotation("name: 42  #@x");
362        assert_eq!(field, "name: 42  #@x");
363        assert_eq!(ann, "");
364    }
365
366    // ── parse_field_decl_into — enum suffix forms ─────────────────────────────
367
368    fn make_ann() -> Ann<'static> {
369        Ann {
370            wire_type: "",
371            field_type: "",
372            field_number: None,
373            is_packed: false,
374            tag_overhang_count: None,
375            value_overhang_count: None,
376            length_overhang_count: None,
377            missing_bytes_count: None,
378            mismatched_group_end: None,
379            open_ended_group: false,
380            end_tag_overhang_count: None,
381            records_overhung_count: vec![],
382            neg_int32_truncated: false,
383            records_neg_int32_truncated: vec![],
384            enum_scalar_value: None,
385            enum_packed_values: vec![],
386            nan_bits: None,
387            pack_size: None,
388            elem_ohb: None,
389            elem_neg_trunc: false,
390        }
391    }
392
393    #[test]
394    fn parse_scalar_enum() {
395        let mut ann = make_ann();
396        parse_field_decl_into("Type(9) = 5", &mut ann);
397        assert_eq!(ann.field_type, "enum");
398        assert_eq!(ann.enum_scalar_value, Some(9));
399        assert_eq!(ann.field_number, Some(5));
400    }
401
402    #[test]
403    fn parse_scalar_enum_neg() {
404        let mut ann = make_ann();
405        parse_field_decl_into("Color(-1) = 3", &mut ann);
406        assert_eq!(ann.field_type, "enum");
407        assert_eq!(ann.enum_scalar_value, Some(-1));
408        assert_eq!(ann.field_number, Some(3));
409    }
410
411    #[test]
412    fn parse_packed_enum() {
413        let mut ann = make_ann();
414        parse_field_decl_into("Label([1, 2, 3]) [packed=true] = 4", &mut ann);
415        assert_eq!(ann.field_type, "enum");
416        assert!(ann.is_packed);
417        assert_eq!(ann.enum_packed_values, vec![1, 2, 3]);
418        assert_eq!(ann.field_number, Some(4));
419    }
420
421    #[test]
422    fn parse_primitive_int32() {
423        let mut ann = make_ann();
424        parse_field_decl_into("int32 = 25", &mut ann);
425        assert_eq!(ann.field_type, "int32");
426        assert_eq!(ann.field_number, Some(25));
427        assert_eq!(ann.enum_scalar_value, None);
428    }
429
430    #[test]
431    fn parse_enum_named_float() {
432        // Latent-bug regression (spec 0004 §5.1): an enum whose type name
433        // collides with the 'float' primitive must route to varint, not fixed32.
434        let mut ann = make_ann();
435        parse_field_decl_into("float(1) = 1", &mut ann);
436        assert_eq!(
437            ann.field_type, "enum",
438            "enum named 'float' must set field_type='enum', not 'float'"
439        );
440        assert_eq!(ann.enum_scalar_value, Some(1));
441    }
442
443    // ── ENUM_UNKNOWN silencing ────────────────────────────────────────────────
444
445    #[test]
446    fn enum_unknown_encodes_correctly() {
447        // A field annotated with ENUM_UNKNOWN must encode the varint from the
448        // annotation's EnumType(N) suffix, not fail or produce wrong bytes.
449        // Field 1, value 99 → tag 0x08 (field=1, wire=varint), varint 0x63.
450        let input = b"#@ prototext: protoc\nkind: 99  #@ Type(99) = 1; ENUM_UNKNOWN\n";
451        let wire = encode_text_to_binary(input);
452        assert_eq!(
453            wire,
454            vec![0x08, 0x63],
455            "ENUM_UNKNOWN field 1 value 99: expected [0x08, 0x63]"
456        );
457    }
458}