Skip to main content

baracuda_types/
numeric.rs

1//! Numeric types used across the CUDA stack.
2//!
3//! These are thin, `#[repr(transparent)]` / `#[repr(C)]` wrappers chosen to
4//! match the layout NVIDIA's headers use for `__half`, `__nv_bfloat16`,
5//! `cuFloatComplex`, and `cuDoubleComplex`. All conversion methods return
6//! the same bit patterns the CUDA runtime itself would produce for typical
7//! inputs; exact agreement with NVIDIA's rounding on edge cases is tested
8//! in the integration suite against `half` and CUDA itself.
9//!
10//! If you already depend on `half` / `num-complex`, enable the `half-crate`
11//! / `num-complex-crate` features for zero-cost `From`/`Into` adapters.
12
13use core::cmp::Ordering;
14use core::fmt;
15
16/// IEEE 754 binary16 ("half-precision", `__half` in CUDA).
17///
18/// Layout: 1 sign bit, 5 exponent bits (bias 15), 10 mantissa bits.
19#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
20#[repr(transparent)]
21pub struct Half(pub u16);
22
23impl Half {
24    pub const ZERO: Self = Self(0x0000);
25    pub const NEG_ZERO: Self = Self(0x8000);
26    pub const ONE: Self = Self(0x3C00);
27    pub const NEG_ONE: Self = Self(0xBC00);
28    pub const INFINITY: Self = Self(0x7C00);
29    pub const NEG_INFINITY: Self = Self(0xFC00);
30    pub const NAN: Self = Self(0x7E00);
31    pub const MIN_POSITIVE: Self = Self(0x0400); // smallest normal
32    pub const MAX: Self = Self(0x7BFF);
33    pub const MIN: Self = Self(0xFBFF);
34    pub const EPSILON: Self = Self(0x1400); // 2^-10
35
36    #[inline]
37    pub const fn from_bits(bits: u16) -> Self {
38        Self(bits)
39    }
40
41    #[inline]
42    pub const fn to_bits(self) -> u16 {
43        self.0
44    }
45
46    #[inline]
47    pub const fn is_nan(self) -> bool {
48        (self.0 & 0x7FFF) > 0x7C00
49    }
50
51    #[inline]
52    pub const fn is_infinite(self) -> bool {
53        (self.0 & 0x7FFF) == 0x7C00
54    }
55
56    #[inline]
57    pub const fn is_finite(self) -> bool {
58        (self.0 & 0x7C00) != 0x7C00
59    }
60
61    #[inline]
62    pub const fn is_sign_negative(self) -> bool {
63        (self.0 & 0x8000) != 0
64    }
65
66    /// Round-to-nearest-even conversion from `f32`.
67    pub fn from_f32(f: f32) -> Self {
68        let bits = f.to_bits();
69        let sign = ((bits >> 16) & 0x8000) as u16;
70        let exp_raw = ((bits >> 23) & 0xFF) as i32;
71        let mant = bits & 0x007F_FFFF;
72
73        // NaN or Inf
74        if exp_raw == 0xFF {
75            if mant != 0 {
76                // NaN: preserve signaling bit position as "quiet" and carry a
77                // subset of the payload.
78                return Self(sign | 0x7E00 | ((mant >> 13) as u16));
79            }
80            return Self(sign | 0x7C00);
81        }
82
83        let e_unbiased = exp_raw - 127; // true exponent
84        let e_half = e_unbiased + 15;
85
86        if e_half >= 0x1F {
87            // Overflow -> inf (any value with true exp >= 16).
88            return Self(sign | 0x7C00);
89        }
90
91        if e_half >= 1 {
92            // Normal result.
93            let trunc = (mant >> 13) as u16;
94            let guard = (mant >> 12) & 1;
95            let sticky = mant & 0x0FFF;
96            let lsb = trunc & 1;
97            let round_up = guard == 1 && (sticky != 0 || lsb == 1);
98            let base = sign | ((e_half as u16) << 10) | trunc;
99            let half = base.wrapping_add(round_up as u16);
100            return Self(half);
101        }
102
103        // Subnormal or underflow.
104        if e_unbiased < -24 {
105            // Fully underflows to signed zero.
106            return Self(sign);
107        }
108
109        // Subnormal range: true exponent in [-24, -15].
110        // The implicit leading 1 comes back into play.
111        let mant_full = mant | 0x0080_0000; // 24-bit mantissa with leading 1
112        let shift = (-14 - e_unbiased) as u32 + 13; // right-shift amount to get 10-bit result
113        let trunc = (mant_full >> shift) as u16;
114        let guard = (mant_full >> (shift - 1)) & 1;
115        let sticky_mask = (1u32 << (shift - 1)) - 1;
116        let sticky = mant_full & sticky_mask;
117        let lsb = trunc & 1;
118        let round_up = guard == 1 && (sticky != 0 || lsb == 1);
119        let half = sign | trunc.wrapping_add(round_up as u16);
120        Self(half)
121    }
122
123    /// Exact conversion to `f32` (every finite `Half` is representable as `f32`).
124    pub fn to_f32(self) -> f32 {
125        let h = self.0 as u32;
126        let sign = (h & 0x8000) << 16;
127        let exp = (h >> 10) & 0x1F;
128        let mant = h & 0x03FF;
129
130        let bits = if exp == 0 {
131            if mant == 0 {
132                sign
133            } else {
134                // Subnormal: normalize.
135                let mut m = mant;
136                let mut e: i32 = 1;
137                while (m & 0x0400) == 0 {
138                    m <<= 1;
139                    e -= 1;
140                }
141                m &= 0x03FF;
142                let exp_f32 = (e + 127 - 15) as u32;
143                sign | (exp_f32 << 23) | (m << 13)
144            }
145        } else if exp == 0x1F {
146            sign | 0x7F80_0000 | (mant << 13)
147        } else {
148            let exp_f32 = exp + 127 - 15;
149            sign | (exp_f32 << 23) | (mant << 13)
150        };
151
152        f32::from_bits(bits)
153    }
154}
155
156impl fmt::Debug for Half {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        write!(f, "Half({})", self.to_f32())
159    }
160}
161
162impl fmt::Display for Half {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        fmt::Display::fmt(&self.to_f32(), f)
165    }
166}
167
168impl PartialOrd for Half {
169    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
170        self.to_f32().partial_cmp(&other.to_f32())
171    }
172}
173
174impl From<Half> for f32 {
175    #[inline]
176    fn from(h: Half) -> f32 {
177        h.to_f32()
178    }
179}
180
181impl From<Half> for f64 {
182    #[inline]
183    fn from(h: Half) -> f64 {
184        h.to_f32() as f64
185    }
186}
187
188impl From<f32> for Half {
189    #[inline]
190    fn from(f: f32) -> Self {
191        Self::from_f32(f)
192    }
193}
194
195/// Brain Floating Point 16 (`__nv_bfloat16` in CUDA). The top 16 bits of an IEEE 754 `f32`.
196#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
197#[repr(transparent)]
198pub struct BFloat16(pub u16);
199
200impl BFloat16 {
201    pub const ZERO: Self = Self(0x0000);
202    pub const NEG_ZERO: Self = Self(0x8000);
203    pub const ONE: Self = Self(0x3F80);
204    pub const NEG_ONE: Self = Self(0xBF80);
205    pub const INFINITY: Self = Self(0x7F80);
206    pub const NEG_INFINITY: Self = Self(0xFF80);
207    pub const NAN: Self = Self(0x7FC0);
208    pub const MIN_POSITIVE: Self = Self(0x0080);
209    pub const MAX: Self = Self(0x7F7F);
210    pub const MIN: Self = Self(0xFF7F);
211    pub const EPSILON: Self = Self(0x3C00);
212
213    #[inline]
214    pub const fn from_bits(bits: u16) -> Self {
215        Self(bits)
216    }
217
218    #[inline]
219    pub const fn to_bits(self) -> u16 {
220        self.0
221    }
222
223    #[inline]
224    pub const fn is_nan(self) -> bool {
225        (self.0 & 0x7FFF) > 0x7F80
226    }
227
228    #[inline]
229    pub const fn is_infinite(self) -> bool {
230        (self.0 & 0x7FFF) == 0x7F80
231    }
232
233    #[inline]
234    pub const fn is_sign_negative(self) -> bool {
235        (self.0 & 0x8000) != 0
236    }
237
238    /// Round-to-nearest-even conversion from `f32` (matches NVIDIA's bfloat16 truncation + rounding).
239    pub fn from_f32(f: f32) -> Self {
240        if f.is_nan() {
241            return Self(0x7FC0);
242        }
243        let bits = f.to_bits();
244        let lsb = (bits >> 16) & 1;
245        // Round-half-to-even: add 0x7FFF + lsb to upper half, then truncate.
246        let rounding_bias = 0x7FFF + lsb;
247        let rounded = bits.wrapping_add(rounding_bias);
248        Self((rounded >> 16) as u16)
249    }
250
251    #[inline]
252    pub fn to_f32(self) -> f32 {
253        f32::from_bits((self.0 as u32) << 16)
254    }
255}
256
257impl fmt::Debug for BFloat16 {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        write!(f, "BFloat16({})", self.to_f32())
260    }
261}
262
263impl fmt::Display for BFloat16 {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        fmt::Display::fmt(&self.to_f32(), f)
266    }
267}
268
269impl PartialOrd for BFloat16 {
270    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271        self.to_f32().partial_cmp(&other.to_f32())
272    }
273}
274
275impl From<BFloat16> for f32 {
276    #[inline]
277    fn from(b: BFloat16) -> f32 {
278        b.to_f32()
279    }
280}
281
282impl From<f32> for BFloat16 {
283    #[inline]
284    fn from(f: f32) -> Self {
285        Self::from_f32(f)
286    }
287}
288
289/// Single-precision complex number (`cuFloatComplex`, layout-compatible with `float2`).
290#[derive(Copy, Clone, Debug, Default, PartialEq)]
291#[repr(C)]
292pub struct Complex32 {
293    pub re: f32,
294    pub im: f32,
295}
296
297impl Complex32 {
298    /// The complex number `0 + 0i`.
299    pub const ZERO: Self = Self { re: 0.0, im: 0.0 };
300    /// The complex number `1 + 0i`.
301    pub const ONE: Self = Self { re: 1.0, im: 0.0 };
302    /// The imaginary unit `0 + 1i`.
303    pub const I: Self = Self { re: 0.0, im: 1.0 };
304
305    /// Construct a complex number from its real and imaginary parts.
306    #[inline]
307    pub const fn new(re: f32, im: f32) -> Self {
308        Self { re, im }
309    }
310
311    /// Squared magnitude: `re² + im²`.
312    #[inline]
313    pub fn norm_sqr(self) -> f32 {
314        self.re * self.re + self.im * self.im
315    }
316
317    /// Complex conjugate: `re - i·im`.
318    #[inline]
319    pub fn conj(self) -> Self {
320        Self {
321            re: self.re,
322            im: -self.im,
323        }
324    }
325}
326
327/// Double-precision complex number (`cuDoubleComplex`, layout-compatible with `double2`).
328#[derive(Copy, Clone, Debug, Default, PartialEq)]
329#[repr(C)]
330pub struct Complex64 {
331    pub re: f64,
332    pub im: f64,
333}
334
335impl Complex64 {
336    /// The complex number `0 + 0i`.
337    pub const ZERO: Self = Self { re: 0.0, im: 0.0 };
338    /// The complex number `1 + 0i`.
339    pub const ONE: Self = Self { re: 1.0, im: 0.0 };
340    /// The imaginary unit `0 + 1i`.
341    pub const I: Self = Self { re: 0.0, im: 1.0 };
342
343    /// Construct a complex number from its real and imaginary parts.
344    #[inline]
345    pub const fn new(re: f64, im: f64) -> Self {
346        Self { re, im }
347    }
348
349    #[inline]
350    pub fn norm_sqr(self) -> f64 {
351        self.re * self.re + self.im * self.im
352    }
353
354    #[inline]
355    pub fn conj(self) -> Self {
356        Self {
357            re: self.re,
358            im: -self.im,
359        }
360    }
361}
362
363#[cfg(feature = "half-crate")]
364mod half_adapters {
365    use super::{BFloat16, Half};
366
367    impl From<half::f16> for Half {
368        #[inline]
369        fn from(v: half::f16) -> Self {
370            Self(v.to_bits())
371        }
372    }
373
374    impl From<Half> for half::f16 {
375        #[inline]
376        fn from(v: Half) -> Self {
377            half::f16::from_bits(v.0)
378        }
379    }
380
381    impl From<half::bf16> for BFloat16 {
382        #[inline]
383        fn from(v: half::bf16) -> Self {
384            Self(v.to_bits())
385        }
386    }
387
388    impl From<BFloat16> for half::bf16 {
389        #[inline]
390        fn from(v: BFloat16) -> Self {
391            half::bf16::from_bits(v.0)
392        }
393    }
394}
395
396#[cfg(feature = "num-complex-crate")]
397mod num_complex_adapters {
398    use super::{Complex32, Complex64};
399
400    impl From<num_complex::Complex<f32>> for Complex32 {
401        #[inline]
402        fn from(v: num_complex::Complex<f32>) -> Self {
403            Self { re: v.re, im: v.im }
404        }
405    }
406
407    impl From<Complex32> for num_complex::Complex<f32> {
408        #[inline]
409        fn from(v: Complex32) -> Self {
410            Self::new(v.re, v.im)
411        }
412    }
413
414    impl From<num_complex::Complex<f64>> for Complex64 {
415        #[inline]
416        fn from(v: num_complex::Complex<f64>) -> Self {
417            Self { re: v.re, im: v.im }
418        }
419    }
420
421    impl From<Complex64> for num_complex::Complex<f64> {
422        #[inline]
423        fn from(v: Complex64) -> Self {
424            Self::new(v.re, v.im)
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn half_constants_roundtrip() {
435        assert_eq!(Half::ZERO.to_f32(), 0.0);
436        assert_eq!(Half::ONE.to_f32(), 1.0);
437        assert_eq!(Half::NEG_ONE.to_f32(), -1.0);
438        assert!(Half::INFINITY.to_f32().is_infinite());
439        assert!(Half::NEG_INFINITY.to_f32().is_infinite());
440        assert!(Half::NAN.to_f32().is_nan());
441    }
442
443    #[test]
444    fn half_roundtrip_exact_values() {
445        for v in [0.0f32, 1.0, -1.0, 0.5, -0.5, 2.0, 65504.0, -65504.0, 1e-4] {
446            let h = Half::from_f32(v);
447            let back = h.to_f32();
448            assert!(
449                (back - v).abs() < (v.abs() * 1e-3 + 1e-7),
450                "{v} -> {back} (half bits = {:#06x})",
451                h.to_bits()
452            );
453        }
454    }
455
456    #[test]
457    fn half_overflow_to_infinity() {
458        assert_eq!(Half::from_f32(1e30).to_bits(), Half::INFINITY.to_bits());
459        assert_eq!(
460            Half::from_f32(-1e30).to_bits(),
461            Half::NEG_INFINITY.to_bits()
462        );
463    }
464
465    #[test]
466    fn half_underflow_to_zero() {
467        assert_eq!(Half::from_f32(1e-30).to_bits(), 0);
468        assert_eq!(Half::from_f32(-1e-30).to_bits(), 0x8000);
469    }
470
471    #[test]
472    fn half_nan_stays_nan() {
473        assert!(Half::from_f32(f32::NAN).is_nan());
474    }
475
476    #[test]
477    fn bfloat_constants_roundtrip() {
478        assert_eq!(BFloat16::ZERO.to_f32(), 0.0);
479        assert_eq!(BFloat16::ONE.to_f32(), 1.0);
480        assert_eq!(BFloat16::NEG_ONE.to_f32(), -1.0);
481        assert!(BFloat16::INFINITY.to_f32().is_infinite());
482        assert!(BFloat16::NAN.to_f32().is_nan());
483    }
484
485    #[test]
486    fn bfloat_truncates_top_16_bits() {
487        // A value whose low 16 f32 bits are zero round-trips exactly.
488        let v: f32 = 1.5; // 0x3FC0_0000
489        let b = BFloat16::from_f32(v);
490        assert_eq!(b.to_bits(), 0x3FC0);
491        assert_eq!(b.to_f32(), 1.5);
492    }
493
494    #[test]
495    fn bfloat_nan_stays_nan() {
496        assert!(BFloat16::from_f32(f32::NAN).is_nan());
497    }
498
499    #[test]
500    fn complex_layout_is_two_floats() {
501        use core::mem::{align_of, size_of};
502        assert_eq!(size_of::<Complex32>(), 8);
503        assert_eq!(size_of::<Complex64>(), 16);
504        assert!(align_of::<Complex32>() >= align_of::<f32>());
505        assert!(align_of::<Complex64>() >= align_of::<f64>());
506    }
507}