Skip to main content

math_utils/
traits.rs

1//! Abstract traits
2
3use either::Either;
4#[cfg(feature = "derive_serdes")]
5use serde;
6
7use crate::{approx, num};
8use crate::types::{
9  Addition, Affinity, LinearIso, Multiplication, NonZero, NonNegative, Normalized,
10  Projectivity, Sign
11};
12
13/// Projective completion (homogeneous coordinates)
14pub trait ProjectiveSpace <S : Field> : AffineSpace <S> {
15  /// Affine subspace
16  type Patch : AffineSpace <S>;
17  /// Construct an augmented matrix for the affinity
18  fn homography (affinity :
19    Affinity <S, Self::Patch, Self::Patch,
20      <<Self::Patch as AffineSpace <S>>::Translation as Module <S>>::LinearEndo>
21  ) -> Projectivity <S, Self, Self, <Self::Translation as Module <S>>::LinearEndo> where
22    <Self::Translation as Module <S>>::LinearEndo :
23      From <<<Self::Patch as AffineSpace <S>>::Translation as Module <S>>::LinearEndo>
24  {
25    Projectivity::new (LinearIso::new ((*affinity.linear_iso).into()).unwrap())
26  }
27  /// Return the projective completion (homogeneous coordinates) of the affine point or
28  /// vector
29  fn homogeneous (point_or_vector :
30    Either <Self::Patch, <Self::Patch as AffineSpace <S>>::Translation>
31  ) -> Self where Self : From <(Self::Patch, S)> {
32    match point_or_vector {
33      Either::Left  (point)  => (point, S::one()).into(),
34      Either::Right (vector) => (Self::Patch::from_vector (vector), S::zero()).into()
35    }
36  }
37}
38
39/// `AffineSpace` with translations in a (Euclidean) real inner product space
40pub trait EuclideanSpace <S : Real> : AffineSpace <S> + MetricSpace <S> { }
41
42/// Space of `Point`s (positions) and `Vector`s (displacements)
43pub trait AffineSpace <S : Field> : Point <Self::Translation> {
44  type Translation : VectorSpace <S> + GroupAction <Addition, Self>;
45}
46
47/// Point types convertible to and from a vector type, with difference function that
48/// follows from the free and transitive group action
49pub trait Point <V> : Sized + std::ops::Sub <Self, Output=V> where
50  V : AdditiveGroup + GroupAction <Addition, Self>
51{
52  fn to_vector (self) -> V;
53  fn from_vector (vector : V) -> Self;
54  fn origin() -> Self {
55    Self::from_vector (V::zero())
56  }
57}
58
59/// Set of points with distance function
60pub trait MetricSpace <S : Field> : NormedVectorSpace <S> {
61  fn distance_squared (self, other : Self) -> NonNegative <S> {
62    (self - other).norm_squared()
63  }
64  fn distance (self, other : Self) -> NonNegative <S> where S : Sqrt {
65    self.distance_squared (other).sqrt()
66  }
67}
68
69/// `VectorSpace` with vector length/magnitude function
70pub trait NormedVectorSpace <S : Field> : InnerProductSpace <S> {
71  type Unit : Into <Self>;
72  // required
73  fn norm_squared (self) -> NonNegative <S>;
74  /// Infinity norm or uniform norm
75  fn norm_max (self) -> NonNegative <S> where S : SignedExt;
76  #[must_use]
77  fn normalize (self) -> Self::Unit where S : Sqrt;
78  // provided
79  fn norm (self) -> NonNegative <S> where S : Sqrt {
80    self.norm_squared().sqrt()
81  }
82  /// Returns zero vector if input is the zero vector, otherwise a normalized vector of
83  /// each axis sign.
84  fn unit_sigvec (self) -> Self where S : SignedExt + Sqrt {
85    let v = self.sigvec();
86    if v.is_zero() {
87      v
88    } else {
89      v.normalize().into()
90    }
91  }
92}
93
94/// Bilinear form on a `VectorSpace`
95pub trait InnerProductSpace <S : Field> : VectorSpace <S> + Dot <S> {
96  fn inner_product (self, other : Self) -> S {
97    self.dot (other)
98  }
99  fn outer_product (self, other : Self) -> Self::LinearEndo;
100  fn orthogonal (self) -> Self where S : approx::AbsDiffEq <Epsilon = S>;
101}
102
103/// Module with scalars taken from a `Field`
104pub trait VectorSpace <S : Field> : Module <S> + std::ops::Div <S, Output=Self> + Copy {
105  type NonZero;
106  // required
107  #[must_use]
108  fn map <F> (self, f : F) -> Self where F : FnMut (S) -> S;
109  // provided
110  /// Map `signum_or_zero` over each element of the given vector
111  fn sigvec (self) -> Self where S : SignedExt {
112    self.map (SignedExt::signum_or_zero)
113  }
114  /// Linear interpolation
115  fn interpolate (a : Self, b : Self, t : Normalized <S>) -> Self {
116    a + (b - a) * *t
117  }
118  /// Linear extrapolation
119  fn extrapolate (a : Self, b : Self, t : S) -> Self {
120    a + (b - a) * t
121  }
122}
123
124/// Additional `Vector2` methods
125pub trait Vector2Ext <S> {
126  /// Exterior or wedge product: $(a,b) \wedge (c,d) = ad - bc$
127  fn exterior_product (self, rhs : Self) -> S;
128}
129
130/// Scalar product (bilinear form) on a `Module`
131pub trait Dot <S : Ring> : Module <S> {
132  fn dot (self, other : Self) -> S;
133  /// Self dot product
134  fn self_dot (self) -> NonNegative <S> where S : OrderedRing {
135    NonNegative::unchecked (self.dot (self))
136  }
137}
138
139/// Additive group combined with scalar multiplication
140pub trait Module <S : Ring> : AdditiveGroup + std::ops::Mul <S, Output=Self> + Copy {
141  /// Linear endomorphism represented by a square matrix type
142  type LinearEndo : LinearMap <S, Self, Self> + Ring;
143}
144
145/// Module homomorphism
146pub trait LinearMap <S, V, W> : std::ops::Mul <V, Output=W> + Copy where
147  S : Ring,
148  V : Module <S>,
149  W : Module <S>
150{
151  fn determinant (self) -> S;
152  fn transpose   (self) -> Self;
153}
154
155/// Additional matrix methods
156pub trait Matrix <S> : Copy {
157  type Rows;
158  type Submatrix;
159  /// Returns the row-major matrix
160  fn rows (self) -> Self::Rows;
161  /// Returns the submatrix formed by removing the given (col,row)
162  fn submatrix (self, i : usize, j : usize) -> Self::Submatrix;
163  /// Returns the matrix formed by filling additional row and column with zeros
164  fn fill_zeros (submatrix : Self::Submatrix) -> Self where S : num::Zero;
165  /// Computes the determinant of the (i,j)-submatrix
166  #[inline]
167  fn minor <V, W> (self, i : usize, j : usize) -> S where
168    S : Ring,
169    V : Module <S>,
170    W : Module <S>,
171    Self::Submatrix : LinearMap <S, V, W>
172  {
173    self.submatrix (i, j).determinant()
174  }
175}
176
177/// Householder transformation
178pub trait ElementaryReflector {
179  type Vector;
180  /// Construct a Householder transformation
181  fn elementary_reflector (v : Self::Vector, index : usize) -> Self;
182}
183
184/// `OrderedField` with special functions
185pub trait Real : OrderedField + Exp + Powf + Sqrt + Trig {
186  // TODO: more constants
187  fn pi() -> Self;
188  fn frac_pi_3() -> Self;
189  fn sqrt_3() -> Self;
190  fn frac_1_sqrt_3() -> Self;
191}
192
193/// A `Field` with an ordering
194pub trait OrderedField : Field + OrderedRing { }
195
196/// A (commutative) `Ring` where $1 \neq 0$ and all non-zero elements are invertible
197pub trait Field : Ring + MultiplicativeGroup + Powi + Rational {
198  fn half() -> Self {
199    Self::one() / Self::two()
200  }
201}
202
203/// Some additional methods for fractional parts of numbers
204pub trait Rational {
205  fn floor (self) -> Self;
206  fn ceil (self) -> Self;
207  fn trunc (self) -> Self;
208  fn fract (self) -> Self;
209}
210
211/// Interface for a group with identity represented by `one`, operation defined by `*`
212/// and `/`
213pub trait MultiplicativeGroup : MultiplicativeMonoid +
214  std::ops::Div <Self, Output=Self> + std::ops::DivAssign + num::Inv <Output=Self>
215{ }
216
217/// Ring of integers
218pub trait Integer : OrderedRing + SignedExt + num::PrimInt { }
219
220/// A Ring with an ordering
221pub trait OrderedRing : Ring + SignedExt + MinMax + PartialOrd + std::ops::Rem { }
222
223/// Interface for a ring as a combination of an additive group and a distributive
224/// multiplication operation.
225///
226/// This is basically the base "scalar" trait (a [`Module`] is defined over a `Ring`).
227/// Here we also add the `Debug` trait as a constraint because basically all scalars
228/// should implement `Debug`, and this avoids us needing to explicitly give a `Debug`
229/// constraint when using functions or macros like `assert_eq` that require it.
230pub trait Ring : AdditiveGroup + MultiplicativeMonoid
231  + num::MulAdd <Self, Self, Output=Self> + num::MulAddAssign <Self, Self>
232  + std::fmt::Debug
233{
234  fn two() -> Self {
235    Self::one() + Self::one()
236  }
237  fn three() -> Self {
238    Self::one() + Self::one() + Self::one()
239  }
240  fn four() -> Self {
241    Self::two() * Self::two()
242  }
243  fn five() -> Self {
244    Self::two() + Self::three()
245  }
246  fn six() -> Self {
247    Self::two() * Self::three()
248  }
249  fn seven() -> Self {
250    Self::three() + Self::four()
251  }
252  fn eight() -> Self {
253    Self::two() * Self::two() * Self::two()
254  }
255  fn nine() -> Self {
256    Self::three() * Self::three()
257  }
258  fn ten() -> Self {
259    Self::two() * Self::five()
260  }
261}
262
263/// Interface for a group with identity represented by `zero`, and operation defined by
264/// `+` and `-`
265pub trait AdditiveGroup : AdditiveMonoid +
266  std::ops::Sub <Self, Output=Self> + std::ops::SubAssign + std::ops::Neg <Output=Self>
267{ }
268
269/// (Right) action of a group on a set
270pub trait GroupAction <G, X> : MonoidAction <G, X> where G : Group <Self> { }
271
272/// Monoid with inverses
273pub trait Group <G> : Monoid <G> {
274  fn inverse (elem : G) -> G;
275}
276
277/// Set with identity represented by `one` and (associative) binary operation defined by
278/// `*`
279pub trait MultiplicativeMonoid : Copy + PartialEq +
280  std::ops::Mul <Self, Output=Self> + std::ops::MulAssign + num::One
281{
282  fn squared (self) -> Self {
283    self * self
284  }
285  fn cubed (self) -> Self {
286    self * self * self
287  }
288}
289
290/// Set with identity represented by `zero` and (associative) binary operation defined
291/// by `+`
292pub trait AdditiveMonoid : Sized + PartialEq + std::iter::Sum
293  + std::ops::Add <Self, Output=Self> + std::ops::AddAssign + num::Zero
294{ }
295
296/// (Right) action of a monoid on a set
297pub trait MonoidAction <M, X> : SemigroupAction <M, X> where M : Monoid <Self> { }
298
299/// Semigroup with identity element
300pub trait Monoid <M> : Semigroup <M> {
301  /// Identity element
302  fn identity() -> M;
303}
304
305/// (Right) action of a semigroup on a set
306pub trait SemigroupAction <S, X> : Sized where S : Semigroup <Self> {
307  fn action (self, x : X) -> X;
308}
309
310/// Set with associative binary operation
311pub trait Semigroup <S> : PartialEq {
312  /// Associative operation
313  fn operation (a : S, b : S) -> S;
314}
315
316/// Interface for angle units
317pub trait Angle <S : OrderedField> : Clone + Copy + PartialEq + PartialOrd + Sized +
318  AdditiveGroup + std::ops::Div <Self, Output=S> + std::ops::Mul <S, Output=Self> +
319  std::ops::Div <S, Output=Self> + std::ops::Rem <Self, Output=Self>
320{
321  // required
322  /// Full rotation
323  fn full_turn() -> Self;
324  // provided
325  /// Half rotation
326  fn half_turn() -> Self {
327    Self::full_turn() / S::two()
328  }
329  /// Restrict to `(-half_turn, half_turn]`
330  fn wrap_signed (self) -> Self {
331    if self > Self::half_turn() || self <= -Self::half_turn() {
332      let out = (self + Self::half_turn()).wrap_unsigned() - Self::half_turn();
333      if out == -Self::half_turn() {
334        Self::half_turn()
335      } else {
336        out
337      }
338    } else {
339      self
340    }
341  }
342  /// Restrict to `[0, full_turn)`
343  fn wrap_unsigned (self) -> Self {
344    if self >= Self::full_turn() {
345      self % Self::full_turn()
346    } else if self < Self::zero() {
347      self + Self::full_turn() * ((self / Self::full_turn()).trunc().abs() + S::one())
348    } else {
349      self
350    }
351  }
352}
353
354/// Unsigned integer power function
355pub trait Pow {
356  fn pow (self, exp : u32) -> Self;
357}
358/// Signed integer power function
359pub trait Powi {
360  fn powi (self, n : i32) -> Self;
361}
362/// Fractional power function
363pub trait Powf {
364  fn powf (self, n : Self) -> Self;
365}
366/// Exponential function
367pub trait Exp {
368  fn exp (self) -> Self;
369}
370/// Square root function
371pub trait Sqrt {
372  fn sqrt (self) -> Self;
373}
374/// Cube root function
375pub trait Cbrt {
376  fn cbrt (self) -> Self;
377}
378/// Trigonometric functions
379pub trait Trig : Sized {
380  fn sin      (self) -> Self;
381  fn sin_cos  (self) -> (Self, Self);
382  fn cos      (self) -> Self;
383  fn tan      (self) -> Self;
384  fn asin     (self) -> Self;
385  fn acos     (self) -> Self;
386  fn atan     (self) -> Self;
387  fn atan2    (self, other : Self) -> Self;
388}
389
390/// Provides `min`, `max`, and `clamp` that are not necessarily consistent with those
391/// from `Ord`. This is provided because `f32` and `f64` do not implement `Ord`, so this
392/// trait is defined to give a uniform interface with `Ord` types.
393pub trait MinMax {
394  fn min   (self, other : Self) -> Self;
395  fn max   (self, other : Self) -> Self;
396  fn clamp (self, min : Self, max : Self) -> Self;
397}
398
399/// Function returning number representing sign of self
400pub trait SignedExt : num::Signed {
401  #[inline]
402  fn sign (self) -> Sign {
403    if self.is_zero() {
404      Sign::Zero
405    } else if self.is_positive() {
406      Sign::Positive
407    } else {
408      debug_assert!(self.is_negative());
409      Sign::Negative
410    }
411  }
412  /// Maps `0.0` to `0.0`, otherwise equal to `S::signum` (which would otherwise map
413  /// `+0.0 -> 1.0` and `-0.0 -> -1.0`)
414  #[inline]
415  fn signum_or_zero (self) -> Self where Self : num::Zero {
416    if self.is_zero() {
417      Self::zero()
418    } else {
419      self.signum()
420    }
421  }
422
423  #[inline]
424  fn signum_or_zero_approx (self) -> Self where
425    Self : OrderedRing + approx::AbsDiffEq <Epsilon = Self>
426  {
427    let one = Self::one();
428    if self.abs() < Self::default_epsilon() * (one + one + one + one) {
429      Self::zero()
430    } else {
431      self.signum()
432    }
433  }
434}
435
436/// Adds serde `Serialize` and `DeserializeOwned` constraints.
437///
438/// This makes it easier to conditionally add these constraints to type definitions when
439/// `derive_serdes` feature is enabled.
440#[cfg(not(feature = "derive_serdes"))]
441pub trait MaybeSerDes { }
442#[cfg(feature = "derive_serdes")]
443pub trait MaybeSerDes : serde::Serialize + serde::de::DeserializeOwned { }
444
445impl <S, T> MetricSpace <S> for T where S : Field, T : NormedVectorSpace <S> { }
446impl <T> OrderedField for T where T : Field + OrderedRing { }
447impl <T> Field        for T where
448  T : Ring + MultiplicativeGroup + Powi + Rational { }
449impl <T> Integer      for T where T : OrderedRing + num::PrimInt { }
450impl <T> OrderedRing  for T where
451  T : Ring + SignedExt + MinMax + PartialOrd + std::ops::Rem
452{ }
453impl <T> Ring         for T where
454  T : AdditiveGroup + MultiplicativeMonoid
455    + num::MulAdd<Self, Self, Output=Self> + num::MulAddAssign <Self, Self>
456    + std::fmt::Debug
457{ }
458impl <T> AdditiveGroup for T where
459  T : AdditiveMonoid + std::ops::Sub <Self, Output=Self> + std::ops::SubAssign
460    + std::ops::Neg <Output=Self>
461{ }
462impl <T> MultiplicativeGroup for T where
463  T : MultiplicativeMonoid + std::ops::Div <Self, Output=Self> + std::ops::DivAssign
464    + num::Inv <Output=Self>
465{ }
466impl <T> GroupAction <Multiplication, T> for T where
467  T              : MonoidAction <Multiplication, T>,
468  Multiplication : Group <T>
469{ }
470impl <T> MonoidAction <Multiplication, T> for T where
471  T              : SemigroupAction <Multiplication, T>,
472  Multiplication : Monoid <T>
473{ }
474impl <T> SemigroupAction <Multiplication, T> for T where Multiplication : Semigroup <T> {
475  #[expect(clippy::renamed_function_params)]
476  fn action (self, g : Self) -> Self {
477    Multiplication::operation (g, self)
478  }
479}
480impl <T> GroupAction <Addition, T> for T where
481  T        : MonoidAction <Addition, T>,
482  Addition : Group <T>
483{ }
484impl <T> MonoidAction <Addition, T> for T where
485  T        : SemigroupAction <Addition, T>,
486  Addition : Monoid <T>
487{ }
488impl <T> SemigroupAction <Addition, T> for T where Addition : Semigroup <T> {
489  #[expect(clippy::renamed_function_params)]
490  fn action (self, g : Self) -> Self {
491    Addition::operation (g, self)
492  }
493}
494impl <T> AdditiveMonoid for T where
495  T : Sized + PartialEq + std::iter::Sum + std::ops::Add <Self, Output=Self>
496    + std::ops::AddAssign + num::Zero
497{ }
498impl <T> MultiplicativeMonoid for T where
499  T : Copy + PartialEq + std::ops::Mul <Self, Output=Self> + std::ops::MulAssign
500    + num::One
501{ }
502impl <T : num::Signed> SignedExt for T { }
503
504impl <
505  #[cfg(not(feature = "derive_serdes"))]
506  T,
507  #[cfg(feature = "derive_serdes")]
508  T : serde::Serialize + serde::de::DeserializeOwned
509> MaybeSerDes for T { }
510
511macro impl_integer ($type:ty) {
512  impl MinMax         for $type {
513    fn min (self, other : Self) -> Self {
514      Ord::min (self, other)
515    }
516    fn max (self, other : Self) -> Self {
517      Ord::max (self, other)
518    }
519    fn clamp (self, min : Self, max : Self) -> Self {
520      Ord::clamp (self, min, max)
521    }
522  }
523  impl Pow for $type {
524    fn pow (self, exp : u32) -> Self {
525      self.pow (exp)
526    }
527  }
528}
529impl_integer!(i8);
530impl_integer!(i16);
531impl_integer!(i32);
532impl_integer!(i64);
533
534macro impl_real_float ($type:ident) {
535  impl AffineSpace <Self> for $type {
536    type Translation = $type;
537  }
538  impl Point <Self> for $type {
539    fn to_vector (self) -> $type {
540      self
541    }
542    fn from_vector (vector : $type) -> Self {
543      vector
544    }
545  }
546  impl VectorSpace <Self> for $type {
547    type NonZero = NonZero <$type>;
548    fn map <F> (self, mut f : F) -> Self where F : FnMut (Self) -> Self {
549      f (self)
550    }
551  }
552  impl Module <Self> for $type {
553    type LinearEndo = Self;
554  }
555  impl LinearMap <Self, Self, Self> for $type {
556    fn determinant (self) -> Self {
557      self
558    }
559    fn transpose (self) -> Self {
560      self
561    }
562  }
563  impl Real for $type {
564    fn pi() -> Self {
565      std::$type::consts::PI
566    }
567    fn frac_pi_3() -> Self {
568      std::$type::consts::FRAC_PI_3
569    }
570    #[expect(clippy::excessive_precision)]
571    #[allow(clippy::allow_attributes)]
572    #[allow(clippy::cast_possible_truncation)]
573    #[allow(trivial_numeric_casts)]
574    fn sqrt_3() -> Self {
575      1.732050807568877293527446341505872366942805253810380628055f64 as $type
576    }
577    #[expect(clippy::excessive_precision)]
578    #[allow(clippy::allow_attributes)]
579    #[allow(clippy::cast_possible_truncation)]
580    #[allow(trivial_numeric_casts)]
581    fn frac_1_sqrt_3() -> Self {
582      (1.0f64 / 1.732050807568877293527446341505872366942805253810380628055f64) as $type
583    }
584  }
585  impl Rational for $type {
586    fn floor (self) -> Self {
587      self.floor()
588    }
589    fn ceil (self) -> Self {
590      self.ceil()
591    }
592    fn trunc (self) -> Self {
593      self.trunc()
594    }
595    fn fract (self) -> Self {
596      self.fract()
597    }
598  }
599  impl MinMax for $type {
600    fn min (self, other : Self) -> Self {
601      self.min (other)
602    }
603    fn max (self, other : Self) -> Self {
604      self.max (other)
605    }
606    fn clamp (self, min : Self, max : Self) -> Self {
607      self.clamp (min, max)
608    }
609  }
610  impl Powi for $type {
611    fn powi (self, n : i32) -> Self {
612      self.powi (n)
613    }
614  }
615  impl Powf for $type {
616    fn powf (self, n : Self) -> Self {
617      self.powf (n)
618    }
619  }
620  impl Exp for $type {
621    fn exp (self) -> Self {
622      self.exp()
623    }
624  }
625  impl Sqrt for $type {
626    fn sqrt (self) -> Self {
627      self.sqrt()
628    }
629  }
630  impl Cbrt for $type {
631    fn cbrt (self) -> Self {
632      self.cbrt()
633    }
634  }
635  impl Trig for $type {
636    fn sin (self) -> Self {
637      self.sin()
638    }
639    fn sin_cos (self) -> (Self, Self) {
640      self.sin_cos()
641    }
642    fn cos (self) -> Self {
643      self.cos()
644    }
645    fn tan (self) -> Self {
646      self.tan()
647    }
648    fn asin (self) -> Self {
649      self.asin()
650    }
651    fn acos (self) -> Self {
652      self.acos()
653    }
654    fn atan (self) -> Self {
655      self.atan()
656    }
657    fn atan2 (self, other : Self) -> Self {
658      self.atan2 (other)
659    }
660  }
661}
662
663impl_real_float!(f32);
664impl_real_float!(f64);