echidna 0.9.0

A high-performance automatic differentiation library for Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
//! GPU acceleration for batched tape evaluation.
//!
//! Provides two backends:
//! - **wgpu** (`gpu-wgpu` feature): cross-platform (Metal, Vulkan, DX12), f32 only
//! - **CUDA** (`gpu-cuda` feature): NVIDIA only, f32 + f64
//!
//! # Context Contract
//!
//! Both [`WgpuContext`] and [`CudaContext`] implement the [`GpuBackend`] trait,
//! which defines the shared f32 operation set:
//!
//! - `new() -> Option<Self>` — acquire a GPU device (inherent, not in trait)
//! - [`upload_tape`](GpuBackend::upload_tape) — upload tape to device
//! - [`forward_batch`](GpuBackend::forward_batch) — batched forward evaluation
//! - [`gradient_batch`](GpuBackend::gradient_batch) — batched gradient (forward + reverse)
//! - [`sparse_jacobian`](GpuBackend::sparse_jacobian) — GPU-accelerated sparse Jacobian
//! - [`hvp_batch`](GpuBackend::hvp_batch) — batched Hessian-vector product
//! - [`sparse_hessian`](GpuBackend::sparse_hessian) — GPU-accelerated sparse Hessian
//! - [`taylor_forward_2nd_batch`](GpuBackend::taylor_forward_2nd_batch) — batched second-order Taylor forward propagation (requires `stde`)
//!
//! CUDA additionally provides f64 methods as inherent methods on [`CudaContext`].
//!
//! # GPU-Accelerated STDE (requires `stde`)
//!
//! The [`stde_gpu`] module provides GPU-accelerated versions of the CPU STDE
//! functions. These use batched second-order Taylor forward propagation to
//! evaluate many directions in parallel:
//!
//! - [`stde_gpu::laplacian_gpu`] — Hutchinson trace estimator on GPU
//! - [`stde_gpu::hessian_diagonal_gpu`] — exact Hessian diagonal via basis pushforwards
//! - [`stde_gpu::laplacian_with_control_gpu`] — variance-reduced Laplacian with diagonal control variate
//!
//! The Taylor kernel propagates `(c0, c1, c2)` triples through the tape for
//! each batch element, where c2 = v^T H v / 2. All 44 opcodes are supported.

use crate::bytecode_tape::BytecodeTape;
use crate::opcode::OpCode;

#[cfg(feature = "gpu-wgpu")]
pub mod wgpu_backend;

#[cfg(feature = "gpu-cuda")]
pub mod cuda_backend;

#[cfg(feature = "stde")]
pub mod stde_gpu;

#[cfg(feature = "stde")]
pub mod taylor_codegen;

#[cfg(feature = "gpu-wgpu")]
pub use wgpu_backend::{WgpuContext, WgpuTapeBuffers};

#[cfg(feature = "gpu-cuda")]
pub use cuda_backend::{CudaContext, CudaTapeBuffers};

/// Common interface for GPU backends (f32 operations).
///
/// Both [`WgpuContext`] and [`CudaContext`] implement this trait for the f32
/// operation set. CUDA additionally provides f64 methods as inherent methods
/// on [`CudaContext`] directly.
///
/// # Associated Type
///
/// [`TapeBuffers`](GpuBackend::TapeBuffers) is the backend-specific opaque
/// handle returned by [`upload_tape`](GpuBackend::upload_tape) and passed to
/// all dispatch methods. It holds GPU-resident buffers and is not cloneable.
///
/// # Implementing a New Backend
///
/// A backend must implement all six methods. Construction (`new()`) is not
/// part of the trait — backends may have different initialization requirements.
pub trait GpuBackend {
    /// Backend-specific uploaded tape handle.
    type TapeBuffers;

    /// Upload a tape to the GPU.
    ///
    /// The returned handle is used for all subsequent operations and holds
    /// GPU-resident buffers for the tape's opcodes, arguments, and constants.
    fn upload_tape(&self, data: &GpuTapeData) -> Self::TapeBuffers;

    /// Number of declared outputs on the uploaded tape.
    ///
    /// Used by estimators like `stde_gpu::laplacian_gpu` to enforce
    /// single-output assumptions whose coefficient layout depends on
    /// the tape's output count.
    fn num_outputs(&self, tape: &Self::TapeBuffers) -> u32;

    /// Batched forward evaluation.
    ///
    /// `inputs` is `[f32; batch_size * num_inputs]` (row-major, one point per row).
    /// Returns output values `[f32; batch_size * num_outputs]`.
    fn forward_batch(
        &self,
        tape: &Self::TapeBuffers,
        inputs: &[f32],
        batch_size: u32,
    ) -> Result<Vec<f32>, GpuError>;

    /// Batched gradient (forward + reverse sweep).
    ///
    /// Returns `(outputs, gradients)` where outputs is
    /// `[f32; batch_size * num_outputs]` and gradients is
    /// `[f32; batch_size * num_inputs]`.
    fn gradient_batch(
        &self,
        tape: &Self::TapeBuffers,
        inputs: &[f32],
        batch_size: u32,
    ) -> Result<(Vec<f32>, Vec<f32>), GpuError>;

    /// GPU-accelerated sparse Jacobian.
    ///
    /// CPU detects sparsity and computes coloring; GPU dispatches colored
    /// tangent sweeps. Returns `(output_values, pattern, jacobian_values)`.
    fn sparse_jacobian(
        &self,
        tape: &Self::TapeBuffers,
        tape_cpu: &mut BytecodeTape<f32>,
        x: &[f32],
    ) -> Result<(Vec<f32>, crate::sparse::JacobianSparsityPattern, Vec<f32>), GpuError>;

    /// Batched Hessian-vector product (forward-over-reverse).
    ///
    /// `tangent_dirs` is `[f32; batch_size * num_inputs]` — one direction per
    /// batch element. Returns `(gradients, hvps)` each
    /// `[f32; batch_size * num_inputs]`.
    fn hvp_batch(
        &self,
        tape: &Self::TapeBuffers,
        x: &[f32],
        tangent_dirs: &[f32],
        batch_size: u32,
    ) -> Result<(Vec<f32>, Vec<f32>), GpuError>;

    /// GPU-accelerated sparse Hessian.
    ///
    /// CPU detects Hessian sparsity and computes distance-2 coloring; GPU
    /// dispatches HVP sweeps. Returns `(value, gradient, pattern, hessian_values)`.
    fn sparse_hessian(
        &self,
        tape: &Self::TapeBuffers,
        tape_cpu: &mut BytecodeTape<f32>,
        x: &[f32],
    ) -> Result<(f32, Vec<f32>, crate::sparse::SparsityPattern, Vec<f32>), GpuError>;

    /// Batched second-order Taylor forward propagation.
    ///
    /// Each batch element pushes one direction through the tape, producing
    /// a Taylor jet with 3 coefficients (c0=value, c1=first derivative,
    /// c2=second derivative / 2).
    ///
    /// `primal_inputs` is `[f32; batch_size * num_inputs]` — primals for each element.
    /// `direction_seeds` is `[f32; batch_size * num_inputs]` — c1 seeds for each element.
    ///
    /// Returns `TaylorBatchResult` with `values`, `c1s`, `c2s` each of size
    /// `[f32; batch_size * num_outputs]`.
    /// Batched second-order Taylor forward propagation.
    ///
    /// Default implementation delegates to `taylor_forward_kth_batch(order=3)`.
    #[cfg(feature = "stde")]
    fn taylor_forward_2nd_batch(
        &self,
        tape: &Self::TapeBuffers,
        primal_inputs: &[f32],
        direction_seeds: &[f32],
        batch_size: u32,
    ) -> Result<TaylorBatchResult<f32>, GpuError> {
        let kth =
            self.taylor_forward_kth_batch(tape, primal_inputs, direction_seeds, batch_size, 3)?;
        let mut coeffs = kth.coefficients.into_iter();
        Ok(TaylorBatchResult {
            values: coeffs.next().unwrap(),
            c1s: coeffs.next().unwrap(),
            c2s: coeffs.next().unwrap(),
        })
    }

    /// Batched K-th order Taylor forward propagation.
    ///
    /// Supports `order` in 1..=5. Each batch element pushes one direction through
    /// the tape, producing K Taylor coefficients (c0, c1, ..., c_{K-1}).
    ///
    /// `primal_inputs` is `[f32; batch_size * num_inputs]` — primals for each element.
    /// `direction_seeds` is `[f32; batch_size * num_inputs]` — c1 seeds for each element.
    ///
    /// Returns `TaylorKthBatchResult` with `coefficients[k]` of size
    /// `[f32; batch_size * num_outputs]` for each k in 0..order.
    #[cfg(feature = "stde")]
    fn taylor_forward_kth_batch(
        &self,
        tape: &Self::TapeBuffers,
        primal_inputs: &[f32],
        direction_seeds: &[f32],
        batch_size: u32,
        order: usize,
    ) -> Result<TaylorKthBatchResult<f32>, GpuError>;
}

/// Result of a batched second-order Taylor forward propagation.
///
/// Each field has `batch_size * num_outputs` elements (row-major: one row per batch element).
/// The Taylor convention is `c[k] = f^(k)(t₀) / k!`, so:
/// - `values[i]` = f(x) (primal value)
/// - `c1s[i]` = directional first derivative
/// - `c2s[i]` = directional second derivative / 2
pub struct TaylorBatchResult<F> {
    /// Primal output values `[batch_size * num_outputs]`.
    pub values: Vec<F>,
    /// First-order Taylor coefficients `[batch_size * num_outputs]`.
    pub c1s: Vec<F>,
    /// Second-order Taylor coefficients `[batch_size * num_outputs]`.
    pub c2s: Vec<F>,
}

/// Result of a batched K-th order Taylor forward propagation.
///
/// `coefficients[k]` has `batch_size * num_outputs` elements for coefficient index k.
/// The Taylor convention is `c[k] = f^(k)(t₀) / k!`.
#[cfg(feature = "stde")]
pub struct TaylorKthBatchResult<F> {
    /// Taylor coefficients: `coefficients[k]` is the k-th order coefficient vector
    /// with `batch_size * num_outputs` elements.
    pub coefficients: Vec<Vec<F>>,
    /// The Taylor order (number of coefficients per output).
    pub order: usize,
}

/// Error type for GPU operations.
#[derive(Debug)]
pub enum GpuError {
    /// No suitable GPU device found.
    NoDevice,
    /// Shader or kernel compilation failed.
    ShaderCompilation(String),
    /// GPU ran out of memory.
    OutOfMemory,
    /// Tape contains custom ops which cannot run on GPU.
    CustomOpsNotSupported,
    /// Backend-specific error.
    Other(String),
}

impl std::fmt::Display for GpuError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            GpuError::NoDevice => write!(f, "no suitable GPU device found"),
            GpuError::ShaderCompilation(msg) => write!(f, "shader compilation failed: {msg}"),
            GpuError::OutOfMemory => write!(f, "GPU out of memory"),
            GpuError::CustomOpsNotSupported => {
                write!(f, "tape contains custom ops which cannot run on GPU")
            }
            GpuError::Other(msg) => write!(f, "GPU error: {msg}"),
        }
    }
}

impl std::error::Error for GpuError {}

crate::assert_send_sync!(GpuError);

/// Flattened tape representation for GPU upload.
///
/// All arrays are the same length (`num_ops`). The GPU shader walks index 0..num_ops
/// sequentially, executing each opcode on the per-thread values buffer.
///
/// Created via [`GpuTapeData::from_tape`] (f32) or [`GpuTapeData::from_tape_f64_lossy`] (f64→f32).
pub struct GpuTapeData {
    /// OpCode discriminants as u32 (one per tape entry).
    pub opcodes: Vec<u32>,
    /// First argument index for each operation.
    pub arg0: Vec<u32>,
    /// Second argument index for each operation.
    pub arg1: Vec<u32>,
    /// Initial values buffer (constants and zeros, f32).
    pub constants: Vec<f32>,
    /// Total number of tape entries.
    pub num_ops: u32,
    /// Number of input variables.
    pub num_inputs: u32,
    /// Total entries in the values buffer (inputs + constants + intermediates).
    pub num_variables: u32,
    /// Primary output index.
    pub output_index: u32,
    /// All output indices (for multi-output tapes).
    pub output_indices: Vec<u32>,
}

impl GpuTapeData {
    /// Build `GpuTapeData` from a tape's structural data and pre-converted constants.
    fn build_from_tape<F: crate::float::Float>(
        tape: &BytecodeTape<F>,
        constants: Vec<f32>,
    ) -> Self {
        let opcodes_raw = tape.opcodes_slice();
        let args = tape.arg_indices_slice();
        let n = opcodes_raw.len();

        GpuTapeData {
            opcodes: opcodes_raw.iter().map(|op| *op as u32).collect(),
            arg0: args.iter().map(|a| a[0]).collect(),
            arg1: args.iter().map(|a| a[1]).collect(),
            constants,
            // SAFETY(u32 cast): n is the number of tape opcodes. Exceeding u32::MAX (~4.3B)
            // would require ~17 GB of opcode storage alone, which is impractical.
            num_ops: n as u32,
            // SAFETY(u32 cast): num_inputs, num_variables, and output_index are bounded
            // by tape size (same order as num_ops), which cannot practically reach u32::MAX.
            num_inputs: tape.num_inputs() as u32,
            num_variables: tape.num_variables_count() as u32,
            output_index: tape.output_index() as u32,
            output_indices: tape.all_output_indices().to_vec(),
        }
    }

    /// Convert a `BytecodeTape<f32>` to GPU-uploadable format.
    ///
    /// Returns `Err(CustomOpsNotSupported)` if the tape contains custom ops,
    /// since custom Rust closures cannot execute on GPU hardware.
    pub fn from_tape(tape: &BytecodeTape<f32>) -> Result<Self, GpuError> {
        if tape.has_custom_ops() {
            return Err(GpuError::CustomOpsNotSupported);
        }
        Ok(Self::build_from_tape(tape, tape.values_slice().to_vec()))
    }

    /// Convert a `BytecodeTape<f64>` to GPU-uploadable f32 format.
    ///
    /// All f64 values are cast to f32, which loses precision. The method name
    /// makes this explicit — use the CUDA backend for native f64 support.
    ///
    /// Returns `Err(CustomOpsNotSupported)` if the tape contains custom ops.
    pub fn from_tape_f64_lossy(tape: &BytecodeTape<f64>) -> Result<Self, GpuError> {
        if tape.has_custom_ops() {
            return Err(GpuError::CustomOpsNotSupported);
        }
        let constants = tape.values_slice().iter().map(|&v| v as f32).collect();
        Ok(Self::build_from_tape(tape, constants))
    }
}

/// Metadata for the tape, uploaded as a uniform buffer to GPU shaders.
///
/// Layout matches the WGSL `TapeMeta` struct (4 × u32 = 16 bytes).
#[cfg(feature = "gpu-wgpu")]
#[repr(C)]
#[derive(Clone, Copy, Debug, bytemuck::Pod, bytemuck::Zeroable)]
pub struct TapeMeta {
    /// Number of opcodes in the tape.
    pub num_ops: u32,
    /// Number of input variables.
    pub num_inputs: u32,
    /// Number of intermediate variables (working slots).
    pub num_variables: u32,
    /// Number of outputs.
    pub num_outputs: u32,
    /// Number of evaluation points in the batch.
    pub batch_size: u32,
    /// Padding to 32-byte alignment.
    pub _pad: [u32; 3],
}

/// Map an [`OpCode`] to the integer constant used in WGSL/CUDA shaders.
///
/// The mapping matches the `OpCode` discriminant (`#[repr(u8)]`), cast to u32.
#[inline]
#[must_use]
pub fn opcode_to_gpu(op: OpCode) -> u32 {
    op as u32
}

/// Default maximum buffer size for WebGPU (128 MiB).
///
/// WebGPU's `maxBufferSize` limit is 256 MiB, but we use 128 MiB as a
/// conservative default to avoid hitting device-specific limits.
#[cfg(feature = "stde")]
pub const WGPU_MAX_BUFFER_BYTES: u64 = 128 * 1024 * 1024;

/// Maximum workgroup dispatches per dimension in WebGPU (65535).
#[cfg(feature = "stde")]
const MAX_WORKGROUPS_PER_DIM: u64 = 65535;

/// Workgroup size used by the Taylor forward shader.
#[cfg(feature = "stde")]
const TAYLOR_WORKGROUP_SIZE: u64 = 256;

/// Chunked batched second-order Taylor forward propagation.
///
/// Splits a large batch into chunks that fit within GPU buffer size limits,
/// dispatches each chunk, and concatenates results. This avoids hitting WebGPU's
/// 128 MiB buffer limit or workgroup dispatch limits.
///
/// # Arguments
///
/// - `backend`: any `GpuBackend` implementation
/// - `tape`: uploaded tape buffers
/// - `primal_inputs`: `[f32; batch_size * num_inputs]` — primals for each element
/// - `direction_seeds`: `[f32; batch_size * num_inputs]` — c1 seeds for each element
/// - `batch_size`: total number of batch elements
/// - `num_inputs`: number of input variables per element
/// - `num_variables`: total tape variable slots (inputs + constants + intermediates)
/// - `max_buffer_bytes`: maximum GPU buffer size in bytes (use [`WGPU_MAX_BUFFER_BYTES`])
///
/// # Errors
///
/// Returns `GpuError::Other` if `max_buffer_bytes` is too small for even a single element.
#[cfg(feature = "stde")]
#[allow(clippy::too_many_arguments)]
pub fn taylor_forward_2nd_batch_chunked<B: GpuBackend>(
    backend: &B,
    tape: &B::TapeBuffers,
    primal_inputs: &[f32],
    direction_seeds: &[f32],
    batch_size: u32,
    num_inputs: u32,
    num_variables: u32,
    max_buffer_bytes: u64,
) -> Result<TaylorBatchResult<f32>, GpuError> {
    if batch_size == 0 {
        return Ok(TaylorBatchResult {
            values: vec![],
            c1s: vec![],
            c2s: vec![],
        });
    }

    // The largest buffer is the jets working buffer: batch_size * num_variables * 3 * 4 bytes
    let bytes_per_element = (num_variables as u64) * 3 * 4;
    if bytes_per_element == 0 {
        return Err(GpuError::Other("num_variables is zero".into()));
    }

    let mut chunk_size = max_buffer_bytes / bytes_per_element;
    if chunk_size == 0 {
        return Err(GpuError::Other(format!(
            "max_buffer_bytes ({max_buffer_bytes}) too small for a single element \
             ({bytes_per_element} bytes per element)"
        )));
    }

    // Also cap at workgroup dispatch limit: 65535 workgroups * 256 threads
    let dispatch_limit = MAX_WORKGROUPS_PER_DIM * TAYLOR_WORKGROUP_SIZE;
    chunk_size = chunk_size.min(dispatch_limit);

    // Cap chunk_size so that WGSL u32 index `bid * nv * K` cannot overflow.
    // K=3 for second-order Taylor jets.
    let nv_k = (num_variables as u64) * 3;
    if let Some(cap) = (u32::MAX as u64).checked_div(nv_k) {
        chunk_size = chunk_size.min(cap);
    }

    let chunk_size = chunk_size as u32;

    // If everything fits in one chunk, dispatch directly
    if batch_size <= chunk_size {
        return backend.taylor_forward_2nd_batch(tape, primal_inputs, direction_seeds, batch_size);
    }

    // Multi-chunk dispatch
    let ni = num_inputs as usize;
    let mut all_values = Vec::new();
    let mut all_c1s = Vec::new();
    let mut all_c2s = Vec::new();

    let mut offset = 0u32;
    while offset < batch_size {
        let this_chunk = chunk_size.min(batch_size - offset);
        let start = (offset as usize) * ni;
        let end = start + (this_chunk as usize) * ni;

        let chunk_result = backend.taylor_forward_2nd_batch(
            tape,
            &primal_inputs[start..end],
            &direction_seeds[start..end],
            this_chunk,
        )?;

        all_values.extend(chunk_result.values);
        all_c1s.extend(chunk_result.c1s);
        all_c2s.extend(chunk_result.c2s);

        offset += this_chunk;
    }

    Ok(TaylorBatchResult {
        values: all_values,
        c1s: all_c1s,
        c2s: all_c2s,
    })
}

#[cfg(all(test, feature = "stde"))]
mod tests {
    use super::*;

    /// Replicate the chunk_size calculation from taylor_forward_2nd_batch_chunked
    /// to test u32 overflow safety without requiring a GPU backend.
    fn compute_chunk_size(num_variables: u32, max_buffer_bytes: u64) -> Option<u32> {
        let bytes_per_element = (num_variables as u64) * 3 * 4;
        if bytes_per_element == 0 {
            return None;
        }
        let mut chunk_size = max_buffer_bytes / bytes_per_element;
        if chunk_size == 0 {
            return None;
        }
        let dispatch_limit = MAX_WORKGROUPS_PER_DIM * TAYLOR_WORKGROUP_SIZE;
        chunk_size = chunk_size.min(dispatch_limit);
        let nv_k = (num_variables as u64) * 3;
        if nv_k > 0 {
            chunk_size = chunk_size.min(u32::MAX as u64 / nv_k);
        }
        Some(chunk_size as u32)
    }

    #[test]
    fn chunking_caps_for_large_num_variables() {
        // With 500,000 variables and K=3, bid * nv * K must stay within u32.
        // Max safe chunk_size = u32::MAX / (500_000 * 3) = 2863
        let chunk = compute_chunk_size(500_000, u64::MAX).unwrap();
        let product = chunk as u64 * 500_000 * 3;
        assert!(
            product <= u32::MAX as u64,
            "chunk_size * nv * K = {} exceeds u32::MAX",
            product
        );
    }

    #[test]
    fn chunking_caps_for_very_large_num_variables() {
        // With very large num_variables, chunk_size should be very small
        let chunk = compute_chunk_size(1_000_000, u64::MAX).unwrap();
        let product = chunk as u64 * 1_000_000 * 3;
        assert!(
            product <= u32::MAX as u64,
            "chunk_size * nv * K = {} exceeds u32::MAX",
            product
        );
    }

    #[test]
    fn chunking_with_small_buffer() {
        // Buffer too small for even one element
        let result = compute_chunk_size(1000, 1);
        assert!(result.is_none(), "should fail with buffer too small");
    }

    #[test]
    fn chunking_single_variable() {
        let chunk = compute_chunk_size(1, WGPU_MAX_BUFFER_BYTES).unwrap();
        assert!(chunk > 0, "should handle single variable");
        let product = chunk as u64 * 1 * 3;
        assert!(product <= u32::MAX as u64);
    }

    #[test]
    fn chunking_zero_variables() {
        let result = compute_chunk_size(0, WGPU_MAX_BUFFER_BYTES);
        assert!(result.is_none(), "should fail with zero variables");
    }
}