ringkernel_cuda_codegen/
lib.rs

1//! CUDA code generation from Rust DSL for RingKernel stencil kernels.
2//!
3//! This crate provides transpilation from a restricted Rust DSL to CUDA C code,
4//! enabling developers to write GPU kernels in Rust without directly writing CUDA.
5//!
6//! # Overview
7//!
8//! The transpiler supports a subset of Rust focused on stencil/grid operations:
9//!
10//! - Primitive types: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
11//! - Array slices: `&[T]`, `&mut [T]`
12//! - Arithmetic and comparison operators
13//! - Let bindings and if/else expressions
14//! - Stencil intrinsics via `GridPos` context
15//!
16//! # Example
17//!
18//! ```ignore
19//! use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig};
20//!
21//! let rust_code = r#"
22//!     fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
23//!         let curr = p[pos.idx()];
24//!         let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
25//!         p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
26//!     }
27//! "#;
28//!
29//! let config = StencilConfig {
30//!     id: "fdtd".to_string(),
31//!     grid: Grid::Grid2D,
32//!     tile_size: (16, 16),
33//!     halo: 1,
34//! };
35//!
36//! let cuda_code = transpile_stencil_kernel(rust_code, &config)?;
37//! ```
38
39pub mod dsl;
40pub mod handler;
41mod intrinsics;
42pub mod loops;
43pub mod persistent_fdtd;
44pub mod reduction_intrinsics;
45pub mod ring_kernel;
46pub mod shared;
47mod stencil;
48mod transpiler;
49mod types;
50mod validation;
51
52pub use handler::{
53    generate_cuda_struct, generate_message_deser, generate_response_ser, ContextMethod,
54    HandlerCodegenConfig, HandlerParam, HandlerParamKind, HandlerReturnType, HandlerSignature,
55    MessageTypeInfo, MessageTypeRegistry,
56};
57pub use intrinsics::{GpuIntrinsic, IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
58pub use loops::{LoopPattern, RangeInfo};
59pub use persistent_fdtd::{generate_persistent_fdtd_kernel, PersistentFdtdConfig};
60pub use reduction_intrinsics::{
61    generate_inline_block_reduce, generate_inline_grid_reduce,
62    generate_inline_reduce_and_broadcast, generate_reduction_helpers, transpile_reduction_call,
63    ReductionCodegenConfig, ReductionIntrinsic, ReductionOp as CodegenReductionOp,
64};
65pub use ring_kernel::{
66    generate_control_block_struct, generate_hlc_struct, generate_k2k_structs,
67    KernelReductionConfig, RingKernelConfig,
68};
69pub use shared::{SharedArray, SharedMemoryConfig, SharedMemoryDecl, SharedTile};
70pub use stencil::{Grid, GridPos, StencilConfig, StencilLaunchConfig};
71pub use transpiler::{transpile_function, CudaTranspiler, SharedVarInfo};
72pub use types::{
73    get_slice_element_type, is_control_block_type, is_mutable_reference, is_ring_context_type,
74    ring_kernel_type_mapper, CudaType, RingKernelParamKind, TypeMapper,
75};
76pub use validation::{
77    is_simple_assignment, validate_function, validate_function_with_mode,
78    validate_stencil_signature, ValidationError, ValidationMode,
79};
80
81use thiserror::Error;
82
83/// Errors that can occur during transpilation.
84#[derive(Error, Debug)]
85pub enum TranspileError {
86    /// Failed to parse Rust code.
87    #[error("Parse error: {0}")]
88    Parse(String),
89
90    /// DSL constraint violation.
91    #[error("Validation error: {0}")]
92    Validation(#[from] ValidationError),
93
94    /// Unsupported Rust construct.
95    #[error("Unsupported construct: {0}")]
96    Unsupported(String),
97
98    /// Type mapping failure.
99    #[error("Type error: {0}")]
100    Type(String),
101}
102
103/// Result type for transpilation operations.
104pub type Result<T> = std::result::Result<T, TranspileError>;
105
106/// Transpile a Rust stencil kernel function to CUDA C code.
107///
108/// This is the main entry point for code generation. It takes a parsed
109/// Rust function and stencil configuration, validates the DSL constraints,
110/// and generates equivalent CUDA C code.
111///
112/// # Arguments
113///
114/// * `func` - The parsed Rust function (from syn)
115/// * `config` - Stencil kernel configuration
116///
117/// # Returns
118///
119/// The generated CUDA C source code as a string.
120pub fn transpile_stencil_kernel(func: &syn::ItemFn, config: &StencilConfig) -> Result<String> {
121    // Validate DSL constraints
122    validate_function(func)?;
123
124    // Create transpiler with stencil config
125    let mut transpiler = CudaTranspiler::new(config.clone());
126
127    // Generate CUDA code
128    transpiler.transpile_stencil(func)
129}
130
131/// Transpile a Rust function to a CUDA `__device__` function.
132///
133/// This generates a device-callable function (not a kernel) from Rust code.
134pub fn transpile_device_function(func: &syn::ItemFn) -> Result<String> {
135    validate_function(func)?;
136    transpile_function(func)
137}
138
139/// Transpile a Rust function to a CUDA `__global__` kernel.
140///
141/// This generates an externally-callable kernel without stencil-specific patterns.
142/// Use DSL functions like `thread_idx_x()`, `block_idx_x()` to access CUDA indices.
143///
144/// # Example
145///
146/// ```ignore
147/// use ringkernel_cuda_codegen::transpile_global_kernel;
148/// use syn::parse_quote;
149///
150/// let func: syn::ItemFn = parse_quote! {
151///     fn exchange_halos(buffer: &mut [f32], copies: &[u32], num_copies: i32) {
152///         let idx = block_idx_x() * block_dim_x() + thread_idx_x();
153///         if idx >= num_copies { return; }
154///         let src = copies[idx * 2] as usize;
155///         let dst = copies[idx * 2 + 1] as usize;
156///         buffer[dst] = buffer[src];
157///     }
158/// };
159///
160/// let cuda = transpile_global_kernel(&func)?;
161/// // Generates: extern "C" __global__ void exchange_halos(...) { ... }
162/// ```
163pub fn transpile_global_kernel(func: &syn::ItemFn) -> Result<String> {
164    validate_function(func)?;
165    let mut transpiler = CudaTranspiler::new_generic();
166    transpiler.transpile_generic_kernel(func)
167}
168
169/// Transpile a Rust handler function to a persistent ring kernel.
170///
171/// Ring kernels are persistent GPU kernels that process messages in a loop.
172/// The handler function is embedded within the message processing loop.
173///
174/// # Example
175///
176/// ```ignore
177/// use ringkernel_cuda_codegen::{transpile_ring_kernel, RingKernelConfig};
178/// use syn::parse_quote;
179///
180/// let handler: syn::ItemFn = parse_quote! {
181///     fn process_message(value: f32) -> f32 {
182///         value * 2.0
183///     }
184/// };
185///
186/// let config = RingKernelConfig::new("processor")
187///     .with_block_size(128)
188///     .with_hlc(true);
189///
190/// let cuda = transpile_ring_kernel(&handler, &config)?;
191/// // Generates a persistent kernel with message loop wrapping the handler
192/// ```
193pub fn transpile_ring_kernel(handler: &syn::ItemFn, config: &RingKernelConfig) -> Result<String> {
194    // Validate handler with generic mode (loops allowed but not required)
195    // The persistent loop is generated by the ring kernel wrapper, not the handler
196    validate_function_with_mode(handler, ValidationMode::Generic)?;
197
198    // Create transpiler in generic mode for the handler
199    let mut transpiler = CudaTranspiler::with_mode(ValidationMode::Generic);
200
201    // Transpile the handler into a ring kernel wrapper
202    transpiler.transpile_ring_kernel(handler, config)
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use syn::parse_quote;
209
210    #[test]
211    fn test_simple_function_transpile() {
212        let func: syn::ItemFn = parse_quote! {
213            fn add(a: f32, b: f32) -> f32 {
214                a + b
215            }
216        };
217
218        let result = transpile_device_function(&func);
219        assert!(
220            result.is_ok(),
221            "Should transpile simple function: {:?}",
222            result
223        );
224
225        let cuda = result.unwrap();
226        assert!(cuda.contains("float"), "Should contain CUDA float type");
227        assert!(cuda.contains("a + b"), "Should contain the expression");
228    }
229
230    #[test]
231    fn test_global_kernel_transpile() {
232        let func: syn::ItemFn = parse_quote! {
233            fn exchange_halos(buffer: &mut [f32], copies: &[u32], num_copies: i32) {
234                let idx = block_idx_x() * block_dim_x() + thread_idx_x();
235                if idx >= num_copies {
236                    return;
237                }
238                let src = copies[idx * 2] as usize;
239                let dst = copies[idx * 2 + 1] as usize;
240                buffer[dst] = buffer[src];
241            }
242        };
243
244        let result = transpile_global_kernel(&func);
245        assert!(
246            result.is_ok(),
247            "Should transpile global kernel: {:?}",
248            result
249        );
250
251        let cuda = result.unwrap();
252        assert!(
253            cuda.contains("extern \"C\" __global__"),
254            "Should be global kernel"
255        );
256        assert!(cuda.contains("exchange_halos"), "Should have kernel name");
257        assert!(cuda.contains("blockIdx.x"), "Should contain blockIdx.x");
258        assert!(cuda.contains("blockDim.x"), "Should contain blockDim.x");
259        assert!(cuda.contains("threadIdx.x"), "Should contain threadIdx.x");
260        assert!(cuda.contains("return"), "Should have early return");
261
262        println!("Generated global kernel:\n{}", cuda);
263    }
264
265    #[test]
266    fn test_stencil_kernel_transpile() {
267        let func: syn::ItemFn = parse_quote! {
268            fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
269                let curr = p[pos.idx()];
270                let prev = p_prev[pos.idx()];
271                let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
272                p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
273            }
274        };
275
276        let config = StencilConfig {
277            id: "fdtd".to_string(),
278            grid: Grid::Grid2D,
279            tile_size: (16, 16),
280            halo: 1,
281        };
282
283        let result = transpile_stencil_kernel(&func, &config);
284        assert!(
285            result.is_ok(),
286            "Should transpile stencil kernel: {:?}",
287            result
288        );
289
290        let cuda = result.unwrap();
291        assert!(cuda.contains("__global__"), "Should be a CUDA kernel");
292        assert!(cuda.contains("threadIdx"), "Should use thread indices");
293    }
294
295    #[test]
296    fn test_ring_kernel_transpile() {
297        // A simple handler that doubles the input value
298        let handler: syn::ItemFn = parse_quote! {
299            fn process(value: f32) -> f32 {
300                let result = value * 2.0;
301                result
302            }
303        };
304
305        let config = RingKernelConfig::new("processor")
306            .with_block_size(128)
307            .with_queue_capacity(1024)
308            .with_hlc(true);
309
310        let result = transpile_ring_kernel(&handler, &config);
311        assert!(result.is_ok(), "Should transpile ring kernel: {:?}", result);
312
313        let cuda = result.unwrap();
314
315        // Check struct definitions
316        assert!(
317            cuda.contains("struct __align__(128) ControlBlock"),
318            "Should have ControlBlock struct"
319        );
320        assert!(cuda.contains("is_active"), "Should have is_active field");
321        assert!(
322            cuda.contains("should_terminate"),
323            "Should have should_terminate field"
324        );
325        assert!(
326            cuda.contains("messages_processed"),
327            "Should have messages_processed field"
328        );
329
330        // Check kernel signature
331        assert!(
332            cuda.contains("extern \"C\" __global__ void ring_kernel_processor"),
333            "Should have correct kernel name"
334        );
335        assert!(
336            cuda.contains("ControlBlock* __restrict__ control"),
337            "Should have control block param"
338        );
339        assert!(
340            cuda.contains("input_buffer"),
341            "Should have input buffer param"
342        );
343        assert!(
344            cuda.contains("output_buffer"),
345            "Should have output buffer param"
346        );
347
348        // Check preamble
349        assert!(
350            cuda.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"),
351            "Should have thread id calculation"
352        );
353        assert!(
354            cuda.contains("MSG_SIZE"),
355            "Should have message size constant"
356        );
357        assert!(cuda.contains("hlc_physical"), "Should have HLC variables");
358        assert!(
359            cuda.contains("hlc_logical"),
360            "Should have HLC logical counter"
361        );
362
363        // Check message loop
364        assert!(cuda.contains("while (true)"), "Should have persistent loop");
365        assert!(
366            cuda.contains("atomicAdd(&control->should_terminate, 0)"),
367            "Should check termination"
368        );
369        assert!(
370            cuda.contains("atomicAdd(&control->is_active, 0)"),
371            "Should check is_active"
372        );
373
374        // Check handler code is embedded
375        assert!(
376            cuda.contains("// === USER HANDLER CODE ==="),
377            "Should have handler marker"
378        );
379        assert!(cuda.contains("value * 2.0"), "Should contain handler logic");
380        assert!(
381            cuda.contains("// === END HANDLER CODE ==="),
382            "Should have end marker"
383        );
384
385        // Check epilogue
386        assert!(
387            cuda.contains("atomicExch(&control->has_terminated, 1)"),
388            "Should mark terminated"
389        );
390
391        println!("Generated ring kernel:\n{}", cuda);
392    }
393
394    #[test]
395    fn test_ring_kernel_with_k2k() {
396        let handler: syn::ItemFn = parse_quote! {
397            fn forward(msg: f32) -> f32 {
398                msg
399            }
400        };
401
402        let config = RingKernelConfig::new("forwarder")
403            .with_block_size(64)
404            .with_k2k(true)
405            .with_hlc(true);
406
407        let result = transpile_ring_kernel(&handler, &config);
408        assert!(result.is_ok(), "Should transpile K2K kernel: {:?}", result);
409
410        let cuda = result.unwrap();
411
412        // Check K2K structs
413        assert!(
414            cuda.contains("K2KRoutingTable"),
415            "Should have K2K routing table"
416        );
417        assert!(cuda.contains("K2KRoute"), "Should have K2K route struct");
418
419        // Check K2K params in signature
420        assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
421        assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
422        assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
423
424        println!("Generated K2K ring kernel:\n{}", cuda);
425    }
426
427    #[test]
428    fn test_ring_kernel_config_defaults() {
429        let config = RingKernelConfig::default();
430        assert_eq!(config.block_size, 128);
431        assert_eq!(config.queue_capacity, 1024);
432        assert!(config.enable_hlc);
433        assert!(!config.enable_k2k);
434    }
435
436    #[test]
437    fn test_ring_kernel_intrinsic_availability() {
438        // Test that ring kernel intrinsics are properly exported
439        assert!(RingKernelIntrinsic::from_name("is_active").is_some());
440        assert!(RingKernelIntrinsic::from_name("should_terminate").is_some());
441        assert!(RingKernelIntrinsic::from_name("hlc_tick").is_some());
442        assert!(RingKernelIntrinsic::from_name("enqueue_response").is_some());
443        assert!(RingKernelIntrinsic::from_name("k2k_send").is_some());
444        assert!(RingKernelIntrinsic::from_name("nanosleep").is_some());
445    }
446
447    #[test]
448    fn test_handler_signature_parsing() {
449        let func: syn::ItemFn = parse_quote! {
450            fn handle(ctx: &RingContext, msg: &MyMessage) -> MyResponse {
451                MyResponse { value: msg.value * 2.0 }
452            }
453        };
454
455        let mapper = TypeMapper::new();
456        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
457
458        assert_eq!(sig.name, "handle");
459        assert!(sig.has_context);
460        assert!(sig.message_param.is_some());
461        assert!(sig.has_response());
462    }
463
464    #[test]
465    fn test_handler_with_context_methods() {
466        // Handler that uses context methods
467        let handler: syn::ItemFn = parse_quote! {
468            fn process(ctx: &RingContext, value: f32) -> f32 {
469                let tid = ctx.thread_id();
470                let result = value * 2.0;
471                ctx.sync_threads();
472                result
473            }
474        };
475
476        let config = RingKernelConfig::new("with_context")
477            .with_block_size(128)
478            .with_hlc(true);
479
480        let result = transpile_ring_kernel(&handler, &config);
481        assert!(
482            result.is_ok(),
483            "Should transpile handler with context: {:?}",
484            result
485        );
486
487        let cuda = result.unwrap();
488
489        // Context methods should be inlined
490        assert!(
491            cuda.contains("threadIdx.x"),
492            "ctx.thread_id() should become threadIdx.x"
493        );
494        assert!(
495            cuda.contains("__syncthreads()"),
496            "ctx.sync_threads() should become __syncthreads()"
497        );
498
499        println!("Generated handler with context:\n{}", cuda);
500    }
501
502    #[test]
503    fn test_handler_with_message_param() {
504        let handler: syn::ItemFn = parse_quote! {
505            fn process_msg(msg: &Message, scale: f32) -> f32 {
506                msg.value * scale
507            }
508        };
509
510        let config = RingKernelConfig::new("msg_handler");
511        let result = transpile_ring_kernel(&handler, &config);
512        assert!(
513            result.is_ok(),
514            "Should transpile handler with message: {:?}",
515            result
516        );
517
518        let cuda = result.unwrap();
519        // Should have message deserialization comment
520        assert!(cuda.contains("Message"), "Should reference Message type");
521
522        println!("Generated message handler:\n{}", cuda);
523    }
524
525    #[test]
526    fn test_context_method_mappings() {
527        // Test that context methods are properly mapped
528        assert!(ContextMethod::from_name("thread_id").is_some());
529        assert!(ContextMethod::from_name("sync_threads").is_some());
530        assert!(ContextMethod::from_name("global_thread_id").is_some());
531        assert!(ContextMethod::from_name("atomic_add").is_some());
532        assert!(ContextMethod::from_name("lane_id").is_some());
533        assert!(ContextMethod::from_name("warp_id").is_some());
534
535        // Test CUDA output
536        assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
537        assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
538        assert_eq!(
539            ContextMethod::GlobalThreadId.to_cuda(&[]),
540            "(blockIdx.x * blockDim.x + threadIdx.x)"
541        );
542    }
543
544    #[test]
545    fn test_message_type_registration() {
546        let mut registry = MessageTypeRegistry::new();
547
548        registry.register_message(MessageTypeInfo {
549            name: "InputMsg".to_string(),
550            size: 16,
551            fields: vec![
552                ("id".to_string(), "unsigned long long".to_string()),
553                ("value".to_string(), "float".to_string()),
554            ],
555        });
556
557        registry.register_response(MessageTypeInfo {
558            name: "OutputMsg".to_string(),
559            size: 8,
560            fields: vec![("result".to_string(), "float".to_string())],
561        });
562
563        let structs = registry.generate_structs();
564        assert!(structs.contains("struct InputMsg"));
565        assert!(structs.contains("struct OutputMsg"));
566        assert!(structs.contains("unsigned long long id"));
567        assert!(structs.contains("float result"));
568    }
569
570    #[test]
571    fn test_full_handler_integration() {
572        // Complete handler with all features
573        let handler: syn::ItemFn = parse_quote! {
574            fn full_handler(ctx: &RingContext, msg: &Request) -> Response {
575                let tid = ctx.global_thread_id();
576                ctx.sync_threads();
577                let result = msg.value * 2.0;
578                Response { value: result, id: tid as u64 }
579            }
580        };
581
582        let config = RingKernelConfig::new("full")
583            .with_block_size(256)
584            .with_queue_capacity(2048)
585            .with_hlc(true)
586            .with_k2k(false);
587
588        let result = transpile_ring_kernel(&handler, &config);
589        assert!(
590            result.is_ok(),
591            "Should transpile full handler: {:?}",
592            result
593        );
594
595        let cuda = result.unwrap();
596
597        // Check all expected components
598        assert!(cuda.contains("ring_kernel_full"), "Kernel name");
599        assert!(cuda.contains("ControlBlock"), "ControlBlock struct");
600        assert!(cuda.contains("while (true)"), "Persistent loop");
601        assert!(cuda.contains("threadIdx.x"), "Thread index");
602        assert!(cuda.contains("__syncthreads()"), "Sync threads");
603        assert!(
604            cuda.contains("blockIdx.x * blockDim.x + threadIdx.x"),
605            "Global thread ID"
606        );
607        assert!(cuda.contains("has_terminated"), "Termination marking");
608
609        println!("Full handler integration:\n{}", cuda);
610    }
611
612    #[test]
613    fn test_k2k_handler_integration() {
614        // Handler that uses K2K messaging
615        let handler: syn::ItemFn = parse_quote! {
616            fn k2k_handler(ctx: &RingContext, msg: &InputMsg) -> OutputMsg {
617                let tid = ctx.global_thread_id();
618
619                // Process incoming message
620                let result = msg.value * 2.0;
621
622                // Build response
623                OutputMsg { result: result, source_id: tid as u64 }
624            }
625        };
626
627        let config = RingKernelConfig::new("k2k_processor")
628            .with_block_size(128)
629            .with_queue_capacity(1024)
630            .with_hlc(true)
631            .with_k2k(true);
632
633        let result = transpile_ring_kernel(&handler, &config);
634        assert!(result.is_ok(), "Should transpile K2K handler: {:?}", result);
635
636        let cuda = result.unwrap();
637
638        // Check K2K-specific components
639        assert!(
640            cuda.contains("K2KRoutingTable"),
641            "Should have K2KRoutingTable"
642        );
643        assert!(cuda.contains("K2KRoute"), "Should have K2KRoute struct");
644        assert!(
645            cuda.contains("K2KInboxHeader"),
646            "Should have K2KInboxHeader"
647        );
648        assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
649        assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
650        assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
651        assert!(cuda.contains("k2k_send"), "Should have k2k_send function");
652        assert!(
653            cuda.contains("k2k_try_recv"),
654            "Should have k2k_try_recv function"
655        );
656        assert!(cuda.contains("k2k_peek"), "Should have k2k_peek function");
657        assert!(
658            cuda.contains("k2k_pending_count"),
659            "Should have k2k_pending_count function"
660        );
661
662        println!("K2K handler integration:\n{}", cuda);
663    }
664
665    #[test]
666    fn test_all_kernel_types_comparison() {
667        // Stencil kernel
668        let stencil_func: syn::ItemFn = parse_quote! {
669            fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
670                let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * p[pos.idx()];
671                p_prev[pos.idx()] = 2.0 * p[pos.idx()] - p_prev[pos.idx()] + c2 * lap;
672            }
673        };
674
675        let stencil_config = StencilConfig::new("fdtd")
676            .with_tile_size(16, 16)
677            .with_halo(1);
678
679        let stencil_cuda = transpile_stencil_kernel(&stencil_func, &stencil_config).unwrap();
680        assert!(
681            !stencil_cuda.contains("GridPos"),
682            "Stencil should remove GridPos"
683        );
684        assert!(
685            stencil_cuda.contains("buffer_width"),
686            "Stencil should have buffer_width"
687        );
688
689        // Global kernel
690        let global_func: syn::ItemFn = parse_quote! {
691            fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
692                let idx = block_idx_x() * block_dim_x() + thread_idx_x();
693                if idx >= n { return; }
694                y[idx as usize] = a * x[idx as usize] + y[idx as usize];
695            }
696        };
697
698        let global_cuda = transpile_global_kernel(&global_func).unwrap();
699        assert!(global_cuda.contains("__global__"), "Global kernel marker");
700        assert!(global_cuda.contains("blockIdx.x"), "CUDA block index");
701
702        // Ring kernel
703        let ring_func: syn::ItemFn = parse_quote! {
704            fn process(msg: f32) -> f32 {
705                msg * 2.0
706            }
707        };
708
709        let ring_config = RingKernelConfig::new("process")
710            .with_block_size(128)
711            .with_hlc(true);
712
713        let ring_cuda = transpile_ring_kernel(&ring_func, &ring_config).unwrap();
714        assert!(
715            ring_cuda.contains("ControlBlock"),
716            "Ring kernel ControlBlock"
717        );
718        assert!(ring_cuda.contains("while (true)"), "Persistent loop");
719        assert!(ring_cuda.contains("has_terminated"), "Termination");
720
721        println!("=== Stencil Kernel ===\n{}\n", stencil_cuda);
722        println!("=== Global Kernel ===\n{}\n", global_cuda);
723        println!("=== Ring Kernel ===\n{}\n", ring_cuda);
724    }
725}