ringkernel_wgpu_codegen/
ring_kernel.rs1#[derive(Debug, Clone)]
9pub struct RingKernelConfig {
10 pub name: String,
12 pub workgroup_size: u32,
14 pub enable_hlc: bool,
16 pub enable_k2k: bool,
18 pub max_messages_per_dispatch: u32,
20}
21
22impl RingKernelConfig {
23 pub fn new(name: &str) -> Self {
25 Self {
26 name: name.to_string(),
27 workgroup_size: 256,
28 enable_hlc: false,
29 enable_k2k: false,
30 max_messages_per_dispatch: 1024,
31 }
32 }
33
34 pub fn with_workgroup_size(mut self, size: u32) -> Self {
36 self.workgroup_size = size;
37 self
38 }
39
40 pub fn with_hlc(mut self, enable: bool) -> Self {
42 self.enable_hlc = enable;
43 self
44 }
45
46 pub fn with_k2k(mut self, enable: bool) -> Self {
48 self.enable_k2k = enable;
49 self
50 }
51
52 pub fn with_max_messages(mut self, max: u32) -> Self {
54 self.max_messages_per_dispatch = max;
55 self
56 }
57
58 pub fn workgroup_size_annotation(&self) -> String {
60 format!("@workgroup_size({}, 1, 1)", self.workgroup_size)
61 }
62}
63
64pub fn generate_control_block_struct(config: &RingKernelConfig) -> String {
66 let mut fields = vec![
67 " is_active: atomic<u32>,".to_string(),
68 " should_terminate: atomic<u32>,".to_string(),
69 " has_terminated: atomic<u32>,".to_string(),
70 " // 64-bit counters as lo/hi pairs".to_string(),
71 " messages_processed_lo: atomic<u32>,".to_string(),
72 " messages_processed_hi: atomic<u32>,".to_string(),
73 " messages_pending_lo: atomic<u32>,".to_string(),
74 " messages_pending_hi: atomic<u32>,".to_string(),
75 ];
76
77 if config.enable_hlc {
78 fields.push(" // HLC timestamp".to_string());
79 fields.push(" hlc_physical_lo: atomic<u32>,".to_string());
80 fields.push(" hlc_physical_hi: atomic<u32>,".to_string());
81 fields.push(" hlc_logical: atomic<u32>,".to_string());
82 }
83
84 format!("struct ControlBlock {{\n{}\n}}", fields.join("\n"))
85}
86
87pub fn generate_u64_helpers() -> &'static str {
89 r#"
90// 64-bit operations using lo/hi u32 pairs
91fn read_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) -> vec2<u32> {
92 return vec2<u32>(atomicLoad(lo), atomicLoad(hi));
93}
94
95fn atomic_inc_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) {
96 let old_lo = atomicAdd(lo, 1u);
97 if (old_lo == 0xFFFFFFFFu) {
98 atomicAdd(hi, 1u);
99 }
100}
101
102fn atomic_add_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>, addend: u32) {
103 let old_lo = atomicAdd(lo, addend);
104 if (old_lo > 0xFFFFFFFFu - addend) {
105 atomicAdd(hi, 1u);
106 }
107}
108
109fn compare_u64(a: vec2<u32>, b: vec2<u32>) -> i32 {
110 if (a.y > b.y) { return 1; }
111 if (a.y < b.y) { return -1; }
112 if (a.x > b.x) { return 1; }
113 if (a.x < b.x) { return -1; }
114 return 0;
115}
116"#
117}
118
119pub fn generate_ring_kernel_bindings() -> &'static str {
121 r#"@group(0) @binding(0) var<storage, read_write> control: ControlBlock;
122@group(0) @binding(1) var<storage, read_write> input_queue: array<u32>;
123@group(0) @binding(2) var<storage, read_write> output_queue: array<u32>;"#
124}
125
126pub fn generate_ring_kernel_preamble() -> &'static str {
128 r#" // Check if kernel is active
129 if (atomicLoad(&control.is_active) == 0u) {
130 return;
131 }
132
133 // Check for termination request
134 if (atomicLoad(&control.should_terminate) != 0u) {
135 if (local_invocation_id.x == 0u) {
136 atomicStore(&control.has_terminated, 1u);
137 }
138 return;
139 }"#
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_ring_kernel_config() {
148 let config = RingKernelConfig::new("processor")
149 .with_workgroup_size(128)
150 .with_hlc(true);
151
152 assert_eq!(config.name, "processor");
153 assert_eq!(config.workgroup_size, 128);
154 assert!(config.enable_hlc);
155 assert!(!config.enable_k2k);
156 }
157
158 #[test]
159 fn test_control_block_generation() {
160 let config = RingKernelConfig::new("test").with_hlc(true);
161 let wgsl = generate_control_block_struct(&config);
162
163 assert!(wgsl.contains("is_active: atomic<u32>"));
164 assert!(wgsl.contains("should_terminate: atomic<u32>"));
165 assert!(wgsl.contains("hlc_physical_lo: atomic<u32>"));
166 }
167
168 #[test]
169 fn test_workgroup_size_annotation() {
170 let config = RingKernelConfig::new("test").with_workgroup_size(64);
171 assert_eq!(
172 config.workgroup_size_annotation(),
173 "@workgroup_size(64, 1, 1)"
174 );
175 }
176}