Skip to main content

jugar_web/
simd.rs

1//! SIMD-accelerated operations using trueno.
2//!
3//! This module provides SIMD-optimized computations for game physics and
4//! particle systems using the trueno compute library.
5//!
6//! ## Backend Selection
7//!
8//! Trueno automatically selects the best available backend at runtime:
9//! - **Tier 1**: WebGPU compute shaders (if available)
10//! - **Tier 2**: WASM SIMD 128-bit
11//! - **Tier 3**: Scalar fallback
12//!
13//! ## Usage
14//!
15//! ```rust,ignore
16//! use jugar_web::simd::{SimdVec2, batch_distance_squared};
17//!
18//! // Create SIMD-optimized 2D vectors
19//! let positions = vec![SimdVec2::new(10.0, 20.0), SimdVec2::new(30.0, 40.0)];
20//! let target = SimdVec2::new(0.0, 0.0);
21//!
22//! // Batch compute distances (SIMD accelerated)
23//! let distances = batch_distance_squared(&positions, target);
24//! ```
25
26// mul_add is less readable for dot product
27#![allow(clippy::suboptimal_flops, clippy::missing_const_for_fn)]
28
29use trueno::{Backend, Vector};
30
31/// A SIMD-optimized 2D vector using trueno.
32#[derive(Debug, Clone)]
33pub struct SimdVec2 {
34    data: Vector<f32>,
35}
36
37impl SimdVec2 {
38    /// Creates a new 2D vector.
39    #[must_use]
40    pub fn new(x: f32, y: f32) -> Self {
41        Self {
42            data: Vector::from_slice(&[x, y]),
43        }
44    }
45
46    /// Returns the X component.
47    #[must_use]
48    pub fn x(&self) -> f32 {
49        self.data.as_slice().first().copied().unwrap_or(0.0)
50    }
51
52    /// Returns the Y component.
53    #[must_use]
54    pub fn y(&self) -> f32 {
55        self.data.as_slice().get(1).copied().unwrap_or(0.0)
56    }
57
58    /// Computes the squared magnitude (avoids sqrt for performance).
59    #[must_use]
60    pub fn magnitude_squared(&self) -> f32 {
61        let x = self.x();
62        let y = self.y();
63        x * x + y * y
64    }
65
66    /// Computes the magnitude (length).
67    #[must_use]
68    pub fn magnitude(&self) -> f32 {
69        self.magnitude_squared().sqrt()
70    }
71
72    /// Adds another vector using SIMD-accelerated operation.
73    #[must_use]
74    pub fn add(&self, other: &Self) -> Self {
75        let result = self.data.add(&other.data).unwrap_or_else(|_| {
76            // Fallback for mismatched sizes
77            Vector::from_slice(&[self.x() + other.x(), self.y() + other.y()])
78        });
79        Self { data: result }
80    }
81
82    /// Subtracts another vector using SIMD-accelerated operation.
83    #[must_use]
84    pub fn sub(&self, other: &Self) -> Self {
85        let result = self.data.sub(&other.data).unwrap_or_else(|_| {
86            // Fallback for mismatched sizes
87            Vector::from_slice(&[self.x() - other.x(), self.y() - other.y()])
88        });
89        Self { data: result }
90    }
91
92    /// Multiplies by a scalar using SIMD-accelerated operation.
93    #[must_use]
94    pub fn scale(&self, scalar: f32) -> Self {
95        let result = self.data.scale(scalar).unwrap_or_else(|_| {
96            // Fallback
97            Vector::from_slice(&[self.x() * scalar, self.y() * scalar])
98        });
99        Self { data: result }
100    }
101
102    /// Computes dot product with another vector.
103    #[must_use]
104    pub fn dot(&self, other: &Self) -> f32 {
105        self.x() * other.x() + self.y() * other.y()
106    }
107}
108
109impl Default for SimdVec2 {
110    fn default() -> Self {
111        Self::new(0.0, 0.0)
112    }
113}
114
115/// Batch-computes squared distances from a set of positions to a target.
116///
117/// This uses SIMD acceleration when available for improved performance
118/// with large numbers of positions (e.g., particle systems).
119///
120/// # Arguments
121///
122/// * `positions` - Slice of position vectors
123/// * `target` - Target position to measure distance from
124///
125/// # Returns
126///
127/// Vector of squared distances (one per position)
128#[must_use]
129pub fn batch_distance_squared(positions: &[SimdVec2], target: &SimdVec2) -> Vec<f32> {
130    positions
131        .iter()
132        .map(|pos| {
133            let diff = pos.sub(target);
134            diff.magnitude_squared()
135        })
136        .collect()
137}
138
139/// SIMD-accelerated particle position update.
140///
141/// Updates particle positions based on velocities using batch operations.
142///
143/// # Arguments
144///
145/// * `positions` - Mutable slice of position X/Y pairs (interleaved)
146/// * `velocities` - Slice of velocity X/Y pairs (interleaved)
147/// * `dt` - Delta time in seconds
148///
149/// # Safety
150///
151/// Positions and velocities slices must have equal length.
152pub fn batch_update_positions(positions: &mut [f32], velocities: &[f32], dt: f32) {
153    if positions.len() != velocities.len() {
154        return;
155    }
156
157    // Create trueno vectors for batch operation
158    let pos_vec = Vector::from_slice(positions);
159    let vel_vec = Vector::from_slice(velocities);
160
161    // SIMD-accelerated: positions += velocities * dt
162    if let Ok(scaled_vel) = vel_vec.scale(dt) {
163        if let Ok(new_pos) = pos_vec.add(&scaled_vel) {
164            // Copy results back
165            let result_slice = new_pos.as_slice();
166            for (i, val) in result_slice.iter().enumerate() {
167                if i < positions.len() {
168                    positions[i] = *val;
169                }
170            }
171        }
172    }
173}
174
175/// SIMD-accelerated batch particle physics update.
176///
177/// Updates positions, applies gravity, and returns updated velocities.
178///
179/// # Arguments
180///
181/// * `positions_x` - Particle X positions
182/// * `positions_y` - Particle Y positions
183/// * `velocities_x` - Particle X velocities
184/// * `velocities_y` - Particle Y velocities
185/// * `gravity` - Gravity acceleration
186/// * `dt` - Delta time in seconds
187pub fn batch_particle_update(
188    positions_x: &mut [f32],
189    positions_y: &mut [f32],
190    velocities_x: &[f32],
191    velocities_y: &mut [f32],
192    gravity: f32,
193    dt: f32,
194) {
195    let n = positions_x.len();
196    if n == 0 || positions_y.len() != n || velocities_x.len() != n || velocities_y.len() != n {
197        return;
198    }
199
200    // Create trueno vectors
201    let pos_x = Vector::from_slice(positions_x);
202    let pos_y = Vector::from_slice(positions_y);
203    let vel_x = Vector::from_slice(velocities_x);
204    let vel_y = Vector::from_slice(velocities_y);
205
206    // SIMD-accelerated position update: pos += vel * dt
207    if let Ok(scaled_vx) = vel_x.scale(dt) {
208        if let Ok(new_pos_x) = pos_x.add(&scaled_vx) {
209            for (i, &val) in new_pos_x.as_slice().iter().enumerate() {
210                if i < positions_x.len() {
211                    positions_x[i] = val;
212                }
213            }
214        }
215    }
216
217    if let Ok(scaled_vy) = vel_y.scale(dt) {
218        if let Ok(new_pos_y) = pos_y.add(&scaled_vy) {
219            for (i, &val) in new_pos_y.as_slice().iter().enumerate() {
220                if i < positions_y.len() {
221                    positions_y[i] = val;
222                }
223            }
224        }
225    }
226
227    // Apply gravity to Y velocities
228    let gravity_delta = gravity * dt;
229    for vy in velocities_y.iter_mut() {
230        *vy += gravity_delta;
231    }
232}
233
234/// Computes collision detection for a ball against multiple paddles.
235///
236/// Returns the index of the first paddle that collides, if any.
237///
238/// # Arguments
239///
240/// * `ball_x` - Ball X position
241/// * `ball_y` - Ball Y position
242/// * `ball_radius` - Ball radius
243/// * `paddle_xs` - Paddle X positions
244/// * `paddle_ys` - Paddle Y positions
245/// * `paddle_heights` - Paddle heights
246/// * `paddle_widths` - Paddle widths
247///
248/// # Returns
249///
250/// Index of colliding paddle, or None if no collision
251#[must_use]
252#[allow(clippy::too_many_arguments)]
253pub fn check_paddle_collisions(
254    ball_x: f32,
255    ball_y: f32,
256    ball_radius: f32,
257    paddle_xs: &[f32],
258    paddle_ys: &[f32],
259    paddle_heights: &[f32],
260    paddle_widths: &[f32],
261) -> Option<usize> {
262    let n = paddle_xs.len();
263    if n == 0 || paddle_ys.len() != n || paddle_heights.len() != n || paddle_widths.len() != n {
264        return None;
265    }
266
267    // Create trueno vector for batch ball X subtraction
268    let ball_x_vec = Vector::from_slice(&vec![ball_x; n]);
269    let paddle_x_vec = Vector::from_slice(paddle_xs);
270
271    // SIMD-accelerated: compute all X distances at once
272    let x_distances = ball_x_vec.sub(&paddle_x_vec).ok()?;
273    let x_dist_slice = x_distances.as_slice();
274
275    // Check each paddle for collision
276    for i in 0..n {
277        let x_dist = x_dist_slice.get(i).copied().unwrap_or(f32::MAX).abs();
278        let half_width = paddle_widths.get(i).copied().unwrap_or(0.0) / 2.0;
279        let half_height = paddle_heights.get(i).copied().unwrap_or(0.0) / 2.0;
280
281        // X axis collision check
282        if x_dist < half_width + ball_radius {
283            // Y axis collision check
284            let y_dist = (ball_y - paddle_ys.get(i).copied().unwrap_or(0.0)).abs();
285            if y_dist < half_height + ball_radius {
286                return Some(i);
287            }
288        }
289    }
290
291    None
292}
293
294/// Information about the compute backend being used.
295#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum ComputeBackend {
297    /// CPU scalar (no SIMD)
298    CpuScalar,
299    /// CPU SIMD (SSE/AVX/NEON)
300    CpuSimd,
301    /// WebAssembly SIMD128
302    WasmSimd,
303    /// GPU compute (WebGPU/Vulkan/Metal)
304    Gpu,
305}
306
307impl core::fmt::Display for ComputeBackend {
308    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
309        match self {
310            Self::CpuScalar => write!(f, "CPU Scalar"),
311            Self::CpuSimd => write!(f, "CPU SIMD"),
312            Self::WasmSimd => write!(f, "WASM SIMD128"),
313            Self::Gpu => write!(f, "GPU Compute"),
314        }
315    }
316}
317
318/// Converts trueno Backend to our ComputeBackend enum.
319#[must_use]
320pub fn trueno_backend_to_compute_backend(backend: Backend) -> ComputeBackend {
321    match backend {
322        Backend::Scalar => ComputeBackend::CpuScalar,
323        Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 | Backend::NEON => {
324            ComputeBackend::CpuSimd
325        }
326        Backend::WasmSIMD => ComputeBackend::WasmSimd,
327        Backend::GPU => ComputeBackend::Gpu,
328        Backend::Auto => ComputeBackend::CpuSimd, // Auto typically selects SIMD
329    }
330}
331
332/// Detects the best available compute backend.
333///
334/// This queries trueno's runtime backend selection to determine
335/// what SIMD capabilities are available.
336#[must_use]
337pub fn detect_compute_backend() -> ComputeBackend {
338    let test_vec = Vector::<f32>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
339    trueno_backend_to_compute_backend(test_vec.backend())
340}
341
342/// Benchmark result for SIMD operations.
343#[derive(Debug, Clone)]
344pub struct SimdBenchmark {
345    /// Operation name
346    pub operation: String,
347    /// Number of elements processed
348    pub element_count: usize,
349    /// Backend used
350    pub backend: ComputeBackend,
351    /// Whether SIMD acceleration was applied
352    pub simd_accelerated: bool,
353}
354
355impl SimdBenchmark {
356    /// Creates a new benchmark result.
357    #[must_use]
358    pub fn new(operation: &str, element_count: usize) -> Self {
359        Self {
360            operation: operation.to_string(),
361            element_count,
362            backend: detect_compute_backend(),
363            simd_accelerated: true,
364        }
365    }
366}
367
368#[cfg(test)]
369#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
370mod tests {
371    use super::*;
372
373    // =========================================================================
374    // SimdVec2 Tests
375    // =========================================================================
376
377    #[test]
378    fn test_simd_vec2_new() {
379        let v = SimdVec2::new(3.0, 4.0);
380        assert!((v.x() - 3.0).abs() < 0.001);
381        assert!((v.y() - 4.0).abs() < 0.001);
382    }
383
384    #[test]
385    fn test_simd_vec2_default() {
386        let v = SimdVec2::default();
387        assert!((v.x()).abs() < 0.001);
388        assert!((v.y()).abs() < 0.001);
389    }
390
391    #[test]
392    fn test_simd_vec2_magnitude_squared() {
393        let v = SimdVec2::new(3.0, 4.0);
394        assert!((v.magnitude_squared() - 25.0).abs() < 0.001);
395    }
396
397    #[test]
398    fn test_simd_vec2_magnitude() {
399        let v = SimdVec2::new(3.0, 4.0);
400        assert!((v.magnitude() - 5.0).abs() < 0.001);
401    }
402
403    #[test]
404    fn test_simd_vec2_add() {
405        let a = SimdVec2::new(1.0, 2.0);
406        let b = SimdVec2::new(3.0, 4.0);
407        let c = a.add(&b);
408        assert!((c.x() - 4.0).abs() < 0.001);
409        assert!((c.y() - 6.0).abs() < 0.001);
410    }
411
412    #[test]
413    fn test_simd_vec2_sub() {
414        let a = SimdVec2::new(5.0, 7.0);
415        let b = SimdVec2::new(2.0, 3.0);
416        let c = a.sub(&b);
417        assert!((c.x() - 3.0).abs() < 0.001);
418        assert!((c.y() - 4.0).abs() < 0.001);
419    }
420
421    #[test]
422    fn test_simd_vec2_scale() {
423        let v = SimdVec2::new(2.0, 3.0);
424        let s = v.scale(2.0);
425        assert!((s.x() - 4.0).abs() < 0.001);
426        assert!((s.y() - 6.0).abs() < 0.001);
427    }
428
429    #[test]
430    fn test_simd_vec2_dot() {
431        let a = SimdVec2::new(1.0, 2.0);
432        let b = SimdVec2::new(3.0, 4.0);
433        assert!((a.dot(&b) - 11.0).abs() < 0.001);
434    }
435
436    // =========================================================================
437    // Batch Operations Tests
438    // =========================================================================
439
440    #[test]
441    fn test_batch_distance_squared() {
442        let positions = vec![
443            SimdVec2::new(3.0, 0.0),
444            SimdVec2::new(0.0, 4.0),
445            SimdVec2::new(3.0, 4.0),
446        ];
447        let target = SimdVec2::new(0.0, 0.0);
448
449        let distances = batch_distance_squared(&positions, &target);
450
451        assert_eq!(distances.len(), 3);
452        assert!((distances[0] - 9.0).abs() < 0.001); // 3^2
453        assert!((distances[1] - 16.0).abs() < 0.001); // 4^2
454        assert!((distances[2] - 25.0).abs() < 0.001); // 3^2 + 4^2
455    }
456
457    #[test]
458    fn test_batch_distance_squared_empty() {
459        let positions: Vec<SimdVec2> = vec![];
460        let target = SimdVec2::new(0.0, 0.0);
461
462        let distances = batch_distance_squared(&positions, &target);
463
464        assert!(distances.is_empty());
465    }
466
467    #[test]
468    fn test_batch_update_positions() {
469        let mut positions = vec![0.0, 0.0, 10.0, 10.0]; // Two 2D positions
470        let velocities = vec![100.0, 200.0, -50.0, -100.0];
471        let dt = 0.1;
472
473        batch_update_positions(&mut positions, &velocities, dt);
474
475        assert!((positions[0] - 10.0).abs() < 0.001);
476        assert!((positions[1] - 20.0).abs() < 0.001);
477        assert!((positions[2] - 5.0).abs() < 0.001);
478        assert!((positions[3] - 0.0).abs() < 0.001);
479    }
480
481    #[test]
482    fn test_batch_update_positions_mismatched_lengths() {
483        let mut positions = vec![0.0, 0.0];
484        let velocities = vec![100.0]; // Mismatched length
485
486        // Should not panic, just skip
487        batch_update_positions(&mut positions, &velocities, 0.1);
488
489        // Positions unchanged
490        assert_eq!(positions[0], 0.0);
491    }
492
493    #[test]
494    fn test_batch_particle_update() {
495        let mut pos_x = vec![0.0, 10.0];
496        let mut pos_y = vec![100.0, 200.0];
497        let vel_x = vec![50.0, -25.0];
498        let mut vel_y = vec![0.0, 10.0];
499        let gravity = 100.0;
500        let dt = 0.1;
501
502        batch_particle_update(&mut pos_x, &mut pos_y, &vel_x, &mut vel_y, gravity, dt);
503
504        // Check X positions updated
505        assert!((pos_x[0] - 5.0).abs() < 0.01); // 0 + 50 * 0.1
506        assert!((pos_x[1] - 7.5).abs() < 0.01); // 10 + (-25) * 0.1
507
508        // Check Y positions updated
509        assert!((pos_y[0] - 100.0).abs() < 0.01); // 100 + 0 * 0.1
510        assert!((pos_y[1] - 201.0).abs() < 0.01); // 200 + 10 * 0.1
511
512        // Check gravity applied to velocities
513        assert!((vel_y[0] - 10.0).abs() < 0.01); // 0 + 100 * 0.1
514        assert!((vel_y[1] - 20.0).abs() < 0.01); // 10 + 100 * 0.1
515    }
516
517    // =========================================================================
518    // Collision Detection Tests
519    // =========================================================================
520
521    #[test]
522    fn test_check_paddle_collisions_hit() {
523        let paddle_xs = vec![50.0, 750.0];
524        let paddle_ys = vec![300.0, 300.0];
525        let paddle_heights = vec![100.0, 100.0];
526        let paddle_widths = vec![20.0, 20.0];
527
528        // Ball near left paddle
529        let result = check_paddle_collisions(
530            55.0,
531            300.0,
532            10.0,
533            &paddle_xs,
534            &paddle_ys,
535            &paddle_heights,
536            &paddle_widths,
537        );
538
539        assert_eq!(result, Some(0));
540    }
541
542    #[test]
543    fn test_check_paddle_collisions_miss() {
544        let paddle_xs = vec![50.0, 750.0];
545        let paddle_ys = vec![300.0, 300.0];
546        let paddle_heights = vec![100.0, 100.0];
547        let paddle_widths = vec![20.0, 20.0];
548
549        // Ball in center, no collision
550        let result = check_paddle_collisions(
551            400.0,
552            300.0,
553            10.0,
554            &paddle_xs,
555            &paddle_ys,
556            &paddle_heights,
557            &paddle_widths,
558        );
559
560        assert!(result.is_none());
561    }
562
563    #[test]
564    fn test_check_paddle_collisions_empty() {
565        let result = check_paddle_collisions(400.0, 300.0, 10.0, &[], &[], &[], &[]);
566
567        assert!(result.is_none());
568    }
569
570    // =========================================================================
571    // Backend Detection Tests
572    // =========================================================================
573
574    #[test]
575    fn test_detect_compute_backend() {
576        let backend = detect_compute_backend();
577        // Should be one of the valid backends
578        assert!(matches!(
579            backend,
580            ComputeBackend::CpuScalar
581                | ComputeBackend::CpuSimd
582                | ComputeBackend::WasmSimd
583                | ComputeBackend::Gpu
584        ));
585    }
586
587    #[test]
588    fn test_compute_backend_display() {
589        assert_eq!(format!("{}", ComputeBackend::CpuScalar), "CPU Scalar");
590        assert_eq!(format!("{}", ComputeBackend::CpuSimd), "CPU SIMD");
591        assert_eq!(format!("{}", ComputeBackend::WasmSimd), "WASM SIMD128");
592        assert_eq!(format!("{}", ComputeBackend::Gpu), "GPU Compute");
593    }
594
595    #[test]
596    fn test_simd_benchmark_new() {
597        let bench = SimdBenchmark::new("particle_update", 1000);
598
599        assert_eq!(bench.operation, "particle_update");
600        assert_eq!(bench.element_count, 1000);
601        assert!(bench.simd_accelerated);
602    }
603
604    #[test]
605    fn test_trueno_backend_to_compute_backend() {
606        assert_eq!(
607            trueno_backend_to_compute_backend(Backend::Scalar),
608            ComputeBackend::CpuScalar
609        );
610        assert_eq!(
611            trueno_backend_to_compute_backend(Backend::SSE2),
612            ComputeBackend::CpuSimd
613        );
614        assert_eq!(
615            trueno_backend_to_compute_backend(Backend::AVX2),
616            ComputeBackend::CpuSimd
617        );
618        assert_eq!(
619            trueno_backend_to_compute_backend(Backend::WasmSIMD),
620            ComputeBackend::WasmSimd
621        );
622        assert_eq!(
623            trueno_backend_to_compute_backend(Backend::GPU),
624            ComputeBackend::Gpu
625        );
626    }
627}