Skip to main content

ariadnetor_core/
scalar.rs

1//! Scalar trait for tensor element types
2//!
3//! Provides `Scalar` trait unifying real and complex floating-point types.
4//! This auxiliary trait resolves E0592 (inherent impl overlap) that prevents
5//! separate generic impls for `DenseTensorData<T: Float>` and
6//! `DenseTensorData<Complex<T: Float>>`.
7
8use num_complex::Complex;
9use num_traits::{One, Zero};
10
11use crate::backend::DispatchScalar;
12
13mod sealed {
14    pub trait Sealed {}
15    impl Sealed for f32 {}
16    impl Sealed for f64 {}
17    impl Sealed for num_complex::Complex<f32> {}
18    impl Sealed for num_complex::Complex<f64> {}
19}
20
21/// Scalar type for tensor elements (sealed trait).
22///
23/// See ADR-0003 for the design rationale (E0592 avoidance via sealed pattern).
24pub trait Scalar:
25    sealed::Sealed
26    + Clone
27    + Copy
28    + 'static
29    + Send
30    + Sync
31    + Zero
32    + One
33    + std::ops::Add<Output = Self>
34    + std::ops::Mul<Output = Self>
35    + std::ops::Mul<Self::Real, Output = Self>
36    + DispatchScalar
37{
38    /// The real part type. Always `Scalar + Float` — supports both tensor
39    /// element storage and floating-point math (`sqrt`, `exp`, etc.).
40    type Real: Scalar + num_traits::Float;
41    /// The complex type having this scalar's real type as its components.
42    type Complex: Scalar;
43    /// Absolute value (modulus), as the real type.
44    fn abs(self) -> Self::Real;
45    /// Real part.
46    fn re(self) -> Self::Real;
47    /// Imaginary part (always zero for real scalars).
48    fn im(self) -> Self::Real;
49    /// Multiply by a real factor.
50    fn scale_real(self, factor: Self::Real) -> Self;
51    /// Complex conjugate (identity for real scalars).
52    fn conj(self) -> Self;
53    /// Widen into the corresponding complex type.
54    fn into_complex(self) -> Self::Complex;
55    /// Build from real and imaginary parts; for real scalars the
56    /// imaginary part is ignored.
57    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self;
58}
59
60impl Scalar for f32 {
61    type Real = f32;
62    type Complex = Complex<f32>;
63    #[inline]
64    fn abs(self) -> Self::Real {
65        self.abs()
66    }
67    #[inline]
68    fn re(self) -> Self::Real {
69        self
70    }
71    #[inline]
72    fn im(self) -> Self::Real {
73        0.0
74    }
75    #[inline]
76    fn scale_real(self, factor: Self::Real) -> Self {
77        self * factor
78    }
79    #[inline]
80    fn conj(self) -> Self {
81        self
82    }
83    #[inline]
84    fn into_complex(self) -> Self::Complex {
85        Complex::new(self, 0.0)
86    }
87    #[inline]
88    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self {
89        let _ = im;
90        re
91    }
92}
93
94impl Scalar for f64 {
95    type Real = f64;
96    type Complex = Complex<f64>;
97    #[inline]
98    fn abs(self) -> Self::Real {
99        self.abs()
100    }
101    #[inline]
102    fn re(self) -> Self::Real {
103        self
104    }
105    #[inline]
106    fn im(self) -> Self::Real {
107        0.0
108    }
109    #[inline]
110    fn scale_real(self, factor: Self::Real) -> Self {
111        self * factor
112    }
113    #[inline]
114    fn conj(self) -> Self {
115        self
116    }
117    #[inline]
118    fn into_complex(self) -> Self::Complex {
119        Complex::new(self, 0.0)
120    }
121    #[inline]
122    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self {
123        let _ = im;
124        re
125    }
126}
127
128impl Scalar for Complex<f32> {
129    type Real = f32;
130    type Complex = Complex<f32>;
131    #[inline]
132    fn abs(self) -> Self::Real {
133        self.norm()
134    }
135    #[inline]
136    fn re(self) -> Self::Real {
137        self.re
138    }
139    #[inline]
140    fn im(self) -> Self::Real {
141        self.im
142    }
143    #[inline]
144    fn scale_real(self, factor: Self::Real) -> Self {
145        Complex::new(self.re * factor, self.im * factor)
146    }
147    #[inline]
148    fn conj(self) -> Self {
149        Complex::conj(&self)
150    }
151    #[inline]
152    fn into_complex(self) -> Self::Complex {
153        self
154    }
155    #[inline]
156    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self {
157        Complex::new(re, im)
158    }
159}
160
161impl Scalar for Complex<f64> {
162    type Real = f64;
163    type Complex = Complex<f64>;
164    #[inline]
165    fn abs(self) -> Self::Real {
166        self.norm()
167    }
168    #[inline]
169    fn re(self) -> Self::Real {
170        self.re
171    }
172    #[inline]
173    fn im(self) -> Self::Real {
174        self.im
175    }
176    #[inline]
177    fn scale_real(self, factor: Self::Real) -> Self {
178        Complex::new(self.re * factor, self.im * factor)
179    }
180    #[inline]
181    fn conj(self) -> Self {
182        Complex::conj(&self)
183    }
184    #[inline]
185    fn into_complex(self) -> Self::Complex {
186        self
187    }
188    #[inline]
189    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self {
190        Complex::new(re, im)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    /// Verify Scalar trait algebraic laws for any implementing type.
199    /// Uses fully-qualified `Scalar::method(x)` calls to test trait impls,
200    /// not inherent methods (e.g., f64::abs shadows Scalar::abs).
201    fn assert_scalar_laws<S>(x: S, factor: S::Real)
202    where
203        S: Scalar + PartialEq + std::fmt::Debug,
204        S::Real: PartialEq + std::fmt::Debug,
205    {
206        // abs is positive for non-zero input
207        assert!(Scalar::abs(x) > S::Real::zero());
208        // conj is involution
209        assert_eq!(Scalar::conj(Scalar::conj(x)), x);
210        // conj preserves re, negates im (real: 0 == -0, complex: real test)
211        assert_eq!(Scalar::re(Scalar::conj(x)), Scalar::re(x));
212        assert_eq!(Scalar::im(Scalar::conj(x)), S::Real::zero() - Scalar::im(x),);
213        // scale_real identity
214        assert_eq!(Scalar::scale_real(x, S::Real::one()), x);
215        // scale_real with non-trivial factor
216        let scaled = Scalar::scale_real(x, factor);
217        assert_eq!(Scalar::re(scaled), Scalar::re(x) * factor);
218        assert_eq!(Scalar::im(scaled), Scalar::im(x) * factor);
219        // re/im round-trip
220        assert_eq!(S::from_real_imag(Scalar::re(x), Scalar::im(x)), x);
221    }
222
223    #[test]
224    fn test_scalar_laws() {
225        assert_scalar_laws(2.5f32, 3.0);
226        assert_scalar_laws(2.5f64, 3.0);
227        assert_scalar_laws(Complex::new(3.0f32, 4.0), 2.0);
228        assert_scalar_laws(Complex::new(3.0f64, 4.0), 2.0);
229    }
230
231    #[test]
232    fn test_into_complex_f32() {
233        let x = 2.5f32;
234        let z = x.into_complex();
235        assert_eq!(z, Complex::new(2.5f32, 0.0));
236    }
237
238    #[test]
239    fn test_into_complex_f64() {
240        let x = 3.0f64;
241        let z = x.into_complex();
242        assert_eq!(z, Complex::new(3.0f64, 0.0));
243    }
244
245    #[test]
246    fn test_into_complex_already_complex() {
247        let z = Complex::new(1.0f64, 2.0);
248        assert_eq!(z.into_complex(), z);
249
250        let z32 = Complex::new(1.0f32, 2.0);
251        assert_eq!(z32.into_complex(), z32);
252    }
253
254    #[test]
255    fn test_re_im_f64() {
256        let x = 3.5f64;
257        assert_eq!(x.re(), 3.5);
258        assert_eq!(x.im(), 0.0);
259    }
260
261    #[test]
262    fn test_re_im_f32() {
263        let x = 2.5f32;
264        assert_eq!(x.re(), 2.5);
265        assert_eq!(x.im(), 0.0);
266    }
267
268    #[test]
269    fn test_re_im_complex_f64() {
270        let z = Complex::new(3.0f64, 4.0);
271        assert_eq!(z.re(), 3.0);
272        assert_eq!(z.im(), 4.0);
273    }
274
275    #[test]
276    fn test_re_im_complex_f32() {
277        let z = Complex::new(1.0f32, -2.0);
278        assert_eq!(z.re(), 1.0);
279        assert_eq!(z.im(), -2.0);
280    }
281
282    #[test]
283    fn test_from_real_imag_complex_f64() {
284        let z = Complex::<f64>::from_real_imag(3.0, 4.0);
285        assert_eq!(z, Complex::new(3.0, 4.0));
286    }
287
288    #[test]
289    fn test_from_real_imag_complex_f32() {
290        let z = Complex::<f32>::from_real_imag(1.0, -2.0);
291        assert_eq!(z, Complex::new(1.0, -2.0));
292    }
293
294    #[test]
295    fn test_from_real_imag_f64() {
296        let x = f64::from_real_imag(3.0, 999.0);
297        assert_eq!(x, 3.0);
298    }
299
300    #[test]
301    fn test_from_real_imag_f32() {
302        let x = f32::from_real_imag(2.5, 999.0);
303        assert_eq!(x, 2.5);
304    }
305}