1use bytemuck::{Pod, Zeroable};
9use wgpu::util::DeviceExt;
10
11#[repr(u32)]
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum GpuErrorCode {
15 Success = 0,
17 RangeError = 1,
19 NotPositive = 2,
21 Overflow = 3,
23 DivByZero = 4,
25 PrecisionError = 5,
27}
28
29impl From<u32> for GpuErrorCode {
30 fn from(v: u32) -> Self {
31 match v {
32 0 => Self::Success,
33 1 => Self::RangeError,
34 2 => Self::NotPositive,
35 3 => Self::Overflow,
36 4 => Self::DivByZero,
37 5 => Self::PrecisionError,
38 _ => Self::RangeError,
39 }
40 }
41}
42
43#[repr(u32)]
45#[derive(Debug, Clone, Copy)]
46pub enum NumericRuleKind {
47 RangeCheck = 0,
49 MustBePositive = 1,
51 MaxPrecision = 2,
53 Clamp = 3,
55 Percentage = 4,
57 TaxCalc = 5,
59 FxConvert = 6,
61 InterestCalc = 7,
63}
64
65#[repr(C)]
67#[derive(Clone, Copy, Pod, Zeroable, Debug)]
68pub struct NumericRule {
69 pub value: i32,
71 pub rule_kind: u32,
72 pub param_a: i32,
74 pub param_b: i32,
75}
76
77#[repr(C)]
79#[derive(Clone, Copy, Pod, Zeroable, Debug)]
80pub struct NumericOutput {
81 pub error_code: u32,
83 pub calc_value: i32,
85 _pad0: u32,
86 _pad1: u32,
87}
88
89pub struct GpuNumericEngine {
93 device: wgpu::Device,
94 queue: wgpu::Queue,
95 pipeline: wgpu::ComputePipeline,
96}
97
98impl GpuNumericEngine {
99 pub async fn new() -> Self {
105 let instance = wgpu::Instance::default();
106 let adapter = instance
107 .request_adapter(&wgpu::RequestAdapterOptions::default())
108 .await
109 .expect("no GPU adapter");
110 let (device, queue) = adapter
111 .request_device(&wgpu::DeviceDescriptor::default(), None)
112 .await
113 .expect("device creation failed");
114
115 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
116 label: Some("numeric_shader"),
117 source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
118 });
119 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
120 label: Some("numeric_pipeline"),
121 layout: None,
122 module: &shader,
123 entry_point: "main",
124 compilation_options: Default::default(),
125 cache: None,
126 });
127
128 Self { device, queue, pipeline }
129 }
130
131 pub fn run(&self, rules: &[NumericRule]) -> Vec<NumericOutput> {
134 if rules.is_empty() {
135 return vec![];
136 }
137 let device = &self.device;
138 let queue = &self.queue;
139 let n = rules.len() as u64;
140
141 let rule_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
142 label: Some("numeric_rules"),
143 contents: bytemuck::cast_slice(rules),
144 usage: wgpu::BufferUsages::STORAGE,
145 });
146
147 let out_stride = std::mem::size_of::<NumericOutput>() as u64;
148 let out_size = n * out_stride;
149 let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
150 label: Some("numeric_output"),
151 size: out_size,
152 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
153 mapped_at_creation: false,
154 });
155 let readback_buf = device.create_buffer(&wgpu::BufferDescriptor {
156 label: Some("numeric_readback"),
157 size: out_size,
158 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
159 mapped_at_creation: false,
160 });
161
162 let bgl = self.pipeline.get_bind_group_layout(0);
163 let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
164 label: Some("numeric_bg"),
165 layout: &bgl,
166 entries: &[
167 wgpu::BindGroupEntry { binding: 0, resource: rule_buf.as_entire_binding() },
168 wgpu::BindGroupEntry { binding: 1, resource: output_buf.as_entire_binding() },
169 ],
170 });
171
172 let mut encoder = device.create_command_encoder(&Default::default());
173 {
174 let mut pass = encoder.begin_compute_pass(&Default::default());
175 pass.set_pipeline(&self.pipeline);
176 pass.set_bind_group(0, &bg, &[]);
177 pass.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
178 }
179 encoder.copy_buffer_to_buffer(&output_buf, 0, &readback_buf, 0, out_size);
180 queue.submit(Some(encoder.finish()));
181
182 let slice = readback_buf.slice(..);
183 slice.map_async(wgpu::MapMode::Read, |_| {});
184 device.poll(wgpu::Maintain::Wait);
185
186 let data = slice.get_mapped_range();
187 let result = bytemuck::cast_slice::<u8, NumericOutput>(&data).to_vec();
188 drop(data);
189 readback_buf.unmap();
190 result
191 }
192}
193
194const SHADER_SRC: &str = r#"
195const SUCCESS : u32 = 0u;
196const RANGE_ERROR : u32 = 1u;
197const NOT_POSITIVE : u32 = 2u;
198const OVERFLOW : u32 = 3u;
199const DIV_BY_ZERO : u32 = 4u;
200const PRECISION_ERROR: u32 = 5u;
201
202const RANGE_CHECK : u32 = 0u;
203const MUST_BE_POS : u32 = 1u;
204const MAX_PRECISION : u32 = 2u;
205const CLAMP : u32 = 3u;
206const PERCENTAGE : u32 = 4u;
207const TAX_CALC : u32 = 5u;
208const FX_CONVERT : u32 = 6u;
209const INTEREST_CALC : u32 = 7u;
210
211struct NumericRule { value: i32, rule_kind: u32, param_a: i32, param_b: i32 }
212struct NumericOutput { error_code: u32, calc_value: i32, _pad0: u32, _pad1: u32 }
213
214@group(0) @binding(0) var<storage, read> rules : array<NumericRule>;
215@group(0) @binding(1) var<storage, read_write> output : array<NumericOutput>;
216
217// Portable wide math (Metal / Naga do not support i64 in WGSL here).
218fn umul32(a: u32, b: u32) -> vec2<u32> {
219 let a0 = a & 0xFFFFu;
220 let a1 = a >> 16u;
221 let b0 = b & 0xFFFFu;
222 let b1 = b >> 16u;
223 let p00 = a0 * b0;
224 let p01 = a0 * b1;
225 let p10 = a1 * b0;
226 let p11 = a1 * b1;
227 let mid = (p00 >> 16u) + (p01 & 0xFFFFu) + (p10 & 0xFFFFu);
228 let lo = (p00 & 0xFFFFu) | ((mid & 0xFFFFu) << 16u);
229 let hi = p11 + (p01 >> 16u) + (p10 >> 16u) + (mid >> 16u);
230 return vec2(lo, hi);
231}
232
233fn abs_i32_u32(x: i32) -> u32 {
234 return u32(select(x, -x, x < 0));
235}
236
237fn div_u64_u32(n_lo: u32, n_hi: u32, den: u32) -> u32 {
238 if n_hi == 0u {
239 return n_lo / den;
240 }
241 var r_lo = 0u;
242 var r_hi = 0u;
243 var q_lo = 0u;
244 var q_hi = 0u;
245 for (var i = 0u; i < 64u; i++) {
246 let k = 63u - i;
247 let bit = select((n_hi >> (k - 32u)) & 1u, (n_lo >> k) & 1u, k < 32u);
248 let nl = (r_lo << 1u) | bit;
249 let nh = (r_hi << 1u) | (r_lo >> 31u);
250 r_lo = nl;
251 r_hi = nh;
252 q_lo = (q_lo << 1u) | (q_hi >> 31u);
253 q_hi = q_hi << 1u;
254 if r_hi > 0u || r_lo >= den {
255 if r_lo >= den {
256 r_lo = r_lo - den;
257 } else {
258 r_hi = r_hi - 1u;
259 r_lo = r_lo - den;
260 }
261 q_lo = q_lo | 1u;
262 }
263 }
264 return q_lo;
265}
266
267fn mul64xu32(lo: u32, hi: u32, m: u32) -> vec3<u32> {
268 let t = umul32(lo, m);
269 let u = umul32(hi, m);
270 let mid = t.y + u.x;
271 let c1 = u32(mid < t.y);
272 return vec3(t.x, mid, u.y + c1);
273}
274
275fn sub96_u32(a: vec3<u32>, b: vec3<u32>) -> vec3<u32> {
276 let l = a.x - b.x;
277 let c0 = u32(a.x < b.x);
278 let m = a.y - b.y - c0;
279 let c1 = u32(a.y < b.y) | u32(a.y == b.y && c0 == 1u);
280 let h = a.z - b.z - c1;
281 return vec3(l, m, h);
282}
283
284fn ge96_u32(a: vec3<u32>, b: vec3<u32>) -> bool {
285 if a.z != b.z { return a.z > b.z; }
286 if a.y != b.y { return a.y > b.y; }
287 return a.x >= b.x;
288}
289
290fn div_u96_u32(l: u32, m: u32, h: u32, den: u32) -> u32 {
291 var r = vec3(0u, 0u, 0u);
292 var q_lo = 0u;
293 var q_hi = 0u;
294 for (var i = 0u; i < 96u; i++) {
295 let k = 95u - i;
296 var bit = 0u;
297 if k < 32u { bit = (l >> k) & 1u; }
298 else if k < 64u { bit = (m >> (k - 32u)) & 1u; }
299 else { bit = (h >> (k - 64u)) & 1u; }
300 let nx = (r.x << 1u) | bit;
301 let ny = (r.y << 1u) | (r.x >> 31u);
302 let nz = (r.z << 1u) | (r.y >> 31u);
303 r = vec3(nx, ny, nz);
304 q_lo = (q_lo << 1u) | (q_hi >> 31u);
305 q_hi = q_hi << 1u;
306 let d3 = vec3(den, 0u, 0u);
307 if ge96_u32(r, d3) {
308 r = sub96_u32(r, d3);
309 q_lo = q_lo | 1u;
310 }
311 }
312 return q_lo;
313}
314
315fn signed_mul_div2(a: i32, b: i32, den: u32) -> i32 {
316 let neg = (a < 0) != (b < 0);
317 let ua = abs_i32_u32(a);
318 let ub = abs_i32_u32(b);
319 let p = umul32(ua, ub);
320 let q = div_u64_u32(p.x, p.y, den);
321 return select(i32(q), -i32(q), neg);
322}
323
324fn signed_mul_div3_nonneg(a: i32, b: i32, c: i32, den: u32) -> i32 {
325 let ua = abs_i32_u32(a);
326 let ub = abs_i32_u32(b);
327 let uc = abs_i32_u32(c);
328 let p = umul32(ua, ub);
329 let t = mul64xu32(p.x, p.y, uc);
330 return i32(div_u96_u32(t.x, t.y, t.z, den));
331}
332
333const INTEREST_DENOM: u32 = 3650000u;
334
335@compute @workgroup_size(64)
336fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
337 let idx = gid.x;
338 if idx >= arrayLength(&rules) { return; }
339
340 let r = rules[idx];
341 var err : u32 = SUCCESS;
342 var calc : i32 = 0;
343
344 switch r.rule_kind {
345
346 case RANGE_CHECK: {
347 if r.value < r.param_a || r.value > r.param_b {
348 err = RANGE_ERROR;
349 }
350 }
351
352 case MUST_BE_POS: {
353 let ok = select(r.value > 0, r.value >= 0, r.param_a == 1);
354 if !ok { err = NOT_POSITIVE; }
355 }
356
357 case MAX_PRECISION: {
358 let divisor = select(select(100, 10, r.param_a == 1), 1, r.param_a >= 2);
359 if r.value % divisor != 0 { err = PRECISION_ERROR; }
360 }
361
362 case CLAMP: {
363 calc = clamp(r.value, r.param_a, r.param_b);
364 if r.value != calc { err = RANGE_ERROR; }
365 }
366
367 case PERCENTAGE: {
368 if r.value < 0 || r.value > 10000 { err = RANGE_ERROR; }
369 }
370
371 case TAX_CALC: {
372 calc = signed_mul_div2(r.value, r.param_a, 10000u);
373 }
374
375 case FX_CONVERT: {
376 if r.param_a == 0 { err = DIV_BY_ZERO; }
377 else {
378 calc = signed_mul_div2(r.value, r.param_a, 1000u);
379 }
380 }
381
382 case INTEREST_CALC: {
383 calc = signed_mul_div3_nonneg(r.value, r.param_a, r.param_b, INTEREST_DENOM);
384 }
385
386 default: { err = RANGE_ERROR; }
387 }
388
389 output[idx] = NumericOutput(err, calc, 0u, 0u);
390}
391"#;