1pub mod dsl;
40pub mod handler;
41mod intrinsics;
42pub mod loops;
43pub mod ring_kernel;
44pub mod shared;
45mod stencil;
46mod transpiler;
47mod types;
48mod validation;
49
50pub use handler::{
51 generate_cuda_struct, generate_message_deser, generate_response_ser, ContextMethod,
52 HandlerCodegenConfig, HandlerParam, HandlerParamKind, HandlerReturnType, HandlerSignature,
53 MessageTypeInfo, MessageTypeRegistry,
54};
55pub use intrinsics::{GpuIntrinsic, IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
56pub use loops::{LoopPattern, RangeInfo};
57pub use ring_kernel::{
58 generate_control_block_struct, generate_hlc_struct, generate_k2k_structs, RingKernelConfig,
59};
60pub use shared::{SharedArray, SharedMemoryConfig, SharedMemoryDecl, SharedTile};
61pub use stencil::{Grid, GridPos, StencilConfig, StencilLaunchConfig};
62pub use transpiler::{transpile_function, CudaTranspiler, SharedVarInfo};
63pub use types::{
64 get_slice_element_type, is_control_block_type, is_mutable_reference, is_ring_context_type,
65 ring_kernel_type_mapper, CudaType, RingKernelParamKind, TypeMapper,
66};
67pub use validation::{
68 is_simple_assignment, validate_function, validate_function_with_mode,
69 validate_stencil_signature, ValidationError, ValidationMode,
70};
71
72use thiserror::Error;
73
74#[derive(Error, Debug)]
76pub enum TranspileError {
77 #[error("Parse error: {0}")]
79 Parse(String),
80
81 #[error("Validation error: {0}")]
83 Validation(#[from] ValidationError),
84
85 #[error("Unsupported construct: {0}")]
87 Unsupported(String),
88
89 #[error("Type error: {0}")]
91 Type(String),
92}
93
94pub type Result<T> = std::result::Result<T, TranspileError>;
96
97pub fn transpile_stencil_kernel(func: &syn::ItemFn, config: &StencilConfig) -> Result<String> {
112 validate_function(func)?;
114
115 let mut transpiler = CudaTranspiler::new(config.clone());
117
118 transpiler.transpile_stencil(func)
120}
121
122pub fn transpile_device_function(func: &syn::ItemFn) -> Result<String> {
126 validate_function(func)?;
127 transpile_function(func)
128}
129
130pub fn transpile_global_kernel(func: &syn::ItemFn) -> Result<String> {
155 validate_function(func)?;
156 let mut transpiler = CudaTranspiler::new_generic();
157 transpiler.transpile_generic_kernel(func)
158}
159
160pub fn transpile_ring_kernel(handler: &syn::ItemFn, config: &RingKernelConfig) -> Result<String> {
185 validate_function_with_mode(handler, ValidationMode::Generic)?;
188
189 let mut transpiler = CudaTranspiler::with_mode(ValidationMode::Generic);
191
192 transpiler.transpile_ring_kernel(handler, config)
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use syn::parse_quote;
200
201 #[test]
202 fn test_simple_function_transpile() {
203 let func: syn::ItemFn = parse_quote! {
204 fn add(a: f32, b: f32) -> f32 {
205 a + b
206 }
207 };
208
209 let result = transpile_device_function(&func);
210 assert!(
211 result.is_ok(),
212 "Should transpile simple function: {:?}",
213 result
214 );
215
216 let cuda = result.unwrap();
217 assert!(cuda.contains("float"), "Should contain CUDA float type");
218 assert!(cuda.contains("a + b"), "Should contain the expression");
219 }
220
221 #[test]
222 fn test_global_kernel_transpile() {
223 let func: syn::ItemFn = parse_quote! {
224 fn exchange_halos(buffer: &mut [f32], copies: &[u32], num_copies: i32) {
225 let idx = block_idx_x() * block_dim_x() + thread_idx_x();
226 if idx >= num_copies {
227 return;
228 }
229 let src = copies[idx * 2] as usize;
230 let dst = copies[idx * 2 + 1] as usize;
231 buffer[dst] = buffer[src];
232 }
233 };
234
235 let result = transpile_global_kernel(&func);
236 assert!(
237 result.is_ok(),
238 "Should transpile global kernel: {:?}",
239 result
240 );
241
242 let cuda = result.unwrap();
243 assert!(
244 cuda.contains("extern \"C\" __global__"),
245 "Should be global kernel"
246 );
247 assert!(cuda.contains("exchange_halos"), "Should have kernel name");
248 assert!(cuda.contains("blockIdx.x"), "Should contain blockIdx.x");
249 assert!(cuda.contains("blockDim.x"), "Should contain blockDim.x");
250 assert!(cuda.contains("threadIdx.x"), "Should contain threadIdx.x");
251 assert!(cuda.contains("return"), "Should have early return");
252
253 println!("Generated global kernel:\n{}", cuda);
254 }
255
256 #[test]
257 fn test_stencil_kernel_transpile() {
258 let func: syn::ItemFn = parse_quote! {
259 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
260 let curr = p[pos.idx()];
261 let prev = p_prev[pos.idx()];
262 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
263 p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
264 }
265 };
266
267 let config = StencilConfig {
268 id: "fdtd".to_string(),
269 grid: Grid::Grid2D,
270 tile_size: (16, 16),
271 halo: 1,
272 };
273
274 let result = transpile_stencil_kernel(&func, &config);
275 assert!(
276 result.is_ok(),
277 "Should transpile stencil kernel: {:?}",
278 result
279 );
280
281 let cuda = result.unwrap();
282 assert!(cuda.contains("__global__"), "Should be a CUDA kernel");
283 assert!(cuda.contains("threadIdx"), "Should use thread indices");
284 }
285
286 #[test]
287 fn test_ring_kernel_transpile() {
288 let handler: syn::ItemFn = parse_quote! {
290 fn process(value: f32) -> f32 {
291 let result = value * 2.0;
292 result
293 }
294 };
295
296 let config = RingKernelConfig::new("processor")
297 .with_block_size(128)
298 .with_queue_capacity(1024)
299 .with_hlc(true);
300
301 let result = transpile_ring_kernel(&handler, &config);
302 assert!(result.is_ok(), "Should transpile ring kernel: {:?}", result);
303
304 let cuda = result.unwrap();
305
306 assert!(
308 cuda.contains("struct __align__(128) ControlBlock"),
309 "Should have ControlBlock struct"
310 );
311 assert!(cuda.contains("is_active"), "Should have is_active field");
312 assert!(
313 cuda.contains("should_terminate"),
314 "Should have should_terminate field"
315 );
316 assert!(
317 cuda.contains("messages_processed"),
318 "Should have messages_processed field"
319 );
320
321 assert!(
323 cuda.contains("extern \"C\" __global__ void ring_kernel_processor"),
324 "Should have correct kernel name"
325 );
326 assert!(
327 cuda.contains("ControlBlock* __restrict__ control"),
328 "Should have control block param"
329 );
330 assert!(
331 cuda.contains("input_buffer"),
332 "Should have input buffer param"
333 );
334 assert!(
335 cuda.contains("output_buffer"),
336 "Should have output buffer param"
337 );
338
339 assert!(
341 cuda.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"),
342 "Should have thread id calculation"
343 );
344 assert!(
345 cuda.contains("MSG_SIZE"),
346 "Should have message size constant"
347 );
348 assert!(cuda.contains("hlc_physical"), "Should have HLC variables");
349 assert!(
350 cuda.contains("hlc_logical"),
351 "Should have HLC logical counter"
352 );
353
354 assert!(cuda.contains("while (true)"), "Should have persistent loop");
356 assert!(
357 cuda.contains("atomicAdd(&control->should_terminate, 0)"),
358 "Should check termination"
359 );
360 assert!(
361 cuda.contains("atomicAdd(&control->is_active, 0)"),
362 "Should check is_active"
363 );
364
365 assert!(
367 cuda.contains("// === USER HANDLER CODE ==="),
368 "Should have handler marker"
369 );
370 assert!(cuda.contains("value * 2.0"), "Should contain handler logic");
371 assert!(
372 cuda.contains("// === END HANDLER CODE ==="),
373 "Should have end marker"
374 );
375
376 assert!(
378 cuda.contains("atomicExch(&control->has_terminated, 1)"),
379 "Should mark terminated"
380 );
381
382 println!("Generated ring kernel:\n{}", cuda);
383 }
384
385 #[test]
386 fn test_ring_kernel_with_k2k() {
387 let handler: syn::ItemFn = parse_quote! {
388 fn forward(msg: f32) -> f32 {
389 msg
390 }
391 };
392
393 let config = RingKernelConfig::new("forwarder")
394 .with_block_size(64)
395 .with_k2k(true)
396 .with_hlc(true);
397
398 let result = transpile_ring_kernel(&handler, &config);
399 assert!(result.is_ok(), "Should transpile K2K kernel: {:?}", result);
400
401 let cuda = result.unwrap();
402
403 assert!(
405 cuda.contains("K2KRoutingTable"),
406 "Should have K2K routing table"
407 );
408 assert!(cuda.contains("K2KRoute"), "Should have K2K route struct");
409
410 assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
412 assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
413 assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
414
415 println!("Generated K2K ring kernel:\n{}", cuda);
416 }
417
418 #[test]
419 fn test_ring_kernel_config_defaults() {
420 let config = RingKernelConfig::default();
421 assert_eq!(config.block_size, 128);
422 assert_eq!(config.queue_capacity, 1024);
423 assert!(config.enable_hlc);
424 assert!(!config.enable_k2k);
425 }
426
427 #[test]
428 fn test_ring_kernel_intrinsic_availability() {
429 assert!(RingKernelIntrinsic::from_name("is_active").is_some());
431 assert!(RingKernelIntrinsic::from_name("should_terminate").is_some());
432 assert!(RingKernelIntrinsic::from_name("hlc_tick").is_some());
433 assert!(RingKernelIntrinsic::from_name("enqueue_response").is_some());
434 assert!(RingKernelIntrinsic::from_name("k2k_send").is_some());
435 assert!(RingKernelIntrinsic::from_name("nanosleep").is_some());
436 }
437
438 #[test]
439 fn test_handler_signature_parsing() {
440 let func: syn::ItemFn = parse_quote! {
441 fn handle(ctx: &RingContext, msg: &MyMessage) -> MyResponse {
442 MyResponse { value: msg.value * 2.0 }
443 }
444 };
445
446 let mapper = TypeMapper::new();
447 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
448
449 assert_eq!(sig.name, "handle");
450 assert!(sig.has_context);
451 assert!(sig.message_param.is_some());
452 assert!(sig.has_response());
453 }
454
455 #[test]
456 fn test_handler_with_context_methods() {
457 let handler: syn::ItemFn = parse_quote! {
459 fn process(ctx: &RingContext, value: f32) -> f32 {
460 let tid = ctx.thread_id();
461 let result = value * 2.0;
462 ctx.sync_threads();
463 result
464 }
465 };
466
467 let config = RingKernelConfig::new("with_context")
468 .with_block_size(128)
469 .with_hlc(true);
470
471 let result = transpile_ring_kernel(&handler, &config);
472 assert!(
473 result.is_ok(),
474 "Should transpile handler with context: {:?}",
475 result
476 );
477
478 let cuda = result.unwrap();
479
480 assert!(
482 cuda.contains("threadIdx.x"),
483 "ctx.thread_id() should become threadIdx.x"
484 );
485 assert!(
486 cuda.contains("__syncthreads()"),
487 "ctx.sync_threads() should become __syncthreads()"
488 );
489
490 println!("Generated handler with context:\n{}", cuda);
491 }
492
493 #[test]
494 fn test_handler_with_message_param() {
495 let handler: syn::ItemFn = parse_quote! {
496 fn process_msg(msg: &Message, scale: f32) -> f32 {
497 msg.value * scale
498 }
499 };
500
501 let config = RingKernelConfig::new("msg_handler");
502 let result = transpile_ring_kernel(&handler, &config);
503 assert!(
504 result.is_ok(),
505 "Should transpile handler with message: {:?}",
506 result
507 );
508
509 let cuda = result.unwrap();
510 assert!(cuda.contains("Message"), "Should reference Message type");
512
513 println!("Generated message handler:\n{}", cuda);
514 }
515
516 #[test]
517 fn test_context_method_mappings() {
518 assert!(ContextMethod::from_name("thread_id").is_some());
520 assert!(ContextMethod::from_name("sync_threads").is_some());
521 assert!(ContextMethod::from_name("global_thread_id").is_some());
522 assert!(ContextMethod::from_name("atomic_add").is_some());
523 assert!(ContextMethod::from_name("lane_id").is_some());
524 assert!(ContextMethod::from_name("warp_id").is_some());
525
526 assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
528 assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
529 assert_eq!(
530 ContextMethod::GlobalThreadId.to_cuda(&[]),
531 "(blockIdx.x * blockDim.x + threadIdx.x)"
532 );
533 }
534
535 #[test]
536 fn test_message_type_registration() {
537 let mut registry = MessageTypeRegistry::new();
538
539 registry.register_message(MessageTypeInfo {
540 name: "InputMsg".to_string(),
541 size: 16,
542 fields: vec![
543 ("id".to_string(), "unsigned long long".to_string()),
544 ("value".to_string(), "float".to_string()),
545 ],
546 });
547
548 registry.register_response(MessageTypeInfo {
549 name: "OutputMsg".to_string(),
550 size: 8,
551 fields: vec![("result".to_string(), "float".to_string())],
552 });
553
554 let structs = registry.generate_structs();
555 assert!(structs.contains("struct InputMsg"));
556 assert!(structs.contains("struct OutputMsg"));
557 assert!(structs.contains("unsigned long long id"));
558 assert!(structs.contains("float result"));
559 }
560
561 #[test]
562 fn test_full_handler_integration() {
563 let handler: syn::ItemFn = parse_quote! {
565 fn full_handler(ctx: &RingContext, msg: &Request) -> Response {
566 let tid = ctx.global_thread_id();
567 ctx.sync_threads();
568 let result = msg.value * 2.0;
569 Response { value: result, id: tid as u64 }
570 }
571 };
572
573 let config = RingKernelConfig::new("full")
574 .with_block_size(256)
575 .with_queue_capacity(2048)
576 .with_hlc(true)
577 .with_k2k(false);
578
579 let result = transpile_ring_kernel(&handler, &config);
580 assert!(
581 result.is_ok(),
582 "Should transpile full handler: {:?}",
583 result
584 );
585
586 let cuda = result.unwrap();
587
588 assert!(cuda.contains("ring_kernel_full"), "Kernel name");
590 assert!(cuda.contains("ControlBlock"), "ControlBlock struct");
591 assert!(cuda.contains("while (true)"), "Persistent loop");
592 assert!(cuda.contains("threadIdx.x"), "Thread index");
593 assert!(cuda.contains("__syncthreads()"), "Sync threads");
594 assert!(
595 cuda.contains("blockIdx.x * blockDim.x + threadIdx.x"),
596 "Global thread ID"
597 );
598 assert!(cuda.contains("has_terminated"), "Termination marking");
599
600 println!("Full handler integration:\n{}", cuda);
601 }
602
603 #[test]
604 fn test_k2k_handler_integration() {
605 let handler: syn::ItemFn = parse_quote! {
607 fn k2k_handler(ctx: &RingContext, msg: &InputMsg) -> OutputMsg {
608 let tid = ctx.global_thread_id();
609
610 let result = msg.value * 2.0;
612
613 OutputMsg { result: result, source_id: tid as u64 }
615 }
616 };
617
618 let config = RingKernelConfig::new("k2k_processor")
619 .with_block_size(128)
620 .with_queue_capacity(1024)
621 .with_hlc(true)
622 .with_k2k(true);
623
624 let result = transpile_ring_kernel(&handler, &config);
625 assert!(result.is_ok(), "Should transpile K2K handler: {:?}", result);
626
627 let cuda = result.unwrap();
628
629 assert!(
631 cuda.contains("K2KRoutingTable"),
632 "Should have K2KRoutingTable"
633 );
634 assert!(cuda.contains("K2KRoute"), "Should have K2KRoute struct");
635 assert!(
636 cuda.contains("K2KInboxHeader"),
637 "Should have K2KInboxHeader"
638 );
639 assert!(cuda.contains("k2k_routes"), "Should have k2k_routes param");
640 assert!(cuda.contains("k2k_inbox"), "Should have k2k_inbox param");
641 assert!(cuda.contains("k2k_outbox"), "Should have k2k_outbox param");
642 assert!(cuda.contains("k2k_send"), "Should have k2k_send function");
643 assert!(
644 cuda.contains("k2k_try_recv"),
645 "Should have k2k_try_recv function"
646 );
647 assert!(cuda.contains("k2k_peek"), "Should have k2k_peek function");
648 assert!(
649 cuda.contains("k2k_pending_count"),
650 "Should have k2k_pending_count function"
651 );
652
653 println!("K2K handler integration:\n{}", cuda);
654 }
655
656 #[test]
657 fn test_all_kernel_types_comparison() {
658 let stencil_func: syn::ItemFn = parse_quote! {
660 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
661 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * p[pos.idx()];
662 p_prev[pos.idx()] = 2.0 * p[pos.idx()] - p_prev[pos.idx()] + c2 * lap;
663 }
664 };
665
666 let stencil_config = StencilConfig::new("fdtd")
667 .with_tile_size(16, 16)
668 .with_halo(1);
669
670 let stencil_cuda = transpile_stencil_kernel(&stencil_func, &stencil_config).unwrap();
671 assert!(
672 !stencil_cuda.contains("GridPos"),
673 "Stencil should remove GridPos"
674 );
675 assert!(
676 stencil_cuda.contains("buffer_width"),
677 "Stencil should have buffer_width"
678 );
679
680 let global_func: syn::ItemFn = parse_quote! {
682 fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
683 let idx = block_idx_x() * block_dim_x() + thread_idx_x();
684 if idx >= n { return; }
685 y[idx as usize] = a * x[idx as usize] + y[idx as usize];
686 }
687 };
688
689 let global_cuda = transpile_global_kernel(&global_func).unwrap();
690 assert!(global_cuda.contains("__global__"), "Global kernel marker");
691 assert!(global_cuda.contains("blockIdx.x"), "CUDA block index");
692
693 let ring_func: syn::ItemFn = parse_quote! {
695 fn process(msg: f32) -> f32 {
696 msg * 2.0
697 }
698 };
699
700 let ring_config = RingKernelConfig::new("process")
701 .with_block_size(128)
702 .with_hlc(true);
703
704 let ring_cuda = transpile_ring_kernel(&ring_func, &ring_config).unwrap();
705 assert!(
706 ring_cuda.contains("ControlBlock"),
707 "Ring kernel ControlBlock"
708 );
709 assert!(ring_cuda.contains("while (true)"), "Persistent loop");
710 assert!(ring_cuda.contains("has_terminated"), "Termination");
711
712 println!("=== Stencil Kernel ===\n{}\n", stencil_cuda);
713 println!("=== Global Kernel ===\n{}\n", global_cuda);
714 println!("=== Ring Kernel ===\n{}\n", ring_cuda);
715 }
716}