use quote::quote;
use super::{BuiltinKind, OriginalParam, OriginalParamKind, is_slice_reference};
struct VariantBody {
extractions: Vec<proc_macro2::TokenStream>,
workgroup_inits: Vec<proc_macro2::TokenStream>,
has_workgroup: bool,
mut_ptr_setup: Vec<proc_macro2::TokenStream>,
wg_ptr_setup: Vec<proc_macro2::TokenStream>,
thread_restore: Vec<proc_macro2::TokenStream>,
threaded_call_args: Vec<proc_macro2::TokenStream>,
}
fn generate_variant_body(
active_params: &[&OriginalParam],
workgroup_size: [u32; 3],
) -> VariantBody {
let wg_x = workgroup_size[0];
let wg_y = workgroup_size[1];
let wg_z = workgroup_size[2];
let mut body = VariantBody {
extractions: vec![],
workgroup_inits: vec![],
has_workgroup: false,
mut_ptr_setup: vec![],
wg_ptr_setup: vec![],
thread_restore: vec![],
threaded_call_args: vec![],
};
for param in active_params {
let name = ¶m.name;
match ¶m.kind {
OriginalParamKind::Builtin(kind) => {
let arg = match kind {
BuiltinKind::GlobalInvocationId => {
quote! {
khal_std::glamx::UVec3::new(
__wg_x * #wg_x + __tx,
__wg_y * #wg_y + __ty,
__wg_z * #wg_z + __tz,
)
}
}
BuiltinKind::LocalInvocationId => {
quote! {
khal_std::glamx::UVec3::new(__tx, __ty, __tz)
}
}
BuiltinKind::WorkgroupId => {
quote! {
khal_std::glamx::UVec3::new(__wg_x, __wg_y, __wg_z)
}
}
BuiltinKind::NumWorkgroups => {
quote! {
khal_std::glamx::UVec3::new(__grid[0], __grid[1], __grid[2])
}
}
BuiltinKind::LocalInvocationIndex => {
quote! {
(__tz * #wg_x * #wg_y + __ty * #wg_x + __tx)
}
}
_ => {
quote! { Default::default() }
}
};
body.threaded_call_args.push(arg);
}
OriginalParamKind::Binding {
is_uniform: _,
is_mutable,
} => {
let is_slice = is_slice_reference(¶m.ty);
if *is_mutable {
let buf_name = syn::Ident::new(&format!("__cpu_{}_buf", name), name.span());
body.extractions.push(quote! {
let mut #buf_name = #name.as_gpu_slice_mut();
let #name = #buf_name.unwrap_slice();
});
body.threaded_call_args.push(if is_slice {
quote! { #name }
} else {
quote! { &mut #name[0] }
});
let ptr_name = syn::Ident::new(&format!("__{}_ptr", name), name.span());
let len_name = syn::Ident::new(&format!("__{}_len", name), name.span());
body.mut_ptr_setup.push(quote! {
let #ptr_name = #name.as_mut_ptr() as usize;
let #len_name = #name.len();
});
body.thread_restore.push(quote! {
let #name = unsafe { core::slice::from_raw_parts_mut(#ptr_name as *mut _, #len_name) };
});
} else {
let slice_name = syn::Ident::new(&format!("__cpu_{}_slice", name), name.span());
body.extractions.push(quote! {
let #slice_name = #name.as_gpu_slice();
let #name = #slice_name.unwrap_slice();
});
body.threaded_call_args.push(if is_slice {
quote! { #name }
} else {
quote! { &#name[0] }
});
}
}
OriginalParamKind::PushConstant => {
body.threaded_call_args.push(quote! { &#name });
}
OriginalParamKind::Workgroup => {
body.has_workgroup = true;
let inner_ty = if let syn::Type::Reference(ref_type) = ¶m.ty {
&*ref_type.elem
} else {
¶m.ty
};
let shared_name = syn::Ident::new(&format!("__cpu_{}_shared", name), name.span());
let ptr_name = syn::Ident::new(&format!("__{}_wg_ptr", name), name.span());
body.workgroup_inits.push(quote! {
let mut #shared_name: #inner_ty = unsafe { core::mem::zeroed() };
});
body.wg_ptr_setup.push(quote! {
let #ptr_name = core::ptr::addr_of_mut!(#shared_name) as usize;
});
body.threaded_call_args
.push(quote! { unsafe { &mut *(#ptr_name as *mut _) } });
}
}
}
body
}
fn gen_dispatch_loop(
body: &VariantBody,
workgroup_size: [u32; 3],
func_ident: &syn::Ident,
) -> proc_macro2::TokenStream {
let wg_x = workgroup_size[0];
let wg_y = workgroup_size[1];
let wg_z = workgroup_size[2];
let extractions = &body.extractions;
let workgroup_inits = &body.workgroup_inits;
let mut_ptr_setup = &body.mut_ptr_setup;
let thread_restore = &body.thread_restore;
let threaded_call_args = &body.threaded_call_args;
let func = func_ident;
if body.has_workgroup {
let wg_ptr_setup = &body.wg_ptr_setup;
quote! {
#(#extractions)*
#(#mut_ptr_setup)*
let __total_threads = (#wg_x as usize) * (#wg_y as usize) * (#wg_z as usize);
khal_std::arch::cpu::dispatch_workgroups(
(__grid[0] as usize) * (__grid[1] as usize) * (__grid[2] as usize),
|__wg_flat| {
let __wg_x = __wg_flat % __grid[0];
let __wg_y = (__wg_flat / __grid[0]) % __grid[1];
let __wg_z = __wg_flat / (__grid[0] * __grid[1]);
#(#workgroup_inits)*
#(#wg_ptr_setup)*
khal_std::arch::cpu::dispatch_workgroup_threads(__total_threads, |__flat| {
let __flat = __flat;
let __tx = __flat % #wg_x;
let __ty = (__flat / #wg_x) % #wg_y;
let __tz = __flat / (#wg_x * #wg_y);
#(#thread_restore)*
#func(#(#threaded_call_args),*);
});
},
);
}
} else {
quote! {
#(#extractions)*
#(#mut_ptr_setup)*
khal_std::arch::cpu::dispatch_workgroups(
(__grid[0] as usize) * (__grid[1] as usize) * (__grid[2] as usize),
|__wg_flat| {
let __wg_x = __wg_flat % __grid[0];
let __wg_y = (__wg_flat / __grid[0]) % __grid[1];
let __wg_z = __wg_flat / (__grid[0] * __grid[1]);
#(#thread_restore)*
for __tz in 0..#wg_z { for __ty in 0..#wg_y { for __tx in 0..#wg_x {
#func(#(#threaded_call_args),*);
}}}
},
);
}
}
}
pub(super) fn generate_cpu_dispatch_block(
original_params: &[OriginalParam],
workgroup_size: [u32; 3],
func_ident: &syn::Ident,
) -> proc_macro2::TokenStream {
let wg_x = workgroup_size[0];
let wg_y = workgroup_size[1];
let wg_z = workgroup_size[2];
let mut cfg_variant_sets: Vec<Vec<syn::Attribute>> = vec![];
for param in original_params {
if !param.cfg_attrs.is_empty() {
let already_exists = cfg_variant_sets.iter().any(|existing| {
existing.len() == param.cfg_attrs.len()
&& existing
.iter()
.zip(param.cfg_attrs.iter())
.all(|(a, b)| quote!(#a).to_string() == quote!(#b).to_string())
});
if !already_exists {
cfg_variant_sets.push(param.cfg_attrs.clone());
}
}
}
if cfg_variant_sets.is_empty() {
let all_params: Vec<&OriginalParam> = original_params.iter().collect();
let body = generate_variant_body(&all_params, workgroup_size);
let dispatch_loop = gen_dispatch_loop(&body, workgroup_size, func_ident);
quote! {
#[cfg(feature = "cpu")]
{
if __pass.is_cpu() {
let __wg_size = [#wg_x, #wg_y, #wg_z];
let __grid = __dispatch_grid.resolve_to_workgroup_counts(&__wg_size);
#dispatch_loop
return Ok(());
}
}
}
} else {
let variant_blocks: Vec<proc_macro2::TokenStream> = cfg_variant_sets
.iter()
.map(|cfg_attrs| {
let active_params: Vec<&OriginalParam> = original_params
.iter()
.filter(|p| {
p.cfg_attrs.is_empty() || {
p.cfg_attrs.len() == cfg_attrs.len()
&& p.cfg_attrs
.iter()
.zip(cfg_attrs.iter())
.all(|(a, b)| quote!(#a).to_string() == quote!(#b).to_string())
}
})
.collect();
let body = generate_variant_body(&active_params, workgroup_size);
let dispatch_loop = gen_dispatch_loop(&body, workgroup_size, func_ident);
quote! {
#(#cfg_attrs)*
{
#dispatch_loop
}
}
})
.collect();
quote! {
#[cfg(feature = "cpu")]
{
if __pass.is_cpu() {
let __wg_size = [#wg_x, #wg_y, #wg_z];
let __grid = __dispatch_grid.resolve_to_workgroup_counts(&__wg_size);
#(#variant_blocks)*
return Ok(());
}
}
}
}
}