Skip to main content

provekit_common/witness/scheduling/
splitter.rs

1use {
2    crate::witness::{scheduling::DependencyInfo, WitnessBuilder},
3    std::{
4        collections::{HashMap, HashSet, VecDeque},
5        fmt,
6    },
7};
8
9/// Error returned when witness splitting validation fails.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum SplitError {
12    /// No witness builder exists for a public input ACIR index.
13    NoBuilderForPublicInput { acir_idx: u32 },
14    /// A public input's builder was partitioned into w2 instead of w1.
15    PublicInputNotInW1 {
16        acir_idx:    u32,
17        builder_idx: usize,
18    },
19}
20
21impl fmt::Display for SplitError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            Self::NoBuilderForPublicInput { acir_idx } => {
25                write!(f, "no builder for public input ACIR index {acir_idx}")
26            }
27            Self::PublicInputNotInW1 {
28                acir_idx,
29                builder_idx,
30            } => {
31                write!(
32                    f,
33                    "public input ACIR index {acir_idx} (builder {builder_idx}) not in w1"
34                )
35            }
36        }
37    }
38}
39
40impl std::error::Error for SplitError {}
41
42/// Analyzes witness builder dependencies and splits them into w1/w2 groups.
43///
44/// Uses backward reachability from challenge consumers (lookup builders) to
45/// identify which builders must be committed before challenge extraction (w1),
46/// minimizing overhead. Balances witness counts between w1 and w2.
47pub struct WitnessSplitter<'a> {
48    witness_builders: &'a [WitnessBuilder],
49    deps:             DependencyInfo,
50}
51
52impl<'a> WitnessSplitter<'a> {
53    pub fn new(witness_builders: &'a [WitnessBuilder]) -> Self {
54        let deps = DependencyInfo::new(witness_builders);
55        Self {
56            witness_builders,
57            deps,
58        }
59    }
60
61    /// Identifies which builders should be in w1 (pre-challenge) vs w2
62    /// (post-challenge).
63    ///
64    /// Returns (w1_builder_indices, w2_builder_indices)
65    pub fn split_builders(
66        &self,
67        acir_public_inputs_indices_set: HashSet<u32>,
68    ) -> Result<(Vec<usize>, Vec<usize>), SplitError> {
69        let builder_count = self.witness_builders.len();
70
71        // Step 1: Find all Challenge builders
72        let challenge_builders: HashSet<usize> = self
73            .witness_builders
74            .iter()
75            .enumerate()
76            .filter_map(|(idx, builder)| {
77                matches!(builder, WitnessBuilder::Challenge(_)).then_some(idx)
78            })
79            .collect();
80
81        if challenge_builders.is_empty() {
82            let w1_indices = self.rearrange_w1(
83                (0..builder_count).collect(),
84                &acir_public_inputs_indices_set,
85            )?;
86            return Ok((w1_indices, Vec::new()));
87        }
88
89        // Step 2: Forward DFS from challenges to find mandatory_w2
90        // (all builders that transitively depend on challenge outputs)
91        // Also collect lookup builders (direct challenge consumers)
92        let mut mandatory_w2 = challenge_builders.clone();
93        let mut lookup_builders = HashSet::new();
94        let mut forward_visited = vec![false; builder_count];
95        let mut forward_stack = VecDeque::new();
96
97        for &challenge_idx in &challenge_builders {
98            forward_visited[challenge_idx] = true;
99            // Collect direct consumers as lookup builders
100            for &consumer_idx in &self.deps.adjacency_list[challenge_idx] {
101                lookup_builders.insert(consumer_idx);
102                if !forward_visited[consumer_idx] {
103                    forward_visited[consumer_idx] = true;
104                    mandatory_w2.insert(consumer_idx);
105                    forward_stack.push_back(consumer_idx);
106                }
107            }
108        }
109
110        // Continue DFS to find all transitive dependents
111        while let Some(current_idx) = forward_stack.pop_front() {
112            for &consumer_idx in &self.deps.adjacency_list[current_idx] {
113                if !forward_visited[consumer_idx] {
114                    forward_visited[consumer_idx] = true;
115                    mandatory_w2.insert(consumer_idx);
116                    forward_stack.push_back(consumer_idx);
117                }
118            }
119        }
120
121        // Step 4: Backward DFS from lookup builders to find mandatory_w1
122        // (exclude anything in mandatory_w2 to maintain disjoint sets)
123        let witness_producer = &self.deps.witness_producer;
124        let mut mandatory_w1 = HashSet::new();
125        let mut backward_visited = vec![false; builder_count];
126        let mut backward_stack = VecDeque::new();
127
128        for &lookup_idx in &lookup_builders {
129            backward_stack.push_back(lookup_idx);
130        }
131
132        while let Some(current_idx) = backward_stack.pop_front() {
133            if backward_visited[current_idx] {
134                continue;
135            }
136            backward_visited[current_idx] = true;
137
138            // Only add to w1 if not in mandatory_w2 (maintain disjoint)
139            if !mandatory_w2.contains(&current_idx)
140                && !challenge_builders.contains(&current_idx)
141                && !lookup_builders.contains(&current_idx)
142            {
143                mandatory_w1.insert(current_idx);
144            }
145
146            for &witness_idx in &self.deps.reads[current_idx] {
147                if let Some(&producer_idx) = witness_producer.get(&witness_idx) {
148                    if !backward_visited[producer_idx] && !mandatory_w2.contains(&producer_idx) {
149                        backward_stack.push_back(producer_idx);
150                    }
151                }
152            }
153        }
154
155        // witness_one (builder 0) must always be in w1 to preserve R1CS index 0
156        // invariant
157        mandatory_w1.insert(0);
158
159        // Step 5: Identify free builders (not in either mandatory set)
160        let mut free_builders = Vec::new();
161        for idx in 0..builder_count {
162            if !mandatory_w1.contains(&idx) && !mandatory_w2.contains(&idx) {
163                free_builders.push(idx);
164            }
165        }
166
167        // Step 6: Calculate witness counts for balancing
168        let mut w1_witness_count: usize = mandatory_w1
169            .iter()
170            .map(|&idx| DependencyInfo::extract_writes(&self.witness_builders[idx]).len())
171            .sum();
172
173        let mut w2_witness_count: usize = mandatory_w2
174            .iter()
175            .map(|&idx| DependencyInfo::extract_writes(&self.witness_builders[idx]).len())
176            .sum();
177
178        // Step 7: Assign free builders greedily while respecting dependencies
179        // Rule: if any dependency is in w2, the builder must also be in w2
180        // (because w1 is solved before w2)
181        // A free builder for public input witnesses goes in w1.
182        let mut w1_set = mandatory_w1;
183        let mut w2_set = mandatory_w2;
184
185        for idx in free_builders {
186            // Check if any dependency is in w2
187            let must_be_w2 = self.deps.reads[idx].iter().any(|&read_witness| {
188                self.deps
189                    .witness_producer
190                    .get(&read_witness)
191                    .map_or(false, |&producer| w2_set.contains(&producer))
192            });
193
194            let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len();
195
196            // If free builder writes a public witness, add it to w1_set.
197            if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] {
198                if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
199                    w1_set.insert(idx);
200                    w1_witness_count += witness_count;
201                    continue;
202                }
203            }
204
205            if must_be_w2 {
206                w2_set.insert(idx);
207                w2_witness_count += witness_count;
208            } else if w1_witness_count <= w2_witness_count {
209                w1_set.insert(idx);
210                w1_witness_count += witness_count;
211            } else {
212                w2_set.insert(idx);
213                w2_witness_count += witness_count;
214            }
215        }
216
217        // Step 8: Convert sets to sorted vectors
218        let mut w1_indices: Vec<usize> = w1_set.into_iter().collect();
219        let mut w2_indices: Vec<usize> = w2_set.into_iter().collect();
220
221        w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set)?;
222        w2_indices.sort_unstable();
223
224        Ok((w1_indices, w2_indices))
225    }
226
227    /// Rearranges w1 builder indices into a canonical order:
228    /// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant
229    /// 2. Public input builders next, grouped together
230    /// 3. All other w1 builders last, sorted by index
231    fn rearrange_w1(
232        &self,
233        w1_indices: Vec<usize>,
234        acir_public_inputs_indices_set: &HashSet<u32>,
235    ) -> Result<Vec<usize>, SplitError> {
236        let mut public_input_builder_indices = Vec::new();
237        let mut rest_indices = Vec::new();
238
239        let w1_indices_set = w1_indices.iter().copied().collect::<HashSet<_>>();
240
241        // Build ACIR index -> builder index map for O(1) lookups (O(B) once)
242        let acir_to_builder: HashMap<u32, usize> = self
243            .witness_builders
244            .iter()
245            .enumerate()
246            .filter_map(|(builder_idx, builder)| {
247                if let WitnessBuilder::Acir(_, acir_idx) = builder {
248                    Some((*acir_idx as u32, builder_idx))
249                } else {
250                    None
251                }
252            })
253            .collect();
254
255        // Sanity check: all public inputs must have builders in w1 (O(P) lookups)
256        for &acir_idx in acir_public_inputs_indices_set {
257            // ACIR witness 0 is always the constant-one witness, handled
258            // separately via mandatory_w1.insert(0) above — not a regular ACIR witness.
259            if acir_idx == 0 {
260                continue;
261            }
262            match acir_to_builder.get(&acir_idx) {
263                Some(&builder_idx) if w1_indices_set.contains(&builder_idx) => {}
264                Some(&builder_idx) => {
265                    return Err(SplitError::PublicInputNotInW1 {
266                        acir_idx,
267                        builder_idx,
268                    })
269                }
270                None => return Err(SplitError::NoBuilderForPublicInput { acir_idx }),
271            }
272        }
273
274        // Separate into: 0, public inputs, and rest
275        for builder_idx in w1_indices {
276            if builder_idx == 0 {
277                continue; // Will add 0 first
278            } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] {
279                if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
280                    public_input_builder_indices.push(builder_idx);
281                    continue;
282                }
283            }
284            rest_indices.push(builder_idx);
285        }
286
287        // Sort public input builders by ACIR index to guarantee the proof's
288        // public inputs appear in ABI parameter order. Without this, HashSet
289        // iteration order (random per process) would produce non-deterministic
290        // public input ordering across different `prepare` invocations.
291        public_input_builder_indices.sort_unstable_by_key(|&builder_idx| {
292            match &self.witness_builders[builder_idx] {
293                WitnessBuilder::Acir(_, acir_idx) => *acir_idx as u32,
294                _ => u32::MAX,
295            }
296        });
297        rest_indices.sort_unstable();
298
299        // Reorder: 0 first, then public inputs (in ACIR index order), then rest
300        let mut new_w1_indices = vec![0];
301        new_w1_indices.extend(public_input_builder_indices);
302        new_w1_indices.extend(rest_indices);
303        Ok(new_w1_indices)
304    }
305}