Skip to main content

cuda_rust_wasm/runtime/
half.rs

1//! Half-precision (fp16) floating-point support
2//!
3//! Provides a software `Half` type that emulates IEEE 754 half-precision
4//! (binary16) arithmetic, mirroring CUDA's `__half` type. All operations
5//! go through f32 internally, which matches the behavior of CUDA's
6//! half-precision on hardware without native fp16 ALUs.
7
8use std::fmt;
9use std::ops::{Add, Sub, Mul, Div, Neg};
10
11/// IEEE 754 half-precision floating-point (binary16)
12///
13/// Layout: 1 sign bit, 5 exponent bits, 10 mantissa bits.
14/// Range: ±65504, smallest normal: 6.1×10⁻⁵, precision: ~3 decimal digits.
15#[derive(Clone, Copy, PartialEq, Eq, Hash)]
16pub struct Half {
17    bits: u16,
18}
19
20impl Half {
21    /// Zero constant
22    pub const ZERO: Self = Self { bits: 0x0000 };
23    /// One constant
24    pub const ONE: Self = Self { bits: 0x3C00 };
25    /// Negative one
26    pub const NEG_ONE: Self = Self { bits: 0xBC00 };
27    /// Positive infinity
28    pub const INFINITY: Self = Self { bits: 0x7C00 };
29    /// Negative infinity
30    pub const NEG_INFINITY: Self = Self { bits: 0xFC00 };
31    /// Not a number (NaN)
32    pub const NAN: Self = Self { bits: 0x7E00 };
33    /// Maximum finite value (65504)
34    pub const MAX: Self = Self { bits: 0x7BFF };
35    /// Minimum positive normal value
36    pub const MIN_POSITIVE: Self = Self { bits: 0x0400 };
37    /// Machine epsilon (2⁻¹⁰ ≈ 9.77×10⁻⁴)
38    pub const EPSILON: Self = Self { bits: 0x1400 };
39
40    /// Create from raw bits
41    pub const fn from_bits(bits: u16) -> Self {
42        Self { bits }
43    }
44
45    /// Get the raw bits
46    pub const fn to_bits(self) -> u16 {
47        self.bits
48    }
49
50    /// Convert from f32 to half-precision
51    pub fn from_f32(value: f32) -> Self {
52        Self { bits: f32_to_f16(value) }
53    }
54
55    /// Convert to f32
56    pub fn to_f32(self) -> f32 {
57        f16_to_f32(self.bits)
58    }
59
60    /// Convert from f64 to half-precision
61    pub fn from_f64(value: f64) -> Self {
62        Self::from_f32(value as f32)
63    }
64
65    /// Convert to f64
66    pub fn to_f64(self) -> f64 {
67        self.to_f32() as f64
68    }
69
70    /// Check if NaN
71    pub fn is_nan(self) -> bool {
72        (self.bits & 0x7C00) == 0x7C00 && (self.bits & 0x03FF) != 0
73    }
74
75    /// Check if infinite
76    pub fn is_infinite(self) -> bool {
77        (self.bits & 0x7FFF) == 0x7C00
78    }
79
80    /// Check if finite
81    pub fn is_finite(self) -> bool {
82        (self.bits & 0x7C00) != 0x7C00
83    }
84
85    /// Check if normal (not zero, denormal, infinity, or NaN)
86    pub fn is_normal(self) -> bool {
87        let exp = self.bits & 0x7C00;
88        exp != 0 && exp != 0x7C00
89    }
90
91    /// Check if zero (positive or negative)
92    pub fn is_zero(self) -> bool {
93        (self.bits & 0x7FFF) == 0
94    }
95
96    /// Check if sign bit is set
97    pub fn is_sign_negative(self) -> bool {
98        (self.bits & 0x8000) != 0
99    }
100
101    /// Absolute value
102    pub fn abs(self) -> Self {
103        Self { bits: self.bits & 0x7FFF }
104    }
105
106    /// Fused multiply-add: a * b + c
107    pub fn fma(a: Self, b: Self, c: Self) -> Self {
108        Self::from_f32(a.to_f32().mul_add(b.to_f32(), c.to_f32()))
109    }
110
111    /// Square root
112    pub fn sqrt(self) -> Self {
113        Self::from_f32(self.to_f32().sqrt())
114    }
115
116    /// Reciprocal (1/x)
117    pub fn recip(self) -> Self {
118        Self::from_f32(1.0 / self.to_f32())
119    }
120
121    /// Minimum of two values
122    pub fn min(self, other: Self) -> Self {
123        Self::from_f32(self.to_f32().min(other.to_f32()))
124    }
125
126    /// Maximum of two values
127    pub fn max(self, other: Self) -> Self {
128        Self::from_f32(self.to_f32().max(other.to_f32()))
129    }
130
131    /// Clamp between min and max
132    pub fn clamp(self, min: Self, max: Self) -> Self {
133        Self::from_f32(self.to_f32().clamp(min.to_f32(), max.to_f32()))
134    }
135}
136
137// -- Conversion functions (IEEE 754 bit manipulation) -------------------------
138
139/// Convert f32 to f16 bits
140fn f32_to_f16(value: f32) -> u16 {
141    let bits = value.to_bits();
142    let sign = ((bits >> 16) & 0x8000) as u16;
143    let exp = ((bits >> 23) & 0xFF) as i32;
144    let mantissa = bits & 0x007FFFFF;
145
146    if exp == 0xFF {
147        // Infinity or NaN
148        if mantissa == 0 {
149            return sign | 0x7C00; // Infinity
150        } else {
151            return sign | 0x7C00 | ((mantissa >> 13) as u16).max(1); // NaN
152        }
153    }
154
155    let unbiased_exp = exp - 127;
156
157    if unbiased_exp > 15 {
158        // Overflow -> infinity
159        return sign | 0x7C00;
160    }
161
162    if unbiased_exp < -24 {
163        // Underflow -> zero
164        return sign;
165    }
166
167    if unbiased_exp < -14 {
168        // Denormalized
169        let shift = -1 - unbiased_exp;
170        let m = (mantissa | 0x00800000) >> (shift + 13);
171        return sign | m as u16;
172    }
173
174    // Normal
175    let f16_exp = ((unbiased_exp + 15) as u16) << 10;
176    let f16_mantissa = (mantissa >> 13) as u16;
177    sign | f16_exp | f16_mantissa
178}
179
180/// Convert f16 bits to f32
181fn f16_to_f32(bits: u16) -> f32 {
182    let sign = ((bits & 0x8000) as u32) << 16;
183    let exp = ((bits >> 10) & 0x1F) as u32;
184    let mantissa = (bits & 0x03FF) as u32;
185
186    if exp == 0x1F {
187        // Infinity or NaN
188        let f32_bits = sign | 0x7F800000 | (mantissa << 13);
189        return f32::from_bits(f32_bits);
190    }
191
192    if exp == 0 {
193        if mantissa == 0 {
194            // Zero
195            return f32::from_bits(sign);
196        }
197        // Denormalized -> normalize
198        let mut m = mantissa;
199        let mut e: i32 = -14;
200        while (m & 0x0400) == 0 {
201            m <<= 1;
202            e -= 1;
203        }
204        m &= 0x03FF;
205        let f32_exp = ((e + 127) as u32) << 23;
206        let f32_bits = sign | f32_exp | (m << 13);
207        return f32::from_bits(f32_bits);
208    }
209
210    // Normal
211    let f32_exp = ((exp as i32 - 15 + 127) as u32) << 23;
212    let f32_bits = sign | f32_exp | (mantissa << 13);
213    f32::from_bits(f32_bits)
214}
215
216// -- Operator implementations -------------------------------------------------
217
218impl Add for Half {
219    type Output = Self;
220    fn add(self, rhs: Self) -> Self {
221        Self::from_f32(self.to_f32() + rhs.to_f32())
222    }
223}
224
225impl Sub for Half {
226    type Output = Self;
227    fn sub(self, rhs: Self) -> Self {
228        Self::from_f32(self.to_f32() - rhs.to_f32())
229    }
230}
231
232impl Mul for Half {
233    type Output = Self;
234    fn mul(self, rhs: Self) -> Self {
235        Self::from_f32(self.to_f32() * rhs.to_f32())
236    }
237}
238
239impl Div for Half {
240    type Output = Self;
241    fn div(self, rhs: Self) -> Self {
242        Self::from_f32(self.to_f32() / rhs.to_f32())
243    }
244}
245
246impl Neg for Half {
247    type Output = Self;
248    fn neg(self) -> Self {
249        Self { bits: self.bits ^ 0x8000 }
250    }
251}
252
253impl PartialOrd for Half {
254    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
255        self.to_f32().partial_cmp(&other.to_f32())
256    }
257}
258
259impl Default for Half {
260    fn default() -> Self {
261        Self::ZERO
262    }
263}
264
265impl fmt::Debug for Half {
266    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267        write!(f, "Half({})", self.to_f32())
268    }
269}
270
271impl fmt::Display for Half {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        write!(f, "{}", self.to_f32())
274    }
275}
276
277impl From<f32> for Half {
278    fn from(v: f32) -> Self {
279        Self::from_f32(v)
280    }
281}
282
283impl From<f64> for Half {
284    fn from(v: f64) -> Self {
285        Self::from_f64(v)
286    }
287}
288
289impl From<Half> for f32 {
290    fn from(v: Half) -> Self {
291        v.to_f32()
292    }
293}
294
295impl From<Half> for f64 {
296    fn from(v: Half) -> Self {
297        v.to_f64()
298    }
299}
300
301/// Convert a slice of f32 to half-precision
302pub fn f32_to_half_slice(src: &[f32]) -> Vec<Half> {
303    src.iter().map(|&v| Half::from_f32(v)).collect()
304}
305
306/// Convert a slice of half-precision to f32
307pub fn half_to_f32_slice(src: &[Half]) -> Vec<f32> {
308    src.iter().map(|v| v.to_f32()).collect()
309}
310
311/// Dot product in half-precision (accumulated in f32 for precision)
312pub fn half_dot(a: &[Half], b: &[Half]) -> Half {
313    let acc: f32 = a.iter()
314        .zip(b.iter())
315        .map(|(x, y)| x.to_f32() * y.to_f32())
316        .sum();
317    Half::from_f32(acc)
318}
319
320/// GEMV (General Matrix-Vector multiply) in half-precision
321pub fn half_gemv(
322    m: usize,
323    n: usize,
324    alpha: Half,
325    a: &[Half],      // m x n matrix (row-major)
326    x: &[Half],      // n-element vector
327    beta: Half,
328    y: &mut [Half],   // m-element vector
329) {
330    let alpha_f = alpha.to_f32();
331    let beta_f = beta.to_f32();
332
333    for i in 0..m {
334        let mut sum: f32 = 0.0;
335        for j in 0..n {
336            sum += a[i * n + j].to_f32() * x[j].to_f32();
337        }
338        let result = alpha_f * sum + beta_f * y[i].to_f32();
339        y[i] = Half::from_f32(result);
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_half_zero() {
349        assert_eq!(Half::ZERO.to_f32(), 0.0);
350        assert!(Half::ZERO.is_zero());
351    }
352
353    #[test]
354    fn test_half_one() {
355        assert_eq!(Half::ONE.to_f32(), 1.0);
356    }
357
358    #[test]
359    fn test_half_roundtrip() {
360        let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, -100.0, 0.001];
361        for &v in &values {
362            let h = Half::from_f32(v);
363            let back = h.to_f32();
364            assert!((back - v).abs() < 0.01, "Roundtrip failed for {}: got {}", v, back);
365        }
366    }
367
368    #[test]
369    fn test_half_infinity() {
370        assert!(Half::INFINITY.is_infinite());
371        assert!(!Half::INFINITY.is_finite());
372        assert!(Half::NEG_INFINITY.is_infinite());
373    }
374
375    #[test]
376    fn test_half_nan() {
377        assert!(Half::NAN.is_nan());
378        assert!(!Half::NAN.is_finite());
379        assert!(!Half::NAN.is_normal());
380    }
381
382    #[test]
383    fn test_half_arithmetic() {
384        let a = Half::from_f32(2.0);
385        let b = Half::from_f32(3.0);
386
387        assert_eq!((a + b).to_f32(), 5.0);
388        assert_eq!((b - a).to_f32(), 1.0);
389        assert_eq!((a * b).to_f32(), 6.0);
390        let div_result = (b / a).to_f32();
391        assert!((div_result - 1.5).abs() < 0.01);
392    }
393
394    #[test]
395    fn test_half_negation() {
396        let a = Half::from_f32(5.0);
397        assert_eq!((-a).to_f32(), -5.0);
398        assert_eq!((-(-a)).to_f32(), 5.0);
399    }
400
401    #[test]
402    fn test_half_comparison() {
403        let a = Half::from_f32(1.0);
404        let b = Half::from_f32(2.0);
405
406        assert!(a < b);
407        assert!(b > a);
408        assert!(a <= a);
409        assert!(a >= a);
410    }
411
412    #[test]
413    fn test_half_abs() {
414        let neg = Half::from_f32(-3.5);
415        let pos = neg.abs();
416        assert!((pos.to_f32() - 3.5).abs() < 0.01);
417    }
418
419    #[test]
420    fn test_half_fma() {
421        let a = Half::from_f32(2.0);
422        let b = Half::from_f32(3.0);
423        let c = Half::from_f32(1.0);
424
425        let result = Half::fma(a, b, c);
426        assert!((result.to_f32() - 7.0).abs() < 0.01);
427    }
428
429    #[test]
430    fn test_half_sqrt() {
431        let a = Half::from_f32(4.0);
432        assert!((a.sqrt().to_f32() - 2.0).abs() < 0.01);
433    }
434
435    #[test]
436    fn test_half_min_max() {
437        let a = Half::from_f32(1.0);
438        let b = Half::from_f32(3.0);
439
440        assert_eq!(a.min(b).to_f32(), 1.0);
441        assert_eq!(a.max(b).to_f32(), 3.0);
442    }
443
444    #[test]
445    fn test_half_clamp() {
446        let v = Half::from_f32(5.0);
447        let lo = Half::from_f32(0.0);
448        let hi = Half::from_f32(3.0);
449
450        assert_eq!(v.clamp(lo, hi).to_f32(), 3.0);
451    }
452
453    #[test]
454    fn test_half_overflow() {
455        let big = Half::from_f32(100000.0);
456        assert!(big.is_infinite());
457    }
458
459    #[test]
460    fn test_half_underflow() {
461        let tiny = Half::from_f32(1e-10);
462        assert!(tiny.is_zero() || !tiny.is_normal());
463    }
464
465    #[test]
466    fn test_f32_to_half_slice() {
467        let src = vec![1.0f32, 2.0, 3.0];
468        let halves = f32_to_half_slice(&src);
469        let back = half_to_f32_slice(&halves);
470        assert_eq!(back, src);
471    }
472
473    #[test]
474    fn test_half_dot_product() {
475        let a = f32_to_half_slice(&[1.0, 2.0, 3.0]);
476        let b = f32_to_half_slice(&[4.0, 5.0, 6.0]);
477
478        let result = half_dot(&a, &b);
479        // 1*4 + 2*5 + 3*6 = 32
480        assert!((result.to_f32() - 32.0).abs() < 0.1);
481    }
482
483    #[test]
484    fn test_half_gemv() {
485        // 2x3 matrix [[1,2,3],[4,5,6]]
486        let a = f32_to_half_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
487        let x = f32_to_half_slice(&[1.0, 1.0, 1.0]);
488        let mut y = f32_to_half_slice(&[0.0, 0.0]);
489
490        half_gemv(2, 3, Half::ONE, &a, &x, Half::ZERO, &mut y);
491
492        assert!((y[0].to_f32() - 6.0).abs() < 0.1);  // 1+2+3
493        assert!((y[1].to_f32() - 15.0).abs() < 0.1); // 4+5+6
494    }
495
496    #[test]
497    fn test_half_display() {
498        let h = Half::from_f32(3.14);
499        let s = format!("{}", h);
500        assert!(s.starts_with("3.1"));
501    }
502
503    #[test]
504    fn test_half_from_f64() {
505        let h = Half::from_f64(2.5);
506        assert!((h.to_f64() - 2.5).abs() < 0.01);
507    }
508
509    #[test]
510    fn test_half_recip() {
511        let a = Half::from_f32(4.0);
512        assert!((a.recip().to_f32() - 0.25).abs() < 0.01);
513    }
514
515    #[test]
516    fn test_half_max_value() {
517        let max = Half::MAX;
518        assert!((max.to_f32() - 65504.0).abs() < 1.0);
519        assert!(max.is_finite());
520    }
521
522    #[test]
523    fn test_half_is_sign_negative() {
524        assert!(!Half::from_f32(1.0).is_sign_negative());
525        assert!(Half::from_f32(-1.0).is_sign_negative());
526        assert!(!Half::ZERO.is_sign_negative());
527    }
528}