1use darling::{ast, FromDeriveInput, FromField, FromMeta};
77use proc_macro::TokenStream;
78use quote::{format_ident, quote};
79use syn::{parse_macro_input, DeriveInput, ItemFn};
80
81#[derive(Debug, FromDeriveInput)]
83#[darling(attributes(message, ring_message), supports(struct_named))]
84struct RingMessageArgs {
85 ident: syn::Ident,
86 generics: syn::Generics,
87 data: ast::Data<(), RingMessageField>,
88 #[darling(default)]
92 type_id: Option<u64>,
93 #[darling(default)]
96 domain: Option<String>,
97 #[darling(default)]
100 k2k_routable: bool,
101 #[darling(default)]
104 category: Option<String>,
105}
106
107#[derive(Debug, FromField)]
109#[darling(attributes(message))]
110struct RingMessageField {
111 ident: Option<syn::Ident>,
112 #[allow(dead_code)]
113 ty: syn::Type,
114 #[darling(default)]
116 id: bool,
117 #[darling(default)]
119 correlation: bool,
120 #[darling(default)]
122 priority: bool,
123}
124
125#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
181pub fn derive_ring_message(input: TokenStream) -> TokenStream {
182 let input = parse_macro_input!(input as DeriveInput);
183
184 let args = match RingMessageArgs::from_derive_input(&input) {
185 Ok(args) => args,
186 Err(e) => return e.write_errors().into(),
187 };
188
189 let name = &args.ident;
190 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
191
192 let base_type_id = args.type_id.unwrap_or_else(|| {
194 use std::collections::hash_map::DefaultHasher;
195 use std::hash::{Hash, Hasher};
196 let mut hasher = DefaultHasher::new();
197 name.to_string().hash(&mut hasher);
198 if args.domain.is_some() {
200 hasher.finish() % 100
201 } else {
202 hasher.finish()
203 }
204 });
205
206 let fields = match &args.data {
208 ast::Data::Struct(fields) => fields,
209 _ => panic!("RingMessage can only be derived for structs"),
210 };
211
212 let mut id_field: Option<&syn::Ident> = None;
213 let mut correlation_field: Option<&syn::Ident> = None;
214 let mut priority_field: Option<&syn::Ident> = None;
215
216 for field in fields.iter() {
217 if field.id {
218 id_field = field.ident.as_ref();
219 }
220 if field.correlation {
221 correlation_field = field.ident.as_ref();
222 }
223 if field.priority {
224 priority_field = field.ident.as_ref();
225 }
226 }
227
228 let message_id_impl = if let Some(field) = id_field {
230 quote! { self.#field }
231 } else {
232 quote! { ::ringkernel_core::message::MessageId::new(0) }
233 };
234
235 let correlation_id_impl = if let Some(field) = correlation_field {
237 quote! { self.#field }
238 } else {
239 quote! { ::ringkernel_core::message::CorrelationId::none() }
240 };
241
242 let priority_impl = if let Some(field) = priority_field {
244 quote! { self.#field }
245 } else {
246 quote! { ::ringkernel_core::message::Priority::Normal }
247 };
248
249 let message_type_impl = if let Some(ref domain_str) = args.domain {
251 quote! {
253 ::ringkernel_core::domain::Domain::from_str(#domain_str)
254 .unwrap_or(::ringkernel_core::domain::Domain::General)
255 .base_type_id() + #base_type_id
256 }
257 } else {
258 quote! { #base_type_id }
260 };
261
262 let domain_impl = if let Some(ref domain_str) = args.domain {
264 quote! {
265 impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
266 fn domain() -> ::ringkernel_core::domain::Domain {
267 ::ringkernel_core::domain::Domain::from_str(#domain_str)
268 .unwrap_or(::ringkernel_core::domain::Domain::General)
269 }
270 }
271 }
272 } else {
273 quote! {}
274 };
275
276 let k2k_registration = if args.k2k_routable {
278 let registration_name = format_ident!(
279 "__K2K_MESSAGE_REGISTRATION_{}",
280 name.to_string().to_uppercase()
281 );
282 let type_name_str = name.to_string();
283 let category_tokens = match &args.category {
284 Some(cat) => quote! { ::std::option::Option::Some(#cat) },
285 None => quote! { ::std::option::Option::None },
286 };
287
288 quote! {
289 #[allow(non_upper_case_globals)]
290 #[::inventory::submit]
291 static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
292 ::ringkernel_core::k2k::K2KMessageRegistration {
293 type_id: {
294 #base_type_id
297 },
298 type_name: #type_name_str,
299 k2k_routable: true,
300 category: #category_tokens,
301 };
302 }
303 } else {
304 quote! {}
305 };
306
307 let expanded = quote! {
308 impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
309 fn message_type() -> u64 {
310 #message_type_impl
311 }
312
313 fn message_id(&self) -> ::ringkernel_core::message::MessageId {
314 #message_id_impl
315 }
316
317 fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
318 #correlation_id_impl
319 }
320
321 fn priority(&self) -> ::ringkernel_core::message::Priority {
322 #priority_impl
323 }
324
325 fn serialize(&self) -> Vec<u8> {
326 ::rkyv::to_bytes::<_, 4096>(self)
329 .map(|v| v.to_vec())
330 .unwrap_or_default()
331 }
332
333 fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
334 where
335 Self: Sized,
336 {
337 use ::rkyv::Deserialize as _;
338 let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
339 let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
340 .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
341 "rkyv deserialization failed".to_string()
342 ))?;
343 Ok(deserialized)
344 }
345
346 fn size_hint(&self) -> usize {
347 ::std::mem::size_of::<Self>()
348 }
349 }
350
351 #domain_impl
352
353 #k2k_registration
354 };
355
356 TokenStream::from(expanded)
357}
358
359#[derive(Debug, FromMeta)]
361struct RingKernelArgs {
362 id: String,
364 #[darling(default)]
366 mode: Option<String>,
367 #[darling(default)]
369 grid_size: Option<u32>,
370 #[darling(default)]
372 block_size: Option<u32>,
373 #[darling(default)]
375 publishes_to: Option<String>,
376}
377
378#[proc_macro_attribute]
398pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
399 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
400 Ok(v) => v,
401 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
402 };
403
404 let args = match RingKernelArgs::from_list(&args) {
405 Ok(v) => v,
406 Err(e) => return TokenStream::from(e.write_errors()),
407 };
408
409 let input = parse_macro_input!(item as ItemFn);
410
411 let kernel_id = &args.id;
412 let fn_name = &input.sig.ident;
413 let fn_vis = &input.vis;
414 let fn_block = &input.block;
415 let fn_attrs = &input.attrs;
416
417 let inputs = &input.sig.inputs;
419 let output = &input.sig.output;
420
421 let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
423 let ctx = inputs.first();
424 let msg = inputs.iter().nth(1);
425 (ctx, msg)
426 } else {
427 (None, None)
428 };
429
430 let msg_type = msg_arg
432 .map(|arg| {
433 if let syn::FnArg::Typed(pat_type) = arg {
434 pat_type.ty.clone()
435 } else {
436 syn::parse_quote!(())
437 }
438 })
439 .unwrap_or_else(|| syn::parse_quote!(()));
440
441 let mode = args.mode.as_deref().unwrap_or("persistent");
443 let mode_expr = if mode == "event_driven" {
444 quote! { ::ringkernel_core::types::KernelMode::EventDriven }
445 } else {
446 quote! { ::ringkernel_core::types::KernelMode::Persistent }
447 };
448
449 let grid_size = args.grid_size.unwrap_or(1);
451 let block_size = args.block_size.unwrap_or(256);
452
453 let publishes_to_targets: Vec<String> = args
455 .publishes_to
456 .as_ref()
457 .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
458 .unwrap_or_default();
459
460 let registration_name = format_ident!(
462 "__RINGKERNEL_REGISTRATION_{}",
463 fn_name.to_string().to_uppercase()
464 );
465 let handler_name = format_ident!("{}_handler", fn_name);
466
467 let expanded = quote! {
469 #(#fn_attrs)*
471 #fn_vis async fn #fn_name #inputs #output #fn_block
472
473 #fn_vis fn #handler_name(
475 ctx: &mut ::ringkernel_core::RingContext<'_>,
476 envelope: ::ringkernel_core::message::MessageEnvelope,
477 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
478 Box::pin(async move {
479 let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
481
482 let response = #fn_name(ctx, msg).await;
484
485 let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
487 let response_header = ::ringkernel_core::message::MessageHeader::new(
488 <_ as ::ringkernel_core::message::RingMessage>::message_type(),
489 envelope.header.dest_kernel,
490 envelope.header.source_kernel,
491 response_payload.len(),
492 ctx.now(),
493 ).with_correlation(envelope.header.correlation_id);
494
495 Ok(::ringkernel_core::message::MessageEnvelope {
496 header: response_header,
497 payload: response_payload,
498 })
499 })
500 }
501
502 #[allow(non_upper_case_globals)]
504 #[::inventory::submit]
505 static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
506 id: #kernel_id,
507 mode: #mode_expr,
508 grid_size: #grid_size,
509 block_size: #block_size,
510 publishes_to: &[#(#publishes_to_targets),*],
511 };
512 };
513
514 TokenStream::from(expanded)
515}
516
517#[proc_macro_derive(GpuType)]
521pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
522 let input = parse_macro_input!(input as DeriveInput);
523 let name = &input.ident;
524 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
525
526 let expanded = quote! {
528 const _: fn() = || {
530 fn assert_copy<T: Copy>() {}
531 assert_copy::<#name #ty_generics>();
532 };
533
534 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
536 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
537 };
538
539 TokenStream::from(expanded)
540}
541
542#[derive(Debug, FromMeta)]
548struct StencilKernelArgs {
549 id: String,
551 #[darling(default)]
553 grid: Option<String>,
554 #[darling(default)]
556 tile_size: Option<u32>,
557 #[darling(default)]
559 tile_width: Option<u32>,
560 #[darling(default)]
562 tile_height: Option<u32>,
563 #[darling(default)]
565 halo: Option<u32>,
566}
567
568#[proc_macro_attribute]
609pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
610 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
611 Ok(v) => v,
612 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
613 };
614
615 let args = match StencilKernelArgs::from_list(&args) {
616 Ok(v) => v,
617 Err(e) => return TokenStream::from(e.write_errors()),
618 };
619
620 let input = parse_macro_input!(item as ItemFn);
621
622 stencil_kernel_impl(args, input)
624}
625
626fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
627 let kernel_id = &args.id;
628 let fn_name = &input.sig.ident;
629 let fn_vis = &input.vis;
630 let fn_block = &input.block;
631 let fn_inputs = &input.sig.inputs;
632 let fn_output = &input.sig.output;
633 let fn_attrs = &input.attrs;
634
635 let grid = args.grid.as_deref().unwrap_or("2d");
637 let tile_width = args
638 .tile_width
639 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
640 let tile_height = args
641 .tile_height
642 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
643 let halo = args.halo.unwrap_or(1);
644
645 let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
647
648 let registration_name = format_ident!(
650 "__STENCIL_KERNEL_REGISTRATION_{}",
651 fn_name.to_string().to_uppercase()
652 );
653
654 #[cfg(feature = "cuda-codegen")]
656 let cuda_source_code = {
657 use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
658
659 let grid_type = match grid {
660 "1d" => Grid::Grid1D,
661 "2d" => Grid::Grid2D,
662 "3d" => Grid::Grid3D,
663 _ => Grid::Grid2D,
664 };
665
666 let config = StencilConfig::new(kernel_id.clone())
667 .with_grid(grid_type)
668 .with_tile_size(tile_width as usize, tile_height as usize)
669 .with_halo(halo as usize);
670
671 match transpile_stencil_kernel(&input, &config) {
672 Ok(cuda) => cuda,
673 Err(e) => {
674 return TokenStream::from(
675 syn::Error::new_spanned(
676 &input.sig.ident,
677 format!("CUDA transpilation failed: {}", e),
678 )
679 .to_compile_error(),
680 );
681 }
682 }
683 };
684
685 #[cfg(not(feature = "cuda-codegen"))]
686 let cuda_source_code = format!(
687 "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
688 kernel_id
689 );
690
691 let expanded = quote! {
693 #(#fn_attrs)*
695 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
696
697 #fn_vis const #cuda_const_name: &str = #cuda_source_code;
699
700 #[allow(non_upper_case_globals)]
702 #[::inventory::submit]
703 static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
704 ::ringkernel_core::__private::StencilKernelRegistration {
705 id: #kernel_id,
706 grid: #grid,
707 tile_width: #tile_width,
708 tile_height: #tile_height,
709 halo: #halo,
710 cuda_source: #cuda_source_code,
711 };
712 };
713
714 TokenStream::from(expanded)
715}
716
717#[derive(Debug, Clone, Copy, PartialEq, Eq)]
723enum GpuBackend {
724 Cuda,
726 Metal,
728 Wgpu,
730 Cpu,
732}
733
734impl GpuBackend {
735 fn from_str(s: &str) -> Option<Self> {
736 match s.to_lowercase().as_str() {
737 "cuda" => Some(Self::Cuda),
738 "metal" => Some(Self::Metal),
739 "wgpu" | "webgpu" => Some(Self::Wgpu),
740 "cpu" => Some(Self::Cpu),
741 _ => None,
742 }
743 }
744
745 fn as_str(&self) -> &'static str {
746 match self {
747 Self::Cuda => "cuda",
748 Self::Metal => "metal",
749 Self::Wgpu => "wgpu",
750 Self::Cpu => "cpu",
751 }
752 }
753}
754
755#[derive(Debug, Clone, Copy, PartialEq, Eq)]
757enum GpuCapability {
758 Float64,
760 Int64,
762 Atomic64,
764 CooperativeGroups,
766 Subgroups,
768 SharedMemory,
770 DynamicParallelism,
772 Float16,
774}
775
776impl GpuCapability {
777 fn from_str(s: &str) -> Option<Self> {
778 match s.to_lowercase().as_str() {
779 "f64" | "float64" => Some(Self::Float64),
780 "i64" | "int64" => Some(Self::Int64),
781 "atomic64" => Some(Self::Atomic64),
782 "cooperative_groups" | "cooperativegroups" | "grid_sync" => {
783 Some(Self::CooperativeGroups)
784 }
785 "subgroups" | "warp" | "simd" => Some(Self::Subgroups),
786 "shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
787 "dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
788 "f16" | "float16" | "half" => Some(Self::Float16),
789 _ => None,
790 }
791 }
792
793 fn as_str(&self) -> &'static str {
794 match self {
795 Self::Float64 => "f64",
796 Self::Int64 => "i64",
797 Self::Atomic64 => "atomic64",
798 Self::CooperativeGroups => "cooperative_groups",
799 Self::Subgroups => "subgroups",
800 Self::SharedMemory => "shared_memory",
801 Self::DynamicParallelism => "dynamic_parallelism",
802 Self::Float16 => "f16",
803 }
804 }
805
806 fn supported_by(&self, backend: GpuBackend) -> bool {
808 match (self, backend) {
809 (_, GpuBackend::Cuda) => true,
811
812 (Self::Float64, GpuBackend::Metal) => false,
814 (Self::CooperativeGroups, GpuBackend::Metal) => false,
815 (Self::DynamicParallelism, GpuBackend::Metal) => false,
816 (_, GpuBackend::Metal) => true,
817
818 (Self::Float64, GpuBackend::Wgpu) => false,
820 (Self::Int64, GpuBackend::Wgpu) => false,
821 (Self::Atomic64, GpuBackend::Wgpu) => false, (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
823 (Self::DynamicParallelism, GpuBackend::Wgpu) => false,
824 (Self::Subgroups, GpuBackend::Wgpu) => true, (_, GpuBackend::Wgpu) => true,
826
827 (_, GpuBackend::Cpu) => true,
829 }
830 }
831}
832
833#[derive(Debug)]
835struct GpuKernelArgs {
836 id: Option<String>,
838 backends: Vec<GpuBackend>,
840 fallback: Vec<GpuBackend>,
842 requires: Vec<GpuCapability>,
844 block_size: Option<u32>,
846}
847
848impl Default for GpuKernelArgs {
849 fn default() -> Self {
850 Self {
851 id: None,
852 backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
853 fallback: vec![
854 GpuBackend::Cuda,
855 GpuBackend::Metal,
856 GpuBackend::Wgpu,
857 GpuBackend::Cpu,
858 ],
859 requires: Vec::new(),
860 block_size: None,
861 }
862 }
863}
864
865impl GpuKernelArgs {
866 fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
867 let mut args = Self::default();
868 let attr_str = attr.to_string();
869
870 if let Some(start) = attr_str.find("backends") {
872 if let Some(bracket_start) = attr_str[start..].find('[') {
873 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
874 let backends_str =
875 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
876 args.backends = backends_str
877 .split(',')
878 .filter_map(|s| GpuBackend::from_str(s.trim()))
879 .collect();
880 }
881 }
882 }
883
884 if let Some(start) = attr_str.find("fallback") {
886 if let Some(bracket_start) = attr_str[start..].find('[') {
887 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
888 let fallback_str =
889 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
890 args.fallback = fallback_str
891 .split(',')
892 .filter_map(|s| GpuBackend::from_str(s.trim()))
893 .collect();
894 }
895 }
896 }
897
898 if let Some(start) = attr_str.find("requires") {
900 if let Some(bracket_start) = attr_str[start..].find('[') {
901 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
902 let requires_str =
903 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
904 args.requires = requires_str
905 .split(',')
906 .filter_map(|s| GpuCapability::from_str(s.trim()))
907 .collect();
908 }
909 }
910 }
911
912 if let Some(start) = attr_str.find("id") {
914 if let Some(quote_start) = attr_str[start..].find('"') {
915 if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
916 args.id = Some(
917 attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
918 .to_string(),
919 );
920 }
921 }
922 }
923
924 if let Some(start) = attr_str.find("block_size") {
926 if let Some(eq) = attr_str[start..].find('=') {
927 let rest = &attr_str[start + eq + 1..];
928 let num_end = rest
929 .find(|c: char| !c.is_numeric() && c != ' ')
930 .unwrap_or(rest.len());
931 if let Ok(n) = rest[..num_end].trim().parse() {
932 args.block_size = Some(n);
933 }
934 }
935 }
936
937 Ok(args)
938 }
939
940 fn validate_capabilities(&self) -> Result<(), String> {
942 for cap in &self.requires {
943 let mut supported_by_any = false;
944 for backend in &self.backends {
945 if cap.supported_by(*backend) {
946 supported_by_any = true;
947 break;
948 }
949 }
950 if !supported_by_any {
951 return Err(format!(
952 "Capability '{}' is not supported by any of the specified backends: {:?}",
953 cap.as_str(),
954 self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
955 ));
956 }
957 }
958 Ok(())
959 }
960
961 fn compatible_backends(&self) -> Vec<GpuBackend> {
963 self.backends
964 .iter()
965 .filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
966 .copied()
967 .collect()
968 }
969}
970
971#[proc_macro_attribute]
1028pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
1029 let attr2: proc_macro2::TokenStream = attr.into();
1030 let args = match GpuKernelArgs::parse(attr2) {
1031 Ok(args) => args,
1032 Err(e) => return TokenStream::from(e.write_errors()),
1033 };
1034
1035 let input = parse_macro_input!(item as ItemFn);
1036
1037 if let Err(msg) = args.validate_capabilities() {
1039 return TokenStream::from(
1040 syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
1041 );
1042 }
1043
1044 gpu_kernel_impl(args, input)
1045}
1046
1047fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
1048 let fn_name = &input.sig.ident;
1049 let fn_vis = &input.vis;
1050 let fn_block = &input.block;
1051 let fn_inputs = &input.sig.inputs;
1052 let fn_output = &input.sig.output;
1053 let fn_attrs = &input.attrs;
1054
1055 let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
1056 let block_size = args.block_size.unwrap_or(256);
1057
1058 let compatible_backends = args.compatible_backends();
1060
1061 let mut source_constants = Vec::new();
1063
1064 for backend in &compatible_backends {
1065 let const_name = format_ident!(
1066 "{}_{}",
1067 fn_name.to_string().to_uppercase(),
1068 backend.as_str().to_uppercase()
1069 );
1070
1071 let backend_str = backend.as_str();
1072
1073 let source_placeholder = format!(
1076 "// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
1077 backend_str.to_uppercase(),
1078 kernel_id,
1079 args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
1080 );
1081
1082 source_constants.push(quote! {
1083 #fn_vis const #const_name: &str = #source_placeholder;
1085 });
1086 }
1087
1088 let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
1090 let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
1091 let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
1092
1093 let registration_name = format_ident!(
1095 "__GPU_KERNEL_REGISTRATION_{}",
1096 fn_name.to_string().to_uppercase()
1097 );
1098
1099 let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
1101
1102 let expanded = quote! {
1104 #(#fn_attrs)*
1106 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
1107
1108 #(#source_constants)*
1110
1111 #fn_vis mod #info_name {
1113 pub const ID: &str = #kernel_id;
1115
1116 pub const BLOCK_SIZE: u32 = #block_size;
1118
1119 pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
1121
1122 pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
1124
1125 pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
1127 }
1128
1129 #[allow(non_upper_case_globals)]
1131 #[::inventory::submit]
1132 static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
1133 ::ringkernel_core::__private::GpuKernelRegistration {
1134 id: #kernel_id,
1135 block_size: #block_size,
1136 capabilities: &[#(#capability_strs),*],
1137 backends: &[#(#backend_strs),*],
1138 fallback_order: &[#(#fallback_strs),*],
1139 };
1140 };
1141
1142 TokenStream::from(expanded)
1143}
1144
1145#[derive(Debug, FromDeriveInput)]
1151#[darling(attributes(state), supports(struct_named))]
1152struct ControlBlockStateArgs {
1153 ident: syn::Ident,
1154 generics: syn::Generics,
1155 #[darling(default)]
1157 version: Option<u32>,
1158}
1159
1160#[proc_macro_derive(ControlBlockState, attributes(state))]
1209pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
1210 let input = parse_macro_input!(input as DeriveInput);
1211
1212 let args = match ControlBlockStateArgs::from_derive_input(&input) {
1213 Ok(args) => args,
1214 Err(e) => return e.write_errors().into(),
1215 };
1216
1217 let name = &args.ident;
1218 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
1219 let version = args.version.unwrap_or(1);
1220
1221 let expanded = quote! {
1222 const _: () = {
1224 assert!(
1225 ::std::mem::size_of::<#name #ty_generics>() <= 24,
1226 "ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
1227 );
1228 };
1229
1230 const _: fn() = || {
1232 fn assert_copy<T: Copy>() {}
1233 assert_copy::<#name #ty_generics>();
1234 };
1235
1236 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
1239 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
1240
1241 impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
1243 const VERSION: u32 = #version;
1244
1245 fn is_embedded() -> bool {
1246 true
1247 }
1248 }
1249 };
1250
1251 TokenStream::from(expanded)
1252}