use proc_macro2::{Ident, TokenStream};
use quote::quote;
use crate::kernel_ir::KernelType;
use crate::kernel_ir::expr::{BinOpKind, KernelExpr};
use super::{LoweringContext, lower_expr};
pub fn lower_logical_expr(
ctx: &mut LoweringContext,
op: &BinOpKind,
lhs: &KernelExpr,
rhs: &KernelExpr,
) -> syn::Result<(Ident, TokenStream)> {
debug_assert!(
op.is_logical(),
"lower_logical_expr called with non-logical op: {op:?}"
);
let (lhs_reg, lhs_ty, lhs_tokens) = lower_expr(ctx, lhs)?;
ensure_bool(&lhs_ty, lhs, op)?;
let p_out = ctx.fresh_reg();
let done_label = ctx.fresh_label("LOGICAL_DONE");
let done_label_str = done_label.clone();
let negate_bra = matches!(op, BinOpKind::And);
let (rhs_reg, rhs_ty, rhs_tokens) = lower_expr(ctx, rhs)?;
ensure_bool(&rhs_ty, rhs, op)?;
let tokens = quote! {
#lhs_tokens
let #p_out = alloc.alloc(PtxType::Pred);
kernel.push(PtxInstruction::Mov {
dst: #p_out,
src: Operand::Reg(#lhs_reg),
ty: PtxType::Pred,
});
kernel.push(PtxInstruction::Control(ControlOp::BraPred {
pred: #lhs_reg,
target: #done_label_str.to_string(),
negate: #negate_bra,
}));
#rhs_tokens
kernel.push(PtxInstruction::Mov {
dst: #p_out,
src: Operand::Reg(#rhs_reg),
ty: PtxType::Pred,
});
kernel.push(PtxInstruction::Label(#done_label_str.to_string()));
};
Ok((p_out, tokens))
}
pub fn lower_logical_if(
ctx: &mut LoweringContext,
op: &BinOpKind,
lhs: &KernelExpr,
rhs: &KernelExpr,
skip_label: &str,
) -> syn::Result<TokenStream> {
debug_assert!(
op.is_logical(),
"lower_logical_if called with non-logical op: {op:?}"
);
let (lhs_reg, lhs_ty, lhs_tokens) = lower_expr(ctx, lhs)?;
ensure_bool(&lhs_ty, lhs, op)?;
let (rhs_reg, rhs_ty, rhs_tokens) = lower_expr(ctx, rhs)?;
ensure_bool(&rhs_ty, rhs, op)?;
let skip = skip_label.to_string();
let tokens = match op {
BinOpKind::And => {
quote! {
#lhs_tokens
kernel.push(PtxInstruction::Control(ControlOp::BraPred {
pred: #lhs_reg,
target: #skip.to_string(),
negate: true,
}));
#rhs_tokens
kernel.push(PtxInstruction::Control(ControlOp::BraPred {
pred: #rhs_reg,
target: #skip.to_string(),
negate: true,
}));
}
}
BinOpKind::Or => {
let take_label = ctx.fresh_label("LOGICAL_OR_TAKE");
quote! {
#lhs_tokens
kernel.push(PtxInstruction::Control(ControlOp::BraPred {
pred: #lhs_reg,
target: #take_label.to_string(),
negate: false,
}));
#rhs_tokens
kernel.push(PtxInstruction::Control(ControlOp::BraPred {
pred: #rhs_reg,
target: #skip.to_string(),
negate: true,
}));
kernel.push(PtxInstruction::Label(#take_label.to_string()));
}
}
_ => unreachable!("lower_logical_if guarded by is_logical()"),
};
Ok(tokens)
}
fn ensure_bool(ty: &KernelType, expr: &KernelExpr, op: &BinOpKind) -> syn::Result<()> {
if *ty != KernelType::Bool {
return Err(syn::Error::new(
expr_span(expr),
format!(
"logical operator {} requires bool operands, got {}",
op_display(op),
ty.display_name()
),
));
}
Ok(())
}
fn op_display(op: &BinOpKind) -> &'static str {
match op {
BinOpKind::And => "&&",
BinOpKind::Or => "||",
_ => "<?>",
}
}
fn expr_span(expr: &KernelExpr) -> proc_macro2::Span {
match expr {
KernelExpr::BinOp { span, .. }
| KernelExpr::UnaryOp { span, .. }
| KernelExpr::Index { span, .. }
| KernelExpr::BuiltinCall { span, .. }
| KernelExpr::Cast { span, .. }
| KernelExpr::LitInt(_, _, span)
| KernelExpr::LitFloat(_, _, span)
| KernelExpr::LitBool(_, span)
| KernelExpr::Var(_, span) => *span,
KernelExpr::Paren(_, span) => *span,
}
}