krnl_macros/
lib.rs

1//! Macros for [krnl](https://docs.rs/krnl).
2#![forbid(unsafe_code)]
3
4use derive_syn_parse::Parse;
5use fxhash::FxHashMap;
6use proc_macro::TokenStream;
7use proc_macro2::{Literal, Span as Span2, TokenStream as TokenStream2};
8use quote::{format_ident, quote, ToTokens};
9use semver::Version;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::{
12    fmt::{self, Debug},
13    str::FromStr,
14    sync::OnceLock,
15};
16use syn::{
17    parse::{Parse, ParseStream},
18    parse_macro_input,
19    punctuated::Punctuated,
20    token::{
21        And, Brace, Bracket, Colon, Comma, Const, Eq as SynEq, Fn, Gt, Lt, Mod, Mut, Paren, Pound,
22        Unsafe,
23    },
24    Attribute, Block, Error, Ident, LitInt, LitStr, Visibility,
25};
26
27type Result<T, E = Error> = std::result::Result<T, E>;
28
29#[derive(Parse, Debug)]
30struct InsideBracket<T> {
31    #[allow(unused)]
32    #[bracket]
33    bracket: Bracket,
34    #[inside(bracket)]
35    value: T,
36}
37
38#[derive(Parse, Debug)]
39struct InsideBrace<T> {
40    #[brace]
41    brace: Brace,
42    #[inside(brace)]
43    value: T,
44}
45
46impl<T: ToTokens> ToTokens for InsideBrace<T> {
47    fn to_tokens(&self, tokens: &mut TokenStream2) {
48        self.brace
49            .surround(tokens, |tokens| self.value.to_tokens(tokens));
50    }
51}
52
53#[proc_macro_attribute]
54pub fn module(attr: TokenStream, item: TokenStream) -> TokenStream {
55    if !attr.is_empty() {
56        return Error::new_spanned(&TokenStream2::from(attr), "unexpected tokens")
57            .into_compile_error()
58            .into();
59    }
60    let mut item = parse_macro_input!(item as ModuleItem);
61    let mut build = true;
62    let mut krnl = quote! { ::krnl };
63    let new_attr = Vec::with_capacity(item.attr.len());
64    for attr in std::mem::replace(&mut item.attr, new_attr) {
65        if attr.path.segments.len() == 1
66            && attr
67                .path
68                .segments
69                .first()
70                .map_or(false, |x| x.ident == "krnl")
71        {
72            let tokens = attr.tokens.clone().into();
73            let args = syn::parse_macro_input!(tokens as ModuleKrnlArgs);
74            for arg in args.args.iter() {
75                if let Some(krnl_crate) = arg.krnl_crate.as_ref() {
76                    krnl = if krnl_crate.leading_colon.is_some()
77                        || krnl_crate
78                            .to_token_stream()
79                            .to_string()
80                            .starts_with("crate")
81                    {
82                        quote! {
83                            #krnl_crate
84                        }
85                    } else {
86                        quote! {
87                            ::#krnl_crate
88                        }
89                    };
90                } else if let Some(ident) = &arg.ident {
91                    if ident == "no_build" {
92                        build = false;
93                    } else {
94                        return Error::new_spanned(
95                            ident,
96                            format!("unknown krnl arg `{ident}`, expected `crate` or `no_build`"),
97                        )
98                        .into_compile_error()
99                        .into();
100                    }
101                }
102            }
103        } else {
104            item.attr.push(attr);
105        }
106    }
107    {
108        let tokens = item.tokens;
109        item.tokens = quote! {
110            #[cfg(not(target_arch = "spirv"))]
111            #[doc(hidden)]
112            macro_rules! __krnl_module_arg {
113                (use crate as $i:ident) => {
114                    use #krnl as $i;
115                };
116            }
117            #tokens
118        };
119    }
120    if build {
121        let source = item.tokens.to_string();
122        let ident = &item.ident;
123        let tokens = item.tokens;
124        item.tokens = quote! {
125            #[doc(hidden)]
126            mod __krnl_module_data {
127                #[allow(non_upper_case_globals)]
128                const __krnl_module_source: &'static str = #source;
129            }
130            #[cfg(not(krnlc))]
131            #[doc(hidden)]
132            macro_rules! __krnl_cache {
133                ($v:literal, $x:literal) => {
134                    #[doc(hidden)]
135                    macro_rules! __krnl_kernel {
136                        ($k:ident) => {
137                            Some(#krnl::macros::__krnl_cache!($v, #ident, $k, $x))
138                        };
139                    }
140                };
141            }
142            #[cfg(not(krnlc))]
143            include!(concat!(env!("CARGO_MANIFEST_DIR"), "/krnl-cache.rs"));
144            #[doc(hidden)]
145            #[cfg(krnlc)]
146            macro_rules! __krnl_kernel {
147                ($k:ident) => {
148                    None
149                };
150            }
151            #tokens
152        };
153    } else {
154        let tokens = item.tokens;
155        item.tokens = quote! {
156            #[doc(hidden)]
157            macro_rules! __krnl_kernel {
158                ($k:ident) => {
159                    None
160                };
161            }
162            #tokens
163        }
164    }
165    item.into_token_stream().into()
166}
167
168#[derive(Parse, Debug)]
169struct ModuleKrnlArgs {
170    #[allow(unused)]
171    #[paren]
172    paren: Paren,
173    #[inside(paren)]
174    #[call(Punctuated::parse_terminated)]
175    args: Punctuated<ModuleKrnlArg, Comma>,
176}
177
178#[derive(Parse, Debug)]
179struct ModuleKrnlArg {
180    #[allow(unused)]
181    crate_token: Option<syn::token::Crate>,
182    #[allow(unused)]
183    #[parse_if(crate_token.is_some())]
184    eq: Option<SynEq>,
185    #[parse_if(crate_token.is_some())]
186    krnl_crate: Option<syn::Path>,
187    #[parse_if(crate_token.is_none())]
188    ident: Option<Ident>,
189}
190
191#[derive(Parse, Debug)]
192struct ModuleItem {
193    #[call(Attribute::parse_outer)]
194    attr: Vec<Attribute>,
195    vis: Visibility,
196    mod_token: Mod,
197    ident: Ident,
198    #[brace]
199    brace: Brace,
200    #[inside(brace)]
201    tokens: TokenStream2,
202}
203
204impl ToTokens for ModuleItem {
205    fn to_tokens(&self, tokens: &mut TokenStream2) {
206        for attr in self.attr.iter() {
207            attr.to_tokens(tokens);
208        }
209        self.vis.to_tokens(tokens);
210        self.mod_token.to_tokens(tokens);
211        self.ident.to_tokens(tokens);
212        self.brace
213            .surround(tokens, |tokens| self.tokens.to_tokens(tokens));
214    }
215}
216
217#[proc_macro_attribute]
218pub fn kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
219    if !attr.is_empty() {
220        return Error::new_spanned(&TokenStream2::from(attr), "unexpected tokens")
221            .into_compile_error()
222            .into();
223    }
224    match kernel_impl(item.into()) {
225        Ok(tokens) => tokens.into(),
226        Err(e) => e.into_compile_error().into(),
227    }
228}
229
230#[derive(Parse, Debug)]
231struct KernelItem {
232    #[call(Attribute::parse_outer)]
233    attrs: Vec<Attribute>,
234    #[allow(unused)]
235    vis: Visibility,
236    unsafe_token: Option<Unsafe>,
237    #[allow(unused)]
238    fn_token: Fn,
239    ident: Ident,
240    #[peek(Lt)]
241    generics: Option<KernelGenerics>,
242    #[allow(unused)]
243    #[paren]
244    paren: Paren,
245    #[inside(paren)]
246    #[call(Punctuated::parse_terminated)]
247    args: Punctuated<KernelArg, Comma>,
248    block: Block,
249}
250
251impl KernelItem {
252    fn meta(&self) -> Result<KernelMeta> {
253        let mut meta = KernelMeta {
254            spec_metas: Vec::new(),
255            unsafe_token: self.unsafe_token,
256            ident: self.ident.clone(),
257            arg_metas: Vec::with_capacity(self.args.len()),
258            block: self.block.clone(),
259            itemwise: false,
260            arrays: FxHashMap::default(),
261        };
262        let mut spec_id = 0;
263        if let Some(generics) = self.generics.as_ref() {
264            meta.spec_metas = generics
265                .specs
266                .iter()
267                .map(|x| {
268                    let meta = KernelSpecMeta {
269                        ident: x.ident.clone(),
270                        ty: x.ty.clone(),
271                        id: spec_id,
272                        thread_dim: None,
273                    };
274                    spec_id += 1;
275                    meta
276                })
277                .collect();
278        }
279        let mut binding = 0;
280        for arg in self.args.iter() {
281            let mut arg_meta = arg.meta()?;
282            if arg_meta.kind.is_global() || arg_meta.kind.is_item() {
283                arg_meta.binding.replace(binding);
284                binding += 1;
285            }
286            meta.itemwise |= arg_meta.kind.is_item();
287            if let Some(len) = arg_meta.len.as_ref() {
288                meta.arrays
289                    .entry(arg_meta.scalar_ty.scalar_type)
290                    .or_default()
291                    .push((arg.ident.clone(), len.clone()));
292            }
293            meta.arg_metas.push(arg_meta);
294        }
295        Ok(meta)
296    }
297}
298
299#[derive(Debug)]
300struct KernelGenerics {
301    //#[allow(unused)]
302    //lt: Lt,
303    //#[call(Punctuated::parse_terminated)]
304    specs: Punctuated<KernelSpec, Comma>, // TODO: doesn't support trailing comma
305                                          //#[allow(unused)]
306                                          //gt: Gt,
307}
308
309impl Parse for KernelGenerics {
310    fn parse(input: ParseStream) -> Result<Self> {
311        input.parse::<Lt>()?;
312        let mut specs = Punctuated::new();
313        while input.peek(Const) {
314            specs.push(input.parse()?);
315            if input.peek(Comma) {
316                input.parse::<Comma>()?;
317            } else {
318                break;
319            }
320        }
321        input.parse::<Gt>()?;
322        Ok(Self { specs })
323    }
324}
325
326#[derive(Parse, Debug)]
327struct KernelSpec {
328    #[allow(unused)]
329    const_token: Const,
330    ident: Ident,
331    #[allow(unused)]
332    colon: Colon,
333    ty: KernelTypeScalar,
334}
335
336#[derive(Debug)]
337struct KernelSpecMeta {
338    ident: Ident,
339    ty: KernelTypeScalar,
340    id: u32,
341    thread_dim: Option<usize>,
342}
343
344impl KernelSpecMeta {
345    fn declare(&self) -> TokenStream2 {
346        use ScalarType::*;
347        let scalar_type = self.ty.scalar_type;
348        let bits = scalar_type.size() * 8;
349        let signed = matches!(scalar_type, I8 | I16 | I32 | I64) as u32;
350        let float = matches!(scalar_type, F32 | F64);
351        let ty_string = if float {
352            format!("%ty = OpTypeFloat {bits}")
353        } else {
354            format!("%ty = OpTypeInt {bits} {signed}")
355        };
356        let spec_id_string = format!("OpDecorate %spec SpecId {}", self.id);
357        let ident = &self.ident;
358        quote! {
359            #[allow(non_snake_case)]
360            let #ident = unsafe {
361                let mut spec = Default::default();
362                ::core::arch::asm! {
363                    #ty_string,
364                    "%spec = OpSpecConstant %ty 0",
365                    #spec_id_string,
366                    "OpStore {spec} %spec",
367                    spec = in(reg) &mut spec,
368                }
369                spec
370            };
371        }
372    }
373}
374
375#[derive(Clone, Debug)]
376struct KernelTypeScalar {
377    ident: Ident,
378    scalar_type: ScalarType,
379}
380
381impl Parse for KernelTypeScalar {
382    fn parse(input: ParseStream<'_>) -> Result<Self> {
383        let ident = input.parse()?;
384        if let Some(scalar_type) = ScalarType::iter().find(|x| ident == x.name()) {
385            Ok(Self { ident, scalar_type })
386        } else {
387            Err(Error::new(ident.span(), "expected scalar"))
388        }
389    }
390}
391
392#[derive(Parse, Debug)]
393struct KernelArg {
394    kind: KernelArgKind,
395    ident: Ident,
396    #[allow(unused)]
397    colon: Colon,
398    #[parse_if(kind.is_global())]
399    slice_ty: Option<KernelTypeSlice>,
400    #[parse_if(kind.is_item())]
401    item_ty: Option<KernelTypeItem>,
402    #[parse_if(kind.is_group())]
403    array_ty: Option<KernelTypeArray>,
404    #[parse_if(kind.is_push())]
405    push_ty: Option<KernelTypeScalar>,
406}
407
408impl KernelArg {
409    fn meta(&self) -> Result<KernelArgMeta> {
410        let kind = self.kind;
411        let (scalar_ty, mutable, len) = if let Some(slice_ty) = self.slice_ty.as_ref() {
412            let slice_ty_ident = &slice_ty.ty;
413            let mutable = if slice_ty.ty == "Slice" {
414                false
415            } else if slice_ty.ty == "UnsafeSlice" {
416                true
417            } else if slice_ty.ty == "SliceMut" {
418                return Err(Error::new_spanned(slice_ty_ident, "try `UnsafeSlice`"));
419            } else {
420                return Err(Error::new_spanned(
421                    slice_ty_ident,
422                    "expected `Slice` or `UnsafeSlice`",
423                ));
424            };
425            (slice_ty.scalar_ty.clone(), mutable, None)
426        } else if let Some(array_ty) = self.array_ty.as_ref() {
427            let len = array_ty.len.to_token_stream();
428            (array_ty.scalar_ty.clone(), true, Some(len))
429        } else if let Some(item_ty) = self.item_ty.as_ref() {
430            (item_ty.scalar_ty.clone(), item_ty.mut_token.is_some(), None)
431        } else if let Some(push_ty) = self.push_ty.as_ref() {
432            (push_ty.clone(), false, None)
433        } else {
434            unreachable!("KernelArg::meta expected type!")
435        };
436        let meta = KernelArgMeta {
437            kind,
438            ident: self.ident.clone(),
439            scalar_ty,
440            mutable,
441            binding: None,
442            len,
443        };
444        Ok(meta)
445    }
446}
447
448#[derive(Debug)]
449struct KernelArgMeta {
450    kind: KernelArgKind,
451    ident: Ident,
452    scalar_ty: KernelTypeScalar,
453    mutable: bool,
454    binding: Option<u32>,
455    len: Option<TokenStream2>,
456}
457
458impl KernelArgMeta {
459    fn compute_def_tokens(&self) -> Option<TokenStream2> {
460        let ident = &self.ident;
461        let ty = &self.scalar_ty.ident;
462        if let Some(binding) = self.binding.as_ref() {
463            let set = LitInt::new("0", Span2::call_site());
464            let binding = LitInt::new(&binding.to_string(), Span2::call_site());
465            let mut_token = if self.mutable {
466                Some(Mut::default())
467            } else {
468                None
469            };
470            Some(quote! {
471                #[spirv(storage_buffer, descriptor_set = #set, binding = #binding)] #ident: &#mut_token [#ty; 1]
472            })
473        } else {
474            None
475        }
476    }
477    fn device_fn_def_tokens(&self) -> TokenStream2 {
478        let ident = &self.ident;
479        let ty = &self.scalar_ty.ident;
480        let mutable = self.mutable;
481        use KernelArgKind::*;
482        match self.kind {
483            Global => {
484                if mutable {
485                    quote! {
486                        #ident: ::krnl_core::buffer::UnsafeSlice<#ty>
487                    }
488                } else {
489                    quote! {
490                        #ident: ::krnl_core::buffer::Slice<#ty>
491                    }
492                }
493            }
494            Item => {
495                if mutable {
496                    quote! {
497                        #ident: &mut #ty
498                    }
499                } else {
500                    quote! {
501                        #ident: #ty
502                    }
503                }
504            }
505            Group => quote! {
506                #ident: ::krnl_core::buffer::UnsafeSlice<#ty>
507            },
508            Push => quote! {
509                #ident: #ty
510            },
511        }
512    }
513    fn device_slices(&self) -> TokenStream2 {
514        let ident = &self.ident;
515        let mutable = self.mutable;
516        use KernelArgKind::*;
517        match self.kind {
518            Global | Item => {
519                let offset = format_ident!("__krnl_offset_{ident}");
520                let len = format_ident!("__krnl_len_{ident}");
521                let slice_fn = if mutable {
522                    quote! {
523                        ::krnl_core::buffer::UnsafeSlice::from_unsafe_raw_parts
524                    }
525                } else {
526                    quote! {
527                        ::krnl_core::buffer::Slice::from_raw_parts
528                    }
529                };
530                quote! {
531                    let #ident = unsafe {
532                        #slice_fn(#ident, __krnl_push_consts.#offset as usize, __krnl_push_consts.#len as usize)
533                    };
534                }
535            }
536            Group => {
537                let offset = format_ident!("__krnl_offset_{ident}");
538                let len = format_ident!("__krnl_len_{ident}");
539                let scalar_name = self.scalar_ty.scalar_type.name();
540                let array = format_ident!("__krnl_group_array_{scalar_name}");
541                quote! {
542                    let #ident = {
543                        unsafe {
544                            ::krnl_core::buffer::UnsafeSlice::from_unsafe_raw_parts(#array, #offset, #len)
545                        }
546                    };
547                }
548            }
549            Push => TokenStream2::new(),
550        }
551    }
552    fn device_fn_call_tokens(&self) -> TokenStream2 {
553        let ident = &self.ident;
554        let mutable = self.mutable;
555        use KernelArgKind::*;
556        match self.kind {
557            Global | Group => quote! {
558                #ident
559            },
560            Item => {
561                if mutable {
562                    quote! {
563                        unsafe {
564                            use ::krnl_core::buffer::UnsafeIndex;
565                            #ident.unsafe_index_mut(__krnl_item_id as usize)
566                        }
567                    }
568                } else {
569                    quote! {
570                        #ident[__krnl_item_id as usize]
571                    }
572                }
573            }
574            Push => quote! {
575                __krnl_push_consts.#ident
576            },
577        }
578    }
579}
580
581#[derive(Parse, Debug)]
582struct KernelArgAttr {
583    #[allow(unused)]
584    pound: Option<Pound>,
585    #[parse_if(pound.is_some())]
586    ident: Option<InsideBracket<Ident>>,
587}
588
589impl KernelArgAttr {
590    fn kind(&self) -> Result<KernelArgKind> {
591        use KernelArgKind::*;
592        let ident = if let Some(ident) = self.ident.as_ref() {
593            &ident.value
594        } else {
595            return Ok(Push);
596        };
597        let kind = if ident == "global" {
598            Global
599        } else if ident == "item" {
600            Item
601        } else if ident == "group" {
602            Group
603        } else {
604            return Err(Error::new_spanned(
605                ident,
606                "expected `global`, `item`, or `group`",
607            ));
608        };
609        Ok(kind)
610    }
611}
612
613#[derive(Clone, Copy, derive_more::IsVariant, PartialEq, Eq, Hash, Debug)]
614enum KernelArgKind {
615    Global,
616    Item,
617    Group,
618    Push,
619}
620
621impl Parse for KernelArgKind {
622    fn parse(input: ParseStream) -> Result<Self> {
623        KernelArgAttr::parse(input)?.kind()
624    }
625}
626
627#[derive(Parse, Debug)]
628struct KernelTypeItem {
629    #[allow(unused)]
630    and: Option<And>,
631    #[parse_if(and.is_some())]
632    mut_token: Option<Mut>,
633    scalar_ty: KernelTypeScalar,
634}
635
636#[derive(Parse, Debug)]
637struct KernelTypeSlice {
638    ty: Ident,
639    #[allow(unused)]
640    lt: Lt,
641    scalar_ty: KernelTypeScalar,
642    #[allow(unused)]
643    gt: Gt,
644}
645
646#[derive(Parse, Debug)]
647struct KernelTypeArray {
648    #[allow(unused)]
649    ty: Ident,
650    #[allow(unused)]
651    lt: Lt,
652    scalar_ty: KernelTypeScalar,
653    #[allow(unused)]
654    comma: Comma,
655    len: KernelArrayLength,
656    #[allow(unused)]
657    gt: Gt,
658}
659
660#[derive(Debug)]
661struct KernelArrayLength {
662    block: Option<Block>,
663    ident: Option<Ident>,
664    lit: Option<LitInt>,
665}
666
667impl Parse for KernelArrayLength {
668    fn parse(input: &syn::parse::ParseBuffer) -> Result<Self> {
669        if input.peek(Brace) {
670            Ok(Self {
671                block: Some(input.parse()?),
672                ident: None,
673                lit: None,
674            })
675        } else if input.peek(Ident) {
676            Ok(Self {
677                block: None,
678                ident: Some(input.parse()?),
679                lit: None,
680            })
681        } else {
682            Ok(Self {
683                block: None,
684                ident: None,
685                lit: Some(input.parse()?),
686            })
687        }
688    }
689}
690
691impl ToTokens for KernelArrayLength {
692    fn to_tokens(&self, tokens: &mut TokenStream2) {
693        if let Some(block) = self.block.as_ref() {
694            for stmt in block.stmts.iter() {
695                stmt.to_tokens(tokens);
696            }
697        } else if let Some(ident) = self.ident.as_ref() {
698            ident.to_tokens(tokens);
699        } else if let Some(lit) = self.lit.as_ref() {
700            lit.to_tokens(tokens);
701        }
702    }
703}
704
705#[derive(Debug)]
706struct KernelMeta {
707    spec_metas: Vec<KernelSpecMeta>,
708    ident: Ident,
709    unsafe_token: Option<Unsafe>,
710    arg_metas: Vec<KernelArgMeta>,
711    itemwise: bool,
712    block: Block,
713    arrays: FxHashMap<ScalarType, Vec<(Ident, TokenStream2)>>,
714}
715
716impl KernelMeta {
717    fn desc(&self) -> Result<KernelDesc> {
718        let mut kernel_desc = KernelDesc {
719            name: self.ident.to_string(),
720            safe: self.unsafe_token.is_none(),
721            ..KernelDesc::default()
722        };
723        for spec in self.spec_metas.iter() {
724            kernel_desc.spec_descs.push(SpecDesc {
725                name: spec.ident.to_string(),
726                scalar_type: spec.ty.scalar_type,
727            })
728        }
729        for arg_meta in self.arg_metas.iter() {
730            let kind = arg_meta.kind;
731            let scalar_type = arg_meta.scalar_ty.scalar_type;
732            use KernelArgKind::*;
733            match kind {
734                Global | Item => {
735                    kernel_desc.slice_descs.push(SliceDesc {
736                        name: arg_meta.ident.to_string(),
737                        scalar_type,
738                        mutable: arg_meta.mutable,
739                        item: kind.is_item(),
740                    });
741                }
742                Group => (),
743                Push => {
744                    kernel_desc.push_descs.push(PushDesc {
745                        name: arg_meta.ident.to_string(),
746                        scalar_type,
747                    });
748                }
749            }
750        }
751        kernel_desc
752            .push_descs
753            .sort_by_key(|x| -(x.scalar_type.size() as i32));
754        Ok(kernel_desc)
755    }
756    fn compute_def_args(&self) -> Punctuated<TokenStream2, Comma> {
757        let mut id = 1;
758        let arrays = self.arrays.keys().map(|scalar_type| {
759            let scalar_name = scalar_type.name();
760            let ident = format_ident!("__krnl_group_array_{scalar_name}_{id}");
761            let ty = format_ident!("{scalar_name}");
762            id += 1;
763            quote! {
764                #[spirv(workgroup)] #ident: &mut [#ty; 1]
765            }
766        });
767        self.arg_metas
768            .iter()
769            .filter_map(|arg| arg.compute_def_tokens())
770            .chain(arrays)
771            .collect()
772    }
773    /*
774    fn threads(&self) -> TokenStream2 {
775        let id = self.spec_metas.len();
776        let spec_id_string = format!("OpDecorate %spec SpecId {}", id);
777        quote! {
778            #[allow(non_snake_case)]
779            let __krnl_threads: u32 = unsafe {
780                let mut spec = Default::default();
781                ::core::arch::asm! {
782                    "%uint = OpTypeInt 32 0",
783                    "%spec = OpSpecConstant %uint 1",
784                    #spec_id_string,
785                    "OpStore {spec} %spec",
786                    spec = in(reg) &mut spec,
787                }
788                spec
789            };
790        }
791    }*/
792    fn declare_specs(&self) -> TokenStream2 {
793        self.spec_metas
794            .iter()
795            .flat_map(|spec| spec.declare())
796            .collect()
797    }
798    fn spec_def_args(&self) -> Punctuated<TokenStream2, Comma> {
799        self.spec_metas
800            .iter()
801            .map(|spec| {
802                let ident = &spec.ident;
803                let ty = &spec.ty.ident;
804                quote! {
805                    #[allow(non_snake_case)]
806                    #ident: #ty
807                }
808            })
809            .collect()
810    }
811    fn spec_args(&self) -> Vec<Ident> {
812        self.spec_metas
813            .iter()
814            .map(|spec| spec.ident.clone())
815            .collect()
816    }
817    fn device_arrays(&self) -> TokenStream2 {
818        let spec_def_args: Punctuated<_, Comma> = self
819            .spec_def_args()
820            .into_iter()
821            .map(|arg| {
822                quote! {
823                    #[allow(unused)] #arg
824                }
825            })
826            .collect();
827        let spec_args: Punctuated<_, Comma> = self.spec_args().into_iter().collect();
828        let group_barrier = if self.arg_metas.iter().any(|arg| arg.kind.is_group()) {
829            quote! {
830                unsafe {
831                     ::krnl_core::spirv_std::arch::workgroup_memory_barrier();
832                }
833            }
834        } else {
835            TokenStream2::new()
836        };
837        let mut id = 1;
838        self.arrays
839            .iter()
840            .flat_map(|(scalar_type, arrays)| {
841                let scalar_name = scalar_type.name();
842                let ident = format_ident!("__krnl_group_array_{scalar_name}");
843                let ident_with_id = format_ident!("{ident}_{id}");
844                let id_lit = LitInt::new(&id.to_string(), Span2::call_site());
845                id += 1;
846                let len = format_ident!("{ident}_len");
847                let offset = format_ident!("{ident}_offset");
848                let array_offsets_lens: TokenStream2 = arrays
849                    .iter()
850                    .map(|(array, len_expr)| {
851                        let array_offset = format_ident!("__krnl_offset_{array}");
852                        let array_len = format_ident!("__krnl_len_{array}");
853                        quote! {
854                            let #array_offset = #offset;
855                            let #array_len = {
856                                const fn #array_len(#spec_def_args) -> usize {
857                                    #len_expr
858                                }
859                                #array_len(#spec_args)
860                            };
861                            #offset += #array_len;
862                        }
863                    })
864                    .collect();
865                quote! {
866                    let #ident = #ident_with_id;
867                    let mut #offset = 0usize;
868                    #array_offsets_lens
869                    let #len = #offset;
870                    unsafe {
871                        ::krnl_core::kernel::__private::group_buffer_len(__krnl_kernel_data, #id_lit, #len);
872                        ::krnl_core::kernel::__private::zero_group_buffer(&kernel, #ident, #len);
873                    }
874                }
875            })
876            .chain(group_barrier)
877            .collect()
878    }
879    fn host_array_length_checks(&self) -> TokenStream2 {
880        let mut spec_def_args = self.spec_def_args();
881        for arg in spec_def_args.iter_mut() {
882            *arg = quote! {
883                #[allow(unused_variables, non_snake_case)]
884                #arg
885            };
886        }
887        self.arg_metas
888            .iter()
889            .flat_map(|arg| {
890                if let Some(len) = arg.len.as_ref() {
891                    quote! {
892                        const _: () = {
893                            #[allow(non_snake_case, clippy::too_many_arguments)]
894                            const fn __krnl_array_len(#spec_def_args) -> usize {
895                                #len
896                            }
897                            let _ = __krnl_array_len;
898                        };
899                    }
900                } else {
901                    TokenStream2::new()
902                }
903            })
904            .collect()
905    }
906    fn device_slices(&self) -> TokenStream2 {
907        self.arg_metas
908            .iter()
909            .map(|arg| arg.device_slices())
910            .collect()
911    }
912    fn device_items(&self) -> TokenStream2 {
913        let mut items = self
914            .arg_metas
915            .iter()
916            .filter(|arg| arg.kind.is_item())
917            .map(|arg| &arg.ident);
918        if let Some(first) = items.next() {
919            quote! {
920                #first.len()
921            }
922            .into_iter()
923            .chain(items.flat_map(|item| {
924                quote! {
925                    .max(#item.len())
926                }
927            }))
928            .collect()
929        } else {
930            quote! {
931                0
932            }
933        }
934    }
935    fn device_fn_def_args(&self) -> Punctuated<TokenStream2, Comma> {
936        self.spec_metas
937            .iter()
938            .map(|x| {
939                let ident = &x.ident;
940                let ty = &x.ty.ident;
941                let allow_unused = x.thread_dim.map(|_| {
942                    quote! {
943                        #[allow(unused)]
944                    }
945                });
946                quote! {
947                    #allow_unused
948                    #[allow(non_snake_case)]
949                    #ident: #ty
950                }
951            })
952            .chain(self.arg_metas.iter().map(|arg| arg.device_fn_def_tokens()))
953            .collect()
954    }
955    fn device_fn_call_args(&self) -> Punctuated<TokenStream2, Comma> {
956        self.spec_metas
957            .iter()
958            .map(|spec| spec.ident.to_token_stream())
959            .chain(self.arg_metas.iter().map(|arg| arg.device_fn_call_tokens()))
960            .collect()
961    }
962    fn dispatch_args(&self) -> TokenStream2 {
963        let mut tokens = TokenStream2::new();
964        for arg in self.arg_metas.iter() {
965            let ident = &arg.ident;
966            let ty = &arg.scalar_ty.ident;
967            if arg.binding.is_some() {
968                let slice_ty = if arg.mutable {
969                    format_ident!("SliceMut")
970                } else {
971                    format_ident!("Slice")
972                };
973                tokens.extend(quote! {
974                    #ident: #slice_ty<#ty>,
975                });
976            } else if arg.kind.is_push() {
977                tokens.extend(quote! {
978                    #ident: #ty,
979                });
980            }
981        }
982        tokens
983    }
984    fn dispatch_slice_args(&self) -> TokenStream2 {
985        let mut tokens = TokenStream2::new();
986        for arg in self.arg_metas.iter() {
987            let ident = &arg.ident;
988            if arg.binding.is_some() {
989                tokens.extend(quote! {
990                    #ident.into(),
991                });
992            }
993        }
994        tokens
995    }
996}
997
998#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
999enum ScalarType {
1000    U8,
1001    I8,
1002    U16,
1003    I16,
1004    F16,
1005    BF16,
1006    U32,
1007    I32,
1008    F32,
1009    U64,
1010    I64,
1011    F64,
1012}
1013
1014impl ScalarType {
1015    fn iter() -> impl Iterator<Item = Self> {
1016        use ScalarType::*;
1017        [U8, I8, U16, I16, F16, BF16, U32, I32, F32, U64, I64, F64].into_iter()
1018    }
1019    fn name(&self) -> &'static str {
1020        use ScalarType::*;
1021        match self {
1022            U8 => "u8",
1023            I8 => "i8",
1024            U16 => "u16",
1025            I16 => "i16",
1026            F16 => "f16",
1027            BF16 => "bf16",
1028            U32 => "u32",
1029            I32 => "i32",
1030            F32 => "f32",
1031            U64 => "u64",
1032            I64 => "i64",
1033            F64 => "f64",
1034        }
1035    }
1036    fn as_str(&self) -> &'static str {
1037        use ScalarType::*;
1038        match self {
1039            U8 => "U8",
1040            I8 => "I8",
1041            U16 => "U16",
1042            I16 => "I16",
1043            F16 => "F16",
1044            BF16 => "BF16",
1045            U32 => "U32",
1046            I32 => "I32",
1047            F32 => "F32",
1048            U64 => "U64",
1049            I64 => "I64",
1050            F64 => "F64",
1051        }
1052    }
1053    fn size(&self) -> usize {
1054        use ScalarType::*;
1055        match self {
1056            U8 | I8 => 1,
1057            U16 | I16 | F16 | BF16 => 2,
1058            U32 | I32 | F32 => 4,
1059            U64 | I64 | F64 => 8,
1060        }
1061    }
1062}
1063
1064impl ToTokens for ScalarType {
1065    fn to_tokens(&self, tokens: &mut TokenStream2) {
1066        let ident = format_ident!("{self:?}");
1067        tokens.extend(quote! {
1068            ScalarType::#ident
1069        });
1070    }
1071}
1072
1073impl FromStr for ScalarType {
1074    type Err = ();
1075    fn from_str(input: &str) -> Result<Self, ()> {
1076        Self::iter()
1077            .find(|x| x.as_str() == input || x.name() == input)
1078            .ok_or(())
1079    }
1080}
1081
1082impl Serialize for ScalarType {
1083    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1084    where
1085        S: Serializer,
1086    {
1087        serializer.serialize_str(self.as_str())
1088    }
1089}
1090
1091impl<'de> Deserialize<'de> for ScalarType {
1092    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1093    where
1094        D: Deserializer<'de>,
1095    {
1096        use serde::de::Visitor;
1097
1098        struct ScalarTypeVisitor;
1099
1100        impl Visitor<'_> for ScalarTypeVisitor {
1101            type Value = ScalarType;
1102            fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1103                write!(formatter, "a scalar type")
1104            }
1105            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
1106            where
1107                E: serde::de::Error,
1108            {
1109                if let Ok(scalar_type) = ScalarType::from_str(v) {
1110                    Ok(scalar_type)
1111                } else {
1112                    Err(E::custom(format!("unknown ScalarType {v}")))
1113                }
1114            }
1115        }
1116        deserializer.deserialize_str(ScalarTypeVisitor)
1117    }
1118}
1119
1120#[derive(Default, Serialize, Deserialize, Debug)]
1121struct KernelDesc {
1122    name: String,
1123    #[serde(skip_serializing)]
1124    spirv: Vec<u32>,
1125    #[serde(skip_serializing)]
1126    features: Features,
1127    safe: bool,
1128    spec_descs: Vec<SpecDesc>,
1129    slice_descs: Vec<SliceDesc>,
1130    push_descs: Vec<PushDesc>,
1131}
1132
1133impl KernelDesc {
1134    fn encode(&self) -> Result<String> {
1135        let bytes = bincode2::serialize(self).map_err(|e| Error::new(Span2::call_site(), e))?;
1136        Ok(format!("__krnl_kernel_data_{}", hex::encode(bytes)))
1137    }
1138    fn push_const_fields(&self) -> Punctuated<TokenStream2, Comma> {
1139        let mut fields = Punctuated::new();
1140        let mut size = 0;
1141        for push_desc in self.push_descs.iter() {
1142            let ident = format_ident!("{}", push_desc.name);
1143            let ty = format_ident!("{}", push_desc.scalar_type.name());
1144            fields.push(quote! {
1145               #ident: #ty
1146            });
1147            size += push_desc.scalar_type.size();
1148        }
1149        for i in 0..4 {
1150            if size % 4 == 0 {
1151                break;
1152            }
1153            let ident = format_ident!("__krnl_pad{i}");
1154            fields.push(quote! {
1155               #ident: u8
1156            });
1157            size += 1;
1158        }
1159        for slice_desc in self.slice_descs.iter() {
1160            let offset_ident = format_ident!("__krnl_offset_{}", slice_desc.name);
1161            let len_ident = format_ident!("__krnl_len_{}", slice_desc.name);
1162            fields.push(quote! {
1163                #offset_ident: u32
1164            });
1165            fields.push(quote! {
1166                #len_ident: u32
1167            });
1168        }
1169        fields
1170    }
1171    fn dispatch_push_args(&self) -> Vec<Ident> {
1172        self.push_descs
1173            .iter()
1174            .map(|push| format_ident!("{}", push.name))
1175            .collect()
1176    }
1177}
1178
1179#[derive(Default, Clone, Copy, PartialEq, Eq, Deserialize)]
1180#[serde(transparent)]
1181struct Features {
1182    bits: u32,
1183}
1184
1185impl Features {
1186    pub const INT8: Self = Self::new(1);
1187    pub const INT16: Self = Self::new(1 << 1);
1188    pub const INT64: Self = Self::new(1 << 2);
1189    pub const FLOAT16: Self = Self::new(1 << 3);
1190    pub const FLOAT64: Self = Self::new(1 << 4);
1191    pub const BUFFER8: Self = Self::new(1 << 8);
1192    pub const BUFFER16: Self = Self::new(1 << 9);
1193    pub const PUSH_CONSTANT8: Self = Self::new(1 << 10);
1194    pub const PUSH_CONSTANT16: Self = Self::new(1 << 11);
1195    pub const SUBGROUP_BASIC: Self = Self::new(1 << 16);
1196    pub const SUBGROUP_VOTE: Self = Self::new(1 << 17);
1197    pub const SUBGROUP_ARITHMETIC: Self = Self::new(1 << 18);
1198    pub const SUBGROUP_BALLOT: Self = Self::new(1 << 19);
1199    pub const SUBGROUP_SHUFFLE: Self = Self::new(1 << 20);
1200    pub const SUBGROUP_SHUFFLE_RELATIVE: Self = Self::new(1 << 21);
1201    pub const SUBGROUP_CLUSTERED: Self = Self::new(1 << 22);
1202    pub const SUBGROUP_QUAD: Self = Self::new(1 << 23);
1203
1204    #[inline]
1205    const fn new(bits: u32) -> Self {
1206        Self { bits }
1207    }
1208    /*
1209    #[inline]
1210    pub const fn empty() -> Self {
1211        Self { bits: 0 }
1212    }
1213    #[inline]
1214    pub const fn all() -> Self {
1215        Self::empty()
1216            .union(Self::INT8)
1217            .union(Self::INT16)
1218            .union(Self::FLOAT16)
1219            .union(Self::INT64)
1220            .union(Self::FLOAT64)
1221            .union(Self::BUFFER8)
1222            .union(Self::BUFFER16)
1223            .union(Self::SUBGROUP_BASIC)
1224            .union(Self::SUBGROUP_VOTE)
1225            .union(Self::SUBGROUP_ARITHMETIC)
1226            .union(Self::SUBGROUP_BALLOT)
1227            .union(Self::SUBGROUP_SHUFFLE)
1228            .union(Self::SUBGROUP_SHUFFLE_RELATIVE)
1229            .union(Self::SUBGROUP_CLUSTERED)
1230            .union(Self::SUBGROUP_QUAD)
1231    }
1232    */
1233    #[inline]
1234    pub const fn contains(self, other: Self) -> bool {
1235        (self.bits | other.bits) == self.bits
1236    }
1237    /*
1238    #[inline]
1239    pub const fn union(self, other: Self) -> Self {
1240        Self::new(self.bits | other.bits)
1241    }
1242    */
1243    fn name_iter(self) -> impl Iterator<Item = &'static str> {
1244        macro_rules! features {
1245            ($($f:ident),*) => {
1246                [
1247                    $(
1248                        (stringify!($f), Self::$f)
1249                    ),*
1250                ]
1251            };
1252        }
1253
1254        features!(
1255            INT8,
1256            INT16,
1257            INT64,
1258            FLOAT16,
1259            FLOAT64,
1260            BUFFER8,
1261            BUFFER16,
1262            PUSH_CONSTANT8,
1263            PUSH_CONSTANT16,
1264            SUBGROUP_BASIC,
1265            SUBGROUP_VOTE,
1266            SUBGROUP_ARITHMETIC,
1267            SUBGROUP_BALLOT,
1268            SUBGROUP_SHUFFLE,
1269            SUBGROUP_SHUFFLE_RELATIVE,
1270            SUBGROUP_CLUSTERED,
1271            SUBGROUP_QUAD
1272        )
1273        .into_iter()
1274        .filter_map(move |(name, features)| {
1275            if self.contains(features) {
1276                Some(name)
1277            } else {
1278                None
1279            }
1280        })
1281    }
1282}
1283
1284impl Debug for Features {
1285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
1286        struct FeaturesStr<'a>(&'a str);
1287
1288        impl Debug for FeaturesStr<'_> {
1289            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1290                f.debug_struct(self.0).finish()
1291            }
1292        }
1293
1294        let alternate = f.alternate();
1295        let mut b = f.debug_tuple("Features");
1296        if alternate {
1297            for name in self.name_iter() {
1298                b.field(&FeaturesStr(name));
1299            }
1300        } else {
1301            b.field(&FeaturesStr(&itertools::join(self.name_iter(), "|")));
1302        }
1303        b.finish()
1304    }
1305}
1306
1307impl ToTokens for Features {
1308    fn to_tokens(&self, tokens: &mut TokenStream2) {
1309        let features = self
1310            .name_iter()
1311            .map(|name| Ident::new(name, Span2::call_site()));
1312        tokens.extend(quote! {
1313            Features::empty()
1314                #(.union(Features::#features))*
1315        });
1316    }
1317}
1318
1319#[derive(Serialize, Deserialize, Debug)]
1320struct SpecDesc {
1321    name: String,
1322    scalar_type: ScalarType,
1323}
1324
1325impl ToTokens for SpecDesc {
1326    fn to_tokens(&self, tokens: &mut TokenStream2) {
1327        let Self { name, scalar_type } = self;
1328        tokens.extend(quote! {
1329            SpecDesc {
1330                name: #name,
1331                scalar_type: #scalar_type,
1332            }
1333        });
1334    }
1335}
1336
1337#[derive(Serialize, Deserialize, Debug)]
1338struct SliceDesc {
1339    name: String,
1340    scalar_type: ScalarType,
1341    mutable: bool,
1342    item: bool,
1343}
1344
1345impl ToTokens for SliceDesc {
1346    fn to_tokens(&self, tokens: &mut TokenStream2) {
1347        let Self {
1348            name,
1349            scalar_type,
1350            mutable,
1351            item,
1352        } = self;
1353        tokens.extend(quote! {
1354            SliceDesc {
1355                name: #name,
1356                scalar_type: #scalar_type,
1357                mutable: #mutable,
1358                item: #item,
1359            }
1360        })
1361    }
1362}
1363
1364#[derive(Serialize, Deserialize, Debug)]
1365struct PushDesc {
1366    name: String,
1367    scalar_type: ScalarType,
1368}
1369
1370impl ToTokens for PushDesc {
1371    fn to_tokens(&self, tokens: &mut TokenStream2) {
1372        let Self { name, scalar_type } = self;
1373        tokens.extend(quote! {
1374            PushDesc {
1375                name: #name,
1376                scalar_type: #scalar_type,
1377            }
1378        })
1379    }
1380}
1381
1382fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
1383    let item: KernelItem = syn::parse2(item_tokens.clone())?;
1384    let kernel_meta = item.meta()?;
1385    let kernel_desc = kernel_meta.desc()?;
1386    let item_attrs = &item.attrs;
1387    let unsafe_token = kernel_meta.unsafe_token;
1388    let ident = &kernel_meta.ident;
1389    let device_tokens = {
1390        let kernel_data = format_ident!("{}", kernel_desc.encode()?);
1391        let block = &kernel_meta.block;
1392        let compute_def_args = kernel_meta.compute_def_args();
1393        let declare_specs = kernel_meta.declare_specs();
1394        let threads_spec_id =
1395            Literal::u32_unsuffixed(kernel_desc.spec_descs.len().try_into().unwrap());
1396        let items = kernel_meta.device_items();
1397        let device_arrays = kernel_meta.device_arrays();
1398        let device_slices = kernel_meta.device_slices();
1399        let device_fn_def_args = kernel_meta.device_fn_def_args();
1400        let device_fn_call_args = kernel_meta.device_fn_call_args();
1401        let push_consts_ident = format_ident!("__krnl_{ident}PushConsts");
1402        let (push_struct_tokens, push_consts_arg) =
1403            if !kernel_desc.push_descs.is_empty() || !kernel_desc.slice_descs.is_empty() {
1404                let push_const_fields = kernel_desc.push_const_fields();
1405                let push_struct_tokens = quote! {
1406                    #[cfg(target_arch = "spirv")]
1407                    #[automatically_derived]
1408                    #[repr(C)]
1409                    pub struct #push_consts_ident {
1410                        #push_const_fields
1411                    }
1412                };
1413                let push_consts_arg = quote! {
1414                    #[spirv(push_constant)]
1415                    __krnl_push_consts: &#push_consts_ident,
1416                };
1417                (push_struct_tokens, push_consts_arg)
1418            } else {
1419                (TokenStream2::new(), TokenStream2::new())
1420            };
1421        let mut device_fn_call = quote! {
1422            #unsafe_token {
1423                #ident (
1424                    kernel,
1425                    #device_fn_call_args
1426                );
1427            }
1428        };
1429        if kernel_meta.itemwise {
1430            device_fn_call = quote! {
1431                let __krnl_items = #items;
1432                let mut __krnl_item_id = kernel.global_id();
1433                while __krnl_item_id < __krnl_items {
1434                    {
1435                        let kernel = unsafe {
1436                            ::krnl_core::kernel::__private::ItemKernelArgs {
1437                                item_id: __krnl_item_id as u32,
1438                                items: __krnl_items as u32,
1439                            }.into_item_kernel()
1440                        };
1441                        #device_fn_call
1442                    }
1443                    __krnl_item_id += kernel.global_threads();
1444                }
1445            };
1446        }
1447        let kernel_type = if kernel_meta.itemwise {
1448            quote! { ItemKernel }
1449        } else {
1450            quote! {
1451                Kernel
1452            }
1453        };
1454        quote! {
1455            #push_struct_tokens
1456            #[cfg(target_arch = "spirv")]
1457            #[::krnl_core::spirv_std::spirv(compute(threads(1)))]
1458            #[allow(unused)]
1459            pub fn #ident(
1460                #push_consts_arg
1461                #[spirv(global_invocation_id)]
1462                __krnl_global_id: ::krnl_core::spirv_std::glam::UVec3,
1463                #[spirv(num_workgroups)]
1464                __krnl_groups: ::krnl_core::spirv_std::glam::UVec3,
1465                #[spirv(workgroup_id)]
1466                __krnl_group_id: ::krnl_core::spirv_std::glam::UVec3,
1467                #[spirv(num_subgroups)]
1468                __krnl_subgroups: u32,
1469                #[spirv(subgroup_id)]
1470                __krnl_subgroup_id: u32,
1471                #[spirv(subgroup_local_invocation_id)]
1472                __krnl_subgroup_thread_id: u32,
1473                #[spirv(spec_constant(id = #threads_spec_id, default = 1))] __krnl_threads: u32,
1474                #[spirv(local_invocation_id)]
1475                __krnl_thread_id: ::krnl_core::spirv_std::glam::UVec3,
1476                #[spirv(storage_buffer, descriptor_set = 1, binding = 0)]
1477                #kernel_data: &mut [u32],
1478                #compute_def_args
1479            ) {
1480                #(#item_attrs)*
1481                #unsafe_token fn #ident(
1482                    #[allow(unused)]
1483                    kernel: ::krnl_core::kernel::#kernel_type,
1484                    #device_fn_def_args
1485                ) #block
1486                {
1487                    let __krnl_kernel_data = #kernel_data;
1488                    unsafe {
1489                        ::krnl_core::kernel::__private::kernel_data(__krnl_kernel_data);
1490                    }
1491                    #declare_specs
1492                    let mut kernel = unsafe {
1493                        ::krnl_core::kernel::__private::KernelArgs {
1494                            global_id: __krnl_global_id.x,
1495                            groups: __krnl_groups.x,
1496                            group_id: __krnl_group_id.x,
1497                            subgroups: __krnl_subgroups,
1498                            subgroup_id: __krnl_subgroup_id,
1499                            subgroup_thread_id: __krnl_subgroup_thread_id,
1500                            threads: __krnl_threads,
1501                            thread_id: __krnl_thread_id.x,
1502                        }.into_kernel()
1503                    };
1504                    #device_arrays
1505                    #device_slices
1506                    #device_fn_call
1507                }
1508            }
1509        }
1510    };
1511    let host_tokens = {
1512        let spec_descs = &kernel_desc.spec_descs;
1513        let slice_descs = &kernel_desc.slice_descs;
1514        let push_descs = &kernel_desc.push_descs;
1515        let dispatch_args = kernel_meta.dispatch_args();
1516        let dispatch_slice_args = kernel_meta.dispatch_slice_args();
1517        let dispatch_push_args = kernel_desc.dispatch_push_args();
1518        let safe = unsafe_token.is_none();
1519        let safety = if safe {
1520            quote! {
1521                Safety::Safe
1522            }
1523        } else {
1524            quote! {
1525                Safety::Unsafe
1526            }
1527        };
1528        let host_array_length_checks = kernel_meta.host_array_length_checks();
1529        let specialize = !kernel_desc.spec_descs.is_empty();
1530        let specialized = [format_ident!("S")];
1531        let specialized = if specialize {
1532            specialized.as_ref()
1533        } else {
1534            &[]
1535        };
1536        let kernel_builder_phantom_data = if specialize {
1537            quote! { S }
1538        } else {
1539            quote! { () }
1540        };
1541        let kernel_builder_build_generics = if specialize {
1542            quote! {
1543                <Specialized<true>>
1544            }
1545        } else {
1546            TokenStream2::new()
1547        };
1548        let kernel_builder_specialize_fn = if specialize {
1549            let spec_def_args = kernel_meta.spec_def_args();
1550            let spec_args = kernel_meta.spec_args();
1551            quote! {
1552                /// Specializes the kernel.
1553                #[allow(clippy::too_many_arguments, non_snake_case)]
1554                pub fn specialize(mut self, #spec_def_args) -> KernelBuilder<Specialized<true>> {
1555                    KernelBuilder {
1556                        inner: self.inner.specialize(&[#(#spec_args.into()),*]),
1557                        _m: PhantomData,
1558                    }
1559                }
1560            }
1561        } else {
1562            TokenStream2::new()
1563        };
1564        let needs_groups = !kernel_meta.itemwise;
1565        let with_groups = [format_ident!("G")];
1566        let with_groups = if needs_groups {
1567            with_groups.as_ref()
1568        } else {
1569            &[]
1570        };
1571        let kernel_phantom_data = if needs_groups {
1572            quote! { G }
1573        } else {
1574            quote! { () }
1575        };
1576        let kernel_dispatch_generics = if needs_groups {
1577            quote! { <WithGroups<true>> }
1578        } else {
1579            TokenStream2::new()
1580        };
1581        let input_docs = {
1582            let input_tokens_string = prettyplease::unparse(&syn::parse2(quote! {
1583                #[kernel]
1584                #item_tokens
1585            })?);
1586            let input_doc_string = format!("```\n{input_tokens_string}\n```");
1587            quote! {
1588                #![cfg_attr(not(doctest), doc = #input_doc_string)]
1589            }
1590        };
1591        let expansion = if rustversion::cfg!(nightly) {
1592            let expansion_tokens_string =
1593                prettyplease::unparse(&syn::parse2(device_tokens.clone())?);
1594            let expansion_doc_string = format!("```\n{expansion_tokens_string}\n```");
1595            quote! {
1596                #[cfg(all(doc, not(doctest)))]
1597                mod expansion {
1598                    #![doc = #expansion_doc_string]
1599                }
1600            }
1601        } else {
1602            TokenStream2::new()
1603        };
1604        quote! {
1605            #[cfg(not(target_arch = "spirv"))]
1606            #(#item_attrs)*
1607            #[automatically_derived]
1608            pub mod #ident {
1609                #input_docs
1610                #expansion
1611                __krnl_module_arg!(use crate as __krnl);
1612                use __krnl::{
1613                    anyhow::{self, Result},
1614                    krnl_core::half::{f16, bf16},
1615                    buffer::{Slice, SliceMut},
1616                    device::{Device, Features},
1617                    scalar::ScalarType,
1618                    kernel::__private::{
1619                        Kernel as KernelBase,
1620                        KernelBuilder as KernelBuilderBase,
1621                        Specialized,
1622                        WithGroups,
1623                        KernelDesc,
1624                        SliceDesc,
1625                        SpecDesc,
1626                        PushDesc,
1627                        Safety,
1628                        validate_kernel
1629                    },
1630                    anyhow::format_err,
1631                };
1632                use ::std::{sync::OnceLock, marker::PhantomData};
1633                #[cfg(not(krnlc))]
1634                #[doc(hidden)]
1635                use __krnl::macros::__krnl_cache;
1636                #[cfg(doc)]
1637                use __krnl::{kernel, device::{DeviceInfo, error::DeviceLost}};
1638
1639                #host_array_length_checks
1640
1641                /// Builder for creating a [`Kernel`].
1642                ///
1643                /// See [`builder()`](builder).
1644                pub struct KernelBuilder #(<#specialized = Specialized<false>>)* {
1645                    #[doc(hidden)]
1646                    inner: KernelBuilderBase,
1647                    #[doc(hidden)]
1648                    _m: PhantomData<#kernel_builder_phantom_data>,
1649                }
1650
1651                /// Creates a builder.
1652                ///
1653                /// The builder is lazily created on first call.
1654                ///
1655                /// # Errors
1656                /// - The kernel wasn't compiled (with `#[krnl(no_build)]` applied to `#[module]`).
1657                pub fn builder() -> Result<KernelBuilder> {
1658                    static BUILDER: OnceLock<Result<KernelBuilderBase, String>> = OnceLock::new();
1659                    let builder = BUILDER.get_or_init(|| {
1660                        const DESC: Option<KernelDesc> = validate_kernel(__krnl_kernel!(#ident), #safety, &[#(#spec_descs),*], &[#(#slice_descs),*], &[#(#push_descs),*]);
1661                        if let Some(desc) = DESC.as_ref() {
1662                            KernelBuilderBase::from_desc(desc.clone())
1663                        } else {
1664                            Err(format!("Kernel `{}` not compiled!", ::std::module_path!()))
1665                        }
1666                    });
1667                    match builder {
1668                        Ok(inner) => Ok(KernelBuilder {
1669                            inner: inner.clone(),
1670                            _m: PhantomData,
1671                        }),
1672                        Err(err) => Err(format_err!("{err}")),
1673                    }
1674                }
1675
1676                impl #(<#specialized>)* KernelBuilder #(<#specialized>)* {
1677                    /// Threads per group.
1678                    ///
1679                    /// Defaults to [`DeviceInfo::default_threads()`](DeviceInfo::default_threads).
1680                    pub fn with_threads(self, threads: u32) -> Self {
1681                        Self {
1682                            inner: self.inner.with_threads(threads),
1683                            _m: PhantomData,
1684                        }
1685                    }
1686                    #kernel_builder_specialize_fn
1687                    #[doc(hidden)]
1688                    #[inline]
1689                    pub fn __features(&self) -> Features {
1690                        self.inner.features()
1691                    }
1692                }
1693
1694                impl KernelBuilder #kernel_builder_build_generics {
1695                    /// Builds the kernel for `device`.
1696                    ///
1697                    /// The kernel is cached, so subsequent calls to `.build()` with identical
1698                    /// builders (ie threads and spec constants) may avoid recompiling.
1699                    ///
1700                    /// # Errors
1701                    /// - `device` doesn't have required features.
1702                    /// - The kernel is not supported on `device`.
1703                    /// - [`DeviceLost`].
1704                    pub fn build(&self, device: Device) -> Result<Kernel> {
1705                        Ok(Kernel {
1706                            inner:  self.inner.build(device)?,
1707                            _m: PhantomData,
1708                        })
1709                    }
1710                }
1711
1712                /// Kernel.
1713                pub struct Kernel #(<#with_groups = WithGroups<false>>)* {
1714                    #[doc(hidden)]
1715                    inner: KernelBase,
1716                    #[doc(hidden)]
1717                    _m: PhantomData<#kernel_phantom_data>,
1718                }
1719
1720                impl #(<#with_groups>)* Kernel #(<#with_groups>)* {
1721                    /// Threads per group.
1722                    pub fn threads(&self) -> u32 {
1723                        self.inner.threads()
1724                    }
1725                    /// Global threads to dispatch.
1726                    ///
1727                    /// Implicitly declares groups by rounding up to the next multiple of threads.
1728                    pub fn with_global_threads(self, global_threads: u32) -> Kernel #kernel_dispatch_generics {
1729                        Kernel {
1730                            inner: self.inner.with_global_threads(global_threads),
1731                            _m: PhantomData,
1732                        }
1733                    }
1734                    /// Groups to dispatch.
1735                    ///
1736                    /// For item kernels, if not provided, is inferred based on item arguments.
1737                    pub fn with_groups(self, groups: u32) -> Kernel #kernel_dispatch_generics {
1738                        Kernel {
1739                            inner: self.inner.with_groups(groups),
1740                            _m: PhantomData,
1741                        }
1742                    }
1743                }
1744
1745                impl Kernel #kernel_dispatch_generics {
1746                    /// Dispatches the kernel.
1747                    ///
1748                    /// - Waits for immutable access to slice arguments.
1749                    /// - Waits for mutable access to mutable slice arguments.
1750                    /// - Blocks until the kernel is queued.
1751                    ///
1752                    /// # Errors
1753                    /// - [`DeviceLost`].
1754                    /// - The kernel could not be queued.
1755                    pub #unsafe_token fn dispatch(&self, #dispatch_args) -> Result<()> {
1756                        unsafe { self.inner.dispatch(&[#dispatch_slice_args], &[#(#dispatch_push_args.into()),*]) }
1757                    }
1758                }
1759            }
1760        }
1761    };
1762    let tokens = quote! {
1763        #host_tokens
1764        #device_tokens
1765        #[cfg(all(target_arch = "spirv", not(krnlc)))]
1766        compile_error!("kernel cannot be used without krnlc!");
1767    };
1768    Ok(tokens)
1769}
1770
1771#[doc(hidden)]
1772#[proc_macro]
1773pub fn __krnl_cache(input: TokenStream) -> TokenStream {
1774    match __krnl_cache_impl(input.into()) {
1775        Ok(tokens) => tokens,
1776        Err(err) => err.into_compile_error(),
1777    }
1778    .into()
1779}
1780
1781#[derive(Parse)]
1782struct KrnlCacheInput {
1783    version: LitStr,
1784    __comma1: Comma,
1785    module: Ident,
1786    _comma2: Comma,
1787    kernel: Ident,
1788    _comma3: Comma,
1789    data: LitStr,
1790}
1791
1792fn __krnl_cache_impl(input: TokenStream2) -> Result<TokenStream2> {
1793    use flate2::{
1794        read::{GzDecoder, GzEncoder},
1795        Compression,
1796    };
1797    use std::io::Read;
1798    use syn::LitByteStr;
1799    use zero85::FromZ85;
1800
1801    static CACHE: OnceLock<std::result::Result<KrnlcCache, String>> = OnceLock::new();
1802
1803    let input = syn::parse2::<KrnlCacheInput>(input)?;
1804    let span = input.module.span();
1805    let cache = CACHE
1806        .get_or_init(|| {
1807            let version = env!("CARGO_PKG_VERSION");
1808            let krnlc_version = input.version.value();
1809            if !krnlc_version_compatible(&krnlc_version, version) {
1810                return Err(format!(
1811                    "Cache created by krnlc {krnlc_version} is not compatible with krnl {version}!"
1812                ));
1813            }
1814            let data = input.data.value();
1815            let decoded_len = data.split_ascii_whitespace().map(|x| x.len() * 4 / 5).sum();
1816            let mut bytes = Vec::with_capacity(decoded_len);
1817            for data in data.split_ascii_whitespace() {
1818                let decoded = data.from_z85().map_err(|e| e.to_string())?;
1819                bytes.extend_from_slice(&decoded);
1820            }
1821            let cache =
1822                bincode2::deserialize_from::<_, KrnlcCache>(GzDecoder::new(bytes.as_slice()))
1823                    .map_err(|e| e.to_string())?;
1824            assert_eq!(krnlc_version, cache.version);
1825            Ok(cache)
1826        })
1827        .as_ref()
1828        .map_err(|e| Error::new(input.version.span(), e))?;
1829    let kernels = cache
1830        .kernels
1831        .iter()
1832        .filter(|kernel| {
1833            let name = &kernel.name;
1834            let mut iter = name.rsplit("::");
1835            if input.kernel != iter.next().unwrap() {
1836                return false;
1837            }
1838            iter.any(|x| input.module == x)
1839        })
1840        .map(|kernel| {
1841            let KernelDesc {
1842                name,
1843                spirv,
1844                safe,
1845                features,
1846                spec_descs,
1847                slice_descs,
1848                push_descs,
1849            } = kernel;
1850            let mut bytes = Vec::new();
1851            GzEncoder::new(bytemuck::cast_slice(spirv), Compression::best())
1852                .read_to_end(&mut bytes)
1853                .unwrap();
1854            let spirv = LitByteStr::new(&bytes, span);
1855            quote! {
1856                KernelDesc::from_args(KernelDescArgs {
1857                    name: #name,
1858                    spirv: #spirv,
1859                    features: #features,
1860                    safe: #safe,
1861                    spec_descs: &[#(#spec_descs),*],
1862                    slice_descs: &[#(#slice_descs),*],
1863                    push_descs: &[#(#push_descs),*],
1864                })
1865            }
1866        });
1867    let tokens = quote! {
1868        {
1869            __krnl_module_arg!(use crate as __krnl);
1870            use __krnl::{
1871                device::Features,
1872                kernel::__private::{find_kernel, KernelDesc, KernelDescArgs, Safety, SpecDesc, SliceDesc, PushDesc},
1873            };
1874
1875            find_kernel(std::module_path!(), &[#(#kernels),*])
1876        }
1877    };
1878    Ok(tokens)
1879}
1880
1881#[derive(Deserialize)]
1882struct KrnlcCache {
1883    #[allow(unused)]
1884    version: String,
1885    kernels: Vec<KernelDesc>,
1886}
1887
1888fn krnlc_version_compatible(krnlc_version: &str, version: &str) -> bool {
1889    let krnlc_version = Version::parse(krnlc_version).unwrap();
1890    let version = Version::parse(version).unwrap();
1891    if !krnlc_version.pre.is_empty() || !version.pre.is_empty() {
1892        krnlc_version == version
1893    } else if version.major == 0 && version.minor == 0 {
1894        krnlc_version.major == 0 && krnlc_version.minor == 0 && krnlc_version.patch == version.patch
1895    } else if version.major == 0 {
1896        krnlc_version.major == 0 && krnlc_version.minor == version.minor
1897    } else {
1898        krnlc_version.major == version.major && krnlc_version.minor == version.minor
1899    }
1900}
1901
1902#[cfg(test)]
1903mod tests {
1904    use super::*;
1905
1906    #[test]
1907    fn krnlc_version_semver() {
1908        assert!(krnlc_version_compatible("0.0.1", "0.0.1"));
1909        assert!(!krnlc_version_compatible("0.0.1", "0.0.2"));
1910        assert!(!krnlc_version_compatible("0.0.2", "0.0.1"));
1911        assert!(!krnlc_version_compatible("0.0.2-alpha", "0.0.2"));
1912        assert!(!krnlc_version_compatible("0.0.2", "0.0.2-alpha"));
1913        assert!(!krnlc_version_compatible("0.0.2", "0.1.0"));
1914        assert!(!krnlc_version_compatible("0.1.1-alpha", "0.1.0"));
1915        assert!(!krnlc_version_compatible("0.1.1", "0.1.0-alpha"));
1916        assert!(krnlc_version_compatible("0.1.1", "0.1.0"));
1917        assert!(krnlc_version_compatible("0.1.0", "0.1.1"));
1918        assert!(krnlc_version_compatible("0.1.1-alpha", "0.1.1-alpha"));
1919        assert!(!krnlc_version_compatible("0.1.0-alpha", "0.1.1-alpha"));
1920        assert!(!krnlc_version_compatible("0.1.1", "0.2.0"));
1921    }
1922}