Skip to main content

rill_core/math/vector/
macros.rs

1//! Macros for convenient construction of vector expressions.
2//!
3//! This module provides macros that simplify working with the vector eDSL,
4//! allowing expressions in natural mathematical notation.
5//!
6//! ## Examples
7//! ```
8//! use rill_core::vector::prelude::*;
9//! use rill_core::vector::macros::*;
10//!
11//! let a = ScalarVector4::splat(1.0);
12//! let b = ScalarVector4::splat(2.0);
13//! let c = a + b; // regular vector operation
14//! assert_eq!(c, ScalarVector4::splat(3.0));
15//!
16//! // Apply expression to the entire slice
17//! let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
18//! let mut output = [0.0f32; 8];
19//! vec_map!(&input, &mut output, |x| x * 2.0 + 1.0);
20//! // output = [3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0]
21//! ```
22//!
23//! ## Available macros
24//! - [`vec_map!`] – applies a vector expression to the entire slice.
25
26use crate::math::vector::scalar::ScalarVector4;
27use crate::math::vector::traits::Vector;
28
29/// Map over SIMD vector chunks of size 4, applying a closure to each chunk.
30#[macro_export]
31macro_rules! vec_map {
32    ($input:expr, $output:expr, |$x:ident| $($body:tt)*) => {{
33        use $crate::math::vector::traits::Vector;
34        use $crate::math::vector::scalar::ScalarVector4;
35        const N: usize = 4;
36        let input: &[_] = $input;
37        let output: &mut [_] = $output;
38        assert_eq!(input.len(), output.len(), "input and output slices must have equal length");
39
40        if input.is_empty() {
41            return;
42        }
43
44        let closure = |$x: ScalarVector4<_>| -> ScalarVector4<_> { $($body)* };
45
46        let chunks = input.len() / N;
47        let remainder = input.len() % N;
48
49        #[allow(clippy::needless_range_loop)]
50        for i in 0..chunks {
51            let start = i * N;
52            let x = <ScalarVector4<_>>::load(&input[start..start + N]);
53            let y = closure(x);
54            y.store(&mut output[start..start + N]);
55        }
56
57        if remainder > 0 {
58            let start = chunks * N;
59            let mut temp_input = [Default::default(); 4];
60            #[allow(clippy::needless_range_loop)]
61            for i in 0..remainder {
62                temp_input[i] = input[start + i];
63            }
64            let x = <ScalarVector4<_>>::load(&temp_input[0..4]);
65            let y = closure(x);
66            #[allow(clippy::needless_range_loop)]
67            for i in 0..remainder {
68                output[start + i] = y.extract(i);
69            }
70        }
71    }};
72}
73
74pub use crate::vec_map;
75
76// -----------------------------------------------------------------------------
77// Tests
78// -----------------------------------------------------------------------------
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::math::vector::scalar::ScalarVector4;
84
85    #[test]
86    fn test_vec_map_f32() {
87        let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
88        let mut output = [0.0f32; 8];
89
90        // Closure: x * 2.0 + 1.0
91        vec_map!(&input, &mut output, |x| x * 2.0 + 1.0);
92
93        assert_eq!(output[0], 3.0); // 1*2 + 1
94        assert_eq!(output[1], 5.0); // 2*2 + 1
95        assert_eq!(output[2], 7.0);
96        assert_eq!(output[3], 9.0);
97        assert_eq!(output[4], 11.0);
98        assert_eq!(output[5], 13.0);
99        assert_eq!(output[6], 15.0);
100        assert_eq!(output[7], 17.0);
101    }
102
103    #[test]
104    fn test_vec_map_f64() {
105        let input = [1.0f64, 2.0, 3.0, 4.0];
106        let mut output = [0.0f64; 4];
107
108        vec_map!(&input, &mut output, |x| x * 3.0 - 1.0);
109
110        assert_eq!(output[0], 2.0); // 1*3 - 1
111        assert_eq!(output[1], 5.0); // 2*3 - 1
112        assert_eq!(output[2], 8.0);
113        assert_eq!(output[3], 11.0);
114    }
115
116    #[test]
117    fn test_vec_map_empty() {
118        let input: [f32; 0] = [];
119        let mut output: [f32; 0] = [];
120        vec_map!(&input, &mut output, |x| x * 2.0); // should not panic
121    }
122
123    #[test]
124    fn test_vec_map_remainder() {
125        let input = [1.0f32, 2.0, 3.0]; // three elements
126        let mut output = [0.0f32; 3];
127
128        vec_map!(&input, &mut output, |x| x + 10.0);
129
130        assert_eq!(output[0], 11.0);
131        assert_eq!(output[1], 12.0);
132        assert_eq!(output[2], 13.0);
133    }
134}