Skip to main content

prototext_core/
instantiate.rs

1// SPDX-FileCopyrightText: 2025-2026 Frederic Ruget <fred@atlant.is> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025-2026 THALES CLOUD SECURISE SAS
3//
4// SPDX-License-Identifier: MIT
5
6//! Pseudo-random protobuf instance generator for `prototext instantiate-schema`.
7//!
8//! Implements spec 0056 §"prototext instantiate-schema".
9
10use prost::bytes::Bytes;
11use prost::Message as ProstMessage;
12use prost_reflect::{Cardinality, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, Value};
13use rand::rngs::StdRng;
14use rand::{Rng, SeedableRng};
15use sha2::{Digest, Sha256};
16
17// ── Well-known types left empty in v1 ────────────────────────────────────────
18
19const UNSUPPORTED_WKTS: &[&str] = &[
20    "google.protobuf.Any",
21    "google.protobuf.Struct",
22    "google.protobuf.Value",
23    "google.protobuf.ListValue",
24];
25
26// ── Public entry point ────────────────────────────────────────────────────────
27
28/// Generation parameters.
29pub struct InstantiateOpts {
30    /// User-visible integer seed (recorded in `# seed:` comment).
31    pub seed: i64,
32    /// Maximum recursion depth for nested messages (default 4).
33    pub max_depth: usize,
34    /// Maximum number of elements for repeated fields (default 3).
35    pub max_repeated: usize,
36    /// Probability [0,1] of populating an optional field (default 0.7).
37    pub p_optional: f64,
38    /// Suppress warnings to stderr.
39    pub quiet: bool,
40}
41
42impl Default for InstantiateOpts {
43    fn default() -> Self {
44        InstantiateOpts {
45            seed: 0,
46            max_depth: 4,
47            max_repeated: 3,
48            p_optional: 0.7,
49            quiet: false,
50        }
51    }
52}
53
54/// Generate a pseudo-random binary protobuf for `descriptor`.
55///
56/// The effective PRNG seed is `SHA256("<seed>:<fqdn>")` → `StdRng::from_seed`.
57/// Returns the raw wire bytes.
58pub fn generate_message_bytes(descriptor: &MessageDescriptor, opts: &InstantiateOpts) -> Vec<u8> {
59    let fqdn = format!(".{}", descriptor.full_name());
60    let seed_input = format!("{}:{}", opts.seed, fqdn);
61    let hash = Sha256::digest(seed_input.as_bytes());
62    let mut seed_bytes = [0u8; 32];
63    seed_bytes.copy_from_slice(&hash);
64    let mut rng = StdRng::from_seed(seed_bytes);
65
66    let msg = generate_message(descriptor, &mut rng, 0, opts);
67    msg.encode_to_vec()
68}
69
70// ── Recursive message generator ───────────────────────────────────────────────
71
72fn generate_message(
73    descriptor: &MessageDescriptor,
74    rng: &mut StdRng,
75    depth: usize,
76    opts: &InstantiateOpts,
77) -> DynamicMessage {
78    let mut msg = DynamicMessage::new(descriptor.clone());
79
80    if UNSUPPORTED_WKTS.contains(&descriptor.full_name()) {
81        if !opts.quiet {
82            eprintln!(
83                "warning: leaving {} empty (unsupported WKT)",
84                descriptor.full_name()
85            );
86        }
87        return msg;
88    }
89
90    // Process oneofs first: for each oneof, decide once whether to populate it
91    // and which field to use.  Track which fields are covered by a oneof.
92    let mut oneof_field_numbers: std::collections::HashSet<u32> = std::collections::HashSet::new();
93
94    for oneof in descriptor.oneofs() {
95        // Skip synthetic oneofs (proto3 optional — treat the field as optional below).
96        if oneof.is_synthetic() {
97            continue;
98        }
99        // Mark all fields in this oneof as handled here.
100        for f in oneof.fields() {
101            oneof_field_numbers.insert(f.number());
102        }
103        // Decide whether to populate the oneof at all.
104        if rng.gen::<f64>() > opts.p_optional {
105            continue;
106        }
107        // Pick one member uniformly at random.
108        let fields: Vec<FieldDescriptor> = oneof.fields().collect();
109        let chosen = &fields[rng.gen_range(0..fields.len())];
110        if let Some(value) = generate_value(chosen, rng, depth, opts) {
111            msg.set_field(chosen, value);
112        }
113    }
114
115    // Process non-oneof fields.
116    for field in descriptor.fields() {
117        if oneof_field_numbers.contains(&field.number()) {
118            continue;
119        }
120
121        match field.cardinality() {
122            Cardinality::Required => {
123                if let Some(value) = generate_value(&field, rng, depth, opts) {
124                    msg.set_field(&field, value);
125                }
126            }
127            Cardinality::Repeated => {
128                let count = rng.gen_range(0..=opts.max_repeated);
129                if count == 0 {
130                    continue;
131                }
132                if field.is_map() {
133                    // Map fields require Value::Map, not Value::List.
134                    // The entry message has exactly two fields: key (1) and value (2).
135                    let entry_desc = match field.kind() {
136                        Kind::Message(m) => m,
137                        _ => continue,
138                    };
139                    let key_field = match entry_desc.get_field(1) {
140                        Some(f) => f,
141                        None => continue,
142                    };
143                    let val_field = match entry_desc.get_field(2) {
144                        Some(f) => f,
145                        None => continue,
146                    };
147                    let mut map = std::collections::HashMap::new();
148                    for _ in 0..count {
149                        let k = match generate_value(&key_field, rng, depth, opts) {
150                            Some(v) => v,
151                            None => continue,
152                        };
153                        let v = generate_value(&val_field, rng, depth, opts)
154                            .unwrap_or_else(|| Value::default_value_for_field(&val_field));
155                        if let Some(map_key) = k.into_map_key() {
156                            map.insert(map_key, v);
157                        }
158                    }
159                    if !map.is_empty() {
160                        msg.set_field(&field, Value::Map(map));
161                    }
162                } else {
163                    let values: Vec<Value> = (0..count)
164                        .filter_map(|_| generate_value(&field, rng, depth, opts))
165                        .collect();
166                    if !values.is_empty() {
167                        msg.set_field(&field, Value::List(values));
168                    }
169                }
170            }
171            Cardinality::Optional => {
172                if rng.gen::<f64>() <= opts.p_optional {
173                    if let Some(value) = generate_value(&field, rng, depth, opts) {
174                        msg.set_field(&field, value);
175                    }
176                }
177            }
178        }
179    }
180
181    msg
182}
183
184// ── Value generator ───────────────────────────────────────────────────────────
185
186fn generate_value(
187    field: &FieldDescriptor,
188    rng: &mut StdRng,
189    depth: usize,
190    opts: &InstantiateOpts,
191) -> Option<Value> {
192    match field.kind() {
193        Kind::Message(msg_desc) => {
194            if depth >= opts.max_depth {
195                return None;
196            }
197            // Map fields: treat the synthetic entry message as a single repeated
198            // message (count already handled by the caller for Repeated fields;
199            // here we just generate one entry message).
200            let nested = generate_message(&msg_desc, rng, depth + 1, opts);
201            Some(Value::Message(nested))
202        }
203        Kind::Enum(enum_desc) => {
204            let values: Vec<i32> = enum_desc.values().map(|v| v.number()).collect();
205            let idx = rng.gen_range(0..values.len());
206            Some(Value::EnumNumber(values[idx]))
207        }
208        Kind::Bool => Some(Value::Bool(rng.gen())),
209        Kind::String => Some(Value::String(format!("s{}", rng.gen_range(0u32..10000)))),
210        Kind::Bytes => {
211            let len = rng.gen_range(0..=8usize);
212            let b: Vec<u8> = (0..len).map(|_| rng.gen()).collect();
213            Some(Value::Bytes(Bytes::from(b)))
214        }
215        Kind::Float => Some(Value::F32(rng.gen_range(0.0f32..1000.0))),
216        Kind::Double => Some(Value::F64(rng.gen_range(0.0f64..1000.0))),
217        // All integer kinds.
218        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Some(Value::I32(rng.gen_range(0..=1000))),
219        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Some(Value::I64(rng.gen_range(0..=1000))),
220        Kind::Uint32 | Kind::Fixed32 => Some(Value::U32(rng.gen_range(0..=1000))),
221        Kind::Uint64 | Kind::Fixed64 => Some(Value::U64(rng.gen_range(0..=1000))),
222    }
223}