Skip to main content

ggmath/
scalar.rs

1use core::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Not, Rem, Shl, Shr, Sub};
2
3use crate::{Aligned, Alignment, Length, SupportedLength, Unaligned, Vector};
4
5/// A trait for types that can be stored in vectors.
6///
7/// All scalars must implement the [`Copy`] trait, and the
8/// [`ScalarBackend<N, A>`] trait which controls the internal representation and
9/// function implementations of the scalar's math types.
10///
11/// For simple implementations there is the [`ScalarDefault`] trait which
12/// provides a default implementation for [`ScalarBackend`].
13///
14/// # Example
15///
16/// ```
17/// use ggmath::{Scalar, ScalarDefault, Vec2, Vec3, Vec4, vec2, vec3, vec4};
18///
19/// #[derive(Debug, Clone, Copy)]
20/// struct Foo(f32);
21///
22/// impl Scalar for Foo {}
23///
24/// impl ScalarDefault for Foo {}
25///
26/// let v2: Vec2<Foo> = vec2!(Foo(1.0), Foo(2.0));
27/// let v3: Vec3<Foo> = vec3!(Foo(1.0), Foo(2.0), Foo(3.0));
28/// let v4: Vec4<Foo> = vec4!(Foo(1.0), Foo(2.0), Foo(3.0), Foo(4.0));
29///
30/// println!("{v2:?}, {v3:?}, {v4:?}");
31/// ```
32pub trait Scalar:
33    Copy
34    + ScalarBackend<2, Aligned>
35    + ScalarBackend<3, Aligned>
36    + ScalarBackend<4, Aligned>
37    + ScalarBackend<2, Unaligned>
38    + ScalarBackend<3, Unaligned>
39    + ScalarBackend<4, Unaligned>
40{
41}
42
43/// A trait to control the implementation of math types.
44///
45/// More specifically, this trait controls the internal representation and
46/// function implementations of math types with `N` as their length, `Self` as
47/// their scalar type, and `A` as their alignment.
48///
49/// This trait is generic over `N` and `A` (length and alignment). This means
50/// that it can be implemented either seperately for each length and alignment,
51/// or using one implementation that is generic over length and alignment.
52///
53/// The [`ScalarDefault`] trait offers a default implementation for
54/// [`ScalarBackend`] that is useful when SIMD optimizations are unneeded or
55/// impossible.
56///
57/// # Safety
58///
59/// The [`ScalarBackend::VectorRepr`] type must respect the memory-layout
60/// requirements of [`Vector<N, Self, A>`].
61///
62/// # SIMD
63///
64/// SIMD optimizations can be made the hard way or the easy way.
65///
66/// The hard way is using intrinsics. For each target architecture you want to
67/// support you'd need to:
68///
69/// - Implement `ScalarBackend<2, Aligned>` using intrinsics.
70/// - Implement `ScalarBackend<3, Aligned>` using intrinsics.
71/// - Implement `ScalarBackend<4, Aligned>` using intrinsics.
72/// - Write an empty implementation for `ScalarBackend<..., Unaligned>`.
73///
74/// The easy way is using existing math types.
75///
76/// For example, if your scalar type is a wrapper around `f32`, you could use
77/// `Vector<N, f32, A>` as the internal representation for
78/// `Vector<N, { your scalar }, A>`, then convert between the two in
79/// function implementations.
80///
81/// # Example
82///
83/// Lets define a custom scalar type that is a wrapper around `f32`:
84///
85/// ```
86/// use ggmath::Scalar;
87///
88/// #[repr(transparent)]
89/// #[derive(Clone, Copy)]
90/// struct Foo(f32);
91///
92/// impl Scalar for Foo {}
93///
94/// // This needs to be replaced with a manual implementation.
95/// impl ggmath::ScalarDefault for Foo {}
96/// ```
97///
98/// We got a compile error because `ScalarBackend` isn't implemented. Lets
99/// implement it using `Vector<N, f32, A>` as our `VectorRepr`:
100///
101/// ```
102/// # use ggmath::Scalar;
103/// #
104/// # #[repr(transparent)]
105/// # #[derive(Clone, Copy)]
106/// # struct Foo(f32);
107/// #
108/// # impl Scalar for Foo {}
109/// #
110/// use ggmath::{Alignment, Length, ScalarBackend, SupportedLength, Vector};
111///
112/// // SAFETY: Because `Foo` is a wrapper around `f32`, any internal
113/// // representation that `Vector<N, f32, A>` may use is also a valid
114/// // representation for `Vector<N, Foo, A>`.
115/// unsafe impl<const N: usize, A: Alignment> ScalarBackend<N, A> for Foo
116/// where
117///     Length<N>: SupportedLength,
118/// {
119///     type VectorRepr = Vector<N, f32, A>;
120/// }
121/// ```
122///
123/// Now whenever `f32` vectors have SIMD alignment, our vectors have the same
124/// alignment too.
125///
126/// Lets implement addition for `Foo` by adding up the internal `f32`s:
127///
128/// ```
129/// # #[repr(transparent)]
130/// # #[derive(Clone, Copy)]
131/// # struct Foo(f32);
132/// #
133/// use core::ops::Add;
134///
135/// impl Add for Foo {
136///     type Output = Self;
137///
138///     #[inline]
139///     fn add(self, rhs: Self) -> Self::Output {
140///         Self(self.0 + rhs.0)
141///     }
142/// }
143/// ```
144///
145/// An implementation of vector addition that is consistant with `Foo` addition
146/// should add up the internal `f32` vectors just like `Foo` addition adds up
147/// the internal `f32`s.
148///
149/// To implement optimized vector addition we need functions for converting
150/// between `Foo` vectors and `f32` vectors. The builtin functions for
151/// conversions are [`Vector::repr`] and [`Vector::from_repr`]. The latter is an
152/// unsafe function because the internal representation of a vector might have
153/// less memory safety guarantees than the outer vector.
154///
155/// Lets make an extension method for `f32` vectors that converts them to `Foo`
156/// vectors:
157///
158/// ```
159/// # use ggmath::Scalar;
160/// #
161/// # #[repr(transparent)]
162/// # #[derive(Clone, Copy)]
163/// # struct Foo(f32);
164/// #
165/// # impl Scalar for Foo {}
166/// #
167/// # use ggmath::{Alignment, Length, ScalarBackend, SupportedLength, Vector};
168/// #
169/// # unsafe impl<const N: usize, A: Alignment> ScalarBackend<N, A> for Foo
170/// # where
171/// #     Length<N>: SupportedLength,
172/// # {
173/// #     type VectorRepr = Vector<N, f32, A>;
174/// # }
175/// #
176/// trait ToFoo {
177///     type Output;
178///
179///     fn to_foo(self) -> Self::Output;
180/// }
181///
182/// impl<const N: usize, A: Alignment> ToFoo for Vector<N, f32, A>
183/// where
184///     Length<N>: SupportedLength,
185/// {
186///     type Output = Vector<N, Foo, A>;
187///
188///     #[inline]
189///     fn to_foo(self) -> Self::Output {
190///         // SAFETY: Any value of `f32` is a valid value of `Foo`, so any
191///         // value of an `f32` vector is a valid value of a `Foo` vector.
192///         unsafe { Vector::from_repr(self) }
193///     }
194/// }
195/// ```
196///
197/// Now that everything is ready lets implement [`ScalarBackend::vec_add`] which
198/// controls the implementation for vector addition:
199///
200/// ```
201/// # use ggmath::Scalar;
202/// #
203/// # #[repr(transparent)]
204/// # #[derive(Clone, Copy)]
205/// # struct Foo(f32);
206/// #
207/// # impl Scalar for Foo {}
208/// #
209/// # use ggmath::{Alignment, Length, ScalarBackend, SupportedLength, Vector};
210/// #
211/// # trait ToFoo {
212/// #     type Output;
213/// #
214/// #     fn to_foo(self) -> Self::Output;
215/// # }
216/// #
217/// # impl<const N: usize, A: Alignment> ToFoo for Vector<N, f32, A>
218/// # where
219/// #     Length<N>: SupportedLength,
220/// # {
221/// #     type Output = Vector<N, Foo, A>;
222/// #
223/// #     #[inline]
224/// #     fn to_foo(self) -> Self::Output {
225/// #         // SAFETY: Any value of `f32` is a valid value of `Foo`, so any
226/// #         // value of an `f32` vector is a valid value of a `Foo` vector.
227/// #         unsafe { Vector::from_repr(self) }
228/// #     }
229/// # }
230/// #
231/// // SAFETY: Because `Foo` is a wrapper around `f32`, any internal
232/// // representation that `Vector<N, f32, A>` may use is also a valid
233/// // representation for `Vector<N, Foo, A>`.
234/// unsafe impl<const N: usize, A: Alignment> ScalarBackend<N, A> for Foo
235/// where
236///     Length<N>: SupportedLength,
237/// {
238///     type VectorRepr = Vector<N, f32, A>;
239///     
240///     #[inline]
241///     fn vec_add(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A> {
242///         (vec.repr() + rhs.repr()).to_foo()
243///     }
244/// }
245/// ```
246///
247/// Now `Foo` vector addition has whatever SIMD optimizations `f32` vectors
248/// have. This pattern can be expanded for all operators and for any
249/// extension-trait `Foo` vectors implement.
250#[diagnostic::on_unimplemented(
251    message = "`{Self}` is missing an implementation for `ScalarBackend`",
252    note = "consider implementing `ScalarDefault` for `{Self}`"
253)]
254pub unsafe trait ScalarBackend<const N: usize, A: Alignment>
255where
256    Length<N>: SupportedLength,
257{
258    /// The internal representation of [`Vector<N, Self, A>`].
259    ///
260    /// This type must respect the memory layout requirements of
261    /// [`Vector<N, Self, A>`].
262    type VectorRepr: Copy;
263
264    /// Overridable implementation of the `==` operator for vectors.
265    #[inline]
266    fn vec_eq(vec: &Vector<N, Self, A>, other: &Vector<N, Self, A>) -> bool
267    where
268        Self: Scalar + PartialEq,
269    {
270        (0..N).all(|i| vec[i] == other[i])
271    }
272
273    /// Overridable implementation of the `!=` operator for vectors.
274    #[inline]
275    fn vec_ne(vec: &Vector<N, Self, A>, other: &Vector<N, Self, A>) -> bool
276    where
277        Self: Scalar + PartialEq,
278    {
279        !Self::vec_eq(vec, other)
280    }
281
282    /// Overridable implementation of the `-` operator for vectors.
283    #[inline]
284    fn vec_neg(vec: Vector<N, Self, A>) -> Vector<N, Self, A>
285    where
286        Self: Scalar + Neg<Output = Self>,
287    {
288        vec.map(Self::neg)
289    }
290
291    /// Overridable implementation of the `!` operator for vectors.
292    #[inline]
293    fn vec_not(vec: Vector<N, Self, A>) -> Vector<N, Self, A>
294    where
295        Self: Scalar + Not<Output = Self>,
296    {
297        vec.map(Self::not)
298    }
299
300    /// Overridable implementation of the `+` operator for vectors.
301    #[inline]
302    fn vec_add(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
303    where
304        Self: Scalar + Add<Output = Self>,
305    {
306        Vector::from_fn(|i| vec[i] + rhs[i])
307    }
308
309    /// Overridable implementation of the `-` operator for vectors.
310    #[inline]
311    fn vec_sub(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
312    where
313        Self: Scalar + Sub<Output = Self>,
314    {
315        Vector::from_fn(|i| vec[i] - rhs[i])
316    }
317
318    /// Overridable implementation of the `*` operator for vectors.
319    #[inline]
320    fn vec_mul(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
321    where
322        Self: Scalar + Mul<Output = Self>,
323    {
324        Vector::from_fn(|i| vec[i] * rhs[i])
325    }
326
327    /// Overridable implementation of the `/` operator for vectors.
328    #[inline]
329    fn vec_div(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
330    where
331        Self: Scalar + Div<Output = Self>,
332    {
333        Vector::from_fn(|i| vec[i] / rhs[i])
334    }
335
336    /// Overridable implementation of the `%` operator for vectors.    
337    #[inline]
338    fn vec_rem(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
339    where
340        Self: Scalar + Rem<Output = Self>,
341    {
342        Vector::from_fn(|i| vec[i] % rhs[i])
343    }
344
345    /// Overridable implementation of the `<<` operator for vectors.
346    #[inline]
347    fn vec_shl(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
348    where
349        Self: Scalar + Shl<Output = Self>,
350    {
351        Vector::from_fn(|i| vec[i] << rhs[i])
352    }
353
354    /// Overridable implementation of the `>>` operator for vectors.
355    #[inline]
356    fn vec_shr(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
357    where
358        Self: Scalar + Shr<Output = Self>,
359    {
360        Vector::from_fn(|i| vec[i] >> rhs[i])
361    }
362
363    /// Overridable implementation of the `&` operator for vectors.
364    #[inline]
365    fn vec_bitand(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
366    where
367        Self: Scalar + BitAnd<Output = Self>,
368    {
369        Vector::from_fn(|i| vec[i] & rhs[i])
370    }
371
372    /// Overridable implementation of the `|` operator for vectors.
373    #[inline]
374    fn vec_bitor(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
375    where
376        Self: Scalar + BitOr<Output = Self>,
377    {
378        Vector::from_fn(|i| vec[i] | rhs[i])
379    }
380
381    /// Overridable implementation of the `^` operator for vectors.
382    #[inline]
383    fn vec_bitxor(vec: Vector<N, Self, A>, rhs: Vector<N, Self, A>) -> Vector<N, Self, A>
384    where
385        Self: Scalar + BitXor<Output = Self>,
386    {
387        Vector::from_fn(|i| vec[i] ^ rhs[i])
388    }
389}
390
391/// A default implementation for [`ScalarBackend`].
392///
393/// This trait is for simple implementations of the [`Scalar`] trait which don't
394/// require any SIMD optimizations.
395///
396/// Don't use this trait as a generic bound because types that implement
397/// [`ScalarDefault`] today might silently switch to manually implementing
398/// [`ScalarBackend`] in the future.
399///
400/// # Example
401///
402/// ```
403/// use ggmath::{Scalar, ScalarDefault};
404///
405/// #[derive(Debug, Clone, Copy)]
406/// struct Foo(f32);
407///
408/// impl Scalar for Foo {}
409///
410/// impl ScalarDefault for Foo {}
411///
412/// // later we can swap this for a manual implementation of `ScalarBackend` to
413/// // add SIMD optimizations.
414/// ```
415#[diagnostic::on_unimplemented(
416    message = "`{Self}` is missing an implementation for `ScalarBackend`",
417    note = "consider implementing `ScalarDefault` for `{Self}`"
418)]
419pub trait ScalarDefault {}
420
421unsafe impl<const N: usize, T, A: Alignment> ScalarBackend<N, A> for T
422where
423    T: Scalar + ScalarDefault,
424    Length<N>: SupportedLength,
425{
426    type VectorRepr = Vector<N, T, Unaligned>;
427}
428
429impl Scalar for f32 {}
430impl Scalar for f64 {}
431impl Scalar for i8 {}
432impl Scalar for i16 {}
433impl Scalar for i32 {}
434impl Scalar for i64 {}
435impl Scalar for i128 {}
436impl Scalar for isize {}
437impl Scalar for u8 {}
438impl Scalar for u16 {}
439impl Scalar for u32 {}
440impl Scalar for u64 {}
441impl Scalar for u128 {}
442impl Scalar for usize {}
443impl Scalar for bool {}