provekit_common/witness/scheduling/
splitter.rs1use {
2 crate::witness::{scheduling::DependencyInfo, WitnessBuilder},
3 std::{
4 collections::{HashMap, HashSet, VecDeque},
5 fmt,
6 },
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum SplitError {
12 NoBuilderForPublicInput { acir_idx: u32 },
14 PublicInputNotInW1 {
16 acir_idx: u32,
17 builder_idx: usize,
18 },
19}
20
21impl fmt::Display for SplitError {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 Self::NoBuilderForPublicInput { acir_idx } => {
25 write!(f, "no builder for public input ACIR index {acir_idx}")
26 }
27 Self::PublicInputNotInW1 {
28 acir_idx,
29 builder_idx,
30 } => {
31 write!(
32 f,
33 "public input ACIR index {acir_idx} (builder {builder_idx}) not in w1"
34 )
35 }
36 }
37 }
38}
39
40impl std::error::Error for SplitError {}
41
42pub struct WitnessSplitter<'a> {
48 witness_builders: &'a [WitnessBuilder],
49 deps: DependencyInfo,
50}
51
52impl<'a> WitnessSplitter<'a> {
53 pub fn new(witness_builders: &'a [WitnessBuilder]) -> Self {
54 let deps = DependencyInfo::new(witness_builders);
55 Self {
56 witness_builders,
57 deps,
58 }
59 }
60
61 pub fn split_builders(
66 &self,
67 acir_public_inputs_indices_set: HashSet<u32>,
68 ) -> Result<(Vec<usize>, Vec<usize>), SplitError> {
69 let builder_count = self.witness_builders.len();
70
71 let challenge_builders: HashSet<usize> = self
73 .witness_builders
74 .iter()
75 .enumerate()
76 .filter_map(|(idx, builder)| {
77 matches!(builder, WitnessBuilder::Challenge(_)).then_some(idx)
78 })
79 .collect();
80
81 if challenge_builders.is_empty() {
82 let w1_indices = self.rearrange_w1(
83 (0..builder_count).collect(),
84 &acir_public_inputs_indices_set,
85 )?;
86 return Ok((w1_indices, Vec::new()));
87 }
88
89 let mut mandatory_w2 = challenge_builders.clone();
93 let mut lookup_builders = HashSet::new();
94 let mut forward_visited = vec![false; builder_count];
95 let mut forward_stack = VecDeque::new();
96
97 for &challenge_idx in &challenge_builders {
98 forward_visited[challenge_idx] = true;
99 for &consumer_idx in &self.deps.adjacency_list[challenge_idx] {
101 lookup_builders.insert(consumer_idx);
102 if !forward_visited[consumer_idx] {
103 forward_visited[consumer_idx] = true;
104 mandatory_w2.insert(consumer_idx);
105 forward_stack.push_back(consumer_idx);
106 }
107 }
108 }
109
110 while let Some(current_idx) = forward_stack.pop_front() {
112 for &consumer_idx in &self.deps.adjacency_list[current_idx] {
113 if !forward_visited[consumer_idx] {
114 forward_visited[consumer_idx] = true;
115 mandatory_w2.insert(consumer_idx);
116 forward_stack.push_back(consumer_idx);
117 }
118 }
119 }
120
121 let witness_producer = &self.deps.witness_producer;
124 let mut mandatory_w1 = HashSet::new();
125 let mut backward_visited = vec![false; builder_count];
126 let mut backward_stack = VecDeque::new();
127
128 for &lookup_idx in &lookup_builders {
129 backward_stack.push_back(lookup_idx);
130 }
131
132 while let Some(current_idx) = backward_stack.pop_front() {
133 if backward_visited[current_idx] {
134 continue;
135 }
136 backward_visited[current_idx] = true;
137
138 if !mandatory_w2.contains(¤t_idx)
140 && !challenge_builders.contains(¤t_idx)
141 && !lookup_builders.contains(¤t_idx)
142 {
143 mandatory_w1.insert(current_idx);
144 }
145
146 for &witness_idx in &self.deps.reads[current_idx] {
147 if let Some(&producer_idx) = witness_producer.get(&witness_idx) {
148 if !backward_visited[producer_idx] && !mandatory_w2.contains(&producer_idx) {
149 backward_stack.push_back(producer_idx);
150 }
151 }
152 }
153 }
154
155 mandatory_w1.insert(0);
158
159 let mut free_builders = Vec::new();
161 for idx in 0..builder_count {
162 if !mandatory_w1.contains(&idx) && !mandatory_w2.contains(&idx) {
163 free_builders.push(idx);
164 }
165 }
166
167 let mut w1_witness_count: usize = mandatory_w1
169 .iter()
170 .map(|&idx| DependencyInfo::extract_writes(&self.witness_builders[idx]).len())
171 .sum();
172
173 let mut w2_witness_count: usize = mandatory_w2
174 .iter()
175 .map(|&idx| DependencyInfo::extract_writes(&self.witness_builders[idx]).len())
176 .sum();
177
178 let mut w1_set = mandatory_w1;
183 let mut w2_set = mandatory_w2;
184
185 for idx in free_builders {
186 let must_be_w2 = self.deps.reads[idx].iter().any(|&read_witness| {
188 self.deps
189 .witness_producer
190 .get(&read_witness)
191 .map_or(false, |&producer| w2_set.contains(&producer))
192 });
193
194 let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len();
195
196 if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] {
198 if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
199 w1_set.insert(idx);
200 w1_witness_count += witness_count;
201 continue;
202 }
203 }
204
205 if must_be_w2 {
206 w2_set.insert(idx);
207 w2_witness_count += witness_count;
208 } else if w1_witness_count <= w2_witness_count {
209 w1_set.insert(idx);
210 w1_witness_count += witness_count;
211 } else {
212 w2_set.insert(idx);
213 w2_witness_count += witness_count;
214 }
215 }
216
217 let mut w1_indices: Vec<usize> = w1_set.into_iter().collect();
219 let mut w2_indices: Vec<usize> = w2_set.into_iter().collect();
220
221 w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set)?;
222 w2_indices.sort_unstable();
223
224 Ok((w1_indices, w2_indices))
225 }
226
227 fn rearrange_w1(
232 &self,
233 w1_indices: Vec<usize>,
234 acir_public_inputs_indices_set: &HashSet<u32>,
235 ) -> Result<Vec<usize>, SplitError> {
236 let mut public_input_builder_indices = Vec::new();
237 let mut rest_indices = Vec::new();
238
239 let w1_indices_set = w1_indices.iter().copied().collect::<HashSet<_>>();
240
241 let acir_to_builder: HashMap<u32, usize> = self
243 .witness_builders
244 .iter()
245 .enumerate()
246 .filter_map(|(builder_idx, builder)| {
247 if let WitnessBuilder::Acir(_, acir_idx) = builder {
248 Some((*acir_idx as u32, builder_idx))
249 } else {
250 None
251 }
252 })
253 .collect();
254
255 for &acir_idx in acir_public_inputs_indices_set {
257 if acir_idx == 0 {
260 continue;
261 }
262 match acir_to_builder.get(&acir_idx) {
263 Some(&builder_idx) if w1_indices_set.contains(&builder_idx) => {}
264 Some(&builder_idx) => {
265 return Err(SplitError::PublicInputNotInW1 {
266 acir_idx,
267 builder_idx,
268 })
269 }
270 None => return Err(SplitError::NoBuilderForPublicInput { acir_idx }),
271 }
272 }
273
274 for builder_idx in w1_indices {
276 if builder_idx == 0 {
277 continue; } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] {
279 if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
280 public_input_builder_indices.push(builder_idx);
281 continue;
282 }
283 }
284 rest_indices.push(builder_idx);
285 }
286
287 public_input_builder_indices.sort_unstable_by_key(|&builder_idx| {
292 match &self.witness_builders[builder_idx] {
293 WitnessBuilder::Acir(_, acir_idx) => *acir_idx as u32,
294 _ => u32::MAX,
295 }
296 });
297 rest_indices.sort_unstable();
298
299 let mut new_w1_indices = vec![0];
301 new_w1_indices.extend(public_input_builder_indices);
302 new_w1_indices.extend(rest_indices);
303 Ok(new_w1_indices)
304 }
305}