1pub mod dsl;
40pub mod handler;
41mod intrinsics;
42pub mod loops;
43pub mod persistent_fdtd;
44pub mod ring_kernel;
45pub mod shared;
46mod stencil;
47mod transpiler;
48mod types;
49mod validation;
50
51pub use handler::{
52 generate_cuda_struct, generate_message_deser, generate_response_ser, ContextMethod,
53 HandlerCodegenConfig, HandlerParam, HandlerParamKind, HandlerReturnType, HandlerSignature,
54 MessageTypeInfo, MessageTypeRegistry,
55};
56pub use intrinsics::{GpuIntrinsic, IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
57pub use loops::{LoopPattern, RangeInfo};
58pub use persistent_fdtd::{generate_persistent_fdtd_kernel, PersistentFdtdConfig};
59pub use ring_kernel::{
60 generate_control_block_struct, generate_hlc_struct, generate_k2k_structs, RingKernelConfig,
61};
62pub use shared::{SharedArray, SharedMemoryConfig, SharedMemoryDecl, SharedTile};
63pub use stencil::{Grid, GridPos, StencilConfig, StencilLaunchConfig};
64pub use transpiler::{transpile_function, CudaTranspiler, SharedVarInfo};
65pub use types::{
66 get_slice_element_type, is_control_block_type, is_mutable_reference, is_ring_context_type,
67 ring_kernel_type_mapper, CudaType, RingKernelParamKind, TypeMapper,
68};
69pub use validation::{
70 is_simple_assignment, validate_function, validate_function_with_mode,
71 validate_stencil_signature, ValidationError, ValidationMode,
72};
73
74use thiserror::Error;
75
76#[derive(Error, Debug)]
78pub enum TranspileError {
79 #[error("Parse error: {0}")]
81 Parse(String),
82
83 #[error("Validation error: {0}")]
85 Validation(#[from] ValidationError),
86
87 #[error("Unsupported construct: {0}")]
89 Unsupported(String),
90
91 #[error("Type error: {0}")]
93 Type(String),
94}
95
96pub type Result<T> = std::result::Result<T, TranspileError>;
98
99pub fn transpile_stencil_kernel(func: &syn::ItemFn, config: &StencilConfig) -> Result<String> {
114 validate_function(func)?;
116
117 let mut transpiler = CudaTranspiler::new(config.clone());
119
120 transpiler.transpile_stencil(func)
122}
123
124pub fn transpile_device_function(func: &syn::ItemFn) -> Result<String> {
128 validate_function(func)?;
129 transpile_function(func)
130}
131
132pub fn transpile_global_kernel(func: &syn::ItemFn) -> Result<String> {
157 validate_function(func)?;
158 let mut transpiler = CudaTranspiler::new_generic();
159 transpiler.transpile_generic_kernel(func)
160}
161
162pub fn transpile_ring_kernel(handler: &syn::ItemFn, config: &RingKernelConfig) -> Result<String> {
187 validate_function_with_mode(handler, ValidationMode::Generic)?;
190
191 let mut transpiler = CudaTranspiler::with_mode(ValidationMode::Generic);
193
194 transpiler.transpile_ring_kernel(handler, config)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use syn::parse_quote;
202
203 #[test]
204 fn test_simple_function_transpile() {
205 let func: syn::ItemFn = parse_quote! {
206 fn add(a: f32, b: f32) -> f32 {
207 a + b
208 }
209 };
210
211 let result = transpile_device_function(&func);
212 assert!(
213 result.is_ok(),
214 "Should transpile simple function: {:?}",
215 result
216 );
217
218 let cuda = result.unwrap();
219 assert!(cuda.contains("float"), "Should contain CUDA float type");
220 assert!(cuda.contains("a + b"), "Should contain the expression");
221 }
222
223 #[test]
224 fn test_global_kernel_transpile() {
225 let func: syn::ItemFn = parse_quote! {
226 fn exchange_halos(buffer: &mut [f32], copies: &[u32], num_copies: i32) {
227 let idx = block_idx_x() * block_dim_x() + thread_idx_x();
228 if idx >= num_copies {
229 return;
230 }
231 let src = copies[idx * 2] as usize;
232 let dst = copies[idx * 2 + 1] as usize;
233 buffer[dst] = buffer[src];
234 }
235 };
236
237 let result = transpile_global_kernel(&func);
238 assert!(
239 result.is_ok(),
240 "Should transpile global kernel: {:?}",
241 result
242 );
243
244 let cuda = result.unwrap();
245 assert!(
246 cuda.contains("extern \"C\" __global__"),
247 "Should be global kernel"
248 );
249 assert!(cuda.contains("exchange_halos"), "Should have kernel name");
250 assert!(cuda.contains("blockIdx.x"), "Should contain blockIdx.x");
251 assert!(cuda.contains("blockDim.x"), "Should contain blockDim.x");
252 assert!(cuda.contains("threadIdx.x"), "Should contain threadIdx.x");
253 assert!(cuda.contains("return"), "Should have early return");
254
255 println!("Generated global kernel:\n{}", cuda);
256 }
257
258 #[test]
259 fn test_stencil_kernel_transpile() {
260 let func: syn::ItemFn = parse_quote! {
261 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
262 let curr = p[pos.idx()];
263 let prev = p_prev[pos.idx()];
264 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
265 p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
266 }
267 };
268
269 let config = StencilConfig {
270 id: "fdtd".to_string(),
271 grid: Grid::Grid2D,
272 tile_size: (16, 16),
273 halo: 1,
274 };
275
276 let result = transpile_stencil_kernel(&func, &config);
277 assert!(
278 result.is_ok(),
279 "Should transpile stencil kernel: {:?}",
280 result
281 );
282
283 let cuda = result.unwrap();
284 assert!(cuda.contains("__global__"), "Should be a CUDA kernel");
285 assert!(cuda.contains("threadIdx"), "Should use thread indices");
286 }
287
288 #[test]
289 fn test_ring_kernel_transpile() {
290 let handler: syn::ItemFn = parse_quote! {
292 fn process(value: f32) -> f32 {
293 let result = value * 2.0;
294 result
295 }
296 };
297
298 let config = RingKernelConfig::new("processor")
299 .with_block_size(128)
300 .with_queue_capacity(1024)
301 .with_hlc(true);
302
303 let result = transpile_ring_kernel(&handler, &config);
304 assert!(result.is_ok(), "Should transpile ring kernel: {:?}", result);
305
306 let cuda = result.unwrap();
307
308 assert!(
310 cuda.contains("struct __align__(128) ControlBlock"),
311 "Should have ControlBlock struct"
312 );
313 assert!(cuda.contains("is_active"), "Should have is_active field");
314 assert!(
315 cuda.contains("should_terminate"),
316 "Should have should_terminate field"
317 );
318 assert!(
319 cuda.contains("messages_processed"),
320 "Should have messages_processed field"
321 );
322
323 assert!(
325 cuda.contains("extern \"C\" __global__ void ring_kernel_processor"),
326 "Should have correct kernel name"
327 );
328 assert!(
329 cuda.contains("ControlBlock* __restrict__ control"),
330 "Should have control block param"
331 );
332 assert!(
333 cuda.contains("input_buffer"),
334 "Should have input buffer param"
335 );
336 assert!(
337 cuda.contains("output_buffer"),
338 "Should have output buffer param"
339 );
340
341 assert!(
343 cuda.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"),
344 "Should have thread id calculation"
345 );
346 assert!(
347 cuda.contains("MSG_SIZE"),
348 "Should have message size constant"
349 );
350 assert!(cuda.contains("hlc_physical"), "Should have HLC variables");
351 assert!(
352 cuda.contains("hlc_logical"),
353 "Should have HLC logical counter"
354 );
355
356 assert!(cuda.contains("while (true)"), "Should have persistent loop");
358 assert!(
359 cuda.contains("atomicAdd(&control->should_terminate, 0)"),
360 "Should check termination"
361 );
362 assert!(
363 cuda.contains("atomicAdd(&control->is_active, 0)"),
364 "Should check is_active"
365 );
366
367 assert!(
369 cuda.contains("// === USER HANDLER CODE ==="),
370 "Should have handler marker"
371 );
372 assert!(cuda.contains("value * 2.0"), "Should contain handler logic");
373 assert!(
374 cuda.contains("// === END HANDLER CODE ==="),
375 "Should have end marker"
376 );
377
378 assert!(
380 cuda.contains("atomicExch(&control->has_terminated, 1)"),
381 "Should mark terminated"
382 );
383
384 println!("Generated ring kernel:\n{}", cuda);
385 }
386
387 #[test]
388 fn test_ring_kernel_with_k2k() {
389 let handler: syn::ItemFn = parse_quote! {
390 fn forward(msg: f32) -> f32 {
391 msg
392 }
393 };
394
395 let config = RingKernelConfig::new("forwarder")
396 .with_block_size(64)
397 .with_k2k(true)
398 .with_hlc(true);
399
400 let result = transpile_ring_kernel(&handler, &config);
401 assert!(result.is_ok(), "Should transpile K2K kernel: {:?}", result);
402
403 let cuda = result.unwrap();
404
405 assert!(
407 cuda.contains("K2KRoutingTable"),
408 "Should have K2K routing table"
409 );
410 assert!(cuda.contains("K2KRoute"), "Should have K2K route struct");
411
412 assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
414 assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
415 assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
416
417 println!("Generated K2K ring kernel:\n{}", cuda);
418 }
419
420 #[test]
421 fn test_ring_kernel_config_defaults() {
422 let config = RingKernelConfig::default();
423 assert_eq!(config.block_size, 128);
424 assert_eq!(config.queue_capacity, 1024);
425 assert!(config.enable_hlc);
426 assert!(!config.enable_k2k);
427 }
428
429 #[test]
430 fn test_ring_kernel_intrinsic_availability() {
431 assert!(RingKernelIntrinsic::from_name("is_active").is_some());
433 assert!(RingKernelIntrinsic::from_name("should_terminate").is_some());
434 assert!(RingKernelIntrinsic::from_name("hlc_tick").is_some());
435 assert!(RingKernelIntrinsic::from_name("enqueue_response").is_some());
436 assert!(RingKernelIntrinsic::from_name("k2k_send").is_some());
437 assert!(RingKernelIntrinsic::from_name("nanosleep").is_some());
438 }
439
440 #[test]
441 fn test_handler_signature_parsing() {
442 let func: syn::ItemFn = parse_quote! {
443 fn handle(ctx: &RingContext, msg: &MyMessage) -> MyResponse {
444 MyResponse { value: msg.value * 2.0 }
445 }
446 };
447
448 let mapper = TypeMapper::new();
449 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
450
451 assert_eq!(sig.name, "handle");
452 assert!(sig.has_context);
453 assert!(sig.message_param.is_some());
454 assert!(sig.has_response());
455 }
456
457 #[test]
458 fn test_handler_with_context_methods() {
459 let handler: syn::ItemFn = parse_quote! {
461 fn process(ctx: &RingContext, value: f32) -> f32 {
462 let tid = ctx.thread_id();
463 let result = value * 2.0;
464 ctx.sync_threads();
465 result
466 }
467 };
468
469 let config = RingKernelConfig::new("with_context")
470 .with_block_size(128)
471 .with_hlc(true);
472
473 let result = transpile_ring_kernel(&handler, &config);
474 assert!(
475 result.is_ok(),
476 "Should transpile handler with context: {:?}",
477 result
478 );
479
480 let cuda = result.unwrap();
481
482 assert!(
484 cuda.contains("threadIdx.x"),
485 "ctx.thread_id() should become threadIdx.x"
486 );
487 assert!(
488 cuda.contains("__syncthreads()"),
489 "ctx.sync_threads() should become __syncthreads()"
490 );
491
492 println!("Generated handler with context:\n{}", cuda);
493 }
494
495 #[test]
496 fn test_handler_with_message_param() {
497 let handler: syn::ItemFn = parse_quote! {
498 fn process_msg(msg: &Message, scale: f32) -> f32 {
499 msg.value * scale
500 }
501 };
502
503 let config = RingKernelConfig::new("msg_handler");
504 let result = transpile_ring_kernel(&handler, &config);
505 assert!(
506 result.is_ok(),
507 "Should transpile handler with message: {:?}",
508 result
509 );
510
511 let cuda = result.unwrap();
512 assert!(cuda.contains("Message"), "Should reference Message type");
514
515 println!("Generated message handler:\n{}", cuda);
516 }
517
518 #[test]
519 fn test_context_method_mappings() {
520 assert!(ContextMethod::from_name("thread_id").is_some());
522 assert!(ContextMethod::from_name("sync_threads").is_some());
523 assert!(ContextMethod::from_name("global_thread_id").is_some());
524 assert!(ContextMethod::from_name("atomic_add").is_some());
525 assert!(ContextMethod::from_name("lane_id").is_some());
526 assert!(ContextMethod::from_name("warp_id").is_some());
527
528 assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
530 assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
531 assert_eq!(
532 ContextMethod::GlobalThreadId.to_cuda(&[]),
533 "(blockIdx.x * blockDim.x + threadIdx.x)"
534 );
535 }
536
537 #[test]
538 fn test_message_type_registration() {
539 let mut registry = MessageTypeRegistry::new();
540
541 registry.register_message(MessageTypeInfo {
542 name: "InputMsg".to_string(),
543 size: 16,
544 fields: vec![
545 ("id".to_string(), "unsigned long long".to_string()),
546 ("value".to_string(), "float".to_string()),
547 ],
548 });
549
550 registry.register_response(MessageTypeInfo {
551 name: "OutputMsg".to_string(),
552 size: 8,
553 fields: vec![("result".to_string(), "float".to_string())],
554 });
555
556 let structs = registry.generate_structs();
557 assert!(structs.contains("struct InputMsg"));
558 assert!(structs.contains("struct OutputMsg"));
559 assert!(structs.contains("unsigned long long id"));
560 assert!(structs.contains("float result"));
561 }
562
563 #[test]
564 fn test_full_handler_integration() {
565 let handler: syn::ItemFn = parse_quote! {
567 fn full_handler(ctx: &RingContext, msg: &Request) -> Response {
568 let tid = ctx.global_thread_id();
569 ctx.sync_threads();
570 let result = msg.value * 2.0;
571 Response { value: result, id: tid as u64 }
572 }
573 };
574
575 let config = RingKernelConfig::new("full")
576 .with_block_size(256)
577 .with_queue_capacity(2048)
578 .with_hlc(true)
579 .with_k2k(false);
580
581 let result = transpile_ring_kernel(&handler, &config);
582 assert!(
583 result.is_ok(),
584 "Should transpile full handler: {:?}",
585 result
586 );
587
588 let cuda = result.unwrap();
589
590 assert!(cuda.contains("ring_kernel_full"), "Kernel name");
592 assert!(cuda.contains("ControlBlock"), "ControlBlock struct");
593 assert!(cuda.contains("while (true)"), "Persistent loop");
594 assert!(cuda.contains("threadIdx.x"), "Thread index");
595 assert!(cuda.contains("__syncthreads()"), "Sync threads");
596 assert!(
597 cuda.contains("blockIdx.x * blockDim.x + threadIdx.x"),
598 "Global thread ID"
599 );
600 assert!(cuda.contains("has_terminated"), "Termination marking");
601
602 println!("Full handler integration:\n{}", cuda);
603 }
604
605 #[test]
606 fn test_k2k_handler_integration() {
607 let handler: syn::ItemFn = parse_quote! {
609 fn k2k_handler(ctx: &RingContext, msg: &InputMsg) -> OutputMsg {
610 let tid = ctx.global_thread_id();
611
612 let result = msg.value * 2.0;
614
615 OutputMsg { result: result, source_id: tid as u64 }
617 }
618 };
619
620 let config = RingKernelConfig::new("k2k_processor")
621 .with_block_size(128)
622 .with_queue_capacity(1024)
623 .with_hlc(true)
624 .with_k2k(true);
625
626 let result = transpile_ring_kernel(&handler, &config);
627 assert!(result.is_ok(), "Should transpile K2K handler: {:?}", result);
628
629 let cuda = result.unwrap();
630
631 assert!(
633 cuda.contains("K2KRoutingTable"),
634 "Should have K2KRoutingTable"
635 );
636 assert!(cuda.contains("K2KRoute"), "Should have K2KRoute struct");
637 assert!(
638 cuda.contains("K2KInboxHeader"),
639 "Should have K2KInboxHeader"
640 );
641 assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
642 assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
643 assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
644 assert!(cuda.contains("k2k_send"), "Should have k2k_send function");
645 assert!(
646 cuda.contains("k2k_try_recv"),
647 "Should have k2k_try_recv function"
648 );
649 assert!(cuda.contains("k2k_peek"), "Should have k2k_peek function");
650 assert!(
651 cuda.contains("k2k_pending_count"),
652 "Should have k2k_pending_count function"
653 );
654
655 println!("K2K handler integration:\n{}", cuda);
656 }
657
658 #[test]
659 fn test_all_kernel_types_comparison() {
660 let stencil_func: syn::ItemFn = parse_quote! {
662 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
663 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * p[pos.idx()];
664 p_prev[pos.idx()] = 2.0 * p[pos.idx()] - p_prev[pos.idx()] + c2 * lap;
665 }
666 };
667
668 let stencil_config = StencilConfig::new("fdtd")
669 .with_tile_size(16, 16)
670 .with_halo(1);
671
672 let stencil_cuda = transpile_stencil_kernel(&stencil_func, &stencil_config).unwrap();
673 assert!(
674 !stencil_cuda.contains("GridPos"),
675 "Stencil should remove GridPos"
676 );
677 assert!(
678 stencil_cuda.contains("buffer_width"),
679 "Stencil should have buffer_width"
680 );
681
682 let global_func: syn::ItemFn = parse_quote! {
684 fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
685 let idx = block_idx_x() * block_dim_x() + thread_idx_x();
686 if idx >= n { return; }
687 y[idx as usize] = a * x[idx as usize] + y[idx as usize];
688 }
689 };
690
691 let global_cuda = transpile_global_kernel(&global_func).unwrap();
692 assert!(global_cuda.contains("__global__"), "Global kernel marker");
693 assert!(global_cuda.contains("blockIdx.x"), "CUDA block index");
694
695 let ring_func: syn::ItemFn = parse_quote! {
697 fn process(msg: f32) -> f32 {
698 msg * 2.0
699 }
700 };
701
702 let ring_config = RingKernelConfig::new("process")
703 .with_block_size(128)
704 .with_hlc(true);
705
706 let ring_cuda = transpile_ring_kernel(&ring_func, &ring_config).unwrap();
707 assert!(
708 ring_cuda.contains("ControlBlock"),
709 "Ring kernel ControlBlock"
710 );
711 assert!(ring_cuda.contains("while (true)"), "Persistent loop");
712 assert!(ring_cuda.contains("has_terminated"), "Termination");
713
714 println!("=== Stencil Kernel ===\n{}\n", stencil_cuda);
715 println!("=== Global Kernel ===\n{}\n", global_cuda);
716 println!("=== Ring Kernel ===\n{}\n", ring_cuda);
717 }
718}