use crate::{
BarrierKind, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, ReduceOp, Stmt, TripCount,
};
use rustc_hash::FxHashMap;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
pub enum ParallelError {
#[error("loop {loop_id:?} cannot be parallelized: {reason}")]
NotParallelizable {
loop_id: LoopId,
reason: String,
},
#[error("invalid chunk size {chunk_size} for trip count {trip_count}")]
InvalidChunkSize {
chunk_size: usize,
trip_count: usize,
},
}
#[derive(Clone, Debug)]
pub struct ParallelConfig {
pub worker_count: usize,
pub min_iterations_per_worker: usize,
pub deterministic: bool,
pub chunk_size: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
worker_count: num_cpus(),
min_iterations_per_worker: 64,
deterministic: true, chunk_size: 0, }
}
}
fn num_cpus() -> usize {
8
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ParallelStrategy {
Static,
Dynamic,
Guided,
}
#[derive(Clone, Debug)]
pub struct ParallelInfo {
pub parallelizable: bool,
pub reason: Option<String>,
pub chunk_size: usize,
pub num_chunks: usize,
pub strategy: ParallelStrategy,
pub is_reduction: bool,
}
impl Default for ParallelInfo {
fn default() -> Self {
Self {
parallelizable: false,
reason: Some("not analyzed".to_string()),
chunk_size: 0,
num_chunks: 0,
strategy: ParallelStrategy::Static,
is_reduction: false,
}
}
}
pub struct ParallelPass {
config: ParallelConfig,
analysis: FxHashMap<LoopId, ParallelInfo>,
}
impl ParallelPass {
pub fn new(config: ParallelConfig) -> Self {
Self {
config,
analysis: FxHashMap::default(),
}
}
pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, ParallelInfo> {
self.analysis.clear();
for stmt in &ir.body.stmts {
self.analyze_stmt(stmt, &ir.loop_info);
}
self.analysis.clone()
}
fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
if let Stmt::Loop(lp) = stmt {
let info = self.analyze_loop(lp, loop_info);
self.analysis.insert(lp.id, info);
for inner_stmt in &lp.body.stmts {
self.analyze_stmt(inner_stmt, loop_info);
}
}
}
fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> ParallelInfo {
let mut info = ParallelInfo::default();
if !lp.attrs.contains(LoopAttrs::PARALLEL) {
info.reason = Some("loop not marked PARALLEL".to_string());
return info;
}
if !lp.attrs.contains(LoopAttrs::INDEPENDENT) {
info.reason = Some("loop has dependencies".to_string());
return info;
}
let metadata = loop_info.iter().find(|m| m.id == lp.id);
let trip_count = match metadata.map(|m| &m.trip_count) {
Some(TripCount::Static(n)) => *n,
Some(TripCount::Bounded(n)) => *n,
_ => {
info.reason = Some("dynamic trip count".to_string());
return info;
}
};
let min_total = self.config.worker_count * self.config.min_iterations_per_worker;
if trip_count < min_total {
info.reason = Some(format!(
"trip count {} below threshold {}",
trip_count, min_total
));
return info;
}
let chunk_size = if self.config.chunk_size > 0 {
self.config.chunk_size
} else {
compute_chunk_size(trip_count, self.config.worker_count)
};
let is_reduction = lp.attrs.contains(LoopAttrs::REDUCTION);
info.parallelizable = true;
info.reason = None;
info.chunk_size = chunk_size;
info.num_chunks = trip_count.div_ceil(chunk_size);
info.is_reduction = is_reduction;
info.strategy = if self.config.deterministic {
ParallelStrategy::Static
} else {
ParallelStrategy::Dynamic
};
info
}
pub fn parallelize(&self, ir: &mut LoopIR) -> Result<ParallelReport, ParallelError> {
let mut report = ParallelReport::default();
for stmt in &mut ir.body.stmts {
self.parallelize_stmt(stmt, &mut ir.loop_info, &mut report)?;
}
Ok(report)
}
fn parallelize_stmt(
&self,
stmt: &mut Stmt,
loop_info: &mut [LoopMetadata],
report: &mut ParallelReport,
) -> Result<(), ParallelError> {
if let Stmt::Loop(lp) = stmt {
if let Some(info) = self.analysis.get(&lp.id) {
if info.parallelizable {
self.parallelize_loop(lp, info, loop_info, report)?;
}
}
}
Ok(())
}
fn parallelize_loop(
&self,
lp: &mut Loop,
info: &ParallelInfo,
loop_info: &mut [LoopMetadata],
report: &mut ParallelReport,
) -> Result<(), ParallelError> {
if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
meta.parallel_chunk = Some(info.chunk_size);
}
if info.is_reduction {
self.parallelize_reduction(lp, info)?;
}
report.parallelized_loops.push(ParallelizedLoopInfo {
loop_id: lp.id,
chunk_size: info.chunk_size,
num_chunks: info.num_chunks,
strategy: info.strategy,
is_reduction: info.is_reduction,
});
Ok(())
}
fn parallelize_reduction(
&self,
lp: &mut Loop,
_info: &ParallelInfo,
) -> Result<(), ParallelError> {
lp.body.push(Stmt::Barrier(BarrierKind::ThreadGroup));
Ok(())
}
}
fn compute_chunk_size(trip_count: usize, worker_count: usize) -> usize {
trip_count.div_ceil(worker_count)
}
#[derive(Clone, Debug, Default)]
pub struct ParallelReport {
pub parallelized_loops: Vec<ParallelizedLoopInfo>,
pub failed_loops: Vec<(LoopId, String)>,
}
impl ParallelReport {
pub fn any_parallelized(&self) -> bool {
!self.parallelized_loops.is_empty()
}
pub fn count(&self) -> usize {
self.parallelized_loops.len()
}
}
impl std::fmt::Display for ParallelReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Parallelization Report")?;
writeln!(f, "======================")?;
writeln!(f, "Parallelized loops: {}", self.parallelized_loops.len())?;
for info in &self.parallelized_loops {
writeln!(
f,
" Loop {:?}: chunks={}, chunk_size={}, strategy={:?}, reduction={}",
info.loop_id, info.num_chunks, info.chunk_size, info.strategy, info.is_reduction
)?;
}
if !self.failed_loops.is_empty() {
writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
for (id, reason) in &self.failed_loops {
writeln!(f, " Loop {:?}: {}", id, reason)?;
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct ParallelizedLoopInfo {
pub loop_id: LoopId,
pub chunk_size: usize,
pub num_chunks: usize,
pub strategy: ParallelStrategy,
pub is_reduction: bool,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Range {
pub start: i64,
pub end: i64,
pub step: i64,
}
impl Range {
pub fn new(start: i64, end: i64) -> Self {
Self {
start,
end,
step: 1,
}
}
pub fn with_step(start: i64, end: i64, step: i64) -> Self {
Self { start, end, step }
}
pub fn len(&self) -> usize {
if self.step > 0 {
((self.end - self.start + self.step - 1) / self.step) as usize
} else if self.step < 0 {
((self.start - self.end - self.step - 1) / (-self.step)) as usize
} else {
0
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn chunk(&self, num_chunks: usize) -> Vec<Range> {
if num_chunks == 0 || self.is_empty() {
return vec![];
}
let total = self.len();
let chunk_size = total.div_ceil(num_chunks);
let mut chunks = Vec::with_capacity(num_chunks);
let mut current = self.start;
for i in 0..num_chunks {
let chunk_iters = if i == num_chunks - 1 {
total - (i * chunk_size)
} else {
chunk_size.min(total - i * chunk_size)
};
if chunk_iters == 0 {
break;
}
let chunk_end = current + (chunk_iters as i64) * self.step;
chunks.push(Range {
start: current,
end: chunk_end,
step: self.step,
});
current = chunk_end;
}
chunks
}
}
#[derive(Clone, Debug)]
pub struct ParFor {
pub range: Range,
pub config: ParallelConfig,
}
impl ParFor {
pub fn new(range: Range) -> Self {
Self {
range,
config: ParallelConfig::default(),
}
}
pub fn with_config(mut self, config: ParallelConfig) -> Self {
self.config = config;
self
}
pub fn chunk_assignments(&self) -> Vec<Range> {
self.range.chunk(self.config.worker_count)
}
}
#[derive(Clone, Debug)]
pub struct ParMap {
pub size: usize,
pub config: ParallelConfig,
}
impl ParMap {
pub fn new(size: usize) -> Self {
Self {
size,
config: ParallelConfig::default(),
}
}
pub fn chunk_assignments(&self) -> Vec<Range> {
let range = Range::new(0, self.size as i64);
range.chunk(self.config.worker_count)
}
}
#[derive(Clone, Debug)]
pub struct ParReduce {
pub size: usize,
pub op: ReduceOp,
pub config: ParallelConfig,
}
impl ParReduce {
pub fn new(size: usize, op: ReduceOp) -> Self {
Self {
size,
op,
config: ParallelConfig::default(),
}
}
pub fn deterministic(mut self, det: bool) -> Self {
self.config.deterministic = det;
self
}
pub fn chunk_assignments(&self) -> Vec<Range> {
let range = Range::new(0, self.size as i64);
range.chunk(self.config.worker_count)
}
pub fn identity(&self) -> f64 {
match self.op {
ReduceOp::Add => 0.0,
ReduceOp::Mul => 1.0,
ReduceOp::Min => f64::INFINITY,
ReduceOp::Max => f64::NEG_INFINITY,
ReduceOp::And => 1.0, ReduceOp::Or => 0.0,
ReduceOp::Xor => 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{AccessPattern, BinOp, Body, LoopType, MemRef, Op, Param, Value, ValueId};
use bhc_index::Idx;
use bhc_intern::Symbol;
use bhc_tensor_ir::BufferId;
fn make_parallelizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
let loop_id = LoopId::new(0);
let loop_var = ValueId::new(0);
let mem_ref = MemRef {
buffer: BufferId::new(0),
index: Value::Var(loop_var, LoopType::Scalar(crate::ScalarType::I64)),
elem_ty: LoopType::Scalar(crate::ScalarType::F32),
access: AccessPattern::Sequential,
};
let mut body = Body::new();
let load_result = ValueId::new(1);
body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
let mul_result = ValueId::new(2);
body.push(Stmt::Assign(
mul_result,
Op::Binary(
BinOp::Mul,
Value::Var(load_result, LoopType::Scalar(crate::ScalarType::F32)),
Value::float(2.0, 32),
),
));
body.push(Stmt::Store(
mem_ref,
Value::Var(mul_result, LoopType::Scalar(crate::ScalarType::F32)),
));
let lp = Loop {
id: loop_id,
var: loop_var,
lower: Value::i64(0),
upper: Value::i64(trip_count as i64),
step: Value::i64(1),
body,
attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
};
let mut outer_body = Body::new();
outer_body.push(Stmt::Loop(lp));
let ir = LoopIR {
name: Symbol::intern("test_kernel"),
params: vec![Param {
name: Symbol::intern("data"),
ty: LoopType::Ptr(Box::new(LoopType::Scalar(crate::ScalarType::F32))),
is_ptr: true,
}],
return_ty: LoopType::Void,
body: outer_body,
allocs: vec![],
loop_info: vec![LoopMetadata {
id: loop_id,
trip_count: TripCount::Static(trip_count),
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
}],
};
(ir, loop_id)
}
#[test]
fn test_parallel_analysis() {
let (ir, loop_id) = make_parallelizable_loop(10000);
let mut pass = ParallelPass::new(ParallelConfig::default());
let analysis = pass.analyze(&ir);
let info = analysis.get(&loop_id).expect("loop should be analyzed");
assert!(info.parallelizable, "loop should be parallelizable");
assert!(info.chunk_size > 0, "should have positive chunk size");
}
#[test]
fn test_parallel_below_threshold() {
let (ir, loop_id) = make_parallelizable_loop(100);
let mut pass = ParallelPass::new(ParallelConfig::default());
let analysis = pass.analyze(&ir);
let info = analysis.get(&loop_id).expect("loop should be analyzed");
assert!(
!info.parallelizable,
"small loop should not be parallelizable"
);
}
#[test]
fn test_range_chunking() {
let range = Range::new(0, 1000);
let chunks = range.chunk(8);
assert_eq!(chunks.len(), 8);
let total_iters: usize = chunks.iter().map(|c| c.len()).sum();
assert_eq!(total_iters, 1000);
for i in 1..chunks.len() {
assert_eq!(chunks[i].start, chunks[i - 1].end);
}
}
#[test]
fn test_range_chunking_uneven() {
let range = Range::new(0, 103); let chunks = range.chunk(8);
let total_iters: usize = chunks.iter().map(|c| c.len()).sum();
assert_eq!(total_iters, 103);
}
#[test]
fn test_par_for_chunks() {
let par_for = ParFor::new(Range::new(0, 10000)).with_config(ParallelConfig {
worker_count: 8,
..Default::default()
});
let chunks = par_for.chunk_assignments();
assert_eq!(chunks.len(), 8);
let sizes: Vec<_> = chunks.iter().map(|c| c.len()).collect();
let avg = sizes.iter().sum::<usize>() / sizes.len();
for size in sizes {
assert!((size as i64 - avg as i64).abs() <= 1);
}
}
#[test]
fn test_par_reduce_deterministic() {
let par_reduce = ParReduce::new(10000, ReduceOp::Add).deterministic(true);
assert!(par_reduce.config.deterministic);
let chunks1 = par_reduce.chunk_assignments();
let chunks2 = par_reduce.chunk_assignments();
for (c1, c2) in chunks1.iter().zip(chunks2.iter()) {
assert_eq!(c1.start, c2.start);
assert_eq!(c1.end, c2.end);
}
}
#[test]
fn test_par_reduce_identity() {
assert_eq!(ParReduce::new(100, ReduceOp::Add).identity(), 0.0);
assert_eq!(ParReduce::new(100, ReduceOp::Mul).identity(), 1.0);
assert_eq!(ParReduce::new(100, ReduceOp::Min).identity(), f64::INFINITY);
assert_eq!(
ParReduce::new(100, ReduceOp::Max).identity(),
f64::NEG_INFINITY
);
}
#[test]
fn test_parallel_report_display() {
let report = ParallelReport {
parallelized_loops: vec![ParallelizedLoopInfo {
loop_id: LoopId::new(0),
chunk_size: 1250,
num_chunks: 8,
strategy: ParallelStrategy::Static,
is_reduction: false,
}],
failed_loops: vec![],
};
let output = format!("{}", report);
assert!(output.contains("Parallelized loops: 1"));
assert!(output.contains("chunks=8"));
assert!(output.contains("Static"));
}
#[test]
fn test_deterministic_vs_dynamic_strategy() {
let mut config = ParallelConfig {
deterministic: true,
..Default::default()
};
let (ir, loop_id) = make_parallelizable_loop(10000);
let mut pass_det = ParallelPass::new(config.clone());
let analysis = pass_det.analyze(&ir);
assert_eq!(analysis[&loop_id].strategy, ParallelStrategy::Static);
config.deterministic = false;
let mut pass_dyn = ParallelPass::new(config);
let analysis = pass_dyn.analyze(&ir);
assert_eq!(analysis[&loop_id].strategy, ParallelStrategy::Dynamic);
}
}