1use crate::{
19 BarrierKind, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, ReduceOp, Stmt, TripCount,
20};
21use rustc_hash::FxHashMap;
22use thiserror::Error;
23
24#[derive(Clone, Debug, Error)]
26pub enum ParallelError {
27 #[error("loop {loop_id:?} cannot be parallelized: {reason}")]
29 NotParallelizable {
30 loop_id: LoopId,
32 reason: String,
34 },
35
36 #[error("invalid chunk size {chunk_size} for trip count {trip_count}")]
38 InvalidChunkSize {
39 chunk_size: usize,
41 trip_count: usize,
43 },
44}
45
46#[derive(Clone, Debug)]
48pub struct ParallelConfig {
49 pub worker_count: usize,
51 pub min_iterations_per_worker: usize,
53 pub deterministic: bool,
55 pub chunk_size: usize,
57}
58
59impl Default for ParallelConfig {
60 fn default() -> Self {
61 Self {
62 worker_count: num_cpus(),
63 min_iterations_per_worker: 64,
64 deterministic: true, chunk_size: 0, }
67 }
68}
69
70fn num_cpus() -> usize {
72 8
74}
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq)]
78pub enum ParallelStrategy {
79 Static,
82 Dynamic,
85 Guided,
88}
89
90#[derive(Clone, Debug)]
92pub struct ParallelInfo {
93 pub parallelizable: bool,
95 pub reason: Option<String>,
97 pub chunk_size: usize,
99 pub num_chunks: usize,
101 pub strategy: ParallelStrategy,
103 pub is_reduction: bool,
105}
106
107impl Default for ParallelInfo {
108 fn default() -> Self {
109 Self {
110 parallelizable: false,
111 reason: Some("not analyzed".to_string()),
112 chunk_size: 0,
113 num_chunks: 0,
114 strategy: ParallelStrategy::Static,
115 is_reduction: false,
116 }
117 }
118}
119
120pub struct ParallelPass {
122 config: ParallelConfig,
123 analysis: FxHashMap<LoopId, ParallelInfo>,
125}
126
127impl ParallelPass {
128 pub fn new(config: ParallelConfig) -> Self {
130 Self {
131 config,
132 analysis: FxHashMap::default(),
133 }
134 }
135
136 pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, ParallelInfo> {
138 self.analysis.clear();
139
140 for stmt in &ir.body.stmts {
141 self.analyze_stmt(stmt, &ir.loop_info);
142 }
143
144 self.analysis.clone()
145 }
146
147 fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
149 if let Stmt::Loop(lp) = stmt {
150 let info = self.analyze_loop(lp, loop_info);
151 self.analysis.insert(lp.id, info);
152
153 for inner_stmt in &lp.body.stmts {
155 self.analyze_stmt(inner_stmt, loop_info);
156 }
157 }
158 }
159
160 fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> ParallelInfo {
162 let mut info = ParallelInfo::default();
163
164 if !lp.attrs.contains(LoopAttrs::PARALLEL) {
166 info.reason = Some("loop not marked PARALLEL".to_string());
167 return info;
168 }
169
170 if !lp.attrs.contains(LoopAttrs::INDEPENDENT) {
172 info.reason = Some("loop has dependencies".to_string());
173 return info;
174 }
175
176 let metadata = loop_info.iter().find(|m| m.id == lp.id);
178 let trip_count = match metadata.map(|m| &m.trip_count) {
179 Some(TripCount::Static(n)) => *n,
180 Some(TripCount::Bounded(n)) => *n,
181 _ => {
182 info.reason = Some("dynamic trip count".to_string());
183 return info;
184 }
185 };
186
187 let min_total = self.config.worker_count * self.config.min_iterations_per_worker;
189 if trip_count < min_total {
190 info.reason = Some(format!(
191 "trip count {} below threshold {}",
192 trip_count, min_total
193 ));
194 return info;
195 }
196
197 let chunk_size = if self.config.chunk_size > 0 {
199 self.config.chunk_size
200 } else {
201 compute_chunk_size(trip_count, self.config.worker_count)
202 };
203
204 let is_reduction = lp.attrs.contains(LoopAttrs::REDUCTION);
206
207 info.parallelizable = true;
208 info.reason = None;
209 info.chunk_size = chunk_size;
210 info.num_chunks = trip_count.div_ceil(chunk_size);
211 info.is_reduction = is_reduction;
212 info.strategy = if self.config.deterministic {
213 ParallelStrategy::Static
214 } else {
215 ParallelStrategy::Dynamic
216 };
217
218 info
219 }
220
221 pub fn parallelize(&self, ir: &mut LoopIR) -> Result<ParallelReport, ParallelError> {
223 let mut report = ParallelReport::default();
224
225 for stmt in &mut ir.body.stmts {
226 self.parallelize_stmt(stmt, &mut ir.loop_info, &mut report)?;
227 }
228
229 Ok(report)
230 }
231
232 fn parallelize_stmt(
234 &self,
235 stmt: &mut Stmt,
236 loop_info: &mut [LoopMetadata],
237 report: &mut ParallelReport,
238 ) -> Result<(), ParallelError> {
239 if let Stmt::Loop(lp) = stmt {
240 if let Some(info) = self.analysis.get(&lp.id) {
241 if info.parallelizable {
242 self.parallelize_loop(lp, info, loop_info, report)?;
243 }
244 }
245 }
246 Ok(())
247 }
248
249 fn parallelize_loop(
251 &self,
252 lp: &mut Loop,
253 info: &ParallelInfo,
254 loop_info: &mut [LoopMetadata],
255 report: &mut ParallelReport,
256 ) -> Result<(), ParallelError> {
257 if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
259 meta.parallel_chunk = Some(info.chunk_size);
260 }
261
262 if info.is_reduction {
264 self.parallelize_reduction(lp, info)?;
265 }
266
267 report.parallelized_loops.push(ParallelizedLoopInfo {
269 loop_id: lp.id,
270 chunk_size: info.chunk_size,
271 num_chunks: info.num_chunks,
272 strategy: info.strategy,
273 is_reduction: info.is_reduction,
274 });
275
276 Ok(())
277 }
278
279 fn parallelize_reduction(
281 &self,
282 lp: &mut Loop,
283 _info: &ParallelInfo,
284 ) -> Result<(), ParallelError> {
285 lp.body.push(Stmt::Barrier(BarrierKind::ThreadGroup));
291
292 Ok(())
293 }
294}
295
296fn compute_chunk_size(trip_count: usize, worker_count: usize) -> usize {
298 trip_count.div_ceil(worker_count)
301}
302
303#[derive(Clone, Debug, Default)]
305pub struct ParallelReport {
306 pub parallelized_loops: Vec<ParallelizedLoopInfo>,
308 pub failed_loops: Vec<(LoopId, String)>,
310}
311
312impl ParallelReport {
313 pub fn any_parallelized(&self) -> bool {
315 !self.parallelized_loops.is_empty()
316 }
317
318 pub fn count(&self) -> usize {
320 self.parallelized_loops.len()
321 }
322}
323
324impl std::fmt::Display for ParallelReport {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 writeln!(f, "Parallelization Report")?;
327 writeln!(f, "======================")?;
328 writeln!(f, "Parallelized loops: {}", self.parallelized_loops.len())?;
329
330 for info in &self.parallelized_loops {
331 writeln!(
332 f,
333 " Loop {:?}: chunks={}, chunk_size={}, strategy={:?}, reduction={}",
334 info.loop_id, info.num_chunks, info.chunk_size, info.strategy, info.is_reduction
335 )?;
336 }
337
338 if !self.failed_loops.is_empty() {
339 writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
340 for (id, reason) in &self.failed_loops {
341 writeln!(f, " Loop {:?}: {}", id, reason)?;
342 }
343 }
344
345 Ok(())
346 }
347}
348
349#[derive(Clone, Debug)]
351pub struct ParallelizedLoopInfo {
352 pub loop_id: LoopId,
354 pub chunk_size: usize,
356 pub num_chunks: usize,
358 pub strategy: ParallelStrategy,
360 pub is_reduction: bool,
362}
363
364#[derive(Clone, Copy, Debug, PartialEq, Eq)]
370pub struct Range {
371 pub start: i64,
373 pub end: i64,
375 pub step: i64,
377}
378
379impl Range {
380 pub fn new(start: i64, end: i64) -> Self {
382 Self {
383 start,
384 end,
385 step: 1,
386 }
387 }
388
389 pub fn with_step(start: i64, end: i64, step: i64) -> Self {
391 Self { start, end, step }
392 }
393
394 pub fn len(&self) -> usize {
396 if self.step > 0 {
397 ((self.end - self.start + self.step - 1) / self.step) as usize
398 } else if self.step < 0 {
399 ((self.start - self.end - self.step - 1) / (-self.step)) as usize
400 } else {
401 0
402 }
403 }
404
405 pub fn is_empty(&self) -> bool {
407 self.len() == 0
408 }
409
410 pub fn chunk(&self, num_chunks: usize) -> Vec<Range> {
412 if num_chunks == 0 || self.is_empty() {
413 return vec![];
414 }
415
416 let total = self.len();
417 let chunk_size = total.div_ceil(num_chunks);
418
419 let mut chunks = Vec::with_capacity(num_chunks);
420 let mut current = self.start;
421
422 for i in 0..num_chunks {
423 let chunk_iters = if i == num_chunks - 1 {
424 total - (i * chunk_size)
425 } else {
426 chunk_size.min(total - i * chunk_size)
427 };
428
429 if chunk_iters == 0 {
430 break;
431 }
432
433 let chunk_end = current + (chunk_iters as i64) * self.step;
434 chunks.push(Range {
435 start: current,
436 end: chunk_end,
437 step: self.step,
438 });
439
440 current = chunk_end;
441 }
442
443 chunks
444 }
445}
446
447#[derive(Clone, Debug)]
455pub struct ParFor {
456 pub range: Range,
458 pub config: ParallelConfig,
460}
461
462impl ParFor {
463 pub fn new(range: Range) -> Self {
465 Self {
466 range,
467 config: ParallelConfig::default(),
468 }
469 }
470
471 pub fn with_config(mut self, config: ParallelConfig) -> Self {
473 self.config = config;
474 self
475 }
476
477 pub fn chunk_assignments(&self) -> Vec<Range> {
479 self.range.chunk(self.config.worker_count)
480 }
481}
482
483#[derive(Clone, Debug)]
490pub struct ParMap {
491 pub size: usize,
493 pub config: ParallelConfig,
495}
496
497impl ParMap {
498 pub fn new(size: usize) -> Self {
500 Self {
501 size,
502 config: ParallelConfig::default(),
503 }
504 }
505
506 pub fn chunk_assignments(&self) -> Vec<Range> {
508 let range = Range::new(0, self.size as i64);
509 range.chunk(self.config.worker_count)
510 }
511}
512
513#[derive(Clone, Debug)]
520pub struct ParReduce {
521 pub size: usize,
523 pub op: ReduceOp,
525 pub config: ParallelConfig,
527}
528
529impl ParReduce {
530 pub fn new(size: usize, op: ReduceOp) -> Self {
532 Self {
533 size,
534 op,
535 config: ParallelConfig::default(),
536 }
537 }
538
539 pub fn deterministic(mut self, det: bool) -> Self {
541 self.config.deterministic = det;
542 self
543 }
544
545 pub fn chunk_assignments(&self) -> Vec<Range> {
554 let range = Range::new(0, self.size as i64);
555 range.chunk(self.config.worker_count)
556 }
557
558 pub fn identity(&self) -> f64 {
560 match self.op {
561 ReduceOp::Add => 0.0,
562 ReduceOp::Mul => 1.0,
563 ReduceOp::Min => f64::INFINITY,
564 ReduceOp::Max => f64::NEG_INFINITY,
565 ReduceOp::And => 1.0, ReduceOp::Or => 0.0,
567 ReduceOp::Xor => 0.0,
568 }
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::{AccessPattern, BinOp, Body, LoopType, MemRef, Op, Param, Value, ValueId};
576 use bhc_index::Idx;
577 use bhc_intern::Symbol;
578 use bhc_tensor_ir::BufferId;
579
580 fn make_parallelizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
581 let loop_id = LoopId::new(0);
582 let loop_var = ValueId::new(0);
583
584 let mem_ref = MemRef {
585 buffer: BufferId::new(0),
586 index: Value::Var(loop_var, LoopType::Scalar(crate::ScalarType::I64)),
587 elem_ty: LoopType::Scalar(crate::ScalarType::F32),
588 access: AccessPattern::Sequential,
589 };
590
591 let mut body = Body::new();
592 let load_result = ValueId::new(1);
593 body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
594
595 let mul_result = ValueId::new(2);
596 body.push(Stmt::Assign(
597 mul_result,
598 Op::Binary(
599 BinOp::Mul,
600 Value::Var(load_result, LoopType::Scalar(crate::ScalarType::F32)),
601 Value::float(2.0, 32),
602 ),
603 ));
604
605 body.push(Stmt::Store(
606 mem_ref,
607 Value::Var(mul_result, LoopType::Scalar(crate::ScalarType::F32)),
608 ));
609
610 let lp = Loop {
611 id: loop_id,
612 var: loop_var,
613 lower: Value::i64(0),
614 upper: Value::i64(trip_count as i64),
615 step: Value::i64(1),
616 body,
617 attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
618 };
619
620 let mut outer_body = Body::new();
621 outer_body.push(Stmt::Loop(lp));
622
623 let ir = LoopIR {
624 name: Symbol::intern("test_kernel"),
625 params: vec![Param {
626 name: Symbol::intern("data"),
627 ty: LoopType::Ptr(Box::new(LoopType::Scalar(crate::ScalarType::F32))),
628 is_ptr: true,
629 }],
630 return_ty: LoopType::Void,
631 body: outer_body,
632 allocs: vec![],
633 loop_info: vec![LoopMetadata {
634 id: loop_id,
635 trip_count: TripCount::Static(trip_count),
636 vector_width: None,
637 parallel_chunk: None,
638 unroll_factor: None,
639 dependencies: Vec::new(),
640 }],
641 };
642
643 (ir, loop_id)
644 }
645
646 #[test]
647 fn test_parallel_analysis() {
648 let (ir, loop_id) = make_parallelizable_loop(10000);
649
650 let mut pass = ParallelPass::new(ParallelConfig::default());
651 let analysis = pass.analyze(&ir);
652
653 let info = analysis.get(&loop_id).expect("loop should be analyzed");
654 assert!(info.parallelizable, "loop should be parallelizable");
655 assert!(info.chunk_size > 0, "should have positive chunk size");
656 }
657
658 #[test]
659 fn test_parallel_below_threshold() {
660 let (ir, loop_id) = make_parallelizable_loop(100); let mut pass = ParallelPass::new(ParallelConfig::default());
663 let analysis = pass.analyze(&ir);
664
665 let info = analysis.get(&loop_id).expect("loop should be analyzed");
666 assert!(
667 !info.parallelizable,
668 "small loop should not be parallelizable"
669 );
670 }
671
672 #[test]
673 fn test_range_chunking() {
674 let range = Range::new(0, 1000);
675 let chunks = range.chunk(8);
676
677 assert_eq!(chunks.len(), 8);
678
679 let total_iters: usize = chunks.iter().map(|c| c.len()).sum();
681 assert_eq!(total_iters, 1000);
682
683 for i in 1..chunks.len() {
685 assert_eq!(chunks[i].start, chunks[i - 1].end);
686 }
687 }
688
689 #[test]
690 fn test_range_chunking_uneven() {
691 let range = Range::new(0, 103); let chunks = range.chunk(8);
693
694 let total_iters: usize = chunks.iter().map(|c| c.len()).sum();
695 assert_eq!(total_iters, 103);
696 }
697
698 #[test]
699 fn test_par_for_chunks() {
700 let par_for = ParFor::new(Range::new(0, 10000)).with_config(ParallelConfig {
701 worker_count: 8,
702 ..Default::default()
703 });
704
705 let chunks = par_for.chunk_assignments();
706 assert_eq!(chunks.len(), 8);
707
708 let sizes: Vec<_> = chunks.iter().map(|c| c.len()).collect();
710 let avg = sizes.iter().sum::<usize>() / sizes.len();
711 for size in sizes {
712 assert!((size as i64 - avg as i64).abs() <= 1);
713 }
714 }
715
716 #[test]
717 fn test_par_reduce_deterministic() {
718 let par_reduce = ParReduce::new(10000, ReduceOp::Add).deterministic(true);
719
720 assert!(par_reduce.config.deterministic);
721
722 let chunks1 = par_reduce.chunk_assignments();
724 let chunks2 = par_reduce.chunk_assignments();
725
726 for (c1, c2) in chunks1.iter().zip(chunks2.iter()) {
727 assert_eq!(c1.start, c2.start);
728 assert_eq!(c1.end, c2.end);
729 }
730 }
731
732 #[test]
733 fn test_par_reduce_identity() {
734 assert_eq!(ParReduce::new(100, ReduceOp::Add).identity(), 0.0);
735 assert_eq!(ParReduce::new(100, ReduceOp::Mul).identity(), 1.0);
736 assert_eq!(ParReduce::new(100, ReduceOp::Min).identity(), f64::INFINITY);
737 assert_eq!(
738 ParReduce::new(100, ReduceOp::Max).identity(),
739 f64::NEG_INFINITY
740 );
741 }
742
743 #[test]
744 fn test_parallel_report_display() {
745 let report = ParallelReport {
746 parallelized_loops: vec![ParallelizedLoopInfo {
747 loop_id: LoopId::new(0),
748 chunk_size: 1250,
749 num_chunks: 8,
750 strategy: ParallelStrategy::Static,
751 is_reduction: false,
752 }],
753 failed_loops: vec![],
754 };
755
756 let output = format!("{}", report);
757 assert!(output.contains("Parallelized loops: 1"));
758 assert!(output.contains("chunks=8"));
759 assert!(output.contains("Static"));
760 }
761
762 #[test]
763 fn test_deterministic_vs_dynamic_strategy() {
764 let mut config = ParallelConfig {
765 deterministic: true,
766 ..Default::default()
767 };
768 let (ir, loop_id) = make_parallelizable_loop(10000);
769 let mut pass_det = ParallelPass::new(config.clone());
770 let analysis = pass_det.analyze(&ir);
771 assert_eq!(analysis[&loop_id].strategy, ParallelStrategy::Static);
772
773 config.deterministic = false;
774 let mut pass_dyn = ParallelPass::new(config);
775 let analysis = pass_dyn.analyze(&ir);
776 assert_eq!(analysis[&loop_id].strategy, ParallelStrategy::Dynamic);
777 }
778}