Skip to main content

bhc_loop_ir/
lower.rs

1//! # Lowering from Tensor IR to Loop IR
2//!
3//! This module converts fused Tensor IR kernels into explicit Loop IR
4//! with iteration structure suitable for vectorization and code generation.
5//!
6//! ## Pipeline Position
7//!
8//! ```text
9//! Tensor IR (fused kernels) → [lower.rs] → Loop IR → [vectorize.rs] → Vectorized Loop IR
10//! ```
11//!
12//! ## Key Transformations
13//!
14//! 1. **Kernel to Function**: Each kernel becomes a Loop IR function
15//! 2. **Shape to Loops**: Tensor shapes become loop nests
16//! 3. **Operations to Statements**: Tensor ops become scalar statements
17//! 4. **Access Patterns**: Memory access patterns are computed from strides
18
19use crate::{
20    AccessPattern, Alloc, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType,
21    MemRef, Op, Param, ScalarType, Stmt, TargetArch, TripCount, Value, ValueId,
22};
23use bhc_index::Idx;
24use bhc_intern::Symbol;
25use bhc_tensor_ir::{
26    BufferId, Kernel, KernelBody, LoopNest as TensorLoopNest, ReduceOp as TensorReduceOp, TensorOp,
27    TensorRef,
28};
29use rustc_hash::FxHashMap;
30use thiserror::Error;
31
32/// Errors that can occur during lowering.
33#[derive(Clone, Debug, Error)]
34pub enum LowerError {
35    /// Unsupported tensor operation.
36    #[error("unsupported tensor operation: {op}")]
37    UnsupportedOp {
38        /// Description of the unsupported operation.
39        op: String,
40    },
41
42    /// Shape mismatch during lowering.
43    #[error("shape mismatch: expected {expected:?}, got {got:?}")]
44    ShapeMismatch {
45        /// Expected shape.
46        expected: Vec<usize>,
47        /// Actual shape.
48        got: Vec<usize>,
49    },
50
51    /// Invalid kernel structure.
52    #[error("invalid kernel structure: {reason}")]
53    InvalidKernel {
54        /// Reason for the error.
55        reason: String,
56    },
57}
58
59/// Configuration for lowering.
60#[derive(Clone, Debug)]
61pub struct LowerConfig {
62    /// Target architecture for vectorization hints.
63    pub target: TargetArch,
64    /// Whether to mark loops as potentially vectorizable.
65    pub enable_vectorization: bool,
66    /// Whether to mark loops as potentially parallelizable.
67    pub enable_parallelization: bool,
68    /// Minimum trip count for vectorization.
69    pub vectorize_threshold: usize,
70    /// Minimum trip count for parallelization.
71    pub parallelize_threshold: usize,
72}
73
74impl Default for LowerConfig {
75    fn default() -> Self {
76        Self {
77            target: TargetArch::default(),
78            enable_vectorization: true,
79            enable_parallelization: true,
80            vectorize_threshold: 4,
81            parallelize_threshold: 1024,
82        }
83    }
84}
85
86/// Context for lowering a single kernel.
87struct LowerContext {
88    /// Configuration.
89    config: LowerConfig,
90    /// Next value ID.
91    next_value: u32,
92    /// Next loop ID.
93    next_loop: u32,
94    /// Mapping from tensor refs to value IDs.
95    tensor_values: FxHashMap<u64, ValueId>,
96    /// Allocations for the lowered function.
97    allocations: Vec<Alloc>,
98    /// Loop metadata accumulated during lowering.
99    loop_metadata: Vec<LoopMetadata>,
100    /// Parameters for the lowered function.
101    params: Vec<Param>,
102}
103
104impl LowerContext {
105    fn new(config: LowerConfig) -> Self {
106        Self {
107            config,
108            next_value: 0,
109            next_loop: 0,
110            tensor_values: FxHashMap::default(),
111            allocations: Vec::new(),
112            loop_metadata: Vec::new(),
113            params: Vec::new(),
114        }
115    }
116
117    fn fresh_value(&mut self) -> ValueId {
118        let id = ValueId::new(self.next_value as usize);
119        self.next_value += 1;
120        id
121    }
122
123    fn fresh_loop(&mut self) -> LoopId {
124        let id = LoopId::new(self.next_loop as usize);
125        self.next_loop += 1;
126        id
127    }
128}
129
130/// Lower a collection of fused kernels to Loop IR.
131///
132/// # Arguments
133///
134/// * `kernels` - The fused kernels from Tensor IR
135/// * `config` - Lowering configuration
136///
137/// # Returns
138///
139/// A vector of lowered Loop IR functions.
140pub fn lower_kernels(kernels: &[Kernel], config: LowerConfig) -> Result<Vec<LoopIR>, LowerError> {
141    kernels
142        .iter()
143        .map(|k| lower_kernel(k, config.clone()))
144        .collect()
145}
146
147/// Lower a single kernel to Loop IR.
148pub fn lower_kernel(kernel: &Kernel, config: LowerConfig) -> Result<LoopIR, LowerError> {
149    let mut ctx = LowerContext::new(config);
150
151    // Add input tensors as parameters
152    for (i, input) in kernel.inputs.iter().enumerate() {
153        let param = tensor_ref_to_param(input, i, &mut ctx);
154        ctx.params.push(param);
155    }
156
157    // Add output tensors as parameters (mutable)
158    for (i, output) in kernel.outputs.iter().enumerate() {
159        let param = tensor_ref_to_param(output, kernel.inputs.len() + i, &mut ctx);
160        ctx.params.push(param);
161    }
162
163    // Lower the kernel body
164    let body = match &kernel.body {
165        KernelBody::Fused(ops) => lower_fused_ops(ops, kernel, &mut ctx)?,
166        KernelBody::LoopNest(nest) => lower_tensor_loop_nest(nest, &mut ctx)?,
167    };
168
169    Ok(LoopIR {
170        name: kernel.name,
171        params: ctx.params,
172        return_ty: LoopType::Void,
173        body,
174        allocs: ctx.allocations,
175        loop_info: ctx.loop_metadata,
176    })
177}
178
179/// Convert a tensor reference to a function parameter.
180fn tensor_ref_to_param(tensor: &TensorRef, index: usize, ctx: &mut LowerContext) -> Param {
181    let elem_ty = ScalarType::from_dtype(tensor.meta.dtype);
182    let value_id = ctx.fresh_value();
183
184    // Register the tensor -> value mapping
185    ctx.tensor_values.insert(tensor.id.index() as u64, value_id);
186
187    Param {
188        name: Symbol::intern(&format!("tensor_{}", index)),
189        ty: LoopType::Ptr(Box::new(LoopType::Scalar(elem_ty))),
190        is_ptr: true,
191    }
192}
193
194/// Lower fused tensor operations to a loop body.
195fn lower_fused_ops(
196    ops: &[TensorOp],
197    kernel: &Kernel,
198    ctx: &mut LowerContext,
199) -> Result<Body, LowerError> {
200    // For fused ops, we generate a single loop nest over the output shape
201    // The innermost loop body contains all the fused operations
202
203    // Get the output shape from the first output tensor
204    let output_shape: Vec<usize> = if let Some(output) = kernel.outputs.first() {
205        output
206            .meta
207            .shape
208            .dims()
209            .iter()
210            .map(|d| d.static_value().unwrap_or(0))
211            .collect()
212    } else {
213        return Err(LowerError::InvalidKernel {
214            reason: "kernel has no outputs".to_string(),
215        });
216    };
217
218    // Generate loop nest for the output shape
219    let (body, loop_vars) = generate_loop_nest(&output_shape, ctx)?;
220
221    // Generate the inner loop body with fused operations
222    let inner_stmts = lower_fused_ops_body(ops, &loop_vars, kernel, ctx)?;
223
224    // Insert the inner statements into the innermost loop
225    let mut result_body = body;
226    insert_inner_stmts(&mut result_body, inner_stmts);
227
228    Ok(result_body)
229}
230
231/// Generate a loop nest for the given shape.
232/// Returns the body with nested loops and the loop variables.
233fn generate_loop_nest(
234    shape: &[usize],
235    ctx: &mut LowerContext,
236) -> Result<(Body, Vec<ValueId>), LowerError> {
237    let mut loop_vars = Vec::with_capacity(shape.len());
238    let mut loops = Vec::with_capacity(shape.len());
239
240    for (dim_idx, &dim_size) in shape.iter().enumerate() {
241        let loop_id = ctx.fresh_loop();
242        let loop_var = ctx.fresh_value();
243        loop_vars.push(loop_var);
244
245        // Determine loop attributes based on config
246        let mut attrs = LoopAttrs::INDEPENDENT;
247
248        // Mark outer loops as potentially parallel
249        if ctx.config.enable_parallelization
250            && dim_idx == 0
251            && dim_size >= ctx.config.parallelize_threshold
252        {
253            attrs |= LoopAttrs::PARALLEL;
254        }
255
256        // Mark innermost loop as potentially vectorizable
257        if ctx.config.enable_vectorization
258            && dim_idx == shape.len() - 1
259            && dim_size >= ctx.config.vectorize_threshold
260        {
261            attrs |= LoopAttrs::VECTORIZE;
262        }
263
264        // Create loop metadata
265        ctx.loop_metadata.push(LoopMetadata {
266            id: loop_id,
267            trip_count: TripCount::Static(dim_size),
268            vector_width: None,   // Will be filled by vectorization pass
269            parallel_chunk: None, // Will be filled by parallelization pass
270            unroll_factor: None,
271            dependencies: Vec::new(),
272        });
273
274        loops.push(Loop {
275            id: loop_id,
276            var: loop_var,
277            lower: Value::i64(0),
278            upper: Value::i64(dim_size as i64),
279            step: Value::i64(1),
280            body: Body::new(),
281            attrs,
282        });
283    }
284
285    // Build nested structure: outermost loop contains next loop, etc.
286    let mut body = Body::new();
287    if loops.is_empty() {
288        return Ok((body, loop_vars));
289    }
290
291    // Nest loops from innermost to outermost
292    let mut current_loop = loops.pop().unwrap();
293    while let Some(mut outer) = loops.pop() {
294        outer.body.push(Stmt::Loop(current_loop));
295        current_loop = outer;
296    }
297
298    body.push(Stmt::Loop(current_loop));
299    Ok((body, loop_vars))
300}
301
302/// Lower fused operations to statements.
303fn lower_fused_ops_body(
304    ops: &[TensorOp],
305    loop_vars: &[ValueId],
306    _kernel: &Kernel,
307    ctx: &mut LowerContext,
308) -> Result<Vec<Stmt>, LowerError> {
309    let mut stmts = Vec::new();
310
311    for op in ops {
312        lower_tensor_op(op, loop_vars, &mut stmts, ctx)?;
313    }
314
315    Ok(stmts)
316}
317
318/// Lower a single tensor operation.
319fn lower_tensor_op(
320    op: &TensorOp,
321    loop_vars: &[ValueId],
322    stmts: &mut Vec<Stmt>,
323    ctx: &mut LowerContext,
324) -> Result<(), LowerError> {
325    match op {
326        TensorOp::Map(_func, input) => {
327            // Load input element
328            let input_val = load_tensor_element(input, loop_vars, stmts, ctx)?;
329
330            // Apply function (simplified: assume unary arithmetic)
331            let result = ctx.fresh_value();
332            stmts.push(Stmt::Assign(result, Op::Unary(crate::UnOp::Neg, input_val)));
333
334            Ok(())
335        }
336
337        TensorOp::ZipWith(_func, a, b) => {
338            // Load both input elements
339            let a_val = load_tensor_element(a, loop_vars, stmts, ctx)?;
340            let b_val = load_tensor_element(b, loop_vars, stmts, ctx)?;
341
342            // Apply binary function
343            let result = ctx.fresh_value();
344            stmts.push(Stmt::Assign(result, Op::Binary(BinOp::Add, a_val, b_val)));
345
346            Ok(())
347        }
348
349        TensorOp::ReduceAll(reduce_op, input) => {
350            lower_reduction(reduce_op, input, loop_vars, stmts, ctx)
351        }
352
353        TensorOp::Broadcast(_shape, input) => {
354            // Broadcast is handled by adjusting memory access patterns
355            let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
356            Ok(())
357        }
358
359        TensorOp::Reshape(_shape, input) => {
360            // Reshape is metadata-only for contiguous tensors
361            let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
362            Ok(())
363        }
364
365        TensorOp::Transpose(_perm, input) => {
366            // Transpose adjusts strides
367            let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
368            Ok(())
369        }
370
371        _ => Err(LowerError::UnsupportedOp {
372            op: format!("{:?}", std::mem::discriminant(op)),
373        }),
374    }
375}
376
377/// Load a tensor element at the current loop indices.
378fn load_tensor_element(
379    tensor: &TensorRef,
380    loop_vars: &[ValueId],
381    stmts: &mut Vec<Stmt>,
382    ctx: &mut LowerContext,
383) -> Result<Value, LowerError> {
384    let elem_ty = ScalarType::from_dtype(tensor.meta.dtype);
385
386    // Compute linear index from loop variables and strides
387    let index = compute_linear_index(tensor, loop_vars)?;
388
389    // Get buffer ID from tensor metadata
390    let buffer_id = tensor
391        .meta
392        .alias
393        .unwrap_or(BufferId::new(tensor.id.index()));
394
395    // Create memory reference
396    let mem_ref = MemRef {
397        buffer: buffer_id,
398        index,
399        elem_ty: LoopType::Scalar(elem_ty),
400        access: compute_access_pattern(tensor),
401    };
402
403    // Generate load
404    let result = ctx.fresh_value();
405    stmts.push(Stmt::Assign(result, Op::Load(mem_ref)));
406
407    Ok(Value::Var(result, LoopType::Scalar(elem_ty)))
408}
409
410/// Compute linear index from loop variables and tensor strides.
411fn compute_linear_index(_tensor: &TensorRef, loop_vars: &[ValueId]) -> Result<Value, LowerError> {
412    // For a tensor with shape [N, M, K] and strides [s0, s1, s2],
413    // the linear index is: i*s0 + j*s1 + k*s2
414
415    if loop_vars.is_empty() {
416        return Ok(Value::i64(0));
417    }
418
419    // For now, return a simple index using the first loop var
420    // In a full implementation, we'd build the affine index expression
421    let first_var = loop_vars[0];
422    Ok(Value::Var(first_var, LoopType::Scalar(ScalarType::I64)))
423}
424
425/// Compute the memory access pattern for a tensor.
426fn compute_access_pattern(tensor: &TensorRef) -> AccessPattern {
427    let strides = tensor.meta.strides.values();
428
429    // If innermost stride is 1, access is sequential
430    if strides.last() == Some(&1) {
431        AccessPattern::Sequential
432    } else if let Some(&stride) = strides.last() {
433        AccessPattern::Strided(stride)
434    } else {
435        AccessPattern::Random
436    }
437}
438
439/// Lower a reduction operation.
440fn lower_reduction(
441    reduce_op: &TensorReduceOp,
442    input: &TensorRef,
443    loop_vars: &[ValueId],
444    stmts: &mut Vec<Stmt>,
445    ctx: &mut LowerContext,
446) -> Result<(), LowerError> {
447    let elem_ty = ScalarType::from_dtype(input.meta.dtype);
448    let bits = elem_ty.size_bytes() as u8 * 8;
449
450    // Initialize accumulator (comment as placeholder)
451    let _init_val = match reduce_op {
452        TensorReduceOp::Sum => Value::float(0.0, bits),
453        TensorReduceOp::Prod => Value::float(1.0, bits),
454        TensorReduceOp::Min => Value::float(f64::INFINITY, bits),
455        TensorReduceOp::Max => Value::float(f64::NEG_INFINITY, bits),
456        _ => Value::float(0.0, bits),
457    };
458
459    // Add accumulator initialization (will be at function start)
460    stmts.push(Stmt::Comment(format!(
461        "reduction accumulator for {:?}",
462        reduce_op
463    )));
464
465    // Initialize acc value
466    let acc = ctx.fresh_value();
467
468    // Load input element
469    let input_val = load_tensor_element(input, loop_vars, stmts, ctx)?;
470
471    // Update accumulator
472    let bin_op = match reduce_op {
473        TensorReduceOp::Sum => BinOp::Add,
474        TensorReduceOp::Prod => BinOp::Mul,
475        TensorReduceOp::Min => BinOp::FMin,
476        TensorReduceOp::Max => BinOp::FMax,
477        _ => BinOp::Add,
478    };
479
480    let new_acc = ctx.fresh_value();
481    stmts.push(Stmt::Assign(
482        new_acc,
483        Op::Binary(
484            bin_op,
485            Value::Var(acc, LoopType::Scalar(elem_ty)),
486            input_val,
487        ),
488    ));
489
490    Ok(())
491}
492
493/// Lower a Tensor IR loop nest to Loop IR.
494fn lower_tensor_loop_nest(
495    nest: &TensorLoopNest,
496    ctx: &mut LowerContext,
497) -> Result<Body, LowerError> {
498    // TensorLoopNest already has explicit loop structure
499    // Convert it to Loop IR format
500    let mut loops = Vec::new();
501
502    for loop_spec in &nest.loops {
503        let loop_id = ctx.fresh_loop();
504        let loop_var = ctx.fresh_value();
505
506        let mut attrs = LoopAttrs::empty();
507        if loop_spec.parallel {
508            attrs |= LoopAttrs::PARALLEL;
509        }
510        if loop_spec.vectorize.is_some() {
511            attrs |= LoopAttrs::VECTORIZE;
512        }
513
514        let trip_count = loop_spec
515            .upper
516            .static_value()
517            .map(TripCount::Static)
518            .unwrap_or(TripCount::Dynamic);
519
520        let upper_bound = loop_spec.upper.static_value().unwrap_or(0) as i64;
521
522        ctx.loop_metadata.push(LoopMetadata {
523            id: loop_id,
524            trip_count,
525            vector_width: None,
526            parallel_chunk: None,
527            unroll_factor: None,
528            dependencies: Vec::new(),
529        });
530
531        loops.push(Loop {
532            id: loop_id,
533            var: loop_var,
534            lower: Value::i64(loop_spec.lower),
535            upper: Value::i64(upper_bound),
536            step: Value::i64(loop_spec.step),
537            body: Body::new(),
538            attrs,
539        });
540    }
541
542    // Build nested structure
543    let mut body = Body::new();
544    if loops.is_empty() {
545        return Ok(body);
546    }
547
548    let mut current_loop = loops.pop().unwrap();
549    while let Some(mut outer) = loops.pop() {
550        outer.body.push(Stmt::Loop(current_loop));
551        current_loop = outer;
552    }
553
554    body.push(Stmt::Loop(current_loop));
555    Ok(body)
556}
557
558/// Insert statements into the innermost loop body.
559fn insert_inner_stmts(body: &mut Body, stmts: Vec<Stmt>) {
560    fn find_innermost_and_insert(body: &mut Body, stmts: Vec<Stmt>) {
561        if let Some(Stmt::Loop(ref mut lp)) = body.stmts.last_mut() {
562            if lp.body.stmts.is_empty() || !matches!(lp.body.stmts.last(), Some(Stmt::Loop(_))) {
563                // This is the innermost loop
564                lp.body.stmts.extend(stmts);
565            } else {
566                // Recurse into nested loop
567                find_innermost_and_insert(&mut lp.body, stmts);
568            }
569        } else {
570            // No loops, just add statements directly
571            body.stmts.extend(stmts);
572        }
573    }
574
575    find_innermost_and_insert(body, stmts);
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use bhc_span::Span;
582    use bhc_tensor_ir::{
583        DType, FusionInfo, KernelId, Layout, MapFn, Shape, Strides, TensorId, TensorMeta,
584    };
585
586    fn make_test_kernel() -> Kernel {
587        let meta = TensorMeta {
588            dtype: DType::Float32,
589            shape: Shape::from_static([1024]),
590            strides: Strides::new([1]),
591            layout: Layout::Contiguous,
592            alias: None,
593        };
594
595        let input = TensorRef {
596            id: TensorId::new(0),
597            meta: meta.clone(),
598        };
599
600        let output = TensorRef {
601            id: TensorId::new(1),
602            meta,
603        };
604
605        let map_fn = MapFn {
606            name: Symbol::intern("f"),
607            span: Span::DUMMY,
608        };
609
610        Kernel {
611            id: KernelId::new(0),
612            name: Symbol::intern("test_kernel"),
613            inputs: vec![input.clone()],
614            outputs: vec![output],
615            body: KernelBody::Fused(vec![TensorOp::Map(map_fn, input)]),
616            allocs: vec![],
617            fusion_info: FusionInfo {
618                original_ops: vec![],
619                decisions: vec![],
620                complete: true,
621            },
622        }
623    }
624
625    #[test]
626    fn test_lower_simple_kernel() {
627        let kernel = make_test_kernel();
628        let config = LowerConfig::default();
629
630        let result = lower_kernel(&kernel, config);
631        assert!(result.is_ok());
632
633        let loop_ir = result.unwrap();
634        assert_eq!(loop_ir.name.as_str(), "test_kernel");
635        assert_eq!(loop_ir.params.len(), 2); // input + output
636        assert!(!loop_ir.body.stmts.is_empty());
637    }
638
639    #[test]
640    fn test_lower_generates_loop_nest() {
641        let kernel = make_test_kernel();
642        let config = LowerConfig::default();
643
644        let loop_ir = lower_kernel(&kernel, config).unwrap();
645
646        // Should have a loop in the body
647        assert!(matches!(loop_ir.body.stmts.first(), Some(Stmt::Loop(_))));
648    }
649
650    #[test]
651    fn test_lower_marks_vectorizable() {
652        let kernel = make_test_kernel();
653        let mut config = LowerConfig::default();
654        config.enable_vectorization = true;
655        config.vectorize_threshold = 4;
656
657        let loop_ir = lower_kernel(&kernel, config).unwrap();
658
659        // Find the loop and check attributes
660        if let Some(Stmt::Loop(lp)) = loop_ir.body.stmts.first() {
661            assert!(lp.attrs.contains(LoopAttrs::VECTORIZE));
662        }
663    }
664
665    #[test]
666    fn test_sequential_access_pattern() {
667        let meta = TensorMeta {
668            dtype: DType::Float32,
669            shape: Shape::from_static([1024]),
670            strides: Strides::new([1]), // Contiguous
671            layout: Layout::Contiguous,
672            alias: None,
673        };
674
675        let tensor = TensorRef {
676            id: TensorId::new(0),
677            meta,
678        };
679
680        let pattern = compute_access_pattern(&tensor);
681        assert_eq!(pattern, AccessPattern::Sequential);
682    }
683
684    #[test]
685    fn test_strided_access_pattern() {
686        let meta = TensorMeta {
687            dtype: DType::Float32,
688            shape: Shape::from_static([1024]),
689            strides: Strides::new([4]), // Non-contiguous
690            layout: Layout::Strided,
691            alias: None,
692        };
693
694        let tensor = TensorRef {
695            id: TensorId::new(0),
696            meta,
697        };
698
699        let pattern = compute_access_pattern(&tensor);
700        assert_eq!(pattern, AccessPattern::Strided(4));
701    }
702}