Skip to main content

fory_core/
float16.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! IEEE 754 half-precision (binary16) floating-point type.
19//!
20//! This module provides a `float16` type that represents IEEE 754 binary16
21//! format (16-bit floating point). The type is a transparent wrapper around
22//! `u16` and provides IEEE-compliant conversions to/from `f32`, classification
23//! methods, and arithmetic operations.
24
25use std::cmp::Ordering;
26use std::fmt;
27use std::hash::{Hash, Hasher};
28use std::ops::{Add, Div, Mul, Neg, Sub};
29
30/// IEEE 754 binary16 (half-precision) floating-point type.
31///
32/// This is a 16-bit floating-point format with:
33/// - 1 sign bit
34/// - 5 exponent bits (bias = 15)
35/// - 10 mantissa bits (with implicit leading 1 for normalized values)
36///
37/// Special values:
38/// - ±0: exponent = 0, mantissa = 0
39/// - ±Inf: exponent = 31, mantissa = 0
40/// - NaN: exponent = 31, mantissa ≠ 0
41/// - Subnormals: exponent = 0, mantissa ≠ 0
42#[repr(transparent)]
43#[derive(Copy, Clone, Default)]
44#[allow(non_camel_case_types)]
45pub struct float16(u16);
46
47// Bit layout constants
48const SIGN_MASK: u16 = 0x8000;
49const EXP_MASK: u16 = 0x7C00;
50const MANTISSA_MASK: u16 = 0x03FF;
51const EXP_SHIFT: u32 = 10;
52// const EXP_BIAS: i32 = 15;  // Reserved for future use
53const MAX_EXP: i32 = 31;
54
55// Special bit patterns
56const INFINITY_BITS: u16 = 0x7C00;
57const NEG_INFINITY_BITS: u16 = 0xFC00;
58const QUIET_NAN_BITS: u16 = 0x7E00;
59
60impl float16 {
61    // ============ Construction ============
62
63    /// Create a `float16` from raw bits.
64    ///
65    /// This is a const function that performs no validation.
66    #[inline(always)]
67    pub const fn from_bits(bits: u16) -> Self {
68        Self(bits)
69    }
70
71    /// Extract the raw bit representation.
72    #[inline(always)]
73    pub const fn to_bits(self) -> u16 {
74        self.0
75    }
76
77    // ============ Constants ============
78
79    /// Positive zero (+0.0).
80    pub const ZERO: Self = Self(0x0000);
81
82    /// Negative zero (-0.0).
83    pub const NEG_ZERO: Self = Self(0x8000);
84
85    /// Positive infinity.
86    pub const INFINITY: Self = Self(INFINITY_BITS);
87
88    /// Negative infinity.
89    pub const NEG_INFINITY: Self = Self(NEG_INFINITY_BITS);
90
91    /// Quiet NaN (canonical).
92    pub const NAN: Self = Self(QUIET_NAN_BITS);
93
94    /// Maximum finite value (65504.0).
95    pub const MAX: Self = Self(0x7BFF);
96
97    /// Minimum positive normal value (2^-14 ≈ 6.104e-5).
98    pub const MIN_POSITIVE: Self = Self(0x0400);
99
100    /// Minimum positive subnormal value (2^-24 ≈ 5.96e-8).
101    pub const MIN_POSITIVE_SUBNORMAL: Self = Self(0x0001);
102
103    // ============ IEEE 754 Conversion ============
104
105    /// Convert `f32` to `float16` using IEEE 754 round-to-nearest, ties-to-even.
106    ///
107    /// Handles:
108    /// - NaN → NaN (preserves payload bits where possible, ensures quiet NaN)
109    /// - ±Inf → ±Inf
110    /// - ±0 → ±0 (preserves sign)
111    /// - Overflow → ±Inf
112    /// - Underflow → subnormal or ±0
113    /// - Normal values → rounded to nearest representable value
114    pub fn from_f32(value: f32) -> Self {
115        let bits = value.to_bits();
116        let sign = bits & 0x8000_0000;
117        let exp = ((bits >> 23) & 0xFF) as i32;
118        let mantissa = bits & 0x007F_FFFF;
119
120        // Handle special cases
121        if exp == 255 {
122            // Inf or NaN
123            if mantissa == 0 {
124                // Infinity
125                return Self(((sign >> 16) | INFINITY_BITS as u32) as u16);
126            } else {
127                // NaN - preserve lower 10 bits of payload, ensure quiet NaN
128                let nan_payload = (mantissa >> 13) & MANTISSA_MASK as u32;
129                let quiet_bit = 0x0200; // Bit 9 = quiet NaN bit
130                return Self(
131                    ((sign >> 16) | INFINITY_BITS as u32 | quiet_bit | nan_payload) as u16,
132                );
133            }
134        }
135
136        // Convert exponent from f32 bias (127) to f16 bias (15)
137        let exp16 = exp - 127 + 15;
138
139        // Handle zero
140        if exp == 0 && mantissa == 0 {
141            return Self((sign >> 16) as u16);
142        }
143
144        // Handle overflow (exponent too large for f16)
145        if exp16 >= MAX_EXP {
146            // Overflow to infinity
147            return Self(((sign >> 16) | INFINITY_BITS as u32) as u16);
148        }
149
150        // Handle underflow (exponent too small for f16)
151        if exp16 <= 0 {
152            // Subnormal or underflow to zero
153            if exp16 < -10 {
154                // Too small even for subnormal - round to zero
155                return Self((sign >> 16) as u16);
156            }
157
158            // Create subnormal
159            // Shift mantissa right by (1 - exp16) positions
160            let shift = 1 - exp16;
161            let implicit_bit = 1u32 << 23; // f32 implicit leading 1
162            let full_mantissa = implicit_bit | mantissa;
163
164            // Shift and round
165            let shift_total = 13 + shift;
166            let round_bit = 1u32 << (shift_total - 1);
167            let sticky_mask = round_bit - 1;
168            let sticky = (full_mantissa & sticky_mask) != 0;
169            let mantissa16 = full_mantissa >> shift_total;
170
171            // Round to nearest, ties to even
172            let result = if (full_mantissa & round_bit) != 0 && (sticky || (mantissa16 & 1) != 0) {
173                mantissa16 + 1
174            } else {
175                mantissa16
176            };
177
178            return Self(((sign >> 16) | result) as u16);
179        }
180
181        // Normal case: convert mantissa (23 bits → 10 bits)
182        // f32 mantissa has 23 bits, f16 has 10 bits
183        // Need to round off 13 bits
184
185        let round_bit = 1u32 << 12; // Bit 12 of f32 mantissa
186        let sticky_mask = round_bit - 1;
187        let sticky = (mantissa & sticky_mask) != 0;
188        let mantissa10 = mantissa >> 13;
189
190        // Round to nearest, ties to even
191        let rounded_mantissa = if (mantissa & round_bit) != 0 && (sticky || (mantissa10 & 1) != 0) {
192            mantissa10 + 1
193        } else {
194            mantissa10
195        };
196
197        // Check if rounding caused mantissa overflow
198        if rounded_mantissa > MANTISSA_MASK as u32 {
199            // Mantissa overflow - increment exponent
200            let new_exp = exp16 + 1;
201            if new_exp >= MAX_EXP {
202                // Overflow to infinity
203                return Self(((sign >> 16) | INFINITY_BITS as u32) as u16);
204            }
205            // Carry into exponent, mantissa becomes 0
206            return Self(((sign >> 16) | ((new_exp as u32) << EXP_SHIFT)) as u16);
207        }
208
209        // Assemble the result
210        let result = (sign >> 16) | ((exp16 as u32) << EXP_SHIFT) | rounded_mantissa;
211        Self(result as u16)
212    }
213
214    /// Convert `float16` to `f32` (exact conversion).
215    ///
216    /// All `float16` values are exactly representable in `f32`.
217    pub fn to_f32(self) -> f32 {
218        let bits = self.0;
219        let sign = (bits & SIGN_MASK) as u32;
220        let exp = ((bits & EXP_MASK) >> EXP_SHIFT) as i32;
221        let mantissa = (bits & MANTISSA_MASK) as u32;
222
223        // Handle special cases
224        if exp == MAX_EXP {
225            // Inf or NaN
226            if mantissa == 0 {
227                // Infinity
228                return f32::from_bits((sign << 16) | 0x7F80_0000);
229            } else {
230                // NaN - preserve payload
231                let nan_payload = mantissa << 13;
232                return f32::from_bits((sign << 16) | 0x7F80_0000 | nan_payload);
233            }
234        }
235
236        if exp == 0 {
237            if mantissa == 0 {
238                // Zero
239                return f32::from_bits(sign << 16);
240            } else {
241                // Subnormal - convert to normal f32
242                // Find leading 1 in mantissa
243                let mut m = mantissa;
244                let mut e = -14i32; // f16 subnormal exponent
245
246                // Normalize
247                while (m & 0x0400) == 0 {
248                    m <<= 1;
249                    e -= 1;
250                }
251
252                // Remove implicit leading 1
253                m &= 0x03FF;
254
255                // Convert to f32 exponent
256                let exp32 = e + 127;
257                let mantissa32 = m << 13;
258
259                return f32::from_bits((sign << 16) | ((exp32 as u32) << 23) | mantissa32);
260            }
261        }
262
263        // Normal value
264        let exp32 = exp - 15 + 127; // Convert bias from 15 to 127
265        let mantissa32 = mantissa << 13; // Expand mantissa from 10 to 23 bits
266
267        f32::from_bits((sign << 16) | ((exp32 as u32) << 23) | mantissa32)
268    }
269
270    // ============ Classification ============
271
272    /// Returns `true` if this value is NaN.
273    #[inline]
274    pub fn is_nan(self) -> bool {
275        (self.0 & EXP_MASK) == EXP_MASK && (self.0 & MANTISSA_MASK) != 0
276    }
277
278    /// Returns `true` if this value is positive or negative infinity.
279    #[inline]
280    pub fn is_infinite(self) -> bool {
281        (self.0 & EXP_MASK) == EXP_MASK && (self.0 & MANTISSA_MASK) == 0
282    }
283
284    /// Returns `true` if this value is finite (not NaN or infinity).
285    #[inline]
286    pub fn is_finite(self) -> bool {
287        (self.0 & EXP_MASK) != EXP_MASK
288    }
289
290    /// Returns `true` if this value is a normal number (not zero, subnormal, infinite, or NaN).
291    #[inline]
292    pub fn is_normal(self) -> bool {
293        let exp = self.0 & EXP_MASK;
294        exp != 0 && exp != EXP_MASK
295    }
296
297    /// Returns `true` if this value is subnormal.
298    #[inline]
299    pub fn is_subnormal(self) -> bool {
300        (self.0 & EXP_MASK) == 0 && (self.0 & MANTISSA_MASK) != 0
301    }
302
303    /// Returns `true` if this value is ±0.
304    #[inline]
305    pub fn is_zero(self) -> bool {
306        (self.0 & !SIGN_MASK) == 0
307    }
308
309    /// Returns `true` if the sign bit is set (negative).
310    #[inline]
311    pub fn is_sign_negative(self) -> bool {
312        (self.0 & SIGN_MASK) != 0
313    }
314
315    /// Returns `true` if the sign bit is not set (positive).
316    #[inline]
317    pub fn is_sign_positive(self) -> bool {
318        (self.0 & SIGN_MASK) == 0
319    }
320
321    // ============ IEEE Value Comparison (separate from bitwise ==) ============
322
323    /// IEEE-754 numeric equality: NaN != NaN, +0 == -0.
324    #[inline]
325    pub fn eq_value(self, other: Self) -> bool {
326        if self.is_nan() || other.is_nan() {
327            false
328        } else if self.is_zero() && other.is_zero() {
329            true // +0 == -0
330        } else {
331            self.0 == other.0
332        }
333    }
334
335    /// IEEE-754 partial comparison: returns `None` if either value is NaN.
336    #[inline]
337    pub fn partial_cmp_value(self, other: Self) -> Option<Ordering> {
338        self.to_f32().partial_cmp(&other.to_f32())
339    }
340
341    /// Total ordering comparison (including NaN).
342    ///
343    /// This matches the behavior of `f32::total_cmp`.
344    #[inline]
345    pub fn total_cmp(self, other: Self) -> Ordering {
346        self.to_f32().total_cmp(&other.to_f32())
347    }
348
349    // ============ Arithmetic (explicit methods) ============
350
351    /// Add two `float16` values (via f32).
352    #[inline]
353    #[allow(clippy::should_implement_trait)]
354    pub fn add(self, rhs: Self) -> Self {
355        Self::from_f32(self.to_f32() + rhs.to_f32())
356    }
357
358    /// Subtract two `float16` values (via f32).
359    #[inline]
360    #[allow(clippy::should_implement_trait)]
361    pub fn sub(self, rhs: Self) -> Self {
362        Self::from_f32(self.to_f32() - rhs.to_f32())
363    }
364
365    /// Multiply two `float16` values (via f32).
366    #[inline]
367    #[allow(clippy::should_implement_trait)]
368    pub fn mul(self, rhs: Self) -> Self {
369        Self::from_f32(self.to_f32() * rhs.to_f32())
370    }
371
372    /// Divide two `float16` values (via f32).
373    #[inline]
374    #[allow(clippy::should_implement_trait)]
375    pub fn div(self, rhs: Self) -> Self {
376        Self::from_f32(self.to_f32() / rhs.to_f32())
377    }
378
379    /// Negate this `float16` value.
380    #[inline]
381    #[allow(clippy::should_implement_trait)]
382    pub fn neg(self) -> Self {
383        Self(self.0 ^ SIGN_MASK)
384    }
385
386    /// Absolute value.
387    #[inline]
388    pub fn abs(self) -> Self {
389        Self(self.0 & !SIGN_MASK)
390    }
391}
392
393// ============ Trait Implementations ============
394
395// Policy A: Bitwise equality and hashing (allows use in HashMap)
396impl PartialEq for float16 {
397    #[inline]
398    fn eq(&self, other: &Self) -> bool {
399        self.0 == other.0
400    }
401}
402
403impl Eq for float16 {}
404
405impl Hash for float16 {
406    #[inline]
407    fn hash<H: Hasher>(&self, state: &mut H) {
408        self.0.hash(state);
409    }
410}
411
412// IEEE partial ordering (NaN breaks total order)
413impl PartialOrd for float16 {
414    #[inline]
415    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
416        self.to_f32().partial_cmp(&other.to_f32())
417    }
418}
419
420// Arithmetic operator traits
421impl Add for float16 {
422    type Output = Self;
423    #[inline]
424    fn add(self, rhs: Self) -> Self {
425        Self::add(self, rhs)
426    }
427}
428
429impl Sub for float16 {
430    type Output = Self;
431    #[inline]
432    fn sub(self, rhs: Self) -> Self {
433        Self::sub(self, rhs)
434    }
435}
436
437impl Mul for float16 {
438    type Output = Self;
439    #[inline]
440    fn mul(self, rhs: Self) -> Self {
441        Self::mul(self, rhs)
442    }
443}
444
445impl Div for float16 {
446    type Output = Self;
447    #[inline]
448    fn div(self, rhs: Self) -> Self {
449        Self::div(self, rhs)
450    }
451}
452
453impl Neg for float16 {
454    type Output = Self;
455    #[inline]
456    fn neg(self) -> Self {
457        Self::neg(self)
458    }
459}
460
461// Display and Debug
462impl fmt::Display for float16 {
463    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464        write!(f, "{}", self.to_f32())
465    }
466}
467
468impl fmt::Debug for float16 {
469    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
470        write!(f, "float16({})", self.to_f32())
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_zero() {
480        assert_eq!(float16::ZERO.to_bits(), 0x0000);
481        assert!(float16::ZERO.is_zero());
482        assert!(!float16::ZERO.is_sign_negative());
483
484        assert_eq!(float16::NEG_ZERO.to_bits(), 0x8000);
485        assert!(float16::NEG_ZERO.is_zero());
486        assert!(float16::NEG_ZERO.is_sign_negative());
487    }
488
489    #[test]
490    fn test_infinity() {
491        assert_eq!(float16::INFINITY.to_bits(), 0x7C00);
492        assert!(float16::INFINITY.is_infinite());
493        assert!(!float16::INFINITY.is_nan());
494
495        assert_eq!(float16::NEG_INFINITY.to_bits(), 0xFC00);
496        assert!(float16::NEG_INFINITY.is_infinite());
497        assert!(float16::NEG_INFINITY.is_sign_negative());
498    }
499
500    #[test]
501    fn test_nan() {
502        assert!(float16::NAN.is_nan());
503        assert!(!float16::NAN.is_infinite());
504        assert!(!float16::NAN.is_finite());
505    }
506
507    #[test]
508    fn test_special_values_conversion() {
509        // Infinity
510        assert_eq!(float16::from_f32(f32::INFINITY), float16::INFINITY);
511        assert_eq!(float16::from_f32(f32::NEG_INFINITY), float16::NEG_INFINITY);
512        assert_eq!(float16::INFINITY.to_f32(), f32::INFINITY);
513        assert_eq!(float16::NEG_INFINITY.to_f32(), f32::NEG_INFINITY);
514
515        // Zero
516        assert_eq!(float16::from_f32(0.0), float16::ZERO);
517        assert_eq!(float16::from_f32(-0.0), float16::NEG_ZERO);
518        assert_eq!(float16::ZERO.to_f32(), 0.0);
519
520        // NaN
521        assert!(float16::from_f32(f32::NAN).is_nan());
522        assert!(float16::NAN.to_f32().is_nan());
523    }
524
525    #[test]
526    fn test_max_min_values() {
527        // Max finite value: 65504.0
528        let max_f32 = 65504.0f32;
529        assert_eq!(float16::from_f32(max_f32), float16::MAX);
530        assert_eq!(float16::MAX.to_f32(), max_f32);
531
532        // Min positive normal: 2^-14
533        let min_normal = 2.0f32.powi(-14);
534        assert_eq!(float16::from_f32(min_normal), float16::MIN_POSITIVE);
535        assert_eq!(float16::MIN_POSITIVE.to_f32(), min_normal);
536
537        // Min positive subnormal: 2^-24
538        let min_subnormal = 2.0f32.powi(-24);
539        let h = float16::from_f32(min_subnormal);
540        assert_eq!(h, float16::MIN_POSITIVE_SUBNORMAL);
541        assert!(h.is_subnormal());
542    }
543
544    #[test]
545    fn test_overflow() {
546        // Values larger than max should overflow to infinity
547        let too_large = 70000.0f32;
548        assert_eq!(float16::from_f32(too_large), float16::INFINITY);
549        assert_eq!(float16::from_f32(-too_large), float16::NEG_INFINITY);
550    }
551
552    #[test]
553    fn test_underflow() {
554        // Very small values should underflow to zero or subnormal
555        let very_small = 2.0f32.powi(-30);
556        let h = float16::from_f32(very_small);
557        assert!(h.is_zero() || h.is_subnormal());
558    }
559
560    #[test]
561    fn test_rounding() {
562        // Test round-to-nearest, ties-to-even
563        // 1.0 is exactly representable
564        let one = float16::from_f32(1.0);
565        assert_eq!(one.to_f32(), 1.0);
566
567        // 1.5 is exactly representable
568        let one_half = float16::from_f32(1.5);
569        assert_eq!(one_half.to_f32(), 1.5);
570    }
571
572    #[test]
573    fn test_arithmetic() {
574        let a = float16::from_f32(1.5);
575        let b = float16::from_f32(2.5);
576
577        assert_eq!((a + b).to_f32(), 4.0);
578        assert_eq!((b - a).to_f32(), 1.0);
579        assert_eq!((a * b).to_f32(), 3.75);
580        assert_eq!((-a).to_f32(), -1.5);
581        assert_eq!(a.abs().to_f32(), 1.5);
582        assert_eq!((-a).abs().to_f32(), 1.5);
583    }
584
585    #[test]
586    fn test_comparison() {
587        let a = float16::from_f32(1.0);
588        let b = float16::from_f32(2.0);
589        let nan = float16::NAN;
590
591        // Bitwise equality
592        assert_eq!(a, a);
593        assert_ne!(a, b);
594
595        // IEEE value equality
596        assert!(a.eq_value(a));
597        assert!(!a.eq_value(b));
598        assert!(!nan.eq_value(nan)); // NaN != NaN
599
600        // +0 == -0 in IEEE
601        assert!(float16::ZERO.eq_value(float16::NEG_ZERO));
602
603        // Partial ordering
604        assert_eq!(a.partial_cmp_value(b), Some(Ordering::Less));
605        assert_eq!(b.partial_cmp_value(a), Some(Ordering::Greater));
606        assert_eq!(a.partial_cmp_value(a), Some(Ordering::Equal));
607        assert_eq!(nan.partial_cmp_value(a), None);
608    }
609
610    #[test]
611    fn test_classification() {
612        assert!(float16::from_f32(1.0).is_normal());
613        assert!(float16::from_f32(1.0).is_finite());
614        assert!(!float16::from_f32(1.0).is_zero());
615        assert!(!float16::from_f32(1.0).is_subnormal());
616
617        assert!(float16::MIN_POSITIVE_SUBNORMAL.is_subnormal());
618        assert!(!float16::MIN_POSITIVE_SUBNORMAL.is_normal());
619    }
620}