1use crate::checker::CheckError;
7use crate::proof::{Proof, ProofNode, ProofNodeId};
8use rustc_hash::FxHashSet;
9use std::sync::Arc;
10
11pub type ParallelCheckResult<T> = Result<T, Vec<CheckError>>;
13
14#[derive(Debug, Clone)]
16pub struct ParallelConfig {
17 pub num_threads: Option<usize>,
19 pub batch_size: usize,
21 pub report_progress: bool,
23}
24
25impl Default for ParallelConfig {
26 fn default() -> Self {
27 Self {
28 num_threads: None,
29 batch_size: 100,
30 report_progress: false,
31 }
32 }
33}
34
35impl ParallelConfig {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn with_threads(mut self, threads: usize) -> Self {
43 self.num_threads = Some(threads);
44 self
45 }
46
47 pub fn with_batch_size(mut self, size: usize) -> Self {
49 self.batch_size = size;
50 self
51 }
52
53 pub fn with_progress(mut self, enabled: bool) -> Self {
55 self.report_progress = enabled;
56 self
57 }
58}
59
60pub struct ParallelProcessor {
62 config: ParallelConfig,
63}
64
65impl Default for ParallelProcessor {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl ParallelProcessor {
72 pub fn new() -> Self {
74 Self {
75 config: ParallelConfig::default(),
76 }
77 }
78
79 pub fn with_config(config: ParallelConfig) -> Self {
81 Self { config }
82 }
83
84 pub fn check_proof_parallel(&self, proof: &Proof) -> ParallelCheckResult<()> {
89 let proof_arc = Arc::new(proof);
91
92 let node_ids: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
94
95 let errors: Vec<CheckError> = node_ids
97 .chunks(self.config.batch_size)
98 .flat_map(|chunk| {
99 chunk
100 .iter()
101 .filter_map(|&node_id| {
102 let proof_ref = Arc::clone(&proof_arc);
103 self.check_node_validity(proof_ref.as_ref(), node_id).err()
104 })
105 .collect::<Vec<_>>()
106 })
107 .collect();
108
109 if errors.is_empty() {
110 Ok(())
111 } else {
112 Err(errors)
113 }
114 }
115
116 pub fn validate_dependencies_parallel(&self, proof: &Proof) -> ParallelCheckResult<()> {
118 let proof_arc = Arc::new(proof);
119 let node_ids: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
120
121 let errors: Vec<CheckError> = node_ids
122 .chunks(self.config.batch_size)
123 .flat_map(|chunk| {
124 chunk
125 .iter()
126 .filter_map(|&node_id| {
127 let proof_ref = Arc::clone(&proof_arc);
128 self.check_node_dependencies(proof_ref.as_ref(), node_id)
129 .err()
130 })
131 .collect::<Vec<_>>()
132 })
133 .collect();
134
135 if errors.is_empty() {
136 Ok(())
137 } else {
138 Err(errors)
139 }
140 }
141
142 pub fn find_nodes_parallel<F>(&self, proof: &Proof, predicate: F) -> Vec<ProofNodeId>
144 where
145 F: Fn(&ProofNode) -> bool + Send + Sync,
146 {
147 let predicate_arc = Arc::new(predicate);
148 let nodes: Vec<&ProofNode> = proof.nodes().iter().collect();
149
150 nodes
151 .chunks(self.config.batch_size)
152 .flat_map(|chunk| {
153 chunk
154 .iter()
155 .filter_map(|node| {
156 let pred = Arc::clone(&predicate_arc);
157 if pred(node) { Some(node.id) } else { None }
158 })
159 .collect::<Vec<_>>()
160 })
161 .collect()
162 }
163
164 fn check_node_validity(&self, proof: &Proof, node_id: ProofNodeId) -> Result<(), CheckError> {
166 if let Some(_node) = proof.get_node(node_id) {
169 Ok(())
171 } else {
172 Err(CheckError::Custom(format!("Node {} not found", node_id)))
173 }
174 }
175
176 fn check_node_dependencies(
178 &self,
179 proof: &Proof,
180 node_id: ProofNodeId,
181 ) -> Result<(), CheckError> {
182 if let Some(node) = proof.get_node(node_id) {
183 if let crate::proof::ProofStep::Inference { premises, .. } = &node.step {
185 for &premise_id in premises.iter() {
186 if proof.get_node(premise_id).is_none() {
187 return Err(CheckError::Custom(format!(
188 "Premise {} not found for node {}",
189 premise_id, node_id
190 )));
191 }
192 }
193 }
194 Ok(())
195 } else {
196 Err(CheckError::Custom(format!("Node {} not found", node_id)))
197 }
198 }
199}
200
201pub struct ParallelStatsComputer {
203 config: ParallelConfig,
204}
205
206impl Default for ParallelStatsComputer {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl ParallelStatsComputer {
213 pub fn new() -> Self {
215 Self {
216 config: ParallelConfig::default(),
217 }
218 }
219
220 pub fn with_config(config: ParallelConfig) -> Self {
222 Self { config }
223 }
224
225 pub fn compute_rule_frequency(&self, proof: &Proof) -> rustc_hash::FxHashMap<String, usize> {
227 let nodes: Vec<&ProofNode> = proof.nodes().iter().collect();
228 let mut frequency = rustc_hash::FxHashMap::default();
229
230 for chunk in nodes.chunks(self.config.batch_size) {
231 for node in chunk {
232 if let crate::proof::ProofStep::Inference { rule, .. } = &node.step {
233 *frequency.entry(rule.clone()).or_insert(0) += 1;
234 }
235 }
236 }
237
238 frequency
239 }
240
241 pub fn find_unique_conclusions(&self, proof: &Proof) -> FxHashSet<String> {
243 let nodes: Vec<&ProofNode> = proof.nodes().iter().collect();
244 let mut conclusions = FxHashSet::default();
245
246 for chunk in nodes.chunks(self.config.batch_size) {
247 for node in chunk {
248 conclusions.insert(node.conclusion().to_string());
249 }
250 }
251
252 conclusions
253 }
254
255 pub fn compute_depth_histogram(&self, proof: &Proof) -> rustc_hash::FxHashMap<usize, usize> {
257 let nodes: Vec<&ProofNode> = proof.nodes().iter().collect();
258 let mut histogram = rustc_hash::FxHashMap::default();
259
260 for chunk in nodes.chunks(self.config.batch_size) {
261 for node in chunk {
262 *histogram.entry(node.depth as usize).or_insert(0) += 1;
263 }
264 }
265
266 histogram
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_parallel_config_new() {
276 let config = ParallelConfig::new();
277 assert_eq!(config.num_threads, None);
278 assert_eq!(config.batch_size, 100);
279 assert!(!config.report_progress);
280 }
281
282 #[test]
283 fn test_parallel_config_with_settings() {
284 let config = ParallelConfig::new()
285 .with_threads(4)
286 .with_batch_size(50)
287 .with_progress(true);
288 assert_eq!(config.num_threads, Some(4));
289 assert_eq!(config.batch_size, 50);
290 assert!(config.report_progress);
291 }
292
293 #[test]
294 fn test_parallel_processor_new() {
295 let processor = ParallelProcessor::new();
296 assert_eq!(processor.config.batch_size, 100);
297 }
298
299 #[test]
300 fn test_check_proof_parallel_empty() {
301 let processor = ParallelProcessor::new();
302 let proof = Proof::new();
303 assert!(processor.check_proof_parallel(&proof).is_ok());
304 }
305
306 #[test]
307 fn test_validate_dependencies_parallel() {
308 let processor = ParallelProcessor::new();
309 let mut proof = Proof::new();
310 proof.add_axiom("x = x");
311 assert!(processor.validate_dependencies_parallel(&proof).is_ok());
312 }
313
314 #[test]
315 fn test_find_nodes_parallel() {
316 let processor = ParallelProcessor::new();
317 let mut proof = Proof::new();
318 let id1 = proof.add_axiom("x = x");
319 let _id2 = proof.add_axiom("y = y");
320
321 let results = processor.find_nodes_parallel(&proof, |n| n.id == id1);
322 assert_eq!(results.len(), 1);
323 assert_eq!(results[0], id1);
324 }
325
326 #[test]
327 fn test_parallel_stats_computer_new() {
328 let computer = ParallelStatsComputer::new();
329 assert_eq!(computer.config.batch_size, 100);
330 }
331
332 #[test]
333 fn test_compute_rule_frequency() {
334 let computer = ParallelStatsComputer::new();
335 let mut proof = Proof::new();
336 let ax1 = proof.add_axiom("x = x");
337 let ax2 = proof.add_axiom("y = y");
338 proof.add_inference("resolution", vec![ax1, ax2], "x = x or y = y");
339
340 let freq = computer.compute_rule_frequency(&proof);
341 assert!(freq.contains_key("resolution"));
342 assert_eq!(*freq.get("resolution").expect("key should exist in map"), 1);
343 }
344
345 #[test]
346 fn test_find_unique_conclusions() {
347 let computer = ParallelStatsComputer::new();
348 let mut proof = Proof::new();
349 proof.add_axiom("x = x");
350 proof.add_axiom("y = y");
351 proof.add_axiom("x = x"); let conclusions = computer.find_unique_conclusions(&proof);
354 assert_eq!(conclusions.len(), 2);
355 assert!(conclusions.contains("x = x"));
356 assert!(conclusions.contains("y = y"));
357 }
358
359 #[test]
360 fn test_compute_depth_histogram() {
361 let computer = ParallelStatsComputer::new();
362 let mut proof = Proof::new();
363 let ax1 = proof.add_axiom("x = x");
364 let ax2 = proof.add_axiom("y = y");
365 proof.add_inference("resolution", vec![ax1, ax2], "x = x or y = y");
366
367 let histogram = computer.compute_depth_histogram(&proof);
368 assert!(histogram.contains_key(&0)); assert!(histogram.contains_key(&1)); }
371}