use proc_macro2::{Ident, TokenStream};
use quote::quote;
use crate::kernel_ir::KernelType;
use crate::kernel_ir::expr::BinOpKind;
use super::LoweringContext;
#[allow(dead_code)] pub fn lower_binop(
ctx: &mut LoweringContext,
op: &BinOpKind,
lhs_reg: &Ident,
rhs_reg: &Ident,
ty: &KernelType,
) -> (Ident, TokenStream) {
let dst = ctx.fresh_reg();
let ptx_ty = ctx.ptx_type_tokens(ty);
let arith_variant = match op {
BinOpKind::Add => quote! {
ArithOp::Add {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::Sub => quote! {
ArithOp::Sub {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::Mul => quote! {
ArithOp::Mul {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::Div => quote! {
ArithOp::Div {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::Rem => quote! {
ArithOp::Rem {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
_ => panic!("lower_binop called with non-arithmetic op: {op:?}"),
};
let tokens = quote! {
let #dst = alloc.alloc(PtxType::#ptx_ty);
kernel.push(PtxInstruction::Arith(#arith_variant));
};
(dst, tokens)
}
#[allow(dead_code)] pub fn lower_neg(
ctx: &mut LoweringContext,
src_reg: &Ident,
ty: &KernelType,
) -> (Ident, TokenStream) {
let dst = ctx.fresh_reg();
let ptx_ty = ctx.ptx_type_tokens(ty);
let tokens = quote! {
let #dst = alloc.alloc(PtxType::#ptx_ty);
kernel.push(PtxInstruction::Arith(ArithOp::Neg {
dst: #dst,
src: Operand::Reg(#src_reg),
ty: PtxType::#ptx_ty,
}));
};
(dst, tokens)
}
#[allow(dead_code)] pub fn lower_bitop(
ctx: &mut LoweringContext,
op: &BinOpKind,
lhs_reg: &Ident,
rhs_reg: &Ident,
ty: &KernelType,
) -> (Ident, TokenStream) {
let dst = ctx.fresh_reg();
let ptx_ty = ctx.ptx_type_tokens(ty);
let arith_variant = match op {
BinOpKind::BitAnd => quote! {
ArithOp::And {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::BitOr => quote! {
ArithOp::Or {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::BitXor => quote! {
ArithOp::Xor {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::Shl => quote! {
ArithOp::Shl {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
BinOpKind::Shr => quote! {
ArithOp::Shr {
dst: #dst,
lhs: Operand::Reg(#lhs_reg),
rhs: Operand::Reg(#rhs_reg),
ty: PtxType::#ptx_ty,
}
},
_ => panic!("lower_bitop called with non-bitwise op: {op:?}"),
};
let tokens = quote! {
let #dst = alloc.alloc(PtxType::#ptx_ty);
kernel.push(PtxInstruction::Arith(#arith_variant));
};
(dst, tokens)
}
#[allow(dead_code)] pub fn lower_not(
ctx: &mut LoweringContext,
src_reg: &Ident,
ty: &KernelType,
) -> (Ident, TokenStream) {
let dst = ctx.fresh_reg();
let ptx_ty = ctx.ptx_type_tokens(ty);
let tokens = quote! {
let #dst = alloc.alloc(PtxType::#ptx_ty);
kernel.push(PtxInstruction::Arith(ArithOp::Not {
dst: #dst,
src: Operand::Reg(#src_reg),
ty: PtxType::#ptx_ty,
}));
};
(dst, tokens)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lower_add_produces_arith_op() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (dst, tokens) = lower_binop(&mut ctx, &BinOpKind::Add, &lhs, &rhs, &KernelType::F32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Add"));
assert!(code.contains("Operand :: Reg"));
assert!(code.contains("PtxType :: F32"));
assert!(dst.to_string().starts_with("_kaio_r"));
}
#[test]
fn lower_mul_produces_mul_not_mulwide() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_binop(&mut ctx, &BinOpKind::Mul, &lhs, &rhs, &KernelType::U32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Mul"));
assert!(!code.contains("MulWide"));
}
#[test]
fn lower_neg_produces_neg_op() {
let mut ctx = LoweringContext::new();
let src = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_neg(&mut ctx, &src, &KernelType::F32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Neg"));
assert!(code.contains("PtxType :: F32"));
}
#[test]
fn lower_bitop_and_produces_and_variant() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (_dst, tokens) =
lower_bitop(&mut ctx, &BinOpKind::BitAnd, &lhs, &rhs, &KernelType::U32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: And"));
assert!(code.contains("PtxType :: U32"));
}
#[test]
fn lower_bitop_shr_preserves_signedness_u32() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_bitop(&mut ctx, &BinOpKind::Shr, &lhs, &rhs, &KernelType::U32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Shr"));
assert!(
code.contains("PtxType :: U32"),
"u32 >> n must carry U32 through to ArithOp, got: {code}"
);
}
#[test]
fn lower_bitop_shr_preserves_signedness_i32() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_bitop(&mut ctx, &BinOpKind::Shr, &lhs, &rhs, &KernelType::I32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Shr"));
assert!(
code.contains("PtxType :: S32"),
"i32 >> n must carry S32 through to ArithOp (arithmetic shift), got: {code}"
);
}
#[test]
fn lower_bitop_shl_typeless_on_signedness() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_bitop(&mut ctx, &BinOpKind::Shl, &lhs, &rhs, &KernelType::I32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Shl"));
}
#[test]
fn lower_not_emits_not_variant() {
let mut ctx = LoweringContext::new();
let src = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_not(&mut ctx, &src, &KernelType::U32);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Not"));
assert!(code.contains("PtxType :: U32"));
}
#[test]
fn lower_not_pred_for_bool() {
let mut ctx = LoweringContext::new();
let src = Ident::new("_kaio_p0", proc_macro2::Span::call_site());
let (_dst, tokens) = lower_not(&mut ctx, &src, &KernelType::Bool);
let code = tokens.to_string();
assert!(code.contains("ArithOp :: Not"));
assert!(
code.contains("PtxType :: Pred"),
"!bool must dispatch to Pred type, got: {code}"
);
}
#[test]
fn fresh_regs_are_unique() {
let mut ctx = LoweringContext::new();
let lhs = Ident::new("_kaio_r0", proc_macro2::Span::call_site());
let rhs = Ident::new("_kaio_r1", proc_macro2::Span::call_site());
let (r1, _) = lower_binop(&mut ctx, &BinOpKind::Add, &lhs, &rhs, &KernelType::F32);
let (r2, _) = lower_binop(&mut ctx, &BinOpKind::Sub, &lhs, &rhs, &KernelType::F32);
assert_ne!(r1.to_string(), r2.to_string());
}
}