Skip to main content

ringkernel_wgpu_codegen/
u64_workarounds.rs

1//! 64-bit integer workarounds for WGSL.
2//!
3//! WGSL 1.0 does not support 64-bit integers or atomics. This module provides
4//! helper functions that emulate 64-bit operations using lo/hi u32 pairs.
5
6/// Helper functions for 64-bit emulation in WGSL.
7pub struct U64Helpers;
8
9impl U64Helpers {
10    /// Generate the complete set of 64-bit helper functions.
11    pub fn generate_all() -> String {
12        [
13            Self::generate_read_u64(),
14            Self::generate_write_u64(),
15            Self::generate_atomic_inc_u64(),
16            Self::generate_atomic_add_u64(),
17            Self::generate_compare_u64(),
18            Self::generate_add_u64(),
19            Self::generate_sub_u64(),
20        ]
21        .join("\n\n")
22    }
23
24    /// Generate read_u64 function.
25    pub fn generate_read_u64() -> &'static str {
26        r#"// Read a 64-bit value from lo/hi atomic pair
27fn read_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) -> vec2<u32> {
28    return vec2<u32>(atomicLoad(lo), atomicLoad(hi));
29}"#
30    }
31
32    /// Generate write_u64 function.
33    pub fn generate_write_u64() -> &'static str {
34        r#"// Write a 64-bit value to lo/hi atomic pair
35fn write_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>, value: vec2<u32>) {
36    atomicStore(lo, value.x);
37    atomicStore(hi, value.y);
38}"#
39    }
40
41    /// Generate atomic_inc_u64 function.
42    pub fn generate_atomic_inc_u64() -> &'static str {
43        r#"// Atomically increment a 64-bit value
44fn atomic_inc_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>) {
45    let old_lo = atomicAdd(lo, 1u);
46    if (old_lo == 0xFFFFFFFFu) {
47        atomicAdd(hi, 1u);
48    }
49}"#
50    }
51
52    /// Generate atomic_add_u64 function.
53    pub fn generate_atomic_add_u64() -> &'static str {
54        r#"// Atomically add to a 64-bit value
55fn atomic_add_u64(lo: ptr<storage, atomic<u32>, read_write>, hi: ptr<storage, atomic<u32>, read_write>, addend: u32) {
56    let old_lo = atomicAdd(lo, addend);
57    if (old_lo > 0xFFFFFFFFu - addend) {
58        atomicAdd(hi, 1u);
59    }
60}"#
61    }
62
63    /// Generate compare_u64 function.
64    pub fn generate_compare_u64() -> &'static str {
65        r#"// Compare two 64-bit values: returns -1 if a < b, 0 if a == b, 1 if a > b
66fn compare_u64(a: vec2<u32>, b: vec2<u32>) -> i32 {
67    if (a.y > b.y) { return 1; }
68    if (a.y < b.y) { return -1; }
69    if (a.x > b.x) { return 1; }
70    if (a.x < b.x) { return -1; }
71    return 0;
72}"#
73    }
74
75    /// Generate add_u64 function (non-atomic).
76    pub fn generate_add_u64() -> &'static str {
77        r#"// Add two 64-bit values (non-atomic)
78fn add_u64(a: vec2<u32>, b: vec2<u32>) -> vec2<u32> {
79    let lo = a.x + b.x;
80    let carry = select(0u, 1u, lo < a.x);
81    let hi = a.y + b.y + carry;
82    return vec2<u32>(lo, hi);
83}"#
84    }
85
86    /// Generate sub_u64 function (non-atomic).
87    pub fn generate_sub_u64() -> &'static str {
88        r#"// Subtract two 64-bit values (non-atomic): a - b
89fn sub_u64(a: vec2<u32>, b: vec2<u32>) -> vec2<u32> {
90    let borrow = select(0u, 1u, a.x < b.x);
91    let lo = a.x - b.x;
92    let hi = a.y - b.y - borrow;
93    return vec2<u32>(lo, hi);
94}"#
95    }
96
97    /// Generate mul_u64_u32 function (multiply 64-bit by 32-bit).
98    pub fn generate_mul_u64_u32() -> &'static str {
99        r#"// Multiply 64-bit value by 32-bit value
100fn mul_u64_u32(a: vec2<u32>, b: u32) -> vec2<u32> {
101    // Split into 16-bit parts to avoid overflow
102    let a_lo_lo = a.x & 0xFFFFu;
103    let a_lo_hi = a.x >> 16u;
104    let a_hi_lo = a.y & 0xFFFFu;
105
106    let b_lo = b & 0xFFFFu;
107    let b_hi = b >> 16u;
108
109    // Partial products
110    let p0 = a_lo_lo * b_lo;
111    let p1 = a_lo_lo * b_hi + a_lo_hi * b_lo;
112    let p2 = a_lo_hi * b_hi + a_hi_lo * b_lo;
113
114    // Combine
115    let lo = p0 + (p1 << 16u);
116    let carry1 = select(0u, 1u, lo < p0);
117    let hi = (p1 >> 16u) + p2 + carry1;
118
119    return vec2<u32>(lo, hi);
120}"#
121    }
122}
123
124/// Convert a Rust u64 value to WGSL vec2<u32> literal.
125pub fn u64_to_vec2_literal(value: u64) -> String {
126    let lo = (value & 0xFFFFFFFF) as u32;
127    let hi = (value >> 32) as u32;
128    format!("vec2<u32>({}u, {}u)", lo, hi)
129}
130
131/// Convert a Rust i64 value to WGSL vec2<i32> literal.
132pub fn i64_to_vec2_literal(value: i64) -> String {
133    let lo = (value as u64 & 0xFFFFFFFF) as i32;
134    let hi = ((value as u64) >> 32) as i32;
135    format!("vec2<i32>({}, {})", lo, hi)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_u64_to_vec2_literal() {
144        assert_eq!(u64_to_vec2_literal(0), "vec2<u32>(0u, 0u)");
145        assert_eq!(u64_to_vec2_literal(1), "vec2<u32>(1u, 0u)");
146        assert_eq!(u64_to_vec2_literal(0x1_0000_0000), "vec2<u32>(0u, 1u)");
147        assert_eq!(
148            u64_to_vec2_literal(0xFFFF_FFFF_FFFF_FFFF),
149            "vec2<u32>(4294967295u, 4294967295u)"
150        );
151    }
152
153    #[test]
154    fn test_generate_all_helpers() {
155        let helpers = U64Helpers::generate_all();
156        assert!(helpers.contains("fn read_u64"));
157        assert!(helpers.contains("fn write_u64"));
158        assert!(helpers.contains("fn atomic_inc_u64"));
159        assert!(helpers.contains("fn atomic_add_u64"));
160        assert!(helpers.contains("fn compare_u64"));
161        assert!(helpers.contains("fn add_u64"));
162        assert!(helpers.contains("fn sub_u64"));
163    }
164}