1use crate::{
19 BarrierKind, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, ReduceOp, Stmt, TripCount,
20};
21use rustc_hash::FxHashMap;
22use thiserror::Error;
23
24#[derive(Clone, Debug, Error)]
26pub enum ParallelError {
27 #[error("loop {loop_id:?} cannot be parallelized: {reason}")]
29 NotParallelizable {
30 loop_id: LoopId,
32 reason: String,
34 },
35
36 #[error("invalid chunk size {chunk_size} for trip count {trip_count}")]
38 InvalidChunkSize {
39 chunk_size: usize,
41 trip_count: usize,
43 },
44}
45
46#[derive(Clone, Debug)]
48pub struct ParallelConfig {
49 pub worker_count: usize,
51 pub min_iterations_per_worker: usize,
53 pub deterministic: bool,
55 pub chunk_size: usize,
57}
58
59impl Default for ParallelConfig {
60 fn default() -> Self {
61 Self {
62 worker_count: num_cpus(),
63 min_iterations_per_worker: 64,
64 deterministic: true, chunk_size: 0, }
67 }
68}
69
70fn num_cpus() -> usize {
72 8
74}
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq)]
78pub enum ParallelStrategy {
79 Static,
82 Dynamic,
85 Guided,
88}
89
90#[derive(Clone, Debug)]
92pub struct ParallelInfo {
93 pub parallelizable: bool,
95 pub reason: Option<String>,
97 pub chunk_size: usize,
99 pub num_chunks: usize,
101 pub strategy: ParallelStrategy,
103 pub is_reduction: bool,
105}
106
107impl Default for ParallelInfo {
108 fn default() -> Self {
109 Self {
110 parallelizable: false,
111 reason: Some("not analyzed".to_string()),
112 chunk_size: 0,
113 num_chunks: 0,
114 strategy: ParallelStrategy::Static,
115 is_reduction: false,
116 }
117 }
118}
119
120pub struct ParallelPass {
122 config: ParallelConfig,
123 analysis: FxHashMap<LoopId, ParallelInfo>,
125}
126
127impl ParallelPass {
128 pub fn new(config: ParallelConfig) -> Self {
130 Self {
131 config,
132 analysis: FxHashMap::default(),
133 }
134 }
135
136 pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, ParallelInfo> {
138 self.analysis.clear();
139
140 for stmt in &ir.body.stmts {
141 self.analyze_stmt(stmt, &ir.loop_info);
142 }
143
144 self.analysis.clone()
145 }
146
147 fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
149 match stmt {
150 Stmt::Loop(lp) => {
151 let info = self.analyze_loop(lp, loop_info);
152 self.analysis.insert(lp.id, info);
153
154 for inner_stmt in &lp.body.stmts {
156 self.analyze_stmt(inner_stmt, loop_info);
157 }
158 }
159 _ => {}
160 }
161 }
162
163 fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> ParallelInfo {
165 let mut info = ParallelInfo::default();
166
167 if !lp.attrs.contains(LoopAttrs::PARALLEL) {
169 info.reason = Some("loop not marked PARALLEL".to_string());
170 return info;
171 }
172
173 if !lp.attrs.contains(LoopAttrs::INDEPENDENT) {
175 info.reason = Some("loop has dependencies".to_string());
176 return info;
177 }
178
179 let metadata = loop_info.iter().find(|m| m.id == lp.id);
181 let trip_count = match metadata.map(|m| &m.trip_count) {
182 Some(TripCount::Static(n)) => *n,
183 Some(TripCount::Bounded(n)) => *n,
184 _ => {
185 info.reason = Some("dynamic trip count".to_string());
186 return info;
187 }
188 };
189
190 let min_total = self.config.worker_count * self.config.min_iterations_per_worker;
192 if trip_count < min_total {
193 info.reason = Some(format!(
194 "trip count {} below threshold {}",
195 trip_count, min_total
196 ));
197 return info;
198 }
199
200 let chunk_size = if self.config.chunk_size > 0 {
202 self.config.chunk_size
203 } else {
204 compute_chunk_size(trip_count, self.config.worker_count)
205 };
206
207 let is_reduction = lp.attrs.contains(LoopAttrs::REDUCTION);
209
210 info.parallelizable = true;
211 info.reason = None;
212 info.chunk_size = chunk_size;
213 info.num_chunks = (trip_count + chunk_size - 1) / chunk_size;
214 info.is_reduction = is_reduction;
215 info.strategy = if self.config.deterministic {
216 ParallelStrategy::Static
217 } else {
218 ParallelStrategy::Dynamic
219 };
220
221 info
222 }
223
224 pub fn parallelize(&self, ir: &mut LoopIR) -> Result<ParallelReport, ParallelError> {
226 let mut report = ParallelReport::default();
227
228 for stmt in &mut ir.body.stmts {
229 self.parallelize_stmt(stmt, &mut ir.loop_info, &mut report)?;
230 }
231
232 Ok(report)
233 }
234
235 fn parallelize_stmt(
237 &self,
238 stmt: &mut Stmt,
239 loop_info: &mut Vec<LoopMetadata>,
240 report: &mut ParallelReport,
241 ) -> Result<(), ParallelError> {
242 match stmt {
243 Stmt::Loop(lp) => {
244 if let Some(info) = self.analysis.get(&lp.id) {
245 if info.parallelizable {
246 self.parallelize_loop(lp, info, loop_info, report)?;
247 }
248 }
249 }
250 _ => {}
251 }
252 Ok(())
253 }
254
255 fn parallelize_loop(
257 &self,
258 lp: &mut Loop,
259 info: &ParallelInfo,
260 loop_info: &mut Vec<LoopMetadata>,
261 report: &mut ParallelReport,
262 ) -> Result<(), ParallelError> {
263 if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
265 meta.parallel_chunk = Some(info.chunk_size);
266 }
267
268 if info.is_reduction {
270 self.parallelize_reduction(lp, info)?;
271 }
272
273 report.parallelized_loops.push(ParallelizedLoopInfo {
275 loop_id: lp.id,
276 chunk_size: info.chunk_size,
277 num_chunks: info.num_chunks,
278 strategy: info.strategy,
279 is_reduction: info.is_reduction,
280 });
281
282 Ok(())
283 }
284
285 fn parallelize_reduction(
287 &self,
288 lp: &mut Loop,
289 _info: &ParallelInfo,
290 ) -> Result<(), ParallelError> {
291 lp.body.push(Stmt::Barrier(BarrierKind::ThreadGroup));
297
298 Ok(())
299 }
300}
301
302fn compute_chunk_size(trip_count: usize, worker_count: usize) -> usize {
304 (trip_count + worker_count - 1) / worker_count
307}
308
309#[derive(Clone, Debug, Default)]
311pub struct ParallelReport {
312 pub parallelized_loops: Vec<ParallelizedLoopInfo>,
314 pub failed_loops: Vec<(LoopId, String)>,
316}
317
318impl ParallelReport {
319 pub fn any_parallelized(&self) -> bool {
321 !self.parallelized_loops.is_empty()
322 }
323
324 pub fn count(&self) -> usize {
326 self.parallelized_loops.len()
327 }
328}
329
330impl std::fmt::Display for ParallelReport {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 writeln!(f, "Parallelization Report")?;
333 writeln!(f, "======================")?;
334 writeln!(f, "Parallelized loops: {}", self.parallelized_loops.len())?;
335
336 for info in &self.parallelized_loops {
337 writeln!(
338 f,
339 " Loop {:?}: chunks={}, chunk_size={}, strategy={:?}, reduction={}",
340 info.loop_id, info.num_chunks, info.chunk_size, info.strategy, info.is_reduction
341 )?;
342 }
343
344 if !self.failed_loops.is_empty() {
345 writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
346 for (id, reason) in &self.failed_loops {
347 writeln!(f, " Loop {:?}: {}", id, reason)?;
348 }
349 }
350
351 Ok(())
352 }
353}
354
355#[derive(Clone, Debug)]
357pub struct ParallelizedLoopInfo {
358 pub loop_id: LoopId,
360 pub chunk_size: usize,
362 pub num_chunks: usize,
364 pub strategy: ParallelStrategy,
366 pub is_reduction: bool,
368}
369
370#[derive(Clone, Copy, Debug, PartialEq, Eq)]
376pub struct Range {
377 pub start: i64,
379 pub end: i64,
381 pub step: i64,
383}
384
385impl Range {
386 pub fn new(start: i64, end: i64) -> Self {
388 Self {
389 start,
390 end,
391 step: 1,
392 }
393 }
394
395 pub fn with_step(start: i64, end: i64, step: i64) -> Self {
397 Self { start, end, step }
398 }
399
400 pub fn len(&self) -> usize {
402 if self.step > 0 {
403 ((self.end - self.start + self.step - 1) / self.step) as usize
404 } else if self.step < 0 {
405 ((self.start - self.end - self.step - 1) / (-self.step)) as usize
406 } else {
407 0
408 }
409 }
410
411 pub fn is_empty(&self) -> bool {
413 self.len() == 0
414 }
415
416 pub fn chunk(&self, num_chunks: usize) -> Vec<Range> {
418 if num_chunks == 0 || self.is_empty() {
419 return vec![];
420 }
421
422 let total = self.len();
423 let chunk_size = (total + num_chunks - 1) / num_chunks;
424
425 let mut chunks = Vec::with_capacity(num_chunks);
426 let mut current = self.start;
427
428 for i in 0..num_chunks {
429 let chunk_iters = if i == num_chunks - 1 {
430 total - (i * chunk_size)
431 } else {
432 chunk_size.min(total - i * chunk_size)
433 };
434
435 if chunk_iters == 0 {
436 break;
437 }
438
439 let chunk_end = current + (chunk_iters as i64) * self.step;
440 chunks.push(Range {
441 start: current,
442 end: chunk_end,
443 step: self.step,
444 });
445
446 current = chunk_end;
447 }
448
449 chunks
450 }
451}
452
453#[derive(Clone, Debug)]
461pub struct ParFor {
462 pub range: Range,
464 pub config: ParallelConfig,
466}
467
468impl ParFor {
469 pub fn new(range: Range) -> Self {
471 Self {
472 range,
473 config: ParallelConfig::default(),
474 }
475 }
476
477 pub fn with_config(mut self, config: ParallelConfig) -> Self {
479 self.config = config;
480 self
481 }
482
483 pub fn chunk_assignments(&self) -> Vec<Range> {
485 self.range.chunk(self.config.worker_count)
486 }
487}
488
489#[derive(Clone, Debug)]
496pub struct ParMap {
497 pub size: usize,
499 pub config: ParallelConfig,
501}
502
503impl ParMap {
504 pub fn new(size: usize) -> Self {
506 Self {
507 size,
508 config: ParallelConfig::default(),
509 }
510 }
511
512 pub fn chunk_assignments(&self) -> Vec<Range> {
514 let range = Range::new(0, self.size as i64);
515 range.chunk(self.config.worker_count)
516 }
517}
518
519#[derive(Clone, Debug)]
526pub struct ParReduce {
527 pub size: usize,
529 pub op: ReduceOp,
531 pub config: ParallelConfig,
533}
534
535impl ParReduce {
536 pub fn new(size: usize, op: ReduceOp) -> Self {
538 Self {
539 size,
540 op,
541 config: ParallelConfig::default(),
542 }
543 }
544
545 pub fn deterministic(mut self, det: bool) -> Self {
547 self.config.deterministic = det;
548 self
549 }
550
551 pub fn chunk_assignments(&self) -> Vec<Range> {
560 let range = Range::new(0, self.size as i64);
561 range.chunk(self.config.worker_count)
562 }
563
564 pub fn identity(&self) -> f64 {
566 match self.op {
567 ReduceOp::Add => 0.0,
568 ReduceOp::Mul => 1.0,
569 ReduceOp::Min => f64::INFINITY,
570 ReduceOp::Max => f64::NEG_INFINITY,
571 ReduceOp::And => 1.0, ReduceOp::Or => 0.0,
573 ReduceOp::Xor => 0.0,
574 }
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::{AccessPattern, BinOp, Body, LoopType, MemRef, Op, Param, Value, ValueId};
582 use bhc_index::Idx;
583 use bhc_intern::Symbol;
584 use bhc_tensor_ir::BufferId;
585
586 fn make_parallelizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
587 let loop_id = LoopId::new(0);
588 let loop_var = ValueId::new(0);
589
590 let mem_ref = MemRef {
591 buffer: BufferId::new(0),
592 index: Value::Var(loop_var, LoopType::Scalar(crate::ScalarType::I64)),
593 elem_ty: LoopType::Scalar(crate::ScalarType::F32),
594 access: AccessPattern::Sequential,
595 };
596
597 let mut body = Body::new();
598 let load_result = ValueId::new(1);
599 body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
600
601 let mul_result = ValueId::new(2);
602 body.push(Stmt::Assign(
603 mul_result,
604 Op::Binary(
605 BinOp::Mul,
606 Value::Var(load_result, LoopType::Scalar(crate::ScalarType::F32)),
607 Value::float(2.0, 32),
608 ),
609 ));
610
611 body.push(Stmt::Store(
612 mem_ref,
613 Value::Var(mul_result, LoopType::Scalar(crate::ScalarType::F32)),
614 ));
615
616 let lp = Loop {
617 id: loop_id,
618 var: loop_var,
619 lower: Value::i64(0),
620 upper: Value::i64(trip_count as i64),
621 step: Value::i64(1),
622 body,
623 attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
624 };
625
626 let mut outer_body = Body::new();
627 outer_body.push(Stmt::Loop(lp));
628
629 let ir = LoopIR {
630 name: Symbol::intern("test_kernel"),
631 params: vec![Param {
632 name: Symbol::intern("data"),
633 ty: LoopType::Ptr(Box::new(LoopType::Scalar(crate::ScalarType::F32))),
634 is_ptr: true,
635 }],
636 return_ty: LoopType::Void,
637 body: outer_body,
638 allocs: vec![],
639 loop_info: vec![LoopMetadata {
640 id: loop_id,
641 trip_count: TripCount::Static(trip_count),
642 vector_width: None,
643 parallel_chunk: None,
644 unroll_factor: None,
645 dependencies: Vec::new(),
646 }],
647 };
648
649 (ir, loop_id)
650 }
651
652 #[test]
653 fn test_parallel_analysis() {
654 let (ir, loop_id) = make_parallelizable_loop(10000);
655
656 let mut pass = ParallelPass::new(ParallelConfig::default());
657 let analysis = pass.analyze(&ir);
658
659 let info = analysis.get(&loop_id).expect("loop should be analyzed");
660 assert!(info.parallelizable, "loop should be parallelizable");
661 assert!(info.chunk_size > 0, "should have positive chunk size");
662 }
663
664 #[test]
665 fn test_parallel_below_threshold() {
666 let (ir, loop_id) = make_parallelizable_loop(100); let mut pass = ParallelPass::new(ParallelConfig::default());
669 let analysis = pass.analyze(&ir);
670
671 let info = analysis.get(&loop_id).expect("loop should be analyzed");
672 assert!(
673 !info.parallelizable,
674 "small loop should not be parallelizable"
675 );
676 }
677
678 #[test]
679 fn test_range_chunking() {
680 let range = Range::new(0, 1000);
681 let chunks = range.chunk(8);
682
683 assert_eq!(chunks.len(), 8);
684
685 let total_iters: usize = chunks.iter().map(|c| c.len()).sum();
687 assert_eq!(total_iters, 1000);
688
689 for i in 1..chunks.len() {
691 assert_eq!(chunks[i].start, chunks[i - 1].end);
692 }
693 }
694
695 #[test]
696 fn test_range_chunking_uneven() {
697 let range = Range::new(0, 103); let chunks = range.chunk(8);
699
700 let total_iters: usize = chunks.iter().map(|c| c.len()).sum();
701 assert_eq!(total_iters, 103);
702 }
703
704 #[test]
705 fn test_par_for_chunks() {
706 let par_for = ParFor::new(Range::new(0, 10000)).with_config(ParallelConfig {
707 worker_count: 8,
708 ..Default::default()
709 });
710
711 let chunks = par_for.chunk_assignments();
712 assert_eq!(chunks.len(), 8);
713
714 let sizes: Vec<_> = chunks.iter().map(|c| c.len()).collect();
716 let avg = sizes.iter().sum::<usize>() / sizes.len();
717 for size in sizes {
718 assert!((size as i64 - avg as i64).abs() <= 1);
719 }
720 }
721
722 #[test]
723 fn test_par_reduce_deterministic() {
724 let par_reduce = ParReduce::new(10000, ReduceOp::Add).deterministic(true);
725
726 assert!(par_reduce.config.deterministic);
727
728 let chunks1 = par_reduce.chunk_assignments();
730 let chunks2 = par_reduce.chunk_assignments();
731
732 for (c1, c2) in chunks1.iter().zip(chunks2.iter()) {
733 assert_eq!(c1.start, c2.start);
734 assert_eq!(c1.end, c2.end);
735 }
736 }
737
738 #[test]
739 fn test_par_reduce_identity() {
740 assert_eq!(ParReduce::new(100, ReduceOp::Add).identity(), 0.0);
741 assert_eq!(ParReduce::new(100, ReduceOp::Mul).identity(), 1.0);
742 assert_eq!(ParReduce::new(100, ReduceOp::Min).identity(), f64::INFINITY);
743 assert_eq!(
744 ParReduce::new(100, ReduceOp::Max).identity(),
745 f64::NEG_INFINITY
746 );
747 }
748
749 #[test]
750 fn test_parallel_report_display() {
751 let report = ParallelReport {
752 parallelized_loops: vec![ParallelizedLoopInfo {
753 loop_id: LoopId::new(0),
754 chunk_size: 1250,
755 num_chunks: 8,
756 strategy: ParallelStrategy::Static,
757 is_reduction: false,
758 }],
759 failed_loops: vec![],
760 };
761
762 let output = format!("{}", report);
763 assert!(output.contains("Parallelized loops: 1"));
764 assert!(output.contains("chunks=8"));
765 assert!(output.contains("Static"));
766 }
767
768 #[test]
769 fn test_deterministic_vs_dynamic_strategy() {
770 let mut config = ParallelConfig::default();
771
772 config.deterministic = true;
773 let (ir, loop_id) = make_parallelizable_loop(10000);
774 let mut pass_det = ParallelPass::new(config.clone());
775 let analysis = pass_det.analyze(&ir);
776 assert_eq!(analysis[&loop_id].strategy, ParallelStrategy::Static);
777
778 config.deterministic = false;
779 let mut pass_dyn = ParallelPass::new(config);
780 let analysis = pass_dyn.analyze(&ir);
781 assert_eq!(analysis[&loop_id].strategy, ParallelStrategy::Dynamic);
782 }
783}