ringkernel_wgpu_codegen/
u64_workarounds.rs1pub struct U64Helpers;
8
9impl U64Helpers {
10 pub fn generate_all() -> String {
12 [
13 Self::generate_read_u64(),
14 Self::generate_write_u64(),
15 Self::generate_atomic_inc_u64(),
16 Self::generate_atomic_add_u64(),
17 Self::generate_compare_u64(),
18 Self::generate_add_u64(),
19 Self::generate_sub_u64(),
20 ]
21 .join("\n\n")
22 }
23
24 pub fn generate_read_u64() -> &'static str {
26 r#"// Read a 64-bit value from lo/hi atomic pair
27fn read_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) -> vec2<u32> {
28 return vec2<u32>(atomicLoad(lo), atomicLoad(hi));
29}"#
30 }
31
32 pub fn generate_write_u64() -> &'static str {
34 r#"// Write a 64-bit value to lo/hi atomic pair
35fn write_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>, value: vec2<u32>) {
36 atomicStore(lo, value.x);
37 atomicStore(hi, value.y);
38}"#
39 }
40
41 pub fn generate_atomic_inc_u64() -> &'static str {
43 r#"// Atomically increment a 64-bit value
44fn atomic_inc_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) {
45 let old_lo = atomicAdd(lo, 1u);
46 if (old_lo == 0xFFFFFFFFu) {
47 atomicAdd(hi, 1u);
48 }
49}"#
50 }
51
52 pub fn generate_atomic_add_u64() -> &'static str {
54 r#"// Atomically add to a 64-bit value
55fn atomic_add_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>, addend: u32) {
56 let old_lo = atomicAdd(lo, addend);
57 if (old_lo > 0xFFFFFFFFu - addend) {
58 atomicAdd(hi, 1u);
59 }
60}"#
61 }
62
63 pub fn generate_compare_u64() -> &'static str {
65 r#"// Compare two 64-bit values: returns -1 if a < b, 0 if a == b, 1 if a > b
66fn compare_u64(a: vec2<u32>, b: vec2<u32>) -> i32 {
67 if (a.y > b.y) { return 1; }
68 if (a.y < b.y) { return -1; }
69 if (a.x > b.x) { return 1; }
70 if (a.x < b.x) { return -1; }
71 return 0;
72}"#
73 }
74
75 pub fn generate_add_u64() -> &'static str {
77 r#"// Add two 64-bit values (non-atomic)
78fn add_u64(a: vec2<u32>, b: vec2<u32>) -> vec2<u32> {
79 let lo = a.x + b.x;
80 let carry = select(0u, 1u, lo < a.x);
81 let hi = a.y + b.y + carry;
82 return vec2<u32>(lo, hi);
83}"#
84 }
85
86 pub fn generate_sub_u64() -> &'static str {
88 r#"// Subtract two 64-bit values (non-atomic): a - b
89fn sub_u64(a: vec2<u32>, b: vec2<u32>) -> vec2<u32> {
90 let borrow = select(0u, 1u, a.x < b.x);
91 let lo = a.x - b.x;
92 let hi = a.y - b.y - borrow;
93 return vec2<u32>(lo, hi);
94}"#
95 }
96
97 pub fn generate_mul_u64_u32() -> &'static str {
99 r#"// Multiply 64-bit value by 32-bit value
100fn mul_u64_u32(a: vec2<u32>, b: u32) -> vec2<u32> {
101 // Split into 16-bit parts to avoid overflow
102 let a_lo_lo = a.x & 0xFFFFu;
103 let a_lo_hi = a.x >> 16u;
104 let a_hi_lo = a.y & 0xFFFFu;
105
106 let b_lo = b & 0xFFFFu;
107 let b_hi = b >> 16u;
108
109 // Partial products
110 let p0 = a_lo_lo * b_lo;
111 let p1 = a_lo_lo * b_hi + a_lo_hi * b_lo;
112 let p2 = a_lo_hi * b_hi + a_hi_lo * b_lo;
113
114 // Combine
115 let lo = p0 + (p1 << 16u);
116 let carry1 = select(0u, 1u, lo < p0);
117 let hi = (p1 >> 16u) + p2 + carry1;
118
119 return vec2<u32>(lo, hi);
120}"#
121 }
122}
123
124pub fn u64_to_vec2_literal(value: u64) -> String {
126 let lo = (value & 0xFFFFFFFF) as u32;
127 let hi = (value >> 32) as u32;
128 format!("vec2<u32>({}u, {}u)", lo, hi)
129}
130
131pub fn i64_to_vec2_literal(value: i64) -> String {
133 let lo = (value as u64 & 0xFFFFFFFF) as i32;
134 let hi = ((value as u64) >> 32) as i32;
135 format!("vec2<i32>({}, {})", lo, hi)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_u64_to_vec2_literal() {
144 assert_eq!(u64_to_vec2_literal(0), "vec2<u32>(0u, 0u)");
145 assert_eq!(u64_to_vec2_literal(1), "vec2<u32>(1u, 0u)");
146 assert_eq!(u64_to_vec2_literal(0x1_0000_0000), "vec2<u32>(0u, 1u)");
147 assert_eq!(
148 u64_to_vec2_literal(0xFFFF_FFFF_FFFF_FFFF),
149 "vec2<u32>(4294967295u, 4294967295u)"
150 );
151 }
152
153 #[test]
154 fn test_generate_all_helpers() {
155 let helpers = U64Helpers::generate_all();
156 assert!(helpers.contains("fn read_u64"));
157 assert!(helpers.contains("fn write_u64"));
158 assert!(helpers.contains("fn atomic_inc_u64"));
159 assert!(helpers.contains("fn atomic_add_u64"));
160 assert!(helpers.contains("fn compare_u64"));
161 assert!(helpers.contains("fn add_u64"));
162 assert!(helpers.contains("fn sub_u64"));
163 }
164}