Skip to main content

rill_core/math/vector/
traits.rs

1use core::fmt;
2use core::ops::{Add, Div, Mul, Neg, Rem, Sub};
3
4use crate::Scalar;
5use crate::Transcendental;
6
7/// Core trait for vector types (basic operations).
8///
9/// Parameterised by element type `T: Scalar` and lane width `N`.
10pub trait Vector<T: Scalar, const N: usize>:
11    Copy
12    + Clone
13    + Send
14    + Sync
15    + 'static
16    + Default
17    + PartialEq
18    + fmt::Debug
19    + Add<Output = Self>
20    + Sub<Output = Self>
21    + Mul<Output = Self>
22    + Div<Output = Self>
23    + Rem<Output = Self>
24    + Neg<Output = Self>
25{
26    /// Construct a vector with all lanes set to the same value.
27    fn splat(value: T) -> Self;
28    /// Load a vector from a slice (panics if slice is too short).
29    fn load(slice: &[T]) -> Self;
30    /// Store the vector lanes into a slice.
31    fn store(&self, slice: &mut [T]);
32    /// Extract the value at the given lane index.
33    fn extract(&self, index: usize) -> T;
34    /// Return a new vector with the value at the given lane replaced.
35    fn insert(&self, index: usize, value: T) -> Self;
36
37    /// Lane-wise addition.
38    fn add(&self, other: &Self) -> Self;
39    /// Lane-wise subtraction.
40    fn sub(&self, other: &Self) -> Self;
41    /// Lane-wise multiplication.
42    fn mul(&self, other: &Self) -> Self;
43    /// Lane-wise division.
44    fn div(&self, other: &Self) -> Self;
45    /// Lane-wise remainder.
46    fn rem(&self, other: &Self) -> Self;
47    /// Lane-wise negation.
48    fn neg(&self) -> Self;
49
50    /// Lane-wise absolute value.
51    fn abs(&self) -> Self;
52    /// Lane-wise minimum.
53    fn min(&self, other: &Self) -> Self;
54    /// Lane-wise maximum.
55    fn max(&self, other: &Self) -> Self;
56    /// Lane-wise clamp to the inclusive range `[min, max]`.
57    fn clamp(&self, min: &Self, max: &Self) -> Self;
58}
59
60/// Trait for vector types with transcendental operations.
61///
62/// Only available for `T: Transcendental` (f32, f64).
63pub trait VectorTranscendental<T: Transcendental, const N: usize>: Vector<T, N> {
64    /// Lane-wise square root.
65    fn sqrt(&self) -> Self;
66    /// Lane-wise exponential (e^x).
67    fn exp(&self) -> Self;
68    /// Lane-wise natural logarithm.
69    fn ln(&self) -> Self;
70    /// Lane-wise sine (input in radians).
71    fn sin(&self) -> Self;
72    /// Lane-wise cosine (input in radians).
73    fn cos(&self) -> Self;
74    /// Lane-wise tangent (input in radians).
75    fn tan(&self) -> Self;
76}
77
78/// Scalar-vector arithmetic operations.
79///
80/// Each method broadcasts the scalar across all lanes.
81/// Blanket-implemented for all [`Vector`] types.
82pub trait VectorScalarOps<T: Scalar, const N: usize>: Vector<T, N> {
83    /// Add a scalar to every lane.
84    fn add_scalar(&self, scalar: T) -> Self {
85        self.add(&Self::splat(scalar))
86    }
87    /// Subtract a scalar from every lane.
88    fn sub_scalar(&self, scalar: T) -> Self {
89        self.sub(&Self::splat(scalar))
90    }
91    /// Multiply every lane by a scalar.
92    fn mul_scalar(&self, scalar: T) -> Self {
93        self.mul(&Self::splat(scalar))
94    }
95    /// Divide every lane by a scalar.
96    fn div_scalar(&self, scalar: T) -> Self {
97        self.div(&Self::splat(scalar))
98    }
99    /// Compute the remainder of every lane divided by a scalar.
100    fn rem_scalar(&self, scalar: T) -> Self {
101        self.rem(&Self::splat(scalar))
102    }
103}
104
105/// Blanket implementation: every [`Vector`] gets scalar ops for free.
106impl<T: Scalar, const N: usize, V: Vector<T, N>> VectorScalarOps<T, N> for V {}
107
108/// Blanket implementation: every [`Vector`] gets reduce ops for free.
109///
110/// Uses element-wise extraction and accumulation. SIMD types may override
111/// individual methods with shuffle-based reductions for better performance.
112impl<T: Scalar, const N: usize, V: Vector<T, N>> VectorReduce<T, N> for V {}
113
114/// Horizontal reduction operations (vector → scalar).
115pub trait VectorReduce<T: Scalar, const N: usize>: Vector<T, N> {
116    /// Sum of all lanes.
117    fn horizontal_sum(&self) -> T {
118        let mut sum = T::ZERO;
119        for i in 0..N {
120            sum += self.extract(i);
121        }
122        sum
123    }
124    /// Product of all lanes.
125    fn horizontal_product(&self) -> T {
126        let mut prod = T::ONE;
127        for i in 0..N {
128            prod *= self.extract(i);
129        }
130        prod
131    }
132    /// Minimum value across all lanes.
133    fn horizontal_min(&self) -> T {
134        let mut min = self.extract(0);
135        for i in 1..N {
136            min = min.min(self.extract(i));
137        }
138        min
139    }
140    /// Maximum value across all lanes.
141    fn horizontal_max(&self) -> T {
142        let mut max = self.extract(0);
143        for i in 1..N {
144            max = max.max(self.extract(i));
145        }
146        max
147    }
148    /// Arithmetic mean of all lanes.
149    fn horizontal_mean(&self) -> T {
150        let sum = self.horizontal_sum();
151        sum / T::from_usize(N)
152    }
153}
154
155/// Vector comparison and masking operations.
156///
157/// Produces a bitmask (or SIMD mask) from lane-wise comparisons,
158/// and allows selecting between two vectors based on a mask.
159pub trait VectorMask<T: Scalar, const N: usize> {
160    /// The mask type (e.g. `i32` bitmask or SIMD mask register).
161    type Mask;
162
163    /// Lane-wise equality comparison.
164    fn eq(&self, other: &Self) -> Self::Mask;
165    /// Lane-wise inequality comparison.
166    fn ne(&self, other: &Self) -> Self::Mask;
167    /// Lane-wise greater-than comparison.
168    fn gt(&self, other: &Self) -> Self::Mask;
169    /// Lane-wise greater-or-equal comparison.
170    fn ge(&self, other: &Self) -> Self::Mask;
171    /// Lane-wise less-than comparison.
172    fn lt(&self, other: &Self) -> Self::Mask;
173    /// Lane-wise less-or-equal comparison.
174    fn le(&self, other: &Self) -> Self::Mask;
175    /// Select lanes from `self` (where mask is truthy) or `other`.
176    fn select(&self, other: &Self, mask: Self::Mask) -> Self;
177    /// Returns true if all mask lanes are set.
178    fn all(mask: &Self::Mask) -> bool;
179}