#![allow(dead_code)]
use proc_macro2::TokenStream;
use quote::quote;
use crate::kernel_ir::{KernelParam, KernelType};
use super::LoweringContext;
pub fn lower_params(ctx: &mut LoweringContext, params: &[KernelParam]) -> syn::Result<TokenStream> {
let mut tokens = TokenStream::new();
for param in params {
tokens.extend(lower_one_param(ctx, param)?);
}
Ok(tokens)
}
fn lower_one_param(ctx: &mut LoweringContext, param: &KernelParam) -> syn::Result<TokenStream> {
match ¶m.ty {
KernelType::SliceRef(elem_ty) | KernelType::SliceMutRef(elem_ty) => {
lower_pointer_param(ctx, ¶m.name, elem_ty, ¶m.ty)
}
scalar_ty => lower_scalar_param(ctx, ¶m.name, scalar_ty),
}
}
fn lower_pointer_param(
ctx: &mut LoweringContext,
name: &str,
elem_ty: &KernelType,
full_ty: &KernelType,
) -> syn::Result<TokenStream> {
let ptx_elem_ty = ctx.ptx_type_tokens(elem_ty);
let param_name = format!("{name}_ptr");
let param_reg = ctx.fresh_reg();
let global_reg = ctx.fresh_reg();
ctx.locals
.insert(name.to_string(), (param_reg.clone(), full_ty.clone()));
ctx.global_addrs
.insert(name.to_string(), global_reg.clone());
Ok(quote! {
kernel.add_param(PtxParam::pointer(#param_name, PtxType::#ptx_elem_ty));
let #param_reg = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
dst: #param_reg,
param_name: #param_name.to_string(),
ty: PtxType::U64,
}));
let #global_reg = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::CvtaToGlobal {
dst: #global_reg,
src: #param_reg,
}));
})
}
fn lower_scalar_param(
ctx: &mut LoweringContext,
name: &str,
ty: &KernelType,
) -> syn::Result<TokenStream> {
let ptx_ty = ctx.ptx_type_tokens(ty);
let reg = ctx.fresh_reg();
ctx.locals
.insert(name.to_string(), (reg.clone(), ty.clone()));
Ok(quote! {
kernel.add_param(PtxParam::scalar(#name, PtxType::#ptx_ty));
let #reg = alloc.alloc(PtxType::#ptx_ty);
kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
dst: #reg,
param_name: #name.to_string(),
ty: PtxType::#ptx_ty,
}));
})
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
#[test]
fn lower_scalar_u32() {
let mut ctx = LoweringContext::new();
let param = KernelParam {
name: "n".to_string(),
ty: KernelType::U32,
span: Span::call_site(),
};
let tokens = lower_params(&mut ctx, &[param]).unwrap();
let code = tokens.to_string();
assert!(code.contains("PtxParam :: scalar"));
assert!(code.contains("\"n\""));
assert!(code.contains("PtxType :: U32"));
assert!(code.contains("MemoryOp :: LdParam"));
assert!(ctx.locals.contains_key("n"));
assert_eq!(ctx.locals["n"].1, KernelType::U32);
}
#[test]
fn lower_pointer_f32() {
let mut ctx = LoweringContext::new();
let param = KernelParam {
name: "a".to_string(),
ty: KernelType::SliceRef(Box::new(KernelType::F32)),
span: Span::call_site(),
};
let tokens = lower_params(&mut ctx, &[param]).unwrap();
let code = tokens.to_string();
assert!(code.contains("PtxParam :: pointer"));
assert!(code.contains("\"a_ptr\""));
assert!(code.contains("PtxType :: F32")); assert!(code.contains("PtxType :: U64")); assert!(code.contains("MemoryOp :: LdParam"));
assert!(ctx.locals.contains_key("a"));
assert!(ctx.locals["a"].1.is_slice());
}
#[test]
fn lower_vector_add_params() {
let mut ctx = LoweringContext::new();
let params = vec![
KernelParam {
name: "a".to_string(),
ty: KernelType::SliceRef(Box::new(KernelType::F32)),
span: Span::call_site(),
},
KernelParam {
name: "b".to_string(),
ty: KernelType::SliceRef(Box::new(KernelType::F32)),
span: Span::call_site(),
},
KernelParam {
name: "out".to_string(),
ty: KernelType::SliceMutRef(Box::new(KernelType::F32)),
span: Span::call_site(),
},
KernelParam {
name: "n".to_string(),
ty: KernelType::U32,
span: Span::call_site(),
},
];
let _tokens = lower_params(&mut ctx, ¶ms).unwrap();
assert_eq!(ctx.locals.len(), 4);
assert!(ctx.locals["a"].1.is_slice());
assert!(ctx.locals["out"].1.is_mut_slice());
assert_eq!(ctx.locals["n"].1, KernelType::U32);
}
}