prototext_core/
instantiate.rs1use prost::bytes::Bytes;
11use prost::Message as ProstMessage;
12use prost_reflect::{Cardinality, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, Value};
13use rand::rngs::StdRng;
14use rand::{Rng, SeedableRng};
15use sha2::{Digest, Sha256};
16
17const UNSUPPORTED_WKTS: &[&str] = &[
20 "google.protobuf.Any",
21 "google.protobuf.Struct",
22 "google.protobuf.Value",
23 "google.protobuf.ListValue",
24];
25
26pub struct InstantiateOpts {
30 pub seed: i64,
32 pub max_depth: usize,
34 pub max_repeated: usize,
36 pub p_optional: f64,
38 pub quiet: bool,
40}
41
42impl Default for InstantiateOpts {
43 fn default() -> Self {
44 InstantiateOpts {
45 seed: 0,
46 max_depth: 4,
47 max_repeated: 3,
48 p_optional: 0.7,
49 quiet: false,
50 }
51 }
52}
53
54pub fn generate_message_bytes(descriptor: &MessageDescriptor, opts: &InstantiateOpts) -> Vec<u8> {
59 let fqdn = format!(".{}", descriptor.full_name());
60 let seed_input = format!("{}:{}", opts.seed, fqdn);
61 let hash = Sha256::digest(seed_input.as_bytes());
62 let mut seed_bytes = [0u8; 32];
63 seed_bytes.copy_from_slice(&hash);
64 let mut rng = StdRng::from_seed(seed_bytes);
65
66 let msg = generate_message(descriptor, &mut rng, 0, opts);
67 msg.encode_to_vec()
68}
69
70fn generate_message(
73 descriptor: &MessageDescriptor,
74 rng: &mut StdRng,
75 depth: usize,
76 opts: &InstantiateOpts,
77) -> DynamicMessage {
78 let mut msg = DynamicMessage::new(descriptor.clone());
79
80 if UNSUPPORTED_WKTS.contains(&descriptor.full_name()) {
81 if !opts.quiet {
82 eprintln!(
83 "warning: leaving {} empty (unsupported WKT)",
84 descriptor.full_name()
85 );
86 }
87 return msg;
88 }
89
90 let mut oneof_field_numbers: std::collections::HashSet<u32> = std::collections::HashSet::new();
93
94 for oneof in descriptor.oneofs() {
95 if oneof.is_synthetic() {
97 continue;
98 }
99 for f in oneof.fields() {
101 oneof_field_numbers.insert(f.number());
102 }
103 if rng.gen::<f64>() > opts.p_optional {
105 continue;
106 }
107 let fields: Vec<FieldDescriptor> = oneof.fields().collect();
109 let chosen = &fields[rng.gen_range(0..fields.len())];
110 if let Some(value) = generate_value(chosen, rng, depth, opts) {
111 msg.set_field(chosen, value);
112 }
113 }
114
115 for field in descriptor.fields() {
117 if oneof_field_numbers.contains(&field.number()) {
118 continue;
119 }
120
121 match field.cardinality() {
122 Cardinality::Required => {
123 if let Some(value) = generate_value(&field, rng, depth, opts) {
124 msg.set_field(&field, value);
125 }
126 }
127 Cardinality::Repeated => {
128 let count = rng.gen_range(0..=opts.max_repeated);
129 if count == 0 {
130 continue;
131 }
132 if field.is_map() {
133 let entry_desc = match field.kind() {
136 Kind::Message(m) => m,
137 _ => continue,
138 };
139 let key_field = match entry_desc.get_field(1) {
140 Some(f) => f,
141 None => continue,
142 };
143 let val_field = match entry_desc.get_field(2) {
144 Some(f) => f,
145 None => continue,
146 };
147 let mut map = std::collections::HashMap::new();
148 for _ in 0..count {
149 let k = match generate_value(&key_field, rng, depth, opts) {
150 Some(v) => v,
151 None => continue,
152 };
153 let v = generate_value(&val_field, rng, depth, opts)
154 .unwrap_or_else(|| Value::default_value_for_field(&val_field));
155 if let Some(map_key) = k.into_map_key() {
156 map.insert(map_key, v);
157 }
158 }
159 if !map.is_empty() {
160 msg.set_field(&field, Value::Map(map));
161 }
162 } else {
163 let values: Vec<Value> = (0..count)
164 .filter_map(|_| generate_value(&field, rng, depth, opts))
165 .collect();
166 if !values.is_empty() {
167 msg.set_field(&field, Value::List(values));
168 }
169 }
170 }
171 Cardinality::Optional => {
172 if rng.gen::<f64>() <= opts.p_optional {
173 if let Some(value) = generate_value(&field, rng, depth, opts) {
174 msg.set_field(&field, value);
175 }
176 }
177 }
178 }
179 }
180
181 msg
182}
183
184fn generate_value(
187 field: &FieldDescriptor,
188 rng: &mut StdRng,
189 depth: usize,
190 opts: &InstantiateOpts,
191) -> Option<Value> {
192 match field.kind() {
193 Kind::Message(msg_desc) => {
194 if depth >= opts.max_depth {
195 return None;
196 }
197 let nested = generate_message(&msg_desc, rng, depth + 1, opts);
201 Some(Value::Message(nested))
202 }
203 Kind::Enum(enum_desc) => {
204 let values: Vec<i32> = enum_desc.values().map(|v| v.number()).collect();
205 let idx = rng.gen_range(0..values.len());
206 Some(Value::EnumNumber(values[idx]))
207 }
208 Kind::Bool => Some(Value::Bool(rng.gen())),
209 Kind::String => Some(Value::String(format!("s{}", rng.gen_range(0u32..10000)))),
210 Kind::Bytes => {
211 let len = rng.gen_range(0..=8usize);
212 let b: Vec<u8> = (0..len).map(|_| rng.gen()).collect();
213 Some(Value::Bytes(Bytes::from(b)))
214 }
215 Kind::Float => Some(Value::F32(rng.gen_range(0.0f32..1000.0))),
216 Kind::Double => Some(Value::F64(rng.gen_range(0.0f64..1000.0))),
217 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Some(Value::I32(rng.gen_range(0..=1000))),
219 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Some(Value::I64(rng.gen_range(0..=1000))),
220 Kind::Uint32 | Kind::Fixed32 => Some(Value::U32(rng.gen_range(0..=1000))),
221 Kind::Uint64 | Kind::Fixed64 => Some(Value::U64(rng.gen_range(0..=1000))),
222 }
223}