1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
use {
crate::witness::{scheduling::DependencyInfo, WitnessBuilder},
std::{
collections::{HashMap, HashSet, VecDeque},
fmt,
},
};
/// Error returned when witness splitting validation fails.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SplitError {
/// No witness builder exists for a public input ACIR index.
NoBuilderForPublicInput { acir_idx: u32 },
/// A public input's builder was partitioned into w2 instead of w1.
PublicInputNotInW1 {
acir_idx: u32,
builder_idx: usize,
},
}
impl fmt::Display for SplitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoBuilderForPublicInput { acir_idx } => {
write!(f, "no builder for public input ACIR index {acir_idx}")
}
Self::PublicInputNotInW1 {
acir_idx,
builder_idx,
} => {
write!(
f,
"public input ACIR index {acir_idx} (builder {builder_idx}) not in w1"
)
}
}
}
}
impl std::error::Error for SplitError {}
/// Analyzes witness builder dependencies and splits them into w1/w2 groups.
///
/// Uses backward reachability from challenge consumers (lookup builders) to
/// identify which builders must be committed before challenge extraction (w1),
/// minimizing overhead. Balances witness counts between w1 and w2.
pub struct WitnessSplitter<'a> {
witness_builders: &'a [WitnessBuilder],
deps: DependencyInfo,
}
impl<'a> WitnessSplitter<'a> {
pub fn new(witness_builders: &'a [WitnessBuilder]) -> Self {
let deps = DependencyInfo::new(witness_builders);
Self {
witness_builders,
deps,
}
}
/// Identifies which builders should be in w1 (pre-challenge) vs w2
/// (post-challenge).
///
/// Returns (w1_builder_indices, w2_builder_indices)
pub fn split_builders(
&self,
acir_public_inputs_indices_set: HashSet<u32>,
) -> Result<(Vec<usize>, Vec<usize>), SplitError> {
let builder_count = self.witness_builders.len();
// Step 1: Find all Challenge builders
let challenge_builders: HashSet<usize> = self
.witness_builders
.iter()
.enumerate()
.filter_map(|(idx, builder)| {
matches!(builder, WitnessBuilder::Challenge(_)).then_some(idx)
})
.collect();
if challenge_builders.is_empty() {
let w1_indices = self.rearrange_w1(
(0..builder_count).collect(),
&acir_public_inputs_indices_set,
)?;
return Ok((w1_indices, Vec::new()));
}
// Step 2: Forward DFS from challenges to find mandatory_w2
// (all builders that transitively depend on challenge outputs)
// Also collect lookup builders (direct challenge consumers)
let mut mandatory_w2 = challenge_builders.clone();
let mut lookup_builders = HashSet::new();
let mut forward_visited = vec![false; builder_count];
let mut forward_stack = VecDeque::new();
for &challenge_idx in &challenge_builders {
forward_visited[challenge_idx] = true;
// Collect direct consumers as lookup builders
for &consumer_idx in &self.deps.adjacency_list[challenge_idx] {
lookup_builders.insert(consumer_idx);
if !forward_visited[consumer_idx] {
forward_visited[consumer_idx] = true;
mandatory_w2.insert(consumer_idx);
forward_stack.push_back(consumer_idx);
}
}
}
// Continue DFS to find all transitive dependents
while let Some(current_idx) = forward_stack.pop_front() {
for &consumer_idx in &self.deps.adjacency_list[current_idx] {
if !forward_visited[consumer_idx] {
forward_visited[consumer_idx] = true;
mandatory_w2.insert(consumer_idx);
forward_stack.push_back(consumer_idx);
}
}
}
// Step 4: Backward DFS from lookup builders to find mandatory_w1
// (exclude anything in mandatory_w2 to maintain disjoint sets)
let witness_producer = &self.deps.witness_producer;
let mut mandatory_w1 = HashSet::new();
let mut backward_visited = vec![false; builder_count];
let mut backward_stack = VecDeque::new();
for &lookup_idx in &lookup_builders {
backward_stack.push_back(lookup_idx);
}
while let Some(current_idx) = backward_stack.pop_front() {
if backward_visited[current_idx] {
continue;
}
backward_visited[current_idx] = true;
// Only add to w1 if not in mandatory_w2 (maintain disjoint)
if !mandatory_w2.contains(¤t_idx)
&& !challenge_builders.contains(¤t_idx)
&& !lookup_builders.contains(¤t_idx)
{
mandatory_w1.insert(current_idx);
}
for &witness_idx in &self.deps.reads[current_idx] {
if let Some(&producer_idx) = witness_producer.get(&witness_idx) {
if !backward_visited[producer_idx] && !mandatory_w2.contains(&producer_idx) {
backward_stack.push_back(producer_idx);
}
}
}
}
// witness_one (builder 0) must always be in w1 to preserve R1CS index 0
// invariant
mandatory_w1.insert(0);
// Step 5: Identify free builders (not in either mandatory set)
let mut free_builders = Vec::new();
for idx in 0..builder_count {
if !mandatory_w1.contains(&idx) && !mandatory_w2.contains(&idx) {
free_builders.push(idx);
}
}
// Step 6: Calculate witness counts for balancing
let mut w1_witness_count: usize = mandatory_w1
.iter()
.map(|&idx| DependencyInfo::extract_writes(&self.witness_builders[idx]).len())
.sum();
let mut w2_witness_count: usize = mandatory_w2
.iter()
.map(|&idx| DependencyInfo::extract_writes(&self.witness_builders[idx]).len())
.sum();
// Step 7: Assign free builders greedily while respecting dependencies
// Rule: if any dependency is in w2, the builder must also be in w2
// (because w1 is solved before w2)
// A free builder for public input witnesses goes in w1.
let mut w1_set = mandatory_w1;
let mut w2_set = mandatory_w2;
for idx in free_builders {
// Check if any dependency is in w2
let must_be_w2 = self.deps.reads[idx].iter().any(|&read_witness| {
self.deps
.witness_producer
.get(&read_witness)
.map_or(false, |&producer| w2_set.contains(&producer))
});
let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len();
// If free builder writes a public witness, add it to w1_set.
if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] {
if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
w1_set.insert(idx);
w1_witness_count += witness_count;
continue;
}
}
if must_be_w2 {
w2_set.insert(idx);
w2_witness_count += witness_count;
} else if w1_witness_count <= w2_witness_count {
w1_set.insert(idx);
w1_witness_count += witness_count;
} else {
w2_set.insert(idx);
w2_witness_count += witness_count;
}
}
// Step 8: Convert sets to sorted vectors
let mut w1_indices: Vec<usize> = w1_set.into_iter().collect();
let mut w2_indices: Vec<usize> = w2_set.into_iter().collect();
w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set)?;
w2_indices.sort_unstable();
Ok((w1_indices, w2_indices))
}
/// Rearranges w1 builder indices into a canonical order:
/// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant
/// 2. Public input builders next, grouped together
/// 3. All other w1 builders last, sorted by index
fn rearrange_w1(
&self,
w1_indices: Vec<usize>,
acir_public_inputs_indices_set: &HashSet<u32>,
) -> Result<Vec<usize>, SplitError> {
let mut public_input_builder_indices = Vec::new();
let mut rest_indices = Vec::new();
let w1_indices_set = w1_indices.iter().copied().collect::<HashSet<_>>();
// Build ACIR index -> builder index map for O(1) lookups (O(B) once)
let acir_to_builder: HashMap<u32, usize> = self
.witness_builders
.iter()
.enumerate()
.filter_map(|(builder_idx, builder)| {
if let WitnessBuilder::Acir(_, acir_idx) = builder {
Some((*acir_idx as u32, builder_idx))
} else {
None
}
})
.collect();
// Sanity check: all public inputs must have builders in w1 (O(P) lookups)
for &acir_idx in acir_public_inputs_indices_set {
// ACIR witness 0 is always the constant-one witness, handled
// separately via mandatory_w1.insert(0) above — not a regular ACIR witness.
if acir_idx == 0 {
continue;
}
match acir_to_builder.get(&acir_idx) {
Some(&builder_idx) if w1_indices_set.contains(&builder_idx) => {}
Some(&builder_idx) => {
return Err(SplitError::PublicInputNotInW1 {
acir_idx,
builder_idx,
})
}
None => return Err(SplitError::NoBuilderForPublicInput { acir_idx }),
}
}
// Separate into: 0, public inputs, and rest
for builder_idx in w1_indices {
if builder_idx == 0 {
continue; // Will add 0 first
} else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] {
if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
public_input_builder_indices.push(builder_idx);
continue;
}
}
rest_indices.push(builder_idx);
}
// Sort public input builders by ACIR index to guarantee the proof's
// public inputs appear in ABI parameter order. Without this, HashSet
// iteration order (random per process) would produce non-deterministic
// public input ordering across different `prepare` invocations.
public_input_builder_indices.sort_unstable_by_key(|&builder_idx| {
match &self.witness_builders[builder_idx] {
WitnessBuilder::Acir(_, acir_idx) => *acir_idx as u32,
_ => u32::MAX,
}
});
rest_indices.sort_unstable();
// Reorder: 0 first, then public inputs (in ACIR index order), then rest
let mut new_w1_indices = vec![0];
new_w1_indices.extend(public_input_builder_indices);
new_w1_indices.extend(rest_indices);
Ok(new_w1_indices)
}
}