1pub mod dsl;
40pub mod handler;
41mod intrinsics;
42pub mod loops;
43pub mod persistent_fdtd;
44pub mod reduction_intrinsics;
45pub mod ring_kernel;
46pub mod shared;
47mod stencil;
48mod transpiler;
49mod types;
50mod validation;
51
52pub use handler::{
53 generate_cuda_struct, generate_message_deser, generate_response_ser, ContextMethod,
54 HandlerCodegenConfig, HandlerParam, HandlerParamKind, HandlerReturnType, HandlerSignature,
55 MessageTypeInfo, MessageTypeRegistry,
56};
57pub use intrinsics::{GpuIntrinsic, IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
58pub use loops::{LoopPattern, RangeInfo};
59pub use persistent_fdtd::{generate_persistent_fdtd_kernel, PersistentFdtdConfig};
60pub use reduction_intrinsics::{
61 generate_inline_block_reduce, generate_inline_grid_reduce,
62 generate_inline_reduce_and_broadcast, generate_reduction_helpers, transpile_reduction_call,
63 ReductionCodegenConfig, ReductionIntrinsic, ReductionOp as CodegenReductionOp,
64};
65pub use ring_kernel::{
66 generate_control_block_struct, generate_hlc_struct, generate_k2k_structs,
67 KernelReductionConfig, RingKernelConfig,
68};
69pub use shared::{SharedArray, SharedMemoryConfig, SharedMemoryDecl, SharedTile};
70pub use stencil::{Grid, GridPos, StencilConfig, StencilLaunchConfig};
71pub use transpiler::{transpile_function, CudaTranspiler, SharedVarInfo};
72pub use types::{
73 get_slice_element_type, is_control_block_type, is_mutable_reference, is_ring_context_type,
74 ring_kernel_type_mapper, CudaType, RingKernelParamKind, TypeMapper,
75};
76pub use validation::{
77 is_simple_assignment, validate_function, validate_function_with_mode,
78 validate_stencil_signature, ValidationError, ValidationMode,
79};
80
81use thiserror::Error;
82
83#[derive(Error, Debug)]
85pub enum TranspileError {
86 #[error("Parse error: {0}")]
88 Parse(String),
89
90 #[error("Validation error: {0}")]
92 Validation(#[from] ValidationError),
93
94 #[error("Unsupported construct: {0}")]
96 Unsupported(String),
97
98 #[error("Type error: {0}")]
100 Type(String),
101}
102
103pub type Result<T> = std::result::Result<T, TranspileError>;
105
106pub fn transpile_stencil_kernel(func: &syn::ItemFn, config: &StencilConfig) -> Result<String> {
121 validate_function(func)?;
123
124 let mut transpiler = CudaTranspiler::new(config.clone());
126
127 transpiler.transpile_stencil(func)
129}
130
131pub fn transpile_device_function(func: &syn::ItemFn) -> Result<String> {
135 validate_function(func)?;
136 transpile_function(func)
137}
138
139pub fn transpile_global_kernel(func: &syn::ItemFn) -> Result<String> {
164 validate_function(func)?;
165 let mut transpiler = CudaTranspiler::new_generic();
166 transpiler.transpile_generic_kernel(func)
167}
168
169pub fn transpile_ring_kernel(handler: &syn::ItemFn, config: &RingKernelConfig) -> Result<String> {
194 validate_function_with_mode(handler, ValidationMode::Generic)?;
197
198 let mut transpiler = CudaTranspiler::with_mode(ValidationMode::Generic);
200
201 transpiler.transpile_ring_kernel(handler, config)
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use syn::parse_quote;
209
210 #[test]
211 fn test_simple_function_transpile() {
212 let func: syn::ItemFn = parse_quote! {
213 fn add(a: f32, b: f32) -> f32 {
214 a + b
215 }
216 };
217
218 let result = transpile_device_function(&func);
219 assert!(
220 result.is_ok(),
221 "Should transpile simple function: {:?}",
222 result
223 );
224
225 let cuda = result.unwrap();
226 assert!(cuda.contains("float"), "Should contain CUDA float type");
227 assert!(cuda.contains("a + b"), "Should contain the expression");
228 }
229
230 #[test]
231 fn test_global_kernel_transpile() {
232 let func: syn::ItemFn = parse_quote! {
233 fn exchange_halos(buffer: &mut [f32], copies: &[u32], num_copies: i32) {
234 let idx = block_idx_x() * block_dim_x() + thread_idx_x();
235 if idx >= num_copies {
236 return;
237 }
238 let src = copies[idx * 2] as usize;
239 let dst = copies[idx * 2 + 1] as usize;
240 buffer[dst] = buffer[src];
241 }
242 };
243
244 let result = transpile_global_kernel(&func);
245 assert!(
246 result.is_ok(),
247 "Should transpile global kernel: {:?}",
248 result
249 );
250
251 let cuda = result.unwrap();
252 assert!(
253 cuda.contains("extern \"C\" __global__"),
254 "Should be global kernel"
255 );
256 assert!(cuda.contains("exchange_halos"), "Should have kernel name");
257 assert!(cuda.contains("blockIdx.x"), "Should contain blockIdx.x");
258 assert!(cuda.contains("blockDim.x"), "Should contain blockDim.x");
259 assert!(cuda.contains("threadIdx.x"), "Should contain threadIdx.x");
260 assert!(cuda.contains("return"), "Should have early return");
261
262 println!("Generated global kernel:\n{}", cuda);
263 }
264
265 #[test]
266 fn test_stencil_kernel_transpile() {
267 let func: syn::ItemFn = parse_quote! {
268 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
269 let curr = p[pos.idx()];
270 let prev = p_prev[pos.idx()];
271 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
272 p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
273 }
274 };
275
276 let config = StencilConfig {
277 id: "fdtd".to_string(),
278 grid: Grid::Grid2D,
279 tile_size: (16, 16),
280 halo: 1,
281 };
282
283 let result = transpile_stencil_kernel(&func, &config);
284 assert!(
285 result.is_ok(),
286 "Should transpile stencil kernel: {:?}",
287 result
288 );
289
290 let cuda = result.unwrap();
291 assert!(cuda.contains("__global__"), "Should be a CUDA kernel");
292 assert!(cuda.contains("threadIdx"), "Should use thread indices");
293 }
294
295 #[test]
296 fn test_ring_kernel_transpile() {
297 let handler: syn::ItemFn = parse_quote! {
299 fn process(value: f32) -> f32 {
300 let result = value * 2.0;
301 result
302 }
303 };
304
305 let config = RingKernelConfig::new("processor")
306 .with_block_size(128)
307 .with_queue_capacity(1024)
308 .with_hlc(true);
309
310 let result = transpile_ring_kernel(&handler, &config);
311 assert!(result.is_ok(), "Should transpile ring kernel: {:?}", result);
312
313 let cuda = result.unwrap();
314
315 assert!(
317 cuda.contains("struct __align__(128) ControlBlock"),
318 "Should have ControlBlock struct"
319 );
320 assert!(cuda.contains("is_active"), "Should have is_active field");
321 assert!(
322 cuda.contains("should_terminate"),
323 "Should have should_terminate field"
324 );
325 assert!(
326 cuda.contains("messages_processed"),
327 "Should have messages_processed field"
328 );
329
330 assert!(
332 cuda.contains("extern \"C\" __global__ void ring_kernel_processor"),
333 "Should have correct kernel name"
334 );
335 assert!(
336 cuda.contains("ControlBlock* __restrict__ control"),
337 "Should have control block param"
338 );
339 assert!(
340 cuda.contains("input_buffer"),
341 "Should have input buffer param"
342 );
343 assert!(
344 cuda.contains("output_buffer"),
345 "Should have output buffer param"
346 );
347
348 assert!(
350 cuda.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"),
351 "Should have thread id calculation"
352 );
353 assert!(
354 cuda.contains("MSG_SIZE"),
355 "Should have message size constant"
356 );
357 assert!(cuda.contains("hlc_physical"), "Should have HLC variables");
358 assert!(
359 cuda.contains("hlc_logical"),
360 "Should have HLC logical counter"
361 );
362
363 assert!(cuda.contains("while (true)"), "Should have persistent loop");
365 assert!(
366 cuda.contains("atomicAdd(&control->should_terminate, 0)"),
367 "Should check termination"
368 );
369 assert!(
370 cuda.contains("atomicAdd(&control->is_active, 0)"),
371 "Should check is_active"
372 );
373
374 assert!(
376 cuda.contains("// === USER HANDLER CODE ==="),
377 "Should have handler marker"
378 );
379 assert!(cuda.contains("value * 2.0"), "Should contain handler logic");
380 assert!(
381 cuda.contains("// === END HANDLER CODE ==="),
382 "Should have end marker"
383 );
384
385 assert!(
387 cuda.contains("atomicExch(&control->has_terminated, 1)"),
388 "Should mark terminated"
389 );
390
391 println!("Generated ring kernel:\n{}", cuda);
392 }
393
394 #[test]
395 fn test_ring_kernel_with_k2k() {
396 let handler: syn::ItemFn = parse_quote! {
397 fn forward(msg: f32) -> f32 {
398 msg
399 }
400 };
401
402 let config = RingKernelConfig::new("forwarder")
403 .with_block_size(64)
404 .with_k2k(true)
405 .with_hlc(true);
406
407 let result = transpile_ring_kernel(&handler, &config);
408 assert!(result.is_ok(), "Should transpile K2K kernel: {:?}", result);
409
410 let cuda = result.unwrap();
411
412 assert!(
414 cuda.contains("K2KRoutingTable"),
415 "Should have K2K routing table"
416 );
417 assert!(cuda.contains("K2KRoute"), "Should have K2K route struct");
418
419 assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
421 assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
422 assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
423
424 println!("Generated K2K ring kernel:\n{}", cuda);
425 }
426
427 #[test]
428 fn test_ring_kernel_config_defaults() {
429 let config = RingKernelConfig::default();
430 assert_eq!(config.block_size, 128);
431 assert_eq!(config.queue_capacity, 1024);
432 assert!(config.enable_hlc);
433 assert!(!config.enable_k2k);
434 }
435
436 #[test]
437 fn test_ring_kernel_intrinsic_availability() {
438 assert!(RingKernelIntrinsic::from_name("is_active").is_some());
440 assert!(RingKernelIntrinsic::from_name("should_terminate").is_some());
441 assert!(RingKernelIntrinsic::from_name("hlc_tick").is_some());
442 assert!(RingKernelIntrinsic::from_name("enqueue_response").is_some());
443 assert!(RingKernelIntrinsic::from_name("k2k_send").is_some());
444 assert!(RingKernelIntrinsic::from_name("nanosleep").is_some());
445 }
446
447 #[test]
448 fn test_handler_signature_parsing() {
449 let func: syn::ItemFn = parse_quote! {
450 fn handle(ctx: &RingContext, msg: &MyMessage) -> MyResponse {
451 MyResponse { value: msg.value * 2.0 }
452 }
453 };
454
455 let mapper = TypeMapper::new();
456 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
457
458 assert_eq!(sig.name, "handle");
459 assert!(sig.has_context);
460 assert!(sig.message_param.is_some());
461 assert!(sig.has_response());
462 }
463
464 #[test]
465 fn test_handler_with_context_methods() {
466 let handler: syn::ItemFn = parse_quote! {
468 fn process(ctx: &RingContext, value: f32) -> f32 {
469 let tid = ctx.thread_id();
470 let result = value * 2.0;
471 ctx.sync_threads();
472 result
473 }
474 };
475
476 let config = RingKernelConfig::new("with_context")
477 .with_block_size(128)
478 .with_hlc(true);
479
480 let result = transpile_ring_kernel(&handler, &config);
481 assert!(
482 result.is_ok(),
483 "Should transpile handler with context: {:?}",
484 result
485 );
486
487 let cuda = result.unwrap();
488
489 assert!(
491 cuda.contains("threadIdx.x"),
492 "ctx.thread_id() should become threadIdx.x"
493 );
494 assert!(
495 cuda.contains("__syncthreads()"),
496 "ctx.sync_threads() should become __syncthreads()"
497 );
498
499 println!("Generated handler with context:\n{}", cuda);
500 }
501
502 #[test]
503 fn test_handler_with_message_param() {
504 let handler: syn::ItemFn = parse_quote! {
505 fn process_msg(msg: &Message, scale: f32) -> f32 {
506 msg.value * scale
507 }
508 };
509
510 let config = RingKernelConfig::new("msg_handler");
511 let result = transpile_ring_kernel(&handler, &config);
512 assert!(
513 result.is_ok(),
514 "Should transpile handler with message: {:?}",
515 result
516 );
517
518 let cuda = result.unwrap();
519 assert!(cuda.contains("Message"), "Should reference Message type");
521
522 println!("Generated message handler:\n{}", cuda);
523 }
524
525 #[test]
526 fn test_context_method_mappings() {
527 assert!(ContextMethod::from_name("thread_id").is_some());
529 assert!(ContextMethod::from_name("sync_threads").is_some());
530 assert!(ContextMethod::from_name("global_thread_id").is_some());
531 assert!(ContextMethod::from_name("atomic_add").is_some());
532 assert!(ContextMethod::from_name("lane_id").is_some());
533 assert!(ContextMethod::from_name("warp_id").is_some());
534
535 assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
537 assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
538 assert_eq!(
539 ContextMethod::GlobalThreadId.to_cuda(&[]),
540 "(blockIdx.x * blockDim.x + threadIdx.x)"
541 );
542 }
543
544 #[test]
545 fn test_message_type_registration() {
546 let mut registry = MessageTypeRegistry::new();
547
548 registry.register_message(MessageTypeInfo {
549 name: "InputMsg".to_string(),
550 size: 16,
551 fields: vec![
552 ("id".to_string(), "unsigned long long".to_string()),
553 ("value".to_string(), "float".to_string()),
554 ],
555 });
556
557 registry.register_response(MessageTypeInfo {
558 name: "OutputMsg".to_string(),
559 size: 8,
560 fields: vec![("result".to_string(), "float".to_string())],
561 });
562
563 let structs = registry.generate_structs();
564 assert!(structs.contains("struct InputMsg"));
565 assert!(structs.contains("struct OutputMsg"));
566 assert!(structs.contains("unsigned long long id"));
567 assert!(structs.contains("float result"));
568 }
569
570 #[test]
571 fn test_full_handler_integration() {
572 let handler: syn::ItemFn = parse_quote! {
574 fn full_handler(ctx: &RingContext, msg: &Request) -> Response {
575 let tid = ctx.global_thread_id();
576 ctx.sync_threads();
577 let result = msg.value * 2.0;
578 Response { value: result, id: tid as u64 }
579 }
580 };
581
582 let config = RingKernelConfig::new("full")
583 .with_block_size(256)
584 .with_queue_capacity(2048)
585 .with_hlc(true)
586 .with_k2k(false);
587
588 let result = transpile_ring_kernel(&handler, &config);
589 assert!(
590 result.is_ok(),
591 "Should transpile full handler: {:?}",
592 result
593 );
594
595 let cuda = result.unwrap();
596
597 assert!(cuda.contains("ring_kernel_full"), "Kernel name");
599 assert!(cuda.contains("ControlBlock"), "ControlBlock struct");
600 assert!(cuda.contains("while (true)"), "Persistent loop");
601 assert!(cuda.contains("threadIdx.x"), "Thread index");
602 assert!(cuda.contains("__syncthreads()"), "Sync threads");
603 assert!(
604 cuda.contains("blockIdx.x * blockDim.x + threadIdx.x"),
605 "Global thread ID"
606 );
607 assert!(cuda.contains("has_terminated"), "Termination marking");
608
609 println!("Full handler integration:\n{}", cuda);
610 }
611
612 #[test]
613 fn test_k2k_handler_integration() {
614 let handler: syn::ItemFn = parse_quote! {
616 fn k2k_handler(ctx: &RingContext, msg: &InputMsg) -> OutputMsg {
617 let tid = ctx.global_thread_id();
618
619 let result = msg.value * 2.0;
621
622 OutputMsg { result: result, source_id: tid as u64 }
624 }
625 };
626
627 let config = RingKernelConfig::new("k2k_processor")
628 .with_block_size(128)
629 .with_queue_capacity(1024)
630 .with_hlc(true)
631 .with_k2k(true);
632
633 let result = transpile_ring_kernel(&handler, &config);
634 assert!(result.is_ok(), "Should transpile K2K handler: {:?}", result);
635
636 let cuda = result.unwrap();
637
638 assert!(
640 cuda.contains("K2KRoutingTable"),
641 "Should have K2KRoutingTable"
642 );
643 assert!(cuda.contains("K2KRoute"), "Should have K2KRoute struct");
644 assert!(
645 cuda.contains("K2KInboxHeader"),
646 "Should have K2KInboxHeader"
647 );
648 assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
649 assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
650 assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
651 assert!(cuda.contains("k2k_send"), "Should have k2k_send function");
652 assert!(
653 cuda.contains("k2k_try_recv"),
654 "Should have k2k_try_recv function"
655 );
656 assert!(cuda.contains("k2k_peek"), "Should have k2k_peek function");
657 assert!(
658 cuda.contains("k2k_pending_count"),
659 "Should have k2k_pending_count function"
660 );
661
662 println!("K2K handler integration:\n{}", cuda);
663 }
664
665 #[test]
666 fn test_all_kernel_types_comparison() {
667 let stencil_func: syn::ItemFn = parse_quote! {
669 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
670 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * p[pos.idx()];
671 p_prev[pos.idx()] = 2.0 * p[pos.idx()] - p_prev[pos.idx()] + c2 * lap;
672 }
673 };
674
675 let stencil_config = StencilConfig::new("fdtd")
676 .with_tile_size(16, 16)
677 .with_halo(1);
678
679 let stencil_cuda = transpile_stencil_kernel(&stencil_func, &stencil_config).unwrap();
680 assert!(
681 !stencil_cuda.contains("GridPos"),
682 "Stencil should remove GridPos"
683 );
684 assert!(
685 stencil_cuda.contains("buffer_width"),
686 "Stencil should have buffer_width"
687 );
688
689 let global_func: syn::ItemFn = parse_quote! {
691 fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
692 let idx = block_idx_x() * block_dim_x() + thread_idx_x();
693 if idx >= n { return; }
694 y[idx as usize] = a * x[idx as usize] + y[idx as usize];
695 }
696 };
697
698 let global_cuda = transpile_global_kernel(&global_func).unwrap();
699 assert!(global_cuda.contains("__global__"), "Global kernel marker");
700 assert!(global_cuda.contains("blockIdx.x"), "CUDA block index");
701
702 let ring_func: syn::ItemFn = parse_quote! {
704 fn process(msg: f32) -> f32 {
705 msg * 2.0
706 }
707 };
708
709 let ring_config = RingKernelConfig::new("process")
710 .with_block_size(128)
711 .with_hlc(true);
712
713 let ring_cuda = transpile_ring_kernel(&ring_func, &ring_config).unwrap();
714 assert!(
715 ring_cuda.contains("ControlBlock"),
716 "Ring kernel ControlBlock"
717 );
718 assert!(ring_cuda.contains("while (true)"), "Persistent loop");
719 assert!(ring_cuda.contains("has_terminated"), "Termination");
720
721 println!("=== Stencil Kernel ===\n{}\n", stencil_cuda);
722 println!("=== Global Kernel ===\n{}\n", global_cuda);
723 println!("=== Ring Kernel ===\n{}\n", ring_cuda);
724 }
725}