Skip to main content

colr_types/
math.rs

1//! Compile-time matrix arithmetic ([`Mat3`]) and floating-point math abstraction ([`Float`]).
2
3#[cfg(all(not(feature = "std"), not(feature = "libm")))]
4compile_error!(
5    "color requires either the `std` or `libm` feature. \
6     Add color = {{ features = [\"std\"] }} to your Cargo.toml."
7);
8
9/// Floating-point scalar operations for color math.
10///
11/// Implemented for `f32` and `f64`. When `f16` stabilizes in Rust it will
12/// implement this trait and all operations will work at half precision without
13/// further changes. Transcendental functions are provided by the standard
14/// library under the `std` feature and by `libm` under `no_std`.
15pub trait Float:
16    Copy
17    + 'static
18    + PartialOrd
19    + core::ops::Add<Output = Self>
20    + core::ops::Sub<Output = Self>
21    + core::ops::Mul<Output = Self>
22    + core::ops::Div<Output = Self>
23    + core::ops::Neg<Output = Self>
24{
25    /// Additive identity.
26    const ZERO: Self;
27    /// Multiplicative identity.
28    const ONE: Self;
29    /// Smallest positive value.
30    const MIN_POSITIVE: Self;
31
32    /// Convert an `f64` literal to `Self`. Used to widen f32 constants in
33    /// generic code; for f32 this narrows back from f64, which is lossless
34    /// for all constants in this crate.
35    fn from_f64(v: f64) -> Self;
36
37    /// base^exp.
38    fn powf(self, exp: Self) -> Self;
39    /// Cube root.
40    fn cbrt(self) -> Self;
41    /// Square root.
42    fn sqrt(self) -> Self;
43    /// Natural logarithm.
44    fn ln(self) -> Self;
45    /// Natural exponential.
46    fn exp(self) -> Self;
47    /// Sine (radians).
48    fn sin(self) -> Self;
49    /// Cosine (radians).
50    fn cos(self) -> Self;
51    /// atan2(self, x).
52    fn atan2(self, x: Self) -> Self;
53    /// Round to nearest integer.
54    fn round(self) -> Self;
55    /// Integer power.
56    fn powi(self, n: i32) -> Self;
57    /// Absolute value.
58    fn abs(self) -> Self;
59    /// Euclidean remainder.
60    fn rem_euclid(self, rhs: Self) -> Self;
61    /// Clamp to [min, max].
62    fn clamp(self, min: Self, max: Self) -> Self;
63    /// Greater of self and other.
64    fn max(self, other: Self) -> Self;
65    /// Lesser of self and other.
66    fn min(self, other: Self) -> Self;
67}
68
69// Generates Float impl for a concrete float type using std inherent methods.
70#[cfg(feature = "std")]
71macro_rules! impl_float_std {
72    ($t:ty, $zero:expr, $one:expr, $min_pos:expr, $from_f64:expr) => {
73        impl Float for $t {
74            const ZERO: $t = $zero;
75            const ONE: $t = $one;
76            const MIN_POSITIVE: $t = $min_pos;
77            #[inline(always)] fn from_f64(v: f64) -> $t { ($from_f64)(v) }
78            #[inline(always)] fn powf(self, exp: $t) -> $t { self.powf(exp) }
79            #[inline(always)] fn cbrt(self) -> $t { self.cbrt() }
80            #[inline(always)] fn sqrt(self) -> $t { self.sqrt() }
81            #[inline(always)] fn ln(self) -> $t { self.ln() }
82            #[inline(always)] fn exp(self) -> $t { self.exp() }
83            #[inline(always)] fn sin(self) -> $t { self.sin() }
84            #[inline(always)] fn cos(self) -> $t { self.cos() }
85            #[inline(always)] fn atan2(self, x: $t) -> $t { self.atan2(x) }
86            #[inline(always)] fn round(self) -> $t { self.round() }
87            #[inline(always)] fn powi(self, n: i32) -> $t { self.powi(n) }
88            #[inline(always)] fn abs(self) -> $t { self.abs() }
89            #[inline(always)] fn rem_euclid(self, rhs: $t) -> $t { self.rem_euclid(rhs) }
90            #[inline(always)] fn clamp(self, min: $t, max: $t) -> $t { self.clamp(min, max) }
91            #[inline(always)] fn max(self, other: $t) -> $t { self.max(other) }
92            #[inline(always)] fn min(self, other: $t) -> $t { self.min(other) }
93        }
94    };
95}
96
97// Generates Float impl for a concrete float type using libm.
98#[cfg(all(not(feature = "std"), feature = "libm"))]
99macro_rules! impl_float_libm {
100    ($t:ty, $zero:expr, $one:expr, $min_pos:expr, $from_f64:expr,
101     $powf:path, $cbrt:path, $sqrt:path, $ln:path, $exp:path,
102     $sin:path, $cos:path, $atan2:path, $round:path,
103     $powi_base:path, $powi_cast:ty,
104     $fabs:path) => {
105        impl Float for $t {
106            const ZERO: $t = $zero;
107            const ONE: $t = $one;
108            const MIN_POSITIVE: $t = $min_pos;
109            #[inline(always)] fn from_f64(v: f64) -> $t { ($from_f64)(v) }
110            #[inline(always)] fn powf(self, exp: $t) -> $t { $powf(self, exp) }
111            #[inline(always)] fn cbrt(self) -> $t { $cbrt(self) }
112            #[inline(always)] fn sqrt(self) -> $t { $sqrt(self) }
113            #[inline(always)] fn ln(self) -> $t { $ln(self) }
114            #[inline(always)] fn exp(self) -> $t { $exp(self) }
115            #[inline(always)] fn sin(self) -> $t { $sin(self) }
116            #[inline(always)] fn cos(self) -> $t { $cos(self) }
117            #[inline(always)] fn atan2(self, x: $t) -> $t { $atan2(self, x) }
118            #[inline(always)] fn round(self) -> $t { $round(self) }
119            #[inline(always)] fn powi(self, n: i32) -> $t { $powi_base(self, n as $powi_cast) }
120            #[inline(always)] fn abs(self) -> $t { $fabs(self) }
121            #[inline(always)] fn rem_euclid(self, rhs: $t) -> $t {
122                let r = self % rhs;
123                if r < $zero { r + $fabs(rhs) } else { r }
124            }
125            #[inline(always)] fn clamp(self, min: $t, max: $t) -> $t {
126                if self < min { min } else if self > max { max } else { self }
127            }
128            #[inline(always)] fn max(self, other: $t) -> $t { if self >= other { self } else { other } }
129            #[inline(always)] fn min(self, other: $t) -> $t { if self <= other { self } else { other } }
130        }
131    };
132}
133
134#[cfg(feature = "std")]
135impl_float_std!(f32, 0.0, 1.0, f32::MIN_POSITIVE, |v: f64| v as f32);
136#[cfg(feature = "std")]
137impl_float_std!(f64, 0.0, 1.0, f64::MIN_POSITIVE, |v: f64| v);
138
139#[cfg(all(not(feature = "std"), feature = "libm"))]
140impl_float_libm!(f32, 0.0f32, 1.0f32, f32::MIN_POSITIVE, |v: f64| v as f32,
141    libm::powf, libm::cbrtf, libm::sqrtf, libm::logf, libm::expf,
142    libm::sinf, libm::cosf, libm::atan2f, libm::roundf,
143    libm::powf, f32,
144    libm::fabsf);
145#[cfg(all(not(feature = "std"), feature = "libm"))]
146impl_float_libm!(f64, 0.0f64, 1.0f64, f64::MIN_POSITIVE, |v: f64| v,
147    libm::pow, libm::cbrt, libm::sqrt, libm::log, libm::exp,
148    libm::sin, libm::cos, libm::atan2, libm::round,
149    libm::pow, f64,
150    libm::fabs);
151
152macro_rules! impl_mat3 {
153    ($name:ident, $scalar:ty, $zero:expr) => {
154        /// Column-major 3x3 matrix with `const fn` arithmetic.
155        ///
156        /// Each column is `[$scalar; 4]` padded to 16 bytes for SIMD alignment.
157        /// All methods are `const fn` for compile-time matrix derivation.
158        #[repr(C, align(16))]
159        #[derive(Debug, Clone, Copy, PartialEq)]
160        pub struct $name {
161            /// Column 0, rows [0..2] with padding at index 3.
162            pub col0: [$scalar; 4],
163            /// Column 1, rows [0..2] with padding at index 3.
164            pub col1: [$scalar; 4],
165            /// Column 2, rows [0..2] with padding at index 3.
166            pub col2: [$scalar; 4],
167        }
168
169        impl $name {
170            /// Left-multiply `a * b`. `const fn`.
171            pub const fn mul(a: &Self, b: &Self) -> Self {
172                macro_rules! e {
173                    ($r:expr, $c:expr) => {
174                        a.col0[$r] * $c[0] + a.col1[$r] * $c[1] + a.col2[$r] * $c[2]
175                    };
176                }
177                Self {
178                    col0: [e!(0, b.col0), e!(1, b.col0), e!(2, b.col0), $zero],
179                    col1: [e!(0, b.col1), e!(1, b.col1), e!(2, b.col1), $zero],
180                    col2: [e!(0, b.col2), e!(1, b.col2), e!(2, b.col2), $zero],
181                }
182            }
183
184            /// Invert via Cramer's rule. `const fn`. Returns a matrix with `NaN`
185            /// or `INFINITY` entries if the determinant is zero or near-zero.
186            /// All standard RGB primary matrices have non-zero determinants.
187            pub const fn invert(m: &Self) -> Self {
188                let [a, b, c, _] = m.col0;
189                let [d, e, f, _] = m.col1;
190                let [g, h, i, _] = m.col2;
191                let det = a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
192                Self {
193                    col0: [
194                        (e * i - f * h) / det,
195                        -(b * i - c * h) / det,
196                        (b * f - c * e) / det,
197                        $zero,
198                    ],
199                    col1: [
200                        -(d * i - f * g) / det,
201                        (a * i - c * g) / det,
202                        -(a * f - c * d) / det,
203                        $zero,
204                    ],
205                    col2: [
206                        (d * h - e * g) / det,
207                        -(a * h - b * g) / det,
208                        (a * e - b * d) / det,
209                        $zero,
210                    ],
211                }
212            }
213
214            /// Construct from row-major `[[$scalar; 3]; 3]`. `const fn`.
215            pub const fn from_rows(m: [[$scalar; 3]; 3]) -> Self {
216                Self {
217                    col0: [m[0][0], m[1][0], m[2][0], $zero],
218                    col1: [m[0][1], m[1][1], m[2][1], $zero],
219                    col2: [m[0][2], m[1][2], m[2][2], $zero],
220                }
221            }
222
223            /// Convert to row-major `[[$scalar; 3]; 3]`. `const fn`.
224            pub const fn to_rows(self) -> [[$scalar; 3]; 3] {
225                [
226                    [self.col0[0], self.col1[0], self.col2[0]],
227                    [self.col0[1], self.col1[1], self.col2[1]],
228                    [self.col0[2], self.col1[2], self.col2[2]],
229                ]
230            }
231
232            /// Apply to a `[$scalar; 3]`. `const fn`.
233            #[inline(always)]
234            pub const fn apply(&self, v: [$scalar; 3]) -> [$scalar; 3] {
235                [
236                    self.col0[0] * v[0] + self.col1[0] * v[1] + self.col2[0] * v[2],
237                    self.col0[1] * v[0] + self.col1[1] * v[1] + self.col2[1] * v[2],
238                    self.col0[2] * v[0] + self.col1[2] * v[1] + self.col2[2] * v[2],
239                ]
240            }
241
242            /// Y row as luminance weights `[w_r, w_g, w_b]`. `const fn`.
243            #[inline(always)]
244            pub const fn luminance_weights(&self) -> [$scalar; 3] {
245                [self.col0[1], self.col1[1], self.col2[1]]
246            }
247        }
248    };
249}
250
251impl_mat3!(Mat3, f32, 0.0f32);
252impl_mat3!(DMat3, f64, 0.0f64);
253
254impl Mat3 {
255    /// Apply to a `glam::Vec4`. Lane 3 preserved.
256    #[cfg(feature = "glam")]
257    #[inline]
258    pub fn apply_glam(&self, v: glam::Vec4) -> glam::Vec4 {
259        let m = glam::Mat3::from_cols(
260            glam::Vec3::from_array([self.col0[0], self.col0[1], self.col0[2]]),
261            glam::Vec3::from_array([self.col1[0], self.col1[1], self.col1[2]]),
262            glam::Vec3::from_array([self.col2[0], self.col2[1], self.col2[2]]),
263        );
264        (m * v.truncate()).extend(v.w)
265    }
266}