rill_core/math/vector/traits.rs
1use core::fmt;
2use core::ops::{Add, Div, Mul, Neg, Rem, Sub};
3
4use crate::Scalar;
5use crate::Transcendental;
6
7/// Core trait for vector types (basic operations).
8///
9/// Parameterised by element type `T: Scalar` and lane width `N`.
10pub trait Vector<T: Scalar, const N: usize>:
11 Copy
12 + Clone
13 + Send
14 + Sync
15 + 'static
16 + Default
17 + PartialEq
18 + fmt::Debug
19 + Add<Output = Self>
20 + Sub<Output = Self>
21 + Mul<Output = Self>
22 + Div<Output = Self>
23 + Rem<Output = Self>
24 + Neg<Output = Self>
25{
26 /// Construct a vector with all lanes set to the same value.
27 fn splat(value: T) -> Self;
28 /// Load a vector from a slice (panics if slice is too short).
29 fn load(slice: &[T]) -> Self;
30 /// Store the vector lanes into a slice.
31 fn store(&self, slice: &mut [T]);
32 /// Extract the value at the given lane index.
33 fn extract(&self, index: usize) -> T;
34 /// Return a new vector with the value at the given lane replaced.
35 fn insert(&self, index: usize, value: T) -> Self;
36
37 /// Lane-wise addition.
38 fn add(&self, other: &Self) -> Self;
39 /// Lane-wise subtraction.
40 fn sub(&self, other: &Self) -> Self;
41 /// Lane-wise multiplication.
42 fn mul(&self, other: &Self) -> Self;
43 /// Lane-wise division.
44 fn div(&self, other: &Self) -> Self;
45 /// Lane-wise remainder.
46 fn rem(&self, other: &Self) -> Self;
47 /// Lane-wise negation.
48 fn neg(&self) -> Self;
49
50 /// Lane-wise absolute value.
51 fn abs(&self) -> Self;
52 /// Lane-wise minimum.
53 fn min(&self, other: &Self) -> Self;
54 /// Lane-wise maximum.
55 fn max(&self, other: &Self) -> Self;
56 /// Lane-wise clamp to the inclusive range `[min, max]`.
57 fn clamp(&self, min: &Self, max: &Self) -> Self;
58}
59
60/// Trait for vector types with transcendental operations.
61///
62/// Only available for `T: Transcendental` (f32, f64).
63pub trait VectorTranscendental<T: Transcendental, const N: usize>: Vector<T, N> {
64 /// Lane-wise square root.
65 fn sqrt(&self) -> Self;
66 /// Lane-wise exponential (e^x).
67 fn exp(&self) -> Self;
68 /// Lane-wise natural logarithm.
69 fn ln(&self) -> Self;
70 /// Lane-wise sine (input in radians).
71 fn sin(&self) -> Self;
72 /// Lane-wise cosine (input in radians).
73 fn cos(&self) -> Self;
74 /// Lane-wise tangent (input in radians).
75 fn tan(&self) -> Self;
76}
77
78/// Scalar-vector arithmetic operations.
79///
80/// Each method broadcasts the scalar across all lanes.
81/// Blanket-implemented for all [`Vector`] types.
82pub trait VectorScalarOps<T: Scalar, const N: usize>: Vector<T, N> {
83 /// Add a scalar to every lane.
84 fn add_scalar(&self, scalar: T) -> Self {
85 self.add(&Self::splat(scalar))
86 }
87 /// Subtract a scalar from every lane.
88 fn sub_scalar(&self, scalar: T) -> Self {
89 self.sub(&Self::splat(scalar))
90 }
91 /// Multiply every lane by a scalar.
92 fn mul_scalar(&self, scalar: T) -> Self {
93 self.mul(&Self::splat(scalar))
94 }
95 /// Divide every lane by a scalar.
96 fn div_scalar(&self, scalar: T) -> Self {
97 self.div(&Self::splat(scalar))
98 }
99 /// Compute the remainder of every lane divided by a scalar.
100 fn rem_scalar(&self, scalar: T) -> Self {
101 self.rem(&Self::splat(scalar))
102 }
103}
104
105/// Blanket implementation: every [`Vector`] gets scalar ops for free.
106impl<T: Scalar, const N: usize, V: Vector<T, N>> VectorScalarOps<T, N> for V {}
107
108/// Blanket implementation: every [`Vector`] gets reduce ops for free.
109///
110/// Uses element-wise extraction and accumulation. SIMD types may override
111/// individual methods with shuffle-based reductions for better performance.
112impl<T: Scalar, const N: usize, V: Vector<T, N>> VectorReduce<T, N> for V {}
113
114/// Horizontal reduction operations (vector → scalar).
115pub trait VectorReduce<T: Scalar, const N: usize>: Vector<T, N> {
116 /// Sum of all lanes.
117 fn horizontal_sum(&self) -> T {
118 let mut sum = T::ZERO;
119 for i in 0..N {
120 sum += self.extract(i);
121 }
122 sum
123 }
124 /// Product of all lanes.
125 fn horizontal_product(&self) -> T {
126 let mut prod = T::ONE;
127 for i in 0..N {
128 prod *= self.extract(i);
129 }
130 prod
131 }
132 /// Minimum value across all lanes.
133 fn horizontal_min(&self) -> T {
134 let mut min = self.extract(0);
135 for i in 1..N {
136 min = min.min(self.extract(i));
137 }
138 min
139 }
140 /// Maximum value across all lanes.
141 fn horizontal_max(&self) -> T {
142 let mut max = self.extract(0);
143 for i in 1..N {
144 max = max.max(self.extract(i));
145 }
146 max
147 }
148 /// Arithmetic mean of all lanes.
149 fn horizontal_mean(&self) -> T {
150 let sum = self.horizontal_sum();
151 sum / T::from_usize(N)
152 }
153}
154
155/// Vector comparison and masking operations.
156///
157/// Produces a bitmask (or SIMD mask) from lane-wise comparisons,
158/// and allows selecting between two vectors based on a mask.
159pub trait VectorMask<T: Scalar, const N: usize> {
160 /// The mask type (e.g. `i32` bitmask or SIMD mask register).
161 type Mask;
162
163 /// Lane-wise equality comparison.
164 fn eq(&self, other: &Self) -> Self::Mask;
165 /// Lane-wise inequality comparison.
166 fn ne(&self, other: &Self) -> Self::Mask;
167 /// Lane-wise greater-than comparison.
168 fn gt(&self, other: &Self) -> Self::Mask;
169 /// Lane-wise greater-or-equal comparison.
170 fn ge(&self, other: &Self) -> Self::Mask;
171 /// Lane-wise less-than comparison.
172 fn lt(&self, other: &Self) -> Self::Mask;
173 /// Lane-wise less-or-equal comparison.
174 fn le(&self, other: &Self) -> Self::Mask;
175 /// Select lanes from `self` (where mask is truthy) or `other`.
176 fn select(&self, other: &Self, mask: Self::Mask) -> Self;
177 /// Returns true if all mask lanes are set.
178 fn all(mask: &Self::Mask) -> bool;
179}