Skip to main content

ringkernel_wgpu_codegen/
ring_kernel.rs

1//! Ring kernel generation for WGSL.
2//!
3//! Generates WGSL compute shaders for ring kernel message processing.
4//! Note: WebGPU does not support true persistent kernels, so ring kernels
5//! are emulated using host-driven dispatch loops.
6
7/// Configuration for ring kernel generation.
8#[derive(Debug, Clone)]
9pub struct RingKernelConfig {
10    /// Kernel name.
11    pub name: String,
12    /// Workgroup size (number of threads).
13    pub workgroup_size: u32,
14    /// Enable hybrid logical clock support.
15    pub enable_hlc: bool,
16    /// Enable kernel-to-kernel messaging (NOT SUPPORTED in WGPU).
17    pub enable_k2k: bool,
18    /// Maximum messages per dispatch.
19    pub max_messages_per_dispatch: u32,
20}
21
22impl RingKernelConfig {
23    /// Create a new ring kernel configuration.
24    pub fn new(name: &str) -> Self {
25        Self {
26            name: name.to_string(),
27            workgroup_size: 256,
28            enable_hlc: false,
29            enable_k2k: false,
30            max_messages_per_dispatch: 1024,
31        }
32    }
33
34    /// Set workgroup size.
35    pub fn with_workgroup_size(mut self, size: u32) -> Self {
36        self.workgroup_size = size;
37        self
38    }
39
40    /// Enable HLC support.
41    pub fn with_hlc(mut self, enable: bool) -> Self {
42        self.enable_hlc = enable;
43        self
44    }
45
46    /// Enable K2K support (will error during transpilation - not supported in WGPU).
47    pub fn with_k2k(mut self, enable: bool) -> Self {
48        self.enable_k2k = enable;
49        self
50    }
51
52    /// Set maximum messages per dispatch.
53    pub fn with_max_messages(mut self, max: u32) -> Self {
54        self.max_messages_per_dispatch = max;
55        self
56    }
57
58    /// Get the workgroup size annotation.
59    pub fn workgroup_size_annotation(&self) -> String {
60        format!("@workgroup_size({}, 1, 1)", self.workgroup_size)
61    }
62}
63
64/// Generate the WGSL ControlBlock struct definition.
65pub fn generate_control_block_struct(config: &RingKernelConfig) -> String {
66    let mut fields = vec![
67        "    is_active: atomic<u32>,".to_string(),
68        "    should_terminate: atomic<u32>,".to_string(),
69        "    has_terminated: atomic<u32>,".to_string(),
70        "    // 64-bit counters as lo/hi pairs".to_string(),
71        "    messages_processed_lo: atomic<u32>,".to_string(),
72        "    messages_processed_hi: atomic<u32>,".to_string(),
73        "    messages_pending_lo: atomic<u32>,".to_string(),
74        "    messages_pending_hi: atomic<u32>,".to_string(),
75    ];
76
77    if config.enable_hlc {
78        fields.push("    // HLC timestamp".to_string());
79        fields.push("    hlc_physical_lo: atomic<u32>,".to_string());
80        fields.push("    hlc_physical_hi: atomic<u32>,".to_string());
81        fields.push("    hlc_logical: atomic<u32>,".to_string());
82    }
83
84    format!("struct ControlBlock {{\n{}\n}}", fields.join("\n"))
85}
86
87/// Generate the 64-bit helper functions.
88pub fn generate_u64_helpers() -> &'static str {
89    r#"
90// 64-bit operations using lo/hi u32 pairs
91fn read_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) -> vec2<u32> {
92    return vec2<u32>(atomicLoad(lo), atomicLoad(hi));
93}
94
95fn atomic_inc_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) {
96    let old_lo = atomicAdd(lo, 1u);
97    if (old_lo == 0xFFFFFFFFu) {
98        atomicAdd(hi, 1u);
99    }
100}
101
102fn atomic_add_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>, addend: u32) {
103    let old_lo = atomicAdd(lo, addend);
104    if (old_lo > 0xFFFFFFFFu - addend) {
105        atomicAdd(hi, 1u);
106    }
107}
108
109fn compare_u64(a: vec2<u32>, b: vec2<u32>) -> i32 {
110    if (a.y > b.y) { return 1; }
111    if (a.y < b.y) { return -1; }
112    if (a.x > b.x) { return 1; }
113    if (a.x < b.x) { return -1; }
114    return 0;
115}
116"#
117}
118
119/// Generate standard ring kernel bindings.
120pub fn generate_ring_kernel_bindings() -> &'static str {
121    r#"@group(0) @binding(0) var<storage, read_write> control: ControlBlock;
122@group(0) @binding(1) var<storage, read_write> input_queue: array<u32>;
123@group(0) @binding(2) var<storage, read_write> output_queue: array<u32>;"#
124}
125
126/// Generate the ring kernel preamble (activation/termination checks).
127pub fn generate_ring_kernel_preamble() -> &'static str {
128    r#"    // Check if kernel is active
129    if (atomicLoad(&control.is_active) == 0u) {
130        return;
131    }
132
133    // Check for termination request
134    if (atomicLoad(&control.should_terminate) != 0u) {
135        if (local_invocation_id.x == 0u) {
136            atomicStore(&control.has_terminated, 1u);
137        }
138        return;
139    }"#
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_ring_kernel_config() {
148        let config = RingKernelConfig::new("processor")
149            .with_workgroup_size(128)
150            .with_hlc(true);
151
152        assert_eq!(config.name, "processor");
153        assert_eq!(config.workgroup_size, 128);
154        assert!(config.enable_hlc);
155        assert!(!config.enable_k2k);
156    }
157
158    #[test]
159    fn test_control_block_generation() {
160        let config = RingKernelConfig::new("test").with_hlc(true);
161        let wgsl = generate_control_block_struct(&config);
162
163        assert!(wgsl.contains("is_active: atomic<u32>"));
164        assert!(wgsl.contains("should_terminate: atomic<u32>"));
165        assert!(wgsl.contains("hlc_physical_lo: atomic<u32>"));
166    }
167
168    #[test]
169    fn test_workgroup_size_annotation() {
170        let config = RingKernelConfig::new("test").with_workgroup_size(64);
171        assert_eq!(
172            config.workgroup_size_annotation(),
173            "@workgroup_size(64, 1, 1)"
174        );
175    }
176}