1use crate::{
20 AccessPattern, Alloc, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType,
21 MemRef, Op, Param, ScalarType, Stmt, TargetArch, TripCount, Value, ValueId,
22};
23use bhc_index::Idx;
24use bhc_intern::Symbol;
25use bhc_tensor_ir::{
26 BufferId, Kernel, KernelBody, LoopNest as TensorLoopNest, ReduceOp as TensorReduceOp, TensorOp,
27 TensorRef,
28};
29use rustc_hash::FxHashMap;
30use thiserror::Error;
31
32#[derive(Clone, Debug, Error)]
34pub enum LowerError {
35 #[error("unsupported tensor operation: {op}")]
37 UnsupportedOp {
38 op: String,
40 },
41
42 #[error("shape mismatch: expected {expected:?}, got {got:?}")]
44 ShapeMismatch {
45 expected: Vec<usize>,
47 got: Vec<usize>,
49 },
50
51 #[error("invalid kernel structure: {reason}")]
53 InvalidKernel {
54 reason: String,
56 },
57}
58
59#[derive(Clone, Debug)]
61pub struct LowerConfig {
62 pub target: TargetArch,
64 pub enable_vectorization: bool,
66 pub enable_parallelization: bool,
68 pub vectorize_threshold: usize,
70 pub parallelize_threshold: usize,
72}
73
74impl Default for LowerConfig {
75 fn default() -> Self {
76 Self {
77 target: TargetArch::default(),
78 enable_vectorization: true,
79 enable_parallelization: true,
80 vectorize_threshold: 4,
81 parallelize_threshold: 1024,
82 }
83 }
84}
85
86struct LowerContext {
88 config: LowerConfig,
90 next_value: u32,
92 next_loop: u32,
94 tensor_values: FxHashMap<u64, ValueId>,
96 allocations: Vec<Alloc>,
98 loop_metadata: Vec<LoopMetadata>,
100 params: Vec<Param>,
102}
103
104impl LowerContext {
105 fn new(config: LowerConfig) -> Self {
106 Self {
107 config,
108 next_value: 0,
109 next_loop: 0,
110 tensor_values: FxHashMap::default(),
111 allocations: Vec::new(),
112 loop_metadata: Vec::new(),
113 params: Vec::new(),
114 }
115 }
116
117 fn fresh_value(&mut self) -> ValueId {
118 let id = ValueId::new(self.next_value as usize);
119 self.next_value += 1;
120 id
121 }
122
123 fn fresh_loop(&mut self) -> LoopId {
124 let id = LoopId::new(self.next_loop as usize);
125 self.next_loop += 1;
126 id
127 }
128}
129
130pub fn lower_kernels(kernels: &[Kernel], config: LowerConfig) -> Result<Vec<LoopIR>, LowerError> {
141 kernels
142 .iter()
143 .map(|k| lower_kernel(k, config.clone()))
144 .collect()
145}
146
147pub fn lower_kernel(kernel: &Kernel, config: LowerConfig) -> Result<LoopIR, LowerError> {
149 let mut ctx = LowerContext::new(config);
150
151 for (i, input) in kernel.inputs.iter().enumerate() {
153 let param = tensor_ref_to_param(input, i, &mut ctx);
154 ctx.params.push(param);
155 }
156
157 for (i, output) in kernel.outputs.iter().enumerate() {
159 let param = tensor_ref_to_param(output, kernel.inputs.len() + i, &mut ctx);
160 ctx.params.push(param);
161 }
162
163 let body = match &kernel.body {
165 KernelBody::Fused(ops) => lower_fused_ops(ops, kernel, &mut ctx)?,
166 KernelBody::LoopNest(nest) => lower_tensor_loop_nest(nest, &mut ctx)?,
167 };
168
169 Ok(LoopIR {
170 name: kernel.name,
171 params: ctx.params,
172 return_ty: LoopType::Void,
173 body,
174 allocs: ctx.allocations,
175 loop_info: ctx.loop_metadata,
176 })
177}
178
179fn tensor_ref_to_param(tensor: &TensorRef, index: usize, ctx: &mut LowerContext) -> Param {
181 let elem_ty = ScalarType::from_dtype(tensor.meta.dtype);
182 let value_id = ctx.fresh_value();
183
184 ctx.tensor_values.insert(tensor.id.index() as u64, value_id);
186
187 Param {
188 name: Symbol::intern(&format!("tensor_{}", index)),
189 ty: LoopType::Ptr(Box::new(LoopType::Scalar(elem_ty))),
190 is_ptr: true,
191 }
192}
193
194fn lower_fused_ops(
196 ops: &[TensorOp],
197 kernel: &Kernel,
198 ctx: &mut LowerContext,
199) -> Result<Body, LowerError> {
200 let output_shape: Vec<usize> = if let Some(output) = kernel.outputs.first() {
205 output
206 .meta
207 .shape
208 .dims()
209 .iter()
210 .map(|d| d.static_value().unwrap_or(0))
211 .collect()
212 } else {
213 return Err(LowerError::InvalidKernel {
214 reason: "kernel has no outputs".to_string(),
215 });
216 };
217
218 let (body, loop_vars) = generate_loop_nest(&output_shape, ctx)?;
220
221 let inner_stmts = lower_fused_ops_body(ops, &loop_vars, kernel, ctx)?;
223
224 let mut result_body = body;
226 insert_inner_stmts(&mut result_body, inner_stmts);
227
228 Ok(result_body)
229}
230
231fn generate_loop_nest(
234 shape: &[usize],
235 ctx: &mut LowerContext,
236) -> Result<(Body, Vec<ValueId>), LowerError> {
237 let mut loop_vars = Vec::with_capacity(shape.len());
238 let mut loops = Vec::with_capacity(shape.len());
239
240 for (dim_idx, &dim_size) in shape.iter().enumerate() {
241 let loop_id = ctx.fresh_loop();
242 let loop_var = ctx.fresh_value();
243 loop_vars.push(loop_var);
244
245 let mut attrs = LoopAttrs::INDEPENDENT;
247
248 if ctx.config.enable_parallelization
250 && dim_idx == 0
251 && dim_size >= ctx.config.parallelize_threshold
252 {
253 attrs |= LoopAttrs::PARALLEL;
254 }
255
256 if ctx.config.enable_vectorization
258 && dim_idx == shape.len() - 1
259 && dim_size >= ctx.config.vectorize_threshold
260 {
261 attrs |= LoopAttrs::VECTORIZE;
262 }
263
264 ctx.loop_metadata.push(LoopMetadata {
266 id: loop_id,
267 trip_count: TripCount::Static(dim_size),
268 vector_width: None, parallel_chunk: None, unroll_factor: None,
271 dependencies: Vec::new(),
272 });
273
274 loops.push(Loop {
275 id: loop_id,
276 var: loop_var,
277 lower: Value::i64(0),
278 upper: Value::i64(dim_size as i64),
279 step: Value::i64(1),
280 body: Body::new(),
281 attrs,
282 });
283 }
284
285 let mut body = Body::new();
287 if loops.is_empty() {
288 return Ok((body, loop_vars));
289 }
290
291 let mut current_loop = loops.pop().unwrap();
293 while let Some(mut outer) = loops.pop() {
294 outer.body.push(Stmt::Loop(current_loop));
295 current_loop = outer;
296 }
297
298 body.push(Stmt::Loop(current_loop));
299 Ok((body, loop_vars))
300}
301
302fn lower_fused_ops_body(
304 ops: &[TensorOp],
305 loop_vars: &[ValueId],
306 _kernel: &Kernel,
307 ctx: &mut LowerContext,
308) -> Result<Vec<Stmt>, LowerError> {
309 let mut stmts = Vec::new();
310
311 for op in ops {
312 lower_tensor_op(op, loop_vars, &mut stmts, ctx)?;
313 }
314
315 Ok(stmts)
316}
317
318fn lower_tensor_op(
320 op: &TensorOp,
321 loop_vars: &[ValueId],
322 stmts: &mut Vec<Stmt>,
323 ctx: &mut LowerContext,
324) -> Result<(), LowerError> {
325 match op {
326 TensorOp::Map(_func, input) => {
327 let input_val = load_tensor_element(input, loop_vars, stmts, ctx)?;
329
330 let result = ctx.fresh_value();
332 stmts.push(Stmt::Assign(result, Op::Unary(crate::UnOp::Neg, input_val)));
333
334 Ok(())
335 }
336
337 TensorOp::ZipWith(_func, a, b) => {
338 let a_val = load_tensor_element(a, loop_vars, stmts, ctx)?;
340 let b_val = load_tensor_element(b, loop_vars, stmts, ctx)?;
341
342 let result = ctx.fresh_value();
344 stmts.push(Stmt::Assign(result, Op::Binary(BinOp::Add, a_val, b_val)));
345
346 Ok(())
347 }
348
349 TensorOp::ReduceAll(reduce_op, input) => {
350 lower_reduction(reduce_op, input, loop_vars, stmts, ctx)
351 }
352
353 TensorOp::Broadcast(_shape, input) => {
354 let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
356 Ok(())
357 }
358
359 TensorOp::Reshape(_shape, input) => {
360 let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
362 Ok(())
363 }
364
365 TensorOp::Transpose(_perm, input) => {
366 let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
368 Ok(())
369 }
370
371 _ => Err(LowerError::UnsupportedOp {
372 op: format!("{:?}", std::mem::discriminant(op)),
373 }),
374 }
375}
376
377fn load_tensor_element(
379 tensor: &TensorRef,
380 loop_vars: &[ValueId],
381 stmts: &mut Vec<Stmt>,
382 ctx: &mut LowerContext,
383) -> Result<Value, LowerError> {
384 let elem_ty = ScalarType::from_dtype(tensor.meta.dtype);
385
386 let index = compute_linear_index(tensor, loop_vars)?;
388
389 let buffer_id = tensor
391 .meta
392 .alias
393 .unwrap_or(BufferId::new(tensor.id.index()));
394
395 let mem_ref = MemRef {
397 buffer: buffer_id,
398 index,
399 elem_ty: LoopType::Scalar(elem_ty),
400 access: compute_access_pattern(tensor),
401 };
402
403 let result = ctx.fresh_value();
405 stmts.push(Stmt::Assign(result, Op::Load(mem_ref)));
406
407 Ok(Value::Var(result, LoopType::Scalar(elem_ty)))
408}
409
410fn compute_linear_index(_tensor: &TensorRef, loop_vars: &[ValueId]) -> Result<Value, LowerError> {
412 if loop_vars.is_empty() {
416 return Ok(Value::i64(0));
417 }
418
419 let first_var = loop_vars[0];
422 Ok(Value::Var(first_var, LoopType::Scalar(ScalarType::I64)))
423}
424
425fn compute_access_pattern(tensor: &TensorRef) -> AccessPattern {
427 let strides = tensor.meta.strides.values();
428
429 if strides.last() == Some(&1) {
431 AccessPattern::Sequential
432 } else if let Some(&stride) = strides.last() {
433 AccessPattern::Strided(stride)
434 } else {
435 AccessPattern::Random
436 }
437}
438
439fn lower_reduction(
441 reduce_op: &TensorReduceOp,
442 input: &TensorRef,
443 loop_vars: &[ValueId],
444 stmts: &mut Vec<Stmt>,
445 ctx: &mut LowerContext,
446) -> Result<(), LowerError> {
447 let elem_ty = ScalarType::from_dtype(input.meta.dtype);
448 let bits = elem_ty.size_bytes() as u8 * 8;
449
450 let _init_val = match reduce_op {
452 TensorReduceOp::Sum => Value::float(0.0, bits),
453 TensorReduceOp::Prod => Value::float(1.0, bits),
454 TensorReduceOp::Min => Value::float(f64::INFINITY, bits),
455 TensorReduceOp::Max => Value::float(f64::NEG_INFINITY, bits),
456 _ => Value::float(0.0, bits),
457 };
458
459 stmts.push(Stmt::Comment(format!(
461 "reduction accumulator for {:?}",
462 reduce_op
463 )));
464
465 let acc = ctx.fresh_value();
467
468 let input_val = load_tensor_element(input, loop_vars, stmts, ctx)?;
470
471 let bin_op = match reduce_op {
473 TensorReduceOp::Sum => BinOp::Add,
474 TensorReduceOp::Prod => BinOp::Mul,
475 TensorReduceOp::Min => BinOp::FMin,
476 TensorReduceOp::Max => BinOp::FMax,
477 _ => BinOp::Add,
478 };
479
480 let new_acc = ctx.fresh_value();
481 stmts.push(Stmt::Assign(
482 new_acc,
483 Op::Binary(
484 bin_op,
485 Value::Var(acc, LoopType::Scalar(elem_ty)),
486 input_val,
487 ),
488 ));
489
490 Ok(())
491}
492
493fn lower_tensor_loop_nest(
495 nest: &TensorLoopNest,
496 ctx: &mut LowerContext,
497) -> Result<Body, LowerError> {
498 let mut loops = Vec::new();
501
502 for loop_spec in &nest.loops {
503 let loop_id = ctx.fresh_loop();
504 let loop_var = ctx.fresh_value();
505
506 let mut attrs = LoopAttrs::empty();
507 if loop_spec.parallel {
508 attrs |= LoopAttrs::PARALLEL;
509 }
510 if loop_spec.vectorize.is_some() {
511 attrs |= LoopAttrs::VECTORIZE;
512 }
513
514 let trip_count = loop_spec
515 .upper
516 .static_value()
517 .map(TripCount::Static)
518 .unwrap_or(TripCount::Dynamic);
519
520 let upper_bound = loop_spec.upper.static_value().unwrap_or(0) as i64;
521
522 ctx.loop_metadata.push(LoopMetadata {
523 id: loop_id,
524 trip_count,
525 vector_width: None,
526 parallel_chunk: None,
527 unroll_factor: None,
528 dependencies: Vec::new(),
529 });
530
531 loops.push(Loop {
532 id: loop_id,
533 var: loop_var,
534 lower: Value::i64(loop_spec.lower),
535 upper: Value::i64(upper_bound),
536 step: Value::i64(loop_spec.step),
537 body: Body::new(),
538 attrs,
539 });
540 }
541
542 let mut body = Body::new();
544 if loops.is_empty() {
545 return Ok(body);
546 }
547
548 let mut current_loop = loops.pop().unwrap();
549 while let Some(mut outer) = loops.pop() {
550 outer.body.push(Stmt::Loop(current_loop));
551 current_loop = outer;
552 }
553
554 body.push(Stmt::Loop(current_loop));
555 Ok(body)
556}
557
558fn insert_inner_stmts(body: &mut Body, stmts: Vec<Stmt>) {
560 fn find_innermost_and_insert(body: &mut Body, stmts: Vec<Stmt>) {
561 if let Some(Stmt::Loop(ref mut lp)) = body.stmts.last_mut() {
562 if lp.body.stmts.is_empty() || !matches!(lp.body.stmts.last(), Some(Stmt::Loop(_))) {
563 lp.body.stmts.extend(stmts);
565 } else {
566 find_innermost_and_insert(&mut lp.body, stmts);
568 }
569 } else {
570 body.stmts.extend(stmts);
572 }
573 }
574
575 find_innermost_and_insert(body, stmts);
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use bhc_span::Span;
582 use bhc_tensor_ir::{
583 DType, FusionInfo, KernelId, Layout, MapFn, Shape, Strides, TensorId, TensorMeta,
584 };
585
586 fn make_test_kernel() -> Kernel {
587 let meta = TensorMeta {
588 dtype: DType::Float32,
589 shape: Shape::from_static([1024]),
590 strides: Strides::new([1]),
591 layout: Layout::Contiguous,
592 alias: None,
593 };
594
595 let input = TensorRef {
596 id: TensorId::new(0),
597 meta: meta.clone(),
598 };
599
600 let output = TensorRef {
601 id: TensorId::new(1),
602 meta,
603 };
604
605 let map_fn = MapFn {
606 name: Symbol::intern("f"),
607 span: Span::DUMMY,
608 };
609
610 Kernel {
611 id: KernelId::new(0),
612 name: Symbol::intern("test_kernel"),
613 inputs: vec![input.clone()],
614 outputs: vec![output],
615 body: KernelBody::Fused(vec![TensorOp::Map(map_fn, input)]),
616 allocs: vec![],
617 fusion_info: FusionInfo {
618 original_ops: vec![],
619 decisions: vec![],
620 complete: true,
621 },
622 }
623 }
624
625 #[test]
626 fn test_lower_simple_kernel() {
627 let kernel = make_test_kernel();
628 let config = LowerConfig::default();
629
630 let result = lower_kernel(&kernel, config);
631 assert!(result.is_ok());
632
633 let loop_ir = result.unwrap();
634 assert_eq!(loop_ir.name.as_str(), "test_kernel");
635 assert_eq!(loop_ir.params.len(), 2); assert!(!loop_ir.body.stmts.is_empty());
637 }
638
639 #[test]
640 fn test_lower_generates_loop_nest() {
641 let kernel = make_test_kernel();
642 let config = LowerConfig::default();
643
644 let loop_ir = lower_kernel(&kernel, config).unwrap();
645
646 assert!(matches!(loop_ir.body.stmts.first(), Some(Stmt::Loop(_))));
648 }
649
650 #[test]
651 fn test_lower_marks_vectorizable() {
652 let kernel = make_test_kernel();
653 let config = LowerConfig {
654 enable_vectorization: true,
655 vectorize_threshold: 4,
656 ..Default::default()
657 };
658
659 let loop_ir = lower_kernel(&kernel, config).unwrap();
660
661 if let Some(Stmt::Loop(lp)) = loop_ir.body.stmts.first() {
663 assert!(lp.attrs.contains(LoopAttrs::VECTORIZE));
664 }
665 }
666
667 #[test]
668 fn test_sequential_access_pattern() {
669 let meta = TensorMeta {
670 dtype: DType::Float32,
671 shape: Shape::from_static([1024]),
672 strides: Strides::new([1]), layout: Layout::Contiguous,
674 alias: None,
675 };
676
677 let tensor = TensorRef {
678 id: TensorId::new(0),
679 meta,
680 };
681
682 let pattern = compute_access_pattern(&tensor);
683 assert_eq!(pattern, AccessPattern::Sequential);
684 }
685
686 #[test]
687 fn test_strided_access_pattern() {
688 let meta = TensorMeta {
689 dtype: DType::Float32,
690 shape: Shape::from_static([1024]),
691 strides: Strides::new([4]), layout: Layout::Strided,
693 alias: None,
694 };
695
696 let tensor = TensorRef {
697 id: TensorId::new(0),
698 meta,
699 };
700
701 let pattern = compute_access_pattern(&tensor);
702 assert_eq!(pattern, AccessPattern::Strided(4));
703 }
704}