1#![forbid(unsafe_code)]
3
4use derive_syn_parse::Parse;
5use fxhash::FxHashMap;
6use proc_macro::TokenStream;
7use proc_macro2::{Literal, Span as Span2, TokenStream as TokenStream2};
8use quote::{format_ident, quote, ToTokens};
9use semver::Version;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::{
12 fmt::{self, Debug},
13 str::FromStr,
14 sync::OnceLock,
15};
16use syn::{
17 parse::{Parse, ParseStream},
18 parse_macro_input,
19 punctuated::Punctuated,
20 token::{
21 And, Brace, Bracket, Colon, Comma, Const, Eq as SynEq, Fn, Gt, Lt, Mod, Mut, Paren, Pound,
22 Unsafe,
23 },
24 Attribute, Block, Error, Ident, LitInt, LitStr, Visibility,
25};
26
27type Result<T, E = Error> = std::result::Result<T, E>;
28
29#[derive(Parse, Debug)]
30struct InsideBracket<T> {
31 #[allow(unused)]
32 #[bracket]
33 bracket: Bracket,
34 #[inside(bracket)]
35 value: T,
36}
37
38#[derive(Parse, Debug)]
39struct InsideBrace<T> {
40 #[brace]
41 brace: Brace,
42 #[inside(brace)]
43 value: T,
44}
45
46impl<T: ToTokens> ToTokens for InsideBrace<T> {
47 fn to_tokens(&self, tokens: &mut TokenStream2) {
48 self.brace
49 .surround(tokens, |tokens| self.value.to_tokens(tokens));
50 }
51}
52
53#[proc_macro_attribute]
54pub fn module(attr: TokenStream, item: TokenStream) -> TokenStream {
55 if !attr.is_empty() {
56 return Error::new_spanned(&TokenStream2::from(attr), "unexpected tokens")
57 .into_compile_error()
58 .into();
59 }
60 let mut item = parse_macro_input!(item as ModuleItem);
61 let mut build = true;
62 let mut krnl = quote! { ::krnl };
63 let new_attr = Vec::with_capacity(item.attr.len());
64 for attr in std::mem::replace(&mut item.attr, new_attr) {
65 if attr.path.segments.len() == 1
66 && attr
67 .path
68 .segments
69 .first()
70 .map_or(false, |x| x.ident == "krnl")
71 {
72 let tokens = attr.tokens.clone().into();
73 let args = syn::parse_macro_input!(tokens as ModuleKrnlArgs);
74 for arg in args.args.iter() {
75 if let Some(krnl_crate) = arg.krnl_crate.as_ref() {
76 krnl = if krnl_crate.leading_colon.is_some()
77 || krnl_crate
78 .to_token_stream()
79 .to_string()
80 .starts_with("crate")
81 {
82 quote! {
83 #krnl_crate
84 }
85 } else {
86 quote! {
87 ::#krnl_crate
88 }
89 };
90 } else if let Some(ident) = &arg.ident {
91 if ident == "no_build" {
92 build = false;
93 } else {
94 return Error::new_spanned(
95 ident,
96 format!("unknown krnl arg `{ident}`, expected `crate` or `no_build`"),
97 )
98 .into_compile_error()
99 .into();
100 }
101 }
102 }
103 } else {
104 item.attr.push(attr);
105 }
106 }
107 {
108 let tokens = item.tokens;
109 item.tokens = quote! {
110 #[cfg(not(target_arch = "spirv"))]
111 #[doc(hidden)]
112 macro_rules! __krnl_module_arg {
113 (use crate as $i:ident) => {
114 use #krnl as $i;
115 };
116 }
117 #tokens
118 };
119 }
120 if build {
121 let source = item.tokens.to_string();
122 let ident = &item.ident;
123 let tokens = item.tokens;
124 item.tokens = quote! {
125 #[doc(hidden)]
126 mod __krnl_module_data {
127 #[allow(non_upper_case_globals)]
128 const __krnl_module_source: &'static str = #source;
129 }
130 #[cfg(not(krnlc))]
131 #[doc(hidden)]
132 macro_rules! __krnl_cache {
133 ($v:literal, $x:literal) => {
134 #[doc(hidden)]
135 macro_rules! __krnl_kernel {
136 ($k:ident) => {
137 Some(#krnl::macros::__krnl_cache!($v, #ident, $k, $x))
138 };
139 }
140 };
141 }
142 #[cfg(not(krnlc))]
143 include!(concat!(env!("CARGO_MANIFEST_DIR"), "/krnl-cache.rs"));
144 #[doc(hidden)]
145 #[cfg(krnlc)]
146 macro_rules! __krnl_kernel {
147 ($k:ident) => {
148 None
149 };
150 }
151 #tokens
152 };
153 } else {
154 let tokens = item.tokens;
155 item.tokens = quote! {
156 #[doc(hidden)]
157 macro_rules! __krnl_kernel {
158 ($k:ident) => {
159 None
160 };
161 }
162 #tokens
163 }
164 }
165 item.into_token_stream().into()
166}
167
168#[derive(Parse, Debug)]
169struct ModuleKrnlArgs {
170 #[allow(unused)]
171 #[paren]
172 paren: Paren,
173 #[inside(paren)]
174 #[call(Punctuated::parse_terminated)]
175 args: Punctuated<ModuleKrnlArg, Comma>,
176}
177
178#[derive(Parse, Debug)]
179struct ModuleKrnlArg {
180 #[allow(unused)]
181 crate_token: Option<syn::token::Crate>,
182 #[allow(unused)]
183 #[parse_if(crate_token.is_some())]
184 eq: Option<SynEq>,
185 #[parse_if(crate_token.is_some())]
186 krnl_crate: Option<syn::Path>,
187 #[parse_if(crate_token.is_none())]
188 ident: Option<Ident>,
189}
190
191#[derive(Parse, Debug)]
192struct ModuleItem {
193 #[call(Attribute::parse_outer)]
194 attr: Vec<Attribute>,
195 vis: Visibility,
196 mod_token: Mod,
197 ident: Ident,
198 #[brace]
199 brace: Brace,
200 #[inside(brace)]
201 tokens: TokenStream2,
202}
203
204impl ToTokens for ModuleItem {
205 fn to_tokens(&self, tokens: &mut TokenStream2) {
206 for attr in self.attr.iter() {
207 attr.to_tokens(tokens);
208 }
209 self.vis.to_tokens(tokens);
210 self.mod_token.to_tokens(tokens);
211 self.ident.to_tokens(tokens);
212 self.brace
213 .surround(tokens, |tokens| self.tokens.to_tokens(tokens));
214 }
215}
216
217#[proc_macro_attribute]
218pub fn kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
219 if !attr.is_empty() {
220 return Error::new_spanned(&TokenStream2::from(attr), "unexpected tokens")
221 .into_compile_error()
222 .into();
223 }
224 match kernel_impl(item.into()) {
225 Ok(tokens) => tokens.into(),
226 Err(e) => e.into_compile_error().into(),
227 }
228}
229
230#[derive(Parse, Debug)]
231struct KernelItem {
232 #[call(Attribute::parse_outer)]
233 attrs: Vec<Attribute>,
234 #[allow(unused)]
235 vis: Visibility,
236 unsafe_token: Option<Unsafe>,
237 #[allow(unused)]
238 fn_token: Fn,
239 ident: Ident,
240 #[peek(Lt)]
241 generics: Option<KernelGenerics>,
242 #[allow(unused)]
243 #[paren]
244 paren: Paren,
245 #[inside(paren)]
246 #[call(Punctuated::parse_terminated)]
247 args: Punctuated<KernelArg, Comma>,
248 block: Block,
249}
250
251impl KernelItem {
252 fn meta(&self) -> Result<KernelMeta> {
253 let mut meta = KernelMeta {
254 spec_metas: Vec::new(),
255 unsafe_token: self.unsafe_token,
256 ident: self.ident.clone(),
257 arg_metas: Vec::with_capacity(self.args.len()),
258 block: self.block.clone(),
259 itemwise: false,
260 arrays: FxHashMap::default(),
261 };
262 let mut spec_id = 0;
263 if let Some(generics) = self.generics.as_ref() {
264 meta.spec_metas = generics
265 .specs
266 .iter()
267 .map(|x| {
268 let meta = KernelSpecMeta {
269 ident: x.ident.clone(),
270 ty: x.ty.clone(),
271 id: spec_id,
272 thread_dim: None,
273 };
274 spec_id += 1;
275 meta
276 })
277 .collect();
278 }
279 let mut binding = 0;
280 for arg in self.args.iter() {
281 let mut arg_meta = arg.meta()?;
282 if arg_meta.kind.is_global() || arg_meta.kind.is_item() {
283 arg_meta.binding.replace(binding);
284 binding += 1;
285 }
286 meta.itemwise |= arg_meta.kind.is_item();
287 if let Some(len) = arg_meta.len.as_ref() {
288 meta.arrays
289 .entry(arg_meta.scalar_ty.scalar_type)
290 .or_default()
291 .push((arg.ident.clone(), len.clone()));
292 }
293 meta.arg_metas.push(arg_meta);
294 }
295 Ok(meta)
296 }
297}
298
299#[derive(Debug)]
300struct KernelGenerics {
301 specs: Punctuated<KernelSpec, Comma>, }
308
309impl Parse for KernelGenerics {
310 fn parse(input: ParseStream) -> Result<Self> {
311 input.parse::<Lt>()?;
312 let mut specs = Punctuated::new();
313 while input.peek(Const) {
314 specs.push(input.parse()?);
315 if input.peek(Comma) {
316 input.parse::<Comma>()?;
317 } else {
318 break;
319 }
320 }
321 input.parse::<Gt>()?;
322 Ok(Self { specs })
323 }
324}
325
326#[derive(Parse, Debug)]
327struct KernelSpec {
328 #[allow(unused)]
329 const_token: Const,
330 ident: Ident,
331 #[allow(unused)]
332 colon: Colon,
333 ty: KernelTypeScalar,
334}
335
336#[derive(Debug)]
337struct KernelSpecMeta {
338 ident: Ident,
339 ty: KernelTypeScalar,
340 id: u32,
341 thread_dim: Option<usize>,
342}
343
344impl KernelSpecMeta {
345 fn declare(&self) -> TokenStream2 {
346 use ScalarType::*;
347 let scalar_type = self.ty.scalar_type;
348 let bits = scalar_type.size() * 8;
349 let signed = matches!(scalar_type, I8 | I16 | I32 | I64) as u32;
350 let float = matches!(scalar_type, F32 | F64);
351 let ty_string = if float {
352 format!("%ty = OpTypeFloat {bits}")
353 } else {
354 format!("%ty = OpTypeInt {bits} {signed}")
355 };
356 let spec_id_string = format!("OpDecorate %spec SpecId {}", self.id);
357 let ident = &self.ident;
358 quote! {
359 #[allow(non_snake_case)]
360 let #ident = unsafe {
361 let mut spec = Default::default();
362 ::core::arch::asm! {
363 #ty_string,
364 "%spec = OpSpecConstant %ty 0",
365 #spec_id_string,
366 "OpStore {spec} %spec",
367 spec = in(reg) &mut spec,
368 }
369 spec
370 };
371 }
372 }
373}
374
375#[derive(Clone, Debug)]
376struct KernelTypeScalar {
377 ident: Ident,
378 scalar_type: ScalarType,
379}
380
381impl Parse for KernelTypeScalar {
382 fn parse(input: ParseStream<'_>) -> Result<Self> {
383 let ident = input.parse()?;
384 if let Some(scalar_type) = ScalarType::iter().find(|x| ident == x.name()) {
385 Ok(Self { ident, scalar_type })
386 } else {
387 Err(Error::new(ident.span(), "expected scalar"))
388 }
389 }
390}
391
392#[derive(Parse, Debug)]
393struct KernelArg {
394 kind: KernelArgKind,
395 ident: Ident,
396 #[allow(unused)]
397 colon: Colon,
398 #[parse_if(kind.is_global())]
399 slice_ty: Option<KernelTypeSlice>,
400 #[parse_if(kind.is_item())]
401 item_ty: Option<KernelTypeItem>,
402 #[parse_if(kind.is_group())]
403 array_ty: Option<KernelTypeArray>,
404 #[parse_if(kind.is_push())]
405 push_ty: Option<KernelTypeScalar>,
406}
407
408impl KernelArg {
409 fn meta(&self) -> Result<KernelArgMeta> {
410 let kind = self.kind;
411 let (scalar_ty, mutable, len) = if let Some(slice_ty) = self.slice_ty.as_ref() {
412 let slice_ty_ident = &slice_ty.ty;
413 let mutable = if slice_ty.ty == "Slice" {
414 false
415 } else if slice_ty.ty == "UnsafeSlice" {
416 true
417 } else if slice_ty.ty == "SliceMut" {
418 return Err(Error::new_spanned(slice_ty_ident, "try `UnsafeSlice`"));
419 } else {
420 return Err(Error::new_spanned(
421 slice_ty_ident,
422 "expected `Slice` or `UnsafeSlice`",
423 ));
424 };
425 (slice_ty.scalar_ty.clone(), mutable, None)
426 } else if let Some(array_ty) = self.array_ty.as_ref() {
427 let len = array_ty.len.to_token_stream();
428 (array_ty.scalar_ty.clone(), true, Some(len))
429 } else if let Some(item_ty) = self.item_ty.as_ref() {
430 (item_ty.scalar_ty.clone(), item_ty.mut_token.is_some(), None)
431 } else if let Some(push_ty) = self.push_ty.as_ref() {
432 (push_ty.clone(), false, None)
433 } else {
434 unreachable!("KernelArg::meta expected type!")
435 };
436 let meta = KernelArgMeta {
437 kind,
438 ident: self.ident.clone(),
439 scalar_ty,
440 mutable,
441 binding: None,
442 len,
443 };
444 Ok(meta)
445 }
446}
447
448#[derive(Debug)]
449struct KernelArgMeta {
450 kind: KernelArgKind,
451 ident: Ident,
452 scalar_ty: KernelTypeScalar,
453 mutable: bool,
454 binding: Option<u32>,
455 len: Option<TokenStream2>,
456}
457
458impl KernelArgMeta {
459 fn compute_def_tokens(&self) -> Option<TokenStream2> {
460 let ident = &self.ident;
461 let ty = &self.scalar_ty.ident;
462 if let Some(binding) = self.binding.as_ref() {
463 let set = LitInt::new("0", Span2::call_site());
464 let binding = LitInt::new(&binding.to_string(), Span2::call_site());
465 let mut_token = if self.mutable {
466 Some(Mut::default())
467 } else {
468 None
469 };
470 Some(quote! {
471 #[spirv(storage_buffer, descriptor_set = #set, binding = #binding)] #ident: &#mut_token [#ty; 1]
472 })
473 } else {
474 None
475 }
476 }
477 fn device_fn_def_tokens(&self) -> TokenStream2 {
478 let ident = &self.ident;
479 let ty = &self.scalar_ty.ident;
480 let mutable = self.mutable;
481 use KernelArgKind::*;
482 match self.kind {
483 Global => {
484 if mutable {
485 quote! {
486 #ident: ::krnl_core::buffer::UnsafeSlice<#ty>
487 }
488 } else {
489 quote! {
490 #ident: ::krnl_core::buffer::Slice<#ty>
491 }
492 }
493 }
494 Item => {
495 if mutable {
496 quote! {
497 #ident: &mut #ty
498 }
499 } else {
500 quote! {
501 #ident: #ty
502 }
503 }
504 }
505 Group => quote! {
506 #ident: ::krnl_core::buffer::UnsafeSlice<#ty>
507 },
508 Push => quote! {
509 #ident: #ty
510 },
511 }
512 }
513 fn device_slices(&self) -> TokenStream2 {
514 let ident = &self.ident;
515 let mutable = self.mutable;
516 use KernelArgKind::*;
517 match self.kind {
518 Global | Item => {
519 let offset = format_ident!("__krnl_offset_{ident}");
520 let len = format_ident!("__krnl_len_{ident}");
521 let slice_fn = if mutable {
522 quote! {
523 ::krnl_core::buffer::UnsafeSlice::from_unsafe_raw_parts
524 }
525 } else {
526 quote! {
527 ::krnl_core::buffer::Slice::from_raw_parts
528 }
529 };
530 quote! {
531 let #ident = unsafe {
532 #slice_fn(#ident, __krnl_push_consts.#offset as usize, __krnl_push_consts.#len as usize)
533 };
534 }
535 }
536 Group => {
537 let offset = format_ident!("__krnl_offset_{ident}");
538 let len = format_ident!("__krnl_len_{ident}");
539 let scalar_name = self.scalar_ty.scalar_type.name();
540 let array = format_ident!("__krnl_group_array_{scalar_name}");
541 quote! {
542 let #ident = {
543 unsafe {
544 ::krnl_core::buffer::UnsafeSlice::from_unsafe_raw_parts(#array, #offset, #len)
545 }
546 };
547 }
548 }
549 Push => TokenStream2::new(),
550 }
551 }
552 fn device_fn_call_tokens(&self) -> TokenStream2 {
553 let ident = &self.ident;
554 let mutable = self.mutable;
555 use KernelArgKind::*;
556 match self.kind {
557 Global | Group => quote! {
558 #ident
559 },
560 Item => {
561 if mutable {
562 quote! {
563 unsafe {
564 use ::krnl_core::buffer::UnsafeIndex;
565 #ident.unsafe_index_mut(__krnl_item_id as usize)
566 }
567 }
568 } else {
569 quote! {
570 #ident[__krnl_item_id as usize]
571 }
572 }
573 }
574 Push => quote! {
575 __krnl_push_consts.#ident
576 },
577 }
578 }
579}
580
581#[derive(Parse, Debug)]
582struct KernelArgAttr {
583 #[allow(unused)]
584 pound: Option<Pound>,
585 #[parse_if(pound.is_some())]
586 ident: Option<InsideBracket<Ident>>,
587}
588
589impl KernelArgAttr {
590 fn kind(&self) -> Result<KernelArgKind> {
591 use KernelArgKind::*;
592 let ident = if let Some(ident) = self.ident.as_ref() {
593 &ident.value
594 } else {
595 return Ok(Push);
596 };
597 let kind = if ident == "global" {
598 Global
599 } else if ident == "item" {
600 Item
601 } else if ident == "group" {
602 Group
603 } else {
604 return Err(Error::new_spanned(
605 ident,
606 "expected `global`, `item`, or `group`",
607 ));
608 };
609 Ok(kind)
610 }
611}
612
613#[derive(Clone, Copy, derive_more::IsVariant, PartialEq, Eq, Hash, Debug)]
614enum KernelArgKind {
615 Global,
616 Item,
617 Group,
618 Push,
619}
620
621impl Parse for KernelArgKind {
622 fn parse(input: ParseStream) -> Result<Self> {
623 KernelArgAttr::parse(input)?.kind()
624 }
625}
626
627#[derive(Parse, Debug)]
628struct KernelTypeItem {
629 #[allow(unused)]
630 and: Option<And>,
631 #[parse_if(and.is_some())]
632 mut_token: Option<Mut>,
633 scalar_ty: KernelTypeScalar,
634}
635
636#[derive(Parse, Debug)]
637struct KernelTypeSlice {
638 ty: Ident,
639 #[allow(unused)]
640 lt: Lt,
641 scalar_ty: KernelTypeScalar,
642 #[allow(unused)]
643 gt: Gt,
644}
645
646#[derive(Parse, Debug)]
647struct KernelTypeArray {
648 #[allow(unused)]
649 ty: Ident,
650 #[allow(unused)]
651 lt: Lt,
652 scalar_ty: KernelTypeScalar,
653 #[allow(unused)]
654 comma: Comma,
655 len: KernelArrayLength,
656 #[allow(unused)]
657 gt: Gt,
658}
659
660#[derive(Debug)]
661struct KernelArrayLength {
662 block: Option<Block>,
663 ident: Option<Ident>,
664 lit: Option<LitInt>,
665}
666
667impl Parse for KernelArrayLength {
668 fn parse(input: &syn::parse::ParseBuffer) -> Result<Self> {
669 if input.peek(Brace) {
670 Ok(Self {
671 block: Some(input.parse()?),
672 ident: None,
673 lit: None,
674 })
675 } else if input.peek(Ident) {
676 Ok(Self {
677 block: None,
678 ident: Some(input.parse()?),
679 lit: None,
680 })
681 } else {
682 Ok(Self {
683 block: None,
684 ident: None,
685 lit: Some(input.parse()?),
686 })
687 }
688 }
689}
690
691impl ToTokens for KernelArrayLength {
692 fn to_tokens(&self, tokens: &mut TokenStream2) {
693 if let Some(block) = self.block.as_ref() {
694 for stmt in block.stmts.iter() {
695 stmt.to_tokens(tokens);
696 }
697 } else if let Some(ident) = self.ident.as_ref() {
698 ident.to_tokens(tokens);
699 } else if let Some(lit) = self.lit.as_ref() {
700 lit.to_tokens(tokens);
701 }
702 }
703}
704
705#[derive(Debug)]
706struct KernelMeta {
707 spec_metas: Vec<KernelSpecMeta>,
708 ident: Ident,
709 unsafe_token: Option<Unsafe>,
710 arg_metas: Vec<KernelArgMeta>,
711 itemwise: bool,
712 block: Block,
713 arrays: FxHashMap<ScalarType, Vec<(Ident, TokenStream2)>>,
714}
715
716impl KernelMeta {
717 fn desc(&self) -> Result<KernelDesc> {
718 let mut kernel_desc = KernelDesc {
719 name: self.ident.to_string(),
720 safe: self.unsafe_token.is_none(),
721 ..KernelDesc::default()
722 };
723 for spec in self.spec_metas.iter() {
724 kernel_desc.spec_descs.push(SpecDesc {
725 name: spec.ident.to_string(),
726 scalar_type: spec.ty.scalar_type,
727 })
728 }
729 for arg_meta in self.arg_metas.iter() {
730 let kind = arg_meta.kind;
731 let scalar_type = arg_meta.scalar_ty.scalar_type;
732 use KernelArgKind::*;
733 match kind {
734 Global | Item => {
735 kernel_desc.slice_descs.push(SliceDesc {
736 name: arg_meta.ident.to_string(),
737 scalar_type,
738 mutable: arg_meta.mutable,
739 item: kind.is_item(),
740 });
741 }
742 Group => (),
743 Push => {
744 kernel_desc.push_descs.push(PushDesc {
745 name: arg_meta.ident.to_string(),
746 scalar_type,
747 });
748 }
749 }
750 }
751 kernel_desc
752 .push_descs
753 .sort_by_key(|x| -(x.scalar_type.size() as i32));
754 Ok(kernel_desc)
755 }
756 fn compute_def_args(&self) -> Punctuated<TokenStream2, Comma> {
757 let mut id = 1;
758 let arrays = self.arrays.keys().map(|scalar_type| {
759 let scalar_name = scalar_type.name();
760 let ident = format_ident!("__krnl_group_array_{scalar_name}_{id}");
761 let ty = format_ident!("{scalar_name}");
762 id += 1;
763 quote! {
764 #[spirv(workgroup)] #ident: &mut [#ty; 1]
765 }
766 });
767 self.arg_metas
768 .iter()
769 .filter_map(|arg| arg.compute_def_tokens())
770 .chain(arrays)
771 .collect()
772 }
773 fn declare_specs(&self) -> TokenStream2 {
793 self.spec_metas
794 .iter()
795 .flat_map(|spec| spec.declare())
796 .collect()
797 }
798 fn spec_def_args(&self) -> Punctuated<TokenStream2, Comma> {
799 self.spec_metas
800 .iter()
801 .map(|spec| {
802 let ident = &spec.ident;
803 let ty = &spec.ty.ident;
804 quote! {
805 #[allow(non_snake_case)]
806 #ident: #ty
807 }
808 })
809 .collect()
810 }
811 fn spec_args(&self) -> Vec<Ident> {
812 self.spec_metas
813 .iter()
814 .map(|spec| spec.ident.clone())
815 .collect()
816 }
817 fn device_arrays(&self) -> TokenStream2 {
818 let spec_def_args: Punctuated<_, Comma> = self
819 .spec_def_args()
820 .into_iter()
821 .map(|arg| {
822 quote! {
823 #[allow(unused)] #arg
824 }
825 })
826 .collect();
827 let spec_args: Punctuated<_, Comma> = self.spec_args().into_iter().collect();
828 let group_barrier = if self.arg_metas.iter().any(|arg| arg.kind.is_group()) {
829 quote! {
830 unsafe {
831 ::krnl_core::spirv_std::arch::workgroup_memory_barrier();
832 }
833 }
834 } else {
835 TokenStream2::new()
836 };
837 let mut id = 1;
838 self.arrays
839 .iter()
840 .flat_map(|(scalar_type, arrays)| {
841 let scalar_name = scalar_type.name();
842 let ident = format_ident!("__krnl_group_array_{scalar_name}");
843 let ident_with_id = format_ident!("{ident}_{id}");
844 let id_lit = LitInt::new(&id.to_string(), Span2::call_site());
845 id += 1;
846 let len = format_ident!("{ident}_len");
847 let offset = format_ident!("{ident}_offset");
848 let array_offsets_lens: TokenStream2 = arrays
849 .iter()
850 .map(|(array, len_expr)| {
851 let array_offset = format_ident!("__krnl_offset_{array}");
852 let array_len = format_ident!("__krnl_len_{array}");
853 quote! {
854 let #array_offset = #offset;
855 let #array_len = {
856 const fn #array_len(#spec_def_args) -> usize {
857 #len_expr
858 }
859 #array_len(#spec_args)
860 };
861 #offset += #array_len;
862 }
863 })
864 .collect();
865 quote! {
866 let #ident = #ident_with_id;
867 let mut #offset = 0usize;
868 #array_offsets_lens
869 let #len = #offset;
870 unsafe {
871 ::krnl_core::kernel::__private::group_buffer_len(__krnl_kernel_data, #id_lit, #len);
872 ::krnl_core::kernel::__private::zero_group_buffer(&kernel, #ident, #len);
873 }
874 }
875 })
876 .chain(group_barrier)
877 .collect()
878 }
879 fn host_array_length_checks(&self) -> TokenStream2 {
880 let mut spec_def_args = self.spec_def_args();
881 for arg in spec_def_args.iter_mut() {
882 *arg = quote! {
883 #[allow(unused_variables, non_snake_case)]
884 #arg
885 };
886 }
887 self.arg_metas
888 .iter()
889 .flat_map(|arg| {
890 if let Some(len) = arg.len.as_ref() {
891 quote! {
892 const _: () = {
893 #[allow(non_snake_case, clippy::too_many_arguments)]
894 const fn __krnl_array_len(#spec_def_args) -> usize {
895 #len
896 }
897 let _ = __krnl_array_len;
898 };
899 }
900 } else {
901 TokenStream2::new()
902 }
903 })
904 .collect()
905 }
906 fn device_slices(&self) -> TokenStream2 {
907 self.arg_metas
908 .iter()
909 .map(|arg| arg.device_slices())
910 .collect()
911 }
912 fn device_items(&self) -> TokenStream2 {
913 let mut items = self
914 .arg_metas
915 .iter()
916 .filter(|arg| arg.kind.is_item())
917 .map(|arg| &arg.ident);
918 if let Some(first) = items.next() {
919 quote! {
920 #first.len()
921 }
922 .into_iter()
923 .chain(items.flat_map(|item| {
924 quote! {
925 .max(#item.len())
926 }
927 }))
928 .collect()
929 } else {
930 quote! {
931 0
932 }
933 }
934 }
935 fn device_fn_def_args(&self) -> Punctuated<TokenStream2, Comma> {
936 self.spec_metas
937 .iter()
938 .map(|x| {
939 let ident = &x.ident;
940 let ty = &x.ty.ident;
941 let allow_unused = x.thread_dim.map(|_| {
942 quote! {
943 #[allow(unused)]
944 }
945 });
946 quote! {
947 #allow_unused
948 #[allow(non_snake_case)]
949 #ident: #ty
950 }
951 })
952 .chain(self.arg_metas.iter().map(|arg| arg.device_fn_def_tokens()))
953 .collect()
954 }
955 fn device_fn_call_args(&self) -> Punctuated<TokenStream2, Comma> {
956 self.spec_metas
957 .iter()
958 .map(|spec| spec.ident.to_token_stream())
959 .chain(self.arg_metas.iter().map(|arg| arg.device_fn_call_tokens()))
960 .collect()
961 }
962 fn dispatch_args(&self) -> TokenStream2 {
963 let mut tokens = TokenStream2::new();
964 for arg in self.arg_metas.iter() {
965 let ident = &arg.ident;
966 let ty = &arg.scalar_ty.ident;
967 if arg.binding.is_some() {
968 let slice_ty = if arg.mutable {
969 format_ident!("SliceMut")
970 } else {
971 format_ident!("Slice")
972 };
973 tokens.extend(quote! {
974 #ident: #slice_ty<#ty>,
975 });
976 } else if arg.kind.is_push() {
977 tokens.extend(quote! {
978 #ident: #ty,
979 });
980 }
981 }
982 tokens
983 }
984 fn dispatch_slice_args(&self) -> TokenStream2 {
985 let mut tokens = TokenStream2::new();
986 for arg in self.arg_metas.iter() {
987 let ident = &arg.ident;
988 if arg.binding.is_some() {
989 tokens.extend(quote! {
990 #ident.into(),
991 });
992 }
993 }
994 tokens
995 }
996}
997
998#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
999enum ScalarType {
1000 U8,
1001 I8,
1002 U16,
1003 I16,
1004 F16,
1005 BF16,
1006 U32,
1007 I32,
1008 F32,
1009 U64,
1010 I64,
1011 F64,
1012}
1013
1014impl ScalarType {
1015 fn iter() -> impl Iterator<Item = Self> {
1016 use ScalarType::*;
1017 [U8, I8, U16, I16, F16, BF16, U32, I32, F32, U64, I64, F64].into_iter()
1018 }
1019 fn name(&self) -> &'static str {
1020 use ScalarType::*;
1021 match self {
1022 U8 => "u8",
1023 I8 => "i8",
1024 U16 => "u16",
1025 I16 => "i16",
1026 F16 => "f16",
1027 BF16 => "bf16",
1028 U32 => "u32",
1029 I32 => "i32",
1030 F32 => "f32",
1031 U64 => "u64",
1032 I64 => "i64",
1033 F64 => "f64",
1034 }
1035 }
1036 fn as_str(&self) -> &'static str {
1037 use ScalarType::*;
1038 match self {
1039 U8 => "U8",
1040 I8 => "I8",
1041 U16 => "U16",
1042 I16 => "I16",
1043 F16 => "F16",
1044 BF16 => "BF16",
1045 U32 => "U32",
1046 I32 => "I32",
1047 F32 => "F32",
1048 U64 => "U64",
1049 I64 => "I64",
1050 F64 => "F64",
1051 }
1052 }
1053 fn size(&self) -> usize {
1054 use ScalarType::*;
1055 match self {
1056 U8 | I8 => 1,
1057 U16 | I16 | F16 | BF16 => 2,
1058 U32 | I32 | F32 => 4,
1059 U64 | I64 | F64 => 8,
1060 }
1061 }
1062}
1063
1064impl ToTokens for ScalarType {
1065 fn to_tokens(&self, tokens: &mut TokenStream2) {
1066 let ident = format_ident!("{self:?}");
1067 tokens.extend(quote! {
1068 ScalarType::#ident
1069 });
1070 }
1071}
1072
1073impl FromStr for ScalarType {
1074 type Err = ();
1075 fn from_str(input: &str) -> Result<Self, ()> {
1076 Self::iter()
1077 .find(|x| x.as_str() == input || x.name() == input)
1078 .ok_or(())
1079 }
1080}
1081
1082impl Serialize for ScalarType {
1083 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1084 where
1085 S: Serializer,
1086 {
1087 serializer.serialize_str(self.as_str())
1088 }
1089}
1090
1091impl<'de> Deserialize<'de> for ScalarType {
1092 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1093 where
1094 D: Deserializer<'de>,
1095 {
1096 use serde::de::Visitor;
1097
1098 struct ScalarTypeVisitor;
1099
1100 impl Visitor<'_> for ScalarTypeVisitor {
1101 type Value = ScalarType;
1102 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1103 write!(formatter, "a scalar type")
1104 }
1105 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
1106 where
1107 E: serde::de::Error,
1108 {
1109 if let Ok(scalar_type) = ScalarType::from_str(v) {
1110 Ok(scalar_type)
1111 } else {
1112 Err(E::custom(format!("unknown ScalarType {v}")))
1113 }
1114 }
1115 }
1116 deserializer.deserialize_str(ScalarTypeVisitor)
1117 }
1118}
1119
1120#[derive(Default, Serialize, Deserialize, Debug)]
1121struct KernelDesc {
1122 name: String,
1123 #[serde(skip_serializing)]
1124 spirv: Vec<u32>,
1125 #[serde(skip_serializing)]
1126 features: Features,
1127 safe: bool,
1128 spec_descs: Vec<SpecDesc>,
1129 slice_descs: Vec<SliceDesc>,
1130 push_descs: Vec<PushDesc>,
1131}
1132
1133impl KernelDesc {
1134 fn encode(&self) -> Result<String> {
1135 let bytes = bincode2::serialize(self).map_err(|e| Error::new(Span2::call_site(), e))?;
1136 Ok(format!("__krnl_kernel_data_{}", hex::encode(bytes)))
1137 }
1138 fn push_const_fields(&self) -> Punctuated<TokenStream2, Comma> {
1139 let mut fields = Punctuated::new();
1140 let mut size = 0;
1141 for push_desc in self.push_descs.iter() {
1142 let ident = format_ident!("{}", push_desc.name);
1143 let ty = format_ident!("{}", push_desc.scalar_type.name());
1144 fields.push(quote! {
1145 #ident: #ty
1146 });
1147 size += push_desc.scalar_type.size();
1148 }
1149 for i in 0..4 {
1150 if size % 4 == 0 {
1151 break;
1152 }
1153 let ident = format_ident!("__krnl_pad{i}");
1154 fields.push(quote! {
1155 #ident: u8
1156 });
1157 size += 1;
1158 }
1159 for slice_desc in self.slice_descs.iter() {
1160 let offset_ident = format_ident!("__krnl_offset_{}", slice_desc.name);
1161 let len_ident = format_ident!("__krnl_len_{}", slice_desc.name);
1162 fields.push(quote! {
1163 #offset_ident: u32
1164 });
1165 fields.push(quote! {
1166 #len_ident: u32
1167 });
1168 }
1169 fields
1170 }
1171 fn dispatch_push_args(&self) -> Vec<Ident> {
1172 self.push_descs
1173 .iter()
1174 .map(|push| format_ident!("{}", push.name))
1175 .collect()
1176 }
1177}
1178
1179#[derive(Default, Clone, Copy, PartialEq, Eq, Deserialize)]
1180#[serde(transparent)]
1181struct Features {
1182 bits: u32,
1183}
1184
1185impl Features {
1186 pub const INT8: Self = Self::new(1);
1187 pub const INT16: Self = Self::new(1 << 1);
1188 pub const INT64: Self = Self::new(1 << 2);
1189 pub const FLOAT16: Self = Self::new(1 << 3);
1190 pub const FLOAT64: Self = Self::new(1 << 4);
1191 pub const BUFFER8: Self = Self::new(1 << 8);
1192 pub const BUFFER16: Self = Self::new(1 << 9);
1193 pub const PUSH_CONSTANT8: Self = Self::new(1 << 10);
1194 pub const PUSH_CONSTANT16: Self = Self::new(1 << 11);
1195 pub const SUBGROUP_BASIC: Self = Self::new(1 << 16);
1196 pub const SUBGROUP_VOTE: Self = Self::new(1 << 17);
1197 pub const SUBGROUP_ARITHMETIC: Self = Self::new(1 << 18);
1198 pub const SUBGROUP_BALLOT: Self = Self::new(1 << 19);
1199 pub const SUBGROUP_SHUFFLE: Self = Self::new(1 << 20);
1200 pub const SUBGROUP_SHUFFLE_RELATIVE: Self = Self::new(1 << 21);
1201 pub const SUBGROUP_CLUSTERED: Self = Self::new(1 << 22);
1202 pub const SUBGROUP_QUAD: Self = Self::new(1 << 23);
1203
1204 #[inline]
1205 const fn new(bits: u32) -> Self {
1206 Self { bits }
1207 }
1208 #[inline]
1234 pub const fn contains(self, other: Self) -> bool {
1235 (self.bits | other.bits) == self.bits
1236 }
1237 fn name_iter(self) -> impl Iterator<Item = &'static str> {
1244 macro_rules! features {
1245 ($($f:ident),*) => {
1246 [
1247 $(
1248 (stringify!($f), Self::$f)
1249 ),*
1250 ]
1251 };
1252 }
1253
1254 features!(
1255 INT8,
1256 INT16,
1257 INT64,
1258 FLOAT16,
1259 FLOAT64,
1260 BUFFER8,
1261 BUFFER16,
1262 PUSH_CONSTANT8,
1263 PUSH_CONSTANT16,
1264 SUBGROUP_BASIC,
1265 SUBGROUP_VOTE,
1266 SUBGROUP_ARITHMETIC,
1267 SUBGROUP_BALLOT,
1268 SUBGROUP_SHUFFLE,
1269 SUBGROUP_SHUFFLE_RELATIVE,
1270 SUBGROUP_CLUSTERED,
1271 SUBGROUP_QUAD
1272 )
1273 .into_iter()
1274 .filter_map(move |(name, features)| {
1275 if self.contains(features) {
1276 Some(name)
1277 } else {
1278 None
1279 }
1280 })
1281 }
1282}
1283
1284impl Debug for Features {
1285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
1286 struct FeaturesStr<'a>(&'a str);
1287
1288 impl Debug for FeaturesStr<'_> {
1289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1290 f.debug_struct(self.0).finish()
1291 }
1292 }
1293
1294 let alternate = f.alternate();
1295 let mut b = f.debug_tuple("Features");
1296 if alternate {
1297 for name in self.name_iter() {
1298 b.field(&FeaturesStr(name));
1299 }
1300 } else {
1301 b.field(&FeaturesStr(&itertools::join(self.name_iter(), "|")));
1302 }
1303 b.finish()
1304 }
1305}
1306
1307impl ToTokens for Features {
1308 fn to_tokens(&self, tokens: &mut TokenStream2) {
1309 let features = self
1310 .name_iter()
1311 .map(|name| Ident::new(name, Span2::call_site()));
1312 tokens.extend(quote! {
1313 Features::empty()
1314 #(.union(Features::#features))*
1315 });
1316 }
1317}
1318
1319#[derive(Serialize, Deserialize, Debug)]
1320struct SpecDesc {
1321 name: String,
1322 scalar_type: ScalarType,
1323}
1324
1325impl ToTokens for SpecDesc {
1326 fn to_tokens(&self, tokens: &mut TokenStream2) {
1327 let Self { name, scalar_type } = self;
1328 tokens.extend(quote! {
1329 SpecDesc {
1330 name: #name,
1331 scalar_type: #scalar_type,
1332 }
1333 });
1334 }
1335}
1336
1337#[derive(Serialize, Deserialize, Debug)]
1338struct SliceDesc {
1339 name: String,
1340 scalar_type: ScalarType,
1341 mutable: bool,
1342 item: bool,
1343}
1344
1345impl ToTokens for SliceDesc {
1346 fn to_tokens(&self, tokens: &mut TokenStream2) {
1347 let Self {
1348 name,
1349 scalar_type,
1350 mutable,
1351 item,
1352 } = self;
1353 tokens.extend(quote! {
1354 SliceDesc {
1355 name: #name,
1356 scalar_type: #scalar_type,
1357 mutable: #mutable,
1358 item: #item,
1359 }
1360 })
1361 }
1362}
1363
1364#[derive(Serialize, Deserialize, Debug)]
1365struct PushDesc {
1366 name: String,
1367 scalar_type: ScalarType,
1368}
1369
1370impl ToTokens for PushDesc {
1371 fn to_tokens(&self, tokens: &mut TokenStream2) {
1372 let Self { name, scalar_type } = self;
1373 tokens.extend(quote! {
1374 PushDesc {
1375 name: #name,
1376 scalar_type: #scalar_type,
1377 }
1378 })
1379 }
1380}
1381
1382fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
1383 let item: KernelItem = syn::parse2(item_tokens.clone())?;
1384 let kernel_meta = item.meta()?;
1385 let kernel_desc = kernel_meta.desc()?;
1386 let item_attrs = &item.attrs;
1387 let unsafe_token = kernel_meta.unsafe_token;
1388 let ident = &kernel_meta.ident;
1389 let device_tokens = {
1390 let kernel_data = format_ident!("{}", kernel_desc.encode()?);
1391 let block = &kernel_meta.block;
1392 let compute_def_args = kernel_meta.compute_def_args();
1393 let declare_specs = kernel_meta.declare_specs();
1394 let threads_spec_id =
1395 Literal::u32_unsuffixed(kernel_desc.spec_descs.len().try_into().unwrap());
1396 let items = kernel_meta.device_items();
1397 let device_arrays = kernel_meta.device_arrays();
1398 let device_slices = kernel_meta.device_slices();
1399 let device_fn_def_args = kernel_meta.device_fn_def_args();
1400 let device_fn_call_args = kernel_meta.device_fn_call_args();
1401 let push_consts_ident = format_ident!("__krnl_{ident}PushConsts");
1402 let (push_struct_tokens, push_consts_arg) =
1403 if !kernel_desc.push_descs.is_empty() || !kernel_desc.slice_descs.is_empty() {
1404 let push_const_fields = kernel_desc.push_const_fields();
1405 let push_struct_tokens = quote! {
1406 #[cfg(target_arch = "spirv")]
1407 #[automatically_derived]
1408 #[repr(C)]
1409 pub struct #push_consts_ident {
1410 #push_const_fields
1411 }
1412 };
1413 let push_consts_arg = quote! {
1414 #[spirv(push_constant)]
1415 __krnl_push_consts: &#push_consts_ident,
1416 };
1417 (push_struct_tokens, push_consts_arg)
1418 } else {
1419 (TokenStream2::new(), TokenStream2::new())
1420 };
1421 let mut device_fn_call = quote! {
1422 #unsafe_token {
1423 #ident (
1424 kernel,
1425 #device_fn_call_args
1426 );
1427 }
1428 };
1429 if kernel_meta.itemwise {
1430 device_fn_call = quote! {
1431 let __krnl_items = #items;
1432 let mut __krnl_item_id = kernel.global_id();
1433 while __krnl_item_id < __krnl_items {
1434 {
1435 let kernel = unsafe {
1436 ::krnl_core::kernel::__private::ItemKernelArgs {
1437 item_id: __krnl_item_id as u32,
1438 items: __krnl_items as u32,
1439 }.into_item_kernel()
1440 };
1441 #device_fn_call
1442 }
1443 __krnl_item_id += kernel.global_threads();
1444 }
1445 };
1446 }
1447 let kernel_type = if kernel_meta.itemwise {
1448 quote! { ItemKernel }
1449 } else {
1450 quote! {
1451 Kernel
1452 }
1453 };
1454 quote! {
1455 #push_struct_tokens
1456 #[cfg(target_arch = "spirv")]
1457 #[::krnl_core::spirv_std::spirv(compute(threads(1)))]
1458 #[allow(unused)]
1459 pub fn #ident(
1460 #push_consts_arg
1461 #[spirv(global_invocation_id)]
1462 __krnl_global_id: ::krnl_core::spirv_std::glam::UVec3,
1463 #[spirv(num_workgroups)]
1464 __krnl_groups: ::krnl_core::spirv_std::glam::UVec3,
1465 #[spirv(workgroup_id)]
1466 __krnl_group_id: ::krnl_core::spirv_std::glam::UVec3,
1467 #[spirv(num_subgroups)]
1468 __krnl_subgroups: u32,
1469 #[spirv(subgroup_id)]
1470 __krnl_subgroup_id: u32,
1471 #[spirv(subgroup_local_invocation_id)]
1472 __krnl_subgroup_thread_id: u32,
1473 #[spirv(spec_constant(id = #threads_spec_id, default = 1))] __krnl_threads: u32,
1474 #[spirv(local_invocation_id)]
1475 __krnl_thread_id: ::krnl_core::spirv_std::glam::UVec3,
1476 #[spirv(storage_buffer, descriptor_set = 1, binding = 0)]
1477 #kernel_data: &mut [u32],
1478 #compute_def_args
1479 ) {
1480 #(#item_attrs)*
1481 #unsafe_token fn #ident(
1482 #[allow(unused)]
1483 kernel: ::krnl_core::kernel::#kernel_type,
1484 #device_fn_def_args
1485 ) #block
1486 {
1487 let __krnl_kernel_data = #kernel_data;
1488 unsafe {
1489 ::krnl_core::kernel::__private::kernel_data(__krnl_kernel_data);
1490 }
1491 #declare_specs
1492 let mut kernel = unsafe {
1493 ::krnl_core::kernel::__private::KernelArgs {
1494 global_id: __krnl_global_id.x,
1495 groups: __krnl_groups.x,
1496 group_id: __krnl_group_id.x,
1497 subgroups: __krnl_subgroups,
1498 subgroup_id: __krnl_subgroup_id,
1499 subgroup_thread_id: __krnl_subgroup_thread_id,
1500 threads: __krnl_threads,
1501 thread_id: __krnl_thread_id.x,
1502 }.into_kernel()
1503 };
1504 #device_arrays
1505 #device_slices
1506 #device_fn_call
1507 }
1508 }
1509 }
1510 };
1511 let host_tokens = {
1512 let spec_descs = &kernel_desc.spec_descs;
1513 let slice_descs = &kernel_desc.slice_descs;
1514 let push_descs = &kernel_desc.push_descs;
1515 let dispatch_args = kernel_meta.dispatch_args();
1516 let dispatch_slice_args = kernel_meta.dispatch_slice_args();
1517 let dispatch_push_args = kernel_desc.dispatch_push_args();
1518 let safe = unsafe_token.is_none();
1519 let safety = if safe {
1520 quote! {
1521 Safety::Safe
1522 }
1523 } else {
1524 quote! {
1525 Safety::Unsafe
1526 }
1527 };
1528 let host_array_length_checks = kernel_meta.host_array_length_checks();
1529 let specialize = !kernel_desc.spec_descs.is_empty();
1530 let specialized = [format_ident!("S")];
1531 let specialized = if specialize {
1532 specialized.as_ref()
1533 } else {
1534 &[]
1535 };
1536 let kernel_builder_phantom_data = if specialize {
1537 quote! { S }
1538 } else {
1539 quote! { () }
1540 };
1541 let kernel_builder_build_generics = if specialize {
1542 quote! {
1543 <Specialized<true>>
1544 }
1545 } else {
1546 TokenStream2::new()
1547 };
1548 let kernel_builder_specialize_fn = if specialize {
1549 let spec_def_args = kernel_meta.spec_def_args();
1550 let spec_args = kernel_meta.spec_args();
1551 quote! {
1552 #[allow(clippy::too_many_arguments, non_snake_case)]
1554 pub fn specialize(mut self, #spec_def_args) -> KernelBuilder<Specialized<true>> {
1555 KernelBuilder {
1556 inner: self.inner.specialize(&[#(#spec_args.into()),*]),
1557 _m: PhantomData,
1558 }
1559 }
1560 }
1561 } else {
1562 TokenStream2::new()
1563 };
1564 let needs_groups = !kernel_meta.itemwise;
1565 let with_groups = [format_ident!("G")];
1566 let with_groups = if needs_groups {
1567 with_groups.as_ref()
1568 } else {
1569 &[]
1570 };
1571 let kernel_phantom_data = if needs_groups {
1572 quote! { G }
1573 } else {
1574 quote! { () }
1575 };
1576 let kernel_dispatch_generics = if needs_groups {
1577 quote! { <WithGroups<true>> }
1578 } else {
1579 TokenStream2::new()
1580 };
1581 let input_docs = {
1582 let input_tokens_string = prettyplease::unparse(&syn::parse2(quote! {
1583 #[kernel]
1584 #item_tokens
1585 })?);
1586 let input_doc_string = format!("```\n{input_tokens_string}\n```");
1587 quote! {
1588 #![cfg_attr(not(doctest), doc = #input_doc_string)]
1589 }
1590 };
1591 let expansion = if rustversion::cfg!(nightly) {
1592 let expansion_tokens_string =
1593 prettyplease::unparse(&syn::parse2(device_tokens.clone())?);
1594 let expansion_doc_string = format!("```\n{expansion_tokens_string}\n```");
1595 quote! {
1596 #[cfg(all(doc, not(doctest)))]
1597 mod expansion {
1598 #![doc = #expansion_doc_string]
1599 }
1600 }
1601 } else {
1602 TokenStream2::new()
1603 };
1604 quote! {
1605 #[cfg(not(target_arch = "spirv"))]
1606 #(#item_attrs)*
1607 #[automatically_derived]
1608 pub mod #ident {
1609 #input_docs
1610 #expansion
1611 __krnl_module_arg!(use crate as __krnl);
1612 use __krnl::{
1613 anyhow::{self, Result},
1614 krnl_core::half::{f16, bf16},
1615 buffer::{Slice, SliceMut},
1616 device::{Device, Features},
1617 scalar::ScalarType,
1618 kernel::__private::{
1619 Kernel as KernelBase,
1620 KernelBuilder as KernelBuilderBase,
1621 Specialized,
1622 WithGroups,
1623 KernelDesc,
1624 SliceDesc,
1625 SpecDesc,
1626 PushDesc,
1627 Safety,
1628 validate_kernel
1629 },
1630 anyhow::format_err,
1631 };
1632 use ::std::{sync::OnceLock, marker::PhantomData};
1633 #[cfg(not(krnlc))]
1634 #[doc(hidden)]
1635 use __krnl::macros::__krnl_cache;
1636 #[cfg(doc)]
1637 use __krnl::{kernel, device::{DeviceInfo, error::DeviceLost}};
1638
1639 #host_array_length_checks
1640
1641 pub struct KernelBuilder #(<#specialized = Specialized<false>>)* {
1645 #[doc(hidden)]
1646 inner: KernelBuilderBase,
1647 #[doc(hidden)]
1648 _m: PhantomData<#kernel_builder_phantom_data>,
1649 }
1650
1651 pub fn builder() -> Result<KernelBuilder> {
1658 static BUILDER: OnceLock<Result<KernelBuilderBase, String>> = OnceLock::new();
1659 let builder = BUILDER.get_or_init(|| {
1660 const DESC: Option<KernelDesc> = validate_kernel(__krnl_kernel!(#ident), #safety, &[#(#spec_descs),*], &[#(#slice_descs),*], &[#(#push_descs),*]);
1661 if let Some(desc) = DESC.as_ref() {
1662 KernelBuilderBase::from_desc(desc.clone())
1663 } else {
1664 Err(format!("Kernel `{}` not compiled!", ::std::module_path!()))
1665 }
1666 });
1667 match builder {
1668 Ok(inner) => Ok(KernelBuilder {
1669 inner: inner.clone(),
1670 _m: PhantomData,
1671 }),
1672 Err(err) => Err(format_err!("{err}")),
1673 }
1674 }
1675
1676 impl #(<#specialized>)* KernelBuilder #(<#specialized>)* {
1677 pub fn with_threads(self, threads: u32) -> Self {
1681 Self {
1682 inner: self.inner.with_threads(threads),
1683 _m: PhantomData,
1684 }
1685 }
1686 #kernel_builder_specialize_fn
1687 #[doc(hidden)]
1688 #[inline]
1689 pub fn __features(&self) -> Features {
1690 self.inner.features()
1691 }
1692 }
1693
1694 impl KernelBuilder #kernel_builder_build_generics {
1695 pub fn build(&self, device: Device) -> Result<Kernel> {
1705 Ok(Kernel {
1706 inner: self.inner.build(device)?,
1707 _m: PhantomData,
1708 })
1709 }
1710 }
1711
1712 pub struct Kernel #(<#with_groups = WithGroups<false>>)* {
1714 #[doc(hidden)]
1715 inner: KernelBase,
1716 #[doc(hidden)]
1717 _m: PhantomData<#kernel_phantom_data>,
1718 }
1719
1720 impl #(<#with_groups>)* Kernel #(<#with_groups>)* {
1721 pub fn threads(&self) -> u32 {
1723 self.inner.threads()
1724 }
1725 pub fn with_global_threads(self, global_threads: u32) -> Kernel #kernel_dispatch_generics {
1729 Kernel {
1730 inner: self.inner.with_global_threads(global_threads),
1731 _m: PhantomData,
1732 }
1733 }
1734 pub fn with_groups(self, groups: u32) -> Kernel #kernel_dispatch_generics {
1738 Kernel {
1739 inner: self.inner.with_groups(groups),
1740 _m: PhantomData,
1741 }
1742 }
1743 }
1744
1745 impl Kernel #kernel_dispatch_generics {
1746 pub #unsafe_token fn dispatch(&self, #dispatch_args) -> Result<()> {
1756 unsafe { self.inner.dispatch(&[#dispatch_slice_args], &[#(#dispatch_push_args.into()),*]) }
1757 }
1758 }
1759 }
1760 }
1761 };
1762 let tokens = quote! {
1763 #host_tokens
1764 #device_tokens
1765 #[cfg(all(target_arch = "spirv", not(krnlc)))]
1766 compile_error!("kernel cannot be used without krnlc!");
1767 };
1768 Ok(tokens)
1769}
1770
1771#[doc(hidden)]
1772#[proc_macro]
1773pub fn __krnl_cache(input: TokenStream) -> TokenStream {
1774 match __krnl_cache_impl(input.into()) {
1775 Ok(tokens) => tokens,
1776 Err(err) => err.into_compile_error(),
1777 }
1778 .into()
1779}
1780
1781#[derive(Parse)]
1782struct KrnlCacheInput {
1783 version: LitStr,
1784 __comma1: Comma,
1785 module: Ident,
1786 _comma2: Comma,
1787 kernel: Ident,
1788 _comma3: Comma,
1789 data: LitStr,
1790}
1791
1792fn __krnl_cache_impl(input: TokenStream2) -> Result<TokenStream2> {
1793 use flate2::{
1794 read::{GzDecoder, GzEncoder},
1795 Compression,
1796 };
1797 use std::io::Read;
1798 use syn::LitByteStr;
1799 use zero85::FromZ85;
1800
1801 static CACHE: OnceLock<std::result::Result<KrnlcCache, String>> = OnceLock::new();
1802
1803 let input = syn::parse2::<KrnlCacheInput>(input)?;
1804 let span = input.module.span();
1805 let cache = CACHE
1806 .get_or_init(|| {
1807 let version = env!("CARGO_PKG_VERSION");
1808 let krnlc_version = input.version.value();
1809 if !krnlc_version_compatible(&krnlc_version, version) {
1810 return Err(format!(
1811 "Cache created by krnlc {krnlc_version} is not compatible with krnl {version}!"
1812 ));
1813 }
1814 let data = input.data.value();
1815 let decoded_len = data.split_ascii_whitespace().map(|x| x.len() * 4 / 5).sum();
1816 let mut bytes = Vec::with_capacity(decoded_len);
1817 for data in data.split_ascii_whitespace() {
1818 let decoded = data.from_z85().map_err(|e| e.to_string())?;
1819 bytes.extend_from_slice(&decoded);
1820 }
1821 let cache =
1822 bincode2::deserialize_from::<_, KrnlcCache>(GzDecoder::new(bytes.as_slice()))
1823 .map_err(|e| e.to_string())?;
1824 assert_eq!(krnlc_version, cache.version);
1825 Ok(cache)
1826 })
1827 .as_ref()
1828 .map_err(|e| Error::new(input.version.span(), e))?;
1829 let kernels = cache
1830 .kernels
1831 .iter()
1832 .filter(|kernel| {
1833 let name = &kernel.name;
1834 let mut iter = name.rsplit("::");
1835 if input.kernel != iter.next().unwrap() {
1836 return false;
1837 }
1838 iter.any(|x| input.module == x)
1839 })
1840 .map(|kernel| {
1841 let KernelDesc {
1842 name,
1843 spirv,
1844 safe,
1845 features,
1846 spec_descs,
1847 slice_descs,
1848 push_descs,
1849 } = kernel;
1850 let mut bytes = Vec::new();
1851 GzEncoder::new(bytemuck::cast_slice(spirv), Compression::best())
1852 .read_to_end(&mut bytes)
1853 .unwrap();
1854 let spirv = LitByteStr::new(&bytes, span);
1855 quote! {
1856 KernelDesc::from_args(KernelDescArgs {
1857 name: #name,
1858 spirv: #spirv,
1859 features: #features,
1860 safe: #safe,
1861 spec_descs: &[#(#spec_descs),*],
1862 slice_descs: &[#(#slice_descs),*],
1863 push_descs: &[#(#push_descs),*],
1864 })
1865 }
1866 });
1867 let tokens = quote! {
1868 {
1869 __krnl_module_arg!(use crate as __krnl);
1870 use __krnl::{
1871 device::Features,
1872 kernel::__private::{find_kernel, KernelDesc, KernelDescArgs, Safety, SpecDesc, SliceDesc, PushDesc},
1873 };
1874
1875 find_kernel(std::module_path!(), &[#(#kernels),*])
1876 }
1877 };
1878 Ok(tokens)
1879}
1880
1881#[derive(Deserialize)]
1882struct KrnlcCache {
1883 #[allow(unused)]
1884 version: String,
1885 kernels: Vec<KernelDesc>,
1886}
1887
1888fn krnlc_version_compatible(krnlc_version: &str, version: &str) -> bool {
1889 let krnlc_version = Version::parse(krnlc_version).unwrap();
1890 let version = Version::parse(version).unwrap();
1891 if !krnlc_version.pre.is_empty() || !version.pre.is_empty() {
1892 krnlc_version == version
1893 } else if version.major == 0 && version.minor == 0 {
1894 krnlc_version.major == 0 && krnlc_version.minor == 0 && krnlc_version.patch == version.patch
1895 } else if version.major == 0 {
1896 krnlc_version.major == 0 && krnlc_version.minor == version.minor
1897 } else {
1898 krnlc_version.major == version.major && krnlc_version.minor == version.minor
1899 }
1900}
1901
1902#[cfg(test)]
1903mod tests {
1904 use super::*;
1905
1906 #[test]
1907 fn krnlc_version_semver() {
1908 assert!(krnlc_version_compatible("0.0.1", "0.0.1"));
1909 assert!(!krnlc_version_compatible("0.0.1", "0.0.2"));
1910 assert!(!krnlc_version_compatible("0.0.2", "0.0.1"));
1911 assert!(!krnlc_version_compatible("0.0.2-alpha", "0.0.2"));
1912 assert!(!krnlc_version_compatible("0.0.2", "0.0.2-alpha"));
1913 assert!(!krnlc_version_compatible("0.0.2", "0.1.0"));
1914 assert!(!krnlc_version_compatible("0.1.1-alpha", "0.1.0"));
1915 assert!(!krnlc_version_compatible("0.1.1", "0.1.0-alpha"));
1916 assert!(krnlc_version_compatible("0.1.1", "0.1.0"));
1917 assert!(krnlc_version_compatible("0.1.0", "0.1.1"));
1918 assert!(krnlc_version_compatible("0.1.1-alpha", "0.1.1-alpha"));
1919 assert!(!krnlc_version_compatible("0.1.0-alpha", "0.1.1-alpha"));
1920 assert!(!krnlc_version_compatible("0.1.1", "0.2.0"));
1921 }
1922}