Skip to main content

cbor_core/
float.rs

1use crate::{
2    DataType, Error, Result,
3    codec::{Argument, Head, Major},
4};
5
6// IEEE 754 half-precision conversion functions
7
8const fn f16_to_f64(bits: u16) -> f64 {
9    let bits = bits as u64;
10    let sign = (bits >> 15) & 1;
11    let exp = (bits >> 10) & 0x1f;
12    let sig = bits & 0x03ff;
13
14    let bits64 = if exp == 0 {
15        if sig == 0 {
16            sign << 63
17        } else {
18            let shift = sig.leading_zeros() - (64 - 10);
19            let sig = (sig << (shift + 1)) & 0x03ff;
20            let exp64 = 1023 - 15 - shift as u64;
21            sign << 63 | exp64 << 52 | sig << 42
22        }
23    } else if exp == 0x1f {
24        sign << 63 | 0x7ff0_0000_0000_0000 | sig << 42
25    } else {
26        let exp64 = exp + (1023 - 15);
27        sign << 63 | exp64 << 52 | sig << 42
28    };
29
30    f64::from_bits(bits64)
31}
32
33const fn f16_to_f32(bits: u16) -> f32 {
34    let bits = bits as u32;
35    let sign = (bits >> 15) & 1;
36    let exp = (bits >> 10) & 0x1f;
37    let sig = bits & 0x03ff;
38
39    let bits32 = if exp == 0 {
40        if sig == 0 {
41            sign << 31
42        } else {
43            let shift = sig.leading_zeros() - (32 - 10);
44            let sig = (sig << (shift + 1)) & 0x03ff;
45            let exp32 = 127 - 15 - shift;
46            (sign << 31) | (exp32 << 23) | (sig << 13)
47        }
48    } else if exp == 0x1f {
49        (sign << 31) | 0x7f80_0000 | (sig << 13)
50    } else {
51        let exp32 = exp + (127 - 15);
52        (sign << 31) | (exp32 << 23) | (sig << 13)
53    };
54
55    f32::from_bits(bits32)
56}
57
58/// Convert f64 to f16 with round-to-nearest-even.
59const fn f64_to_f16(value: f64) -> u16 {
60    let bits = value.to_bits();
61    let sign_bit = ((bits >> 48) & 0x8000) as u16; // 1 Bit
62    let exp = ((bits >> 52) & 0x7ff) as i32; // 11 Bits
63    let sig = bits & 0x000f_ffff_ffff_ffff; // 52 Bits
64
65    match exp {
66        0 => return sign_bit,
67
68        0x7ff => {
69            if sig == 0 {
70                return sign_bit | 0x7c00;
71            } else {
72                let sig16 = (sig >> 42) as u16;
73                return sign_bit | 0x7c00 | if sig16 == 0 { 1 } else { sig16 }; // sig16.max(1);
74            }
75        }
76
77        _ => (),
78    }
79
80    let exp16 = exp - 1008;
81
82    if exp16 >= 0x1f {
83        return sign_bit | 0x7c00;
84    }
85
86    if exp16 <= 0 {
87        let full_sig = sig | 0x0010_0000_0000_0000;
88        let shift = (1 - exp16) as u64 + 42;
89
90        if shift >= 64 {
91            if shift == 64 && full_sig > (1_u64 << 52) {
92                return sign_bit | 1;
93            } else {
94                return sign_bit;
95            }
96        } else {
97            let shifted = full_sig >> shift;
98            let remainder = full_sig & ((1_u64 << shift) - 1);
99            let halfway = 1_u64 << (shift - 1);
100            let round_up = remainder > halfway || (remainder == halfway && (shifted & 1) != 0);
101            let sig16 = (shifted as u16) + round_up as u16;
102            return sign_bit | sig16;
103        }
104    }
105
106    let sig10 = (sig >> 42) as u16;
107    let remainder = sig & 0x3ff_ffff_ffff;
108    let halfway = 0x200_0000_0000_u64;
109    let round_up = remainder > halfway || (remainder == halfway && (sig10 & 1) != 0);
110    let sig16 = sig10 + round_up as u16;
111
112    if sig16 >= 0x0400 {
113        sign_bit | (((exp16 as u16) + 1) << 10)
114    } else {
115        sign_bit | ((exp16 as u16) << 10) | sig16
116    }
117}
118
119/// Reinterpret f32 NaN bits into f64 NaN bits without hardware conversion.
120const fn f32_nan_to_f64(bits: u32) -> f64 {
121    let sign_bit = ((bits & 0x8000_0000) as u64) << 32;
122    let payload = ((bits & 0x007f_ffff) as u64) << 29;
123    f64::from_bits(sign_bit | 0x7ff0_0000_0000_0000 | payload)
124}
125
126/// f16, f32 or f64 as bits
127#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
128pub(crate) enum Inner {
129    F16(u16),
130    F32(u32),
131    F64(u64),
132}
133
134impl Inner {
135    const fn new(x: f64) -> Self {
136        if x.is_finite() {
137            let bits16 = f64_to_f16(x);
138
139            if f16_to_f64(bits16).to_bits() == x.to_bits() {
140                Inner::F16(bits16)
141            } else if ((x as f32) as f64).to_bits() == x.to_bits() {
142                Inner::F32((x as f32).to_bits())
143            } else {
144                Inner::F64(x.to_bits())
145            }
146        } else {
147            let bits64 = x.to_bits();
148            let sign_bit = bits64 & 0x8000_0000_0000_0000;
149
150            if (bits64 & 0x3ff_ffff_ffff) == 0 {
151                let bits = (bits64 >> 42) & 0x7fff | (sign_bit >> 48);
152                Self::F16(bits as u16)
153            } else if (bits64 & 0x1fff_ffff) == 0 {
154                let bits = (bits64 >> 29) & 0x7fff_ffff | (sign_bit >> 32);
155                Self::F32(bits as u32)
156            } else {
157                Self::F64(bits64)
158            }
159        }
160    }
161}
162
163/// A floating-point value stored in its shortest CBOR encoding form.
164///
165/// Internally stores the raw bits as either f16, f32, or f64,
166/// preserving NaN payloads and the exact CBOR encoding.
167/// Two `Float` values are equal iff they encode to the same CBOR bytes.
168#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
169pub struct Float(pub(crate) Inner);
170
171impl Float {
172    /// Return the [`DataType`] indicating the storage width (f16, f32, or f64).
173    #[must_use]
174    pub const fn data_type(&self) -> DataType {
175        match self.0 {
176            Inner::F16(_) => DataType::Float16,
177            Inner::F32(_) => DataType::Float32,
178            Inner::F64(_) => DataType::Float64,
179        }
180    }
181
182    pub(crate) fn cbor_head(&self) -> Head {
183        match self.0 {
184            Inner::F16(bits) => Head::new(Major::SimpleOrFloat, Argument::U16(bits)),
185            Inner::F32(bits) => Head::new(Major::SimpleOrFloat, Argument::U32(bits)),
186            Inner::F64(bits) => Head::new(Major::SimpleOrFloat, Argument::U64(bits)),
187        }
188    }
189
190    pub(crate) const fn from_u16(bits: u16) -> Self {
191        Self(Inner::F16(bits))
192    }
193
194    pub(crate) const fn from_u32(bits: u32) -> Result<Self> {
195        let float = Self(Inner::F32(bits));
196        if matches!(Inner::new(float.to_f64()), Inner::F32(_)) {
197            Ok(float)
198        } else {
199            Err(Error::NonDeterministic)
200        }
201    }
202
203    pub(crate) const fn from_u64(bits: u64) -> Result<Self> {
204        let float = Self(Inner::F64(bits));
205        if matches!(Inner::new(float.to_f64()), Inner::F64(_)) {
206            Ok(float)
207        } else {
208            Err(Error::NonDeterministic)
209        }
210    }
211
212    /// Convert to f64 (NaN payloads are preserved).
213    #[must_use]
214    pub const fn to_f64(self) -> f64 {
215        match self.0 {
216            Inner::F16(bits) => f16_to_f64(bits),
217            Inner::F32(bits) => {
218                let f = f32::from_bits(bits);
219                if f.is_nan() { f32_nan_to_f64(bits) } else { f as f64 }
220            }
221            Inner::F64(bits) => f64::from_bits(bits),
222        }
223    }
224
225    /// Convert to `f32`.
226    ///
227    /// Returns `Err(Precision)` for f64-width values.
228    pub const fn to_f32(self) -> Result<f32> {
229        match self.0 {
230            Inner::F16(bits) => Ok(f16_to_f32(bits)),
231            Inner::F32(bits) => Ok(f32::from_bits(bits)),
232            Inner::F64(_) => Err(Error::Precision),
233        }
234    }
235}
236
237// --- From floating-point types ---
238
239impl From<f64> for Float {
240    fn from(value: f64) -> Self {
241        Self(Inner::new(value))
242    }
243}
244
245impl From<f32> for Float {
246    fn from(value: f32) -> Self {
247        if value.is_nan() {
248            // NaN-safe: bit manipulation to avoid hardware canonicalization
249            Self(Inner::new(f32_nan_to_f64(value.to_bits())))
250        } else {
251            Self(Inner::new(value as f64))
252        }
253    }
254}
255
256// --- From integer types (lossless conversion to f64, like std) ---
257
258macro_rules! try_from {
259    ($type:ty) => {
260        impl From<$type> for Float {
261            fn from(value: $type) -> Self {
262                Self::from(value as f64)
263            }
264        }
265    };
266}
267
268try_from!(u8);
269try_from!(u16);
270try_from!(u32);
271
272try_from!(i8);
273try_from!(i16);
274try_from!(i32);
275
276impl From<bool> for Float {
277    fn from(value: bool) -> Self {
278        Self(if value { Inner::new(1.0) } else { Inner::new(0.0) })
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    fn f16_is_nan(bits: u16) -> bool {
287        (bits & 0x7fff) > 0x7c00
288    }
289
290    // =====================================================================
291    // f16 → f64 conversion
292    // =====================================================================
293
294    #[test]
295    fn to_f64_zero() {
296        assert_eq!(f16_to_f64(0x0000), 0.0);
297        assert!(f16_to_f64(0x0000).is_sign_positive());
298    }
299
300    #[test]
301    fn to_f64_neg_zero() {
302        let v = f16_to_f64(0x8000);
303        assert_eq!(v.to_bits(), (-0.0_f64).to_bits());
304    }
305
306    #[test]
307    fn to_f64_one() {
308        assert_eq!(f16_to_f64(0x3c00), 1.0);
309    }
310
311    #[test]
312    fn to_f64_neg_one() {
313        assert_eq!(f16_to_f64(0xbc00), -1.0);
314    }
315
316    #[test]
317    fn to_f64_max_normal() {
318        assert_eq!(f16_to_f64(0x7bff), 65504.0);
319    }
320
321    #[test]
322    fn to_f64_min_positive_normal() {
323        assert_eq!(f16_to_f64(0x0400), 0.00006103515625);
324    }
325
326    #[test]
327    fn to_f64_min_positive_subnormal() {
328        assert_eq!(f16_to_f64(0x0001), 5.960464477539063e-8);
329    }
330
331    #[test]
332    fn to_f64_max_subnormal() {
333        assert_eq!(f16_to_f64(0x03ff), 0.00006097555160522461);
334    }
335
336    #[test]
337    fn to_f64_infinity() {
338        assert_eq!(f16_to_f64(0x7c00), f64::INFINITY);
339    }
340
341    #[test]
342    fn to_f64_neg_infinity() {
343        assert_eq!(f16_to_f64(0xfc00), f64::NEG_INFINITY);
344    }
345
346    #[test]
347    fn to_f64_nan() {
348        assert!(f16_to_f64(0x7e00).is_nan());
349    }
350
351    #[test]
352    fn to_f64_nan_preserves_payload() {
353        let bits = f16_to_f64(0x7c01).to_bits();
354        assert_eq!(bits, 0x7ff0_0400_0000_0000);
355    }
356
357    #[test]
358    fn to_f64_two() {
359        assert_eq!(f16_to_f64(0x4000), 2.0);
360    }
361
362    #[test]
363    fn to_f64_one_point_five() {
364        assert_eq!(f16_to_f64(0x3e00), 1.5);
365    }
366
367    // =====================================================================
368    // f16 → f32 conversion
369    // =====================================================================
370
371    #[test]
372    fn to_f32_zero() {
373        assert_eq!(f16_to_f32(0x0000), 0.0_f32);
374        assert!(f16_to_f32(0x0000).is_sign_positive());
375    }
376
377    #[test]
378    fn to_f32_neg_zero() {
379        assert_eq!(f16_to_f32(0x8000).to_bits(), (-0.0_f32).to_bits());
380    }
381
382    #[test]
383    fn to_f32_one() {
384        assert_eq!(f16_to_f32(0x3c00), 1.0_f32);
385    }
386
387    #[test]
388    fn to_f32_neg_one() {
389        assert_eq!(f16_to_f32(0xbc00), -1.0_f32);
390    }
391
392    #[test]
393    fn to_f32_two() {
394        assert_eq!(f16_to_f32(0x4000), 2.0_f32);
395    }
396
397    #[test]
398    fn to_f32_one_point_five() {
399        assert_eq!(f16_to_f32(0x3e00), 1.5_f32);
400    }
401
402    #[test]
403    fn to_f32_max_normal() {
404        assert_eq!(f16_to_f32(0x7bff), 65504.0_f32);
405    }
406
407    #[test]
408    fn to_f32_min_positive_normal() {
409        assert_eq!(f16_to_f32(0x0400), 0.000061035156_f32);
410    }
411
412    #[test]
413    fn to_f32_min_positive_subnormal() {
414        assert_eq!(f16_to_f32(0x0001), 5.9604645e-8_f32);
415    }
416
417    #[test]
418    fn to_f32_max_subnormal() {
419        assert_eq!(f16_to_f32(0x03ff), 0.00006097555_f32);
420    }
421
422    #[test]
423    fn to_f32_infinity() {
424        assert_eq!(f16_to_f32(0x7c00), f32::INFINITY);
425    }
426
427    #[test]
428    fn to_f32_neg_infinity() {
429        assert_eq!(f16_to_f32(0xfc00), f32::NEG_INFINITY);
430    }
431
432    #[test]
433    fn to_f32_nan() {
434        assert!(f16_to_f32(0x7e00).is_nan());
435    }
436
437    #[test]
438    fn to_f32_nan_preserves_payload() {
439        let bits = f16_to_f32(0x7c01).to_bits();
440        // f16 sig bit 0 → f32 sig bit shifted left by 13
441        assert_eq!(bits, 0x7f80_2000);
442    }
443
444    #[test]
445    fn to_f32_agrees_with_f16_to_f64() {
446        // Every non-NaN f16 → f32 must equal f16 → f64 cast to f32
447        for bits in 0..=0x7fff_u16 {
448            if f16_is_nan(bits) {
449                continue;
450            }
451            let via_f32 = f16_to_f32(bits);
452            let via_f64 = f16_to_f64(bits) as f32;
453            assert_eq!(via_f32.to_bits(), via_f64.to_bits(), "mismatch for bits 0x{bits:04x}");
454
455            let neg = bits | 0x8000;
456            let via_f32n = f16_to_f32(neg);
457            let via_f64n = f16_to_f64(neg) as f32;
458            assert_eq!(via_f32n.to_bits(), via_f64n.to_bits(), "mismatch for bits 0x{neg:04x}");
459        }
460    }
461
462    // =====================================================================
463    // f64 → f16 conversion (round-to-nearest-even)
464    // =====================================================================
465
466    #[test]
467    fn from_f64_zero() {
468        assert_eq!(f64_to_f16(0.0), 0x0000);
469    }
470
471    #[test]
472    fn from_f64_neg_zero() {
473        assert_eq!(f64_to_f16(-0.0), 0x8000);
474    }
475
476    #[test]
477    fn from_f64_one() {
478        assert_eq!(f64_to_f16(1.0), 0x3c00);
479    }
480
481    #[test]
482    fn from_f64_neg_one() {
483        assert_eq!(f64_to_f16(-1.0), 0xbc00);
484    }
485
486    #[test]
487    fn from_f64_max_normal() {
488        assert_eq!(f64_to_f16(65504.0), 0x7bff);
489    }
490
491    #[test]
492    fn from_f64_overflow_to_infinity() {
493        assert_eq!(f64_to_f16(65520.0), 0x7c00);
494    }
495
496    #[test]
497    fn from_f64_infinity() {
498        assert_eq!(f64_to_f16(f64::INFINITY), 0x7c00);
499    }
500
501    #[test]
502    fn from_f64_neg_infinity() {
503        assert_eq!(f64_to_f16(f64::NEG_INFINITY), 0xfc00);
504    }
505
506    #[test]
507    fn from_f64_nan() {
508        assert!(f16_is_nan(f64_to_f16(f64::NAN)));
509    }
510
511    #[test]
512    fn from_f64_min_positive_subnormal() {
513        assert_eq!(f64_to_f16(5.960464477539063e-8), 0x0001);
514    }
515
516    #[test]
517    fn from_f64_min_positive_normal() {
518        assert_eq!(f64_to_f16(0.00006103515625), 0x0400);
519    }
520
521    // =====================================================================
522    // Round-to-nearest-even: critical boundary tests
523    // =====================================================================
524
525    #[test]
526    fn rounding_exactly_halfway_rounds_to_even_down() {
527        let halfway = f64::from_bits(0x3FF0_0200_0000_0000);
528        assert_eq!(f64_to_f16(halfway), 0x3c00);
529    }
530
531    #[test]
532    fn rounding_exactly_halfway_rounds_to_even_up() {
533        let halfway = f64::from_bits(0x3FF0_0600_0000_0000);
534        assert_eq!(f64_to_f16(halfway), 0x3c02);
535    }
536
537    #[test]
538    fn rounding_just_below_halfway_rounds_down() {
539        let below = f64::from_bits(0x3FF0_01FF_FFFF_FFFF);
540        assert_eq!(f64_to_f16(below), 0x3c00);
541    }
542
543    #[test]
544    fn rounding_just_above_halfway_rounds_up() {
545        let above = f64::from_bits(0x3FF0_0200_0000_0001);
546        assert_eq!(f64_to_f16(above), 0x3c01);
547    }
548
549    #[test]
550    fn rounding_subnormal_halfway_rounds_to_even() {
551        let val = 1.5 * 5.960464477539063e-8;
552        assert_eq!(f64_to_f16(val), 0x0002);
553    }
554
555    #[test]
556    fn rounding_subnormal_halfway_even_down() {
557        let val = 2.5 * 5.960464477539063e-8;
558        assert_eq!(f64_to_f16(val), 0x0002);
559    }
560
561    #[test]
562    fn rounding_normal_to_subnormal_boundary() {
563        let min_normal = 0.00006103515625_f64;
564        assert_eq!(f64_to_f16(min_normal), 0x0400);
565
566        let below = f64::from_bits(min_normal.to_bits() - 1);
567        assert_eq!(f64_to_f16(below), 0x0400);
568    }
569
570    #[test]
571    fn rounding_overflow_at_max() {
572        assert_eq!(f64_to_f16(65504.0), 0x7bff);
573        assert_eq!(f64_to_f16(65519.99), 0x7bff);
574        assert_eq!(f64_to_f16(65520.0), 0x7c00);
575    }
576
577    #[test]
578    fn rounding_tiny_to_zero() {
579        assert_eq!(f64_to_f16(1e-30), 0x0000);
580        assert_eq!(f64_to_f16(-1e-30), 0x8000);
581    }
582
583    #[test]
584    fn rounding_tiny_to_min_subnormal() {
585        let half_min: f64 = 0.5 * 5.960464477539063e-8;
586        assert_eq!(f64_to_f16(half_min), 0x0000);
587
588        let above = f64::from_bits(half_min.to_bits() + 1);
589        assert_eq!(f64_to_f16(above), 0x0001);
590    }
591
592    // =====================================================================
593    // Roundtrip: f64 → f16 → f64
594    // =====================================================================
595
596    #[test]
597    fn roundtrip_all_exact_f16_values() {
598        for bits in 0..=0x7fff_u16 {
599            if f16_is_nan(bits) {
600                continue;
601            }
602            let f = f16_to_f64(bits);
603            let h2 = f64_to_f16(f);
604            assert_eq!(bits, h2, "roundtrip failed for bits 0x{bits:04x}");
605
606            // Also check negative
607            let neg_bits = bits | 0x8000;
608            let fn_ = f16_to_f64(neg_bits);
609            let hn2 = f64_to_f16(fn_);
610            assert_eq!(neg_bits, hn2, "roundtrip failed for bits 0x{neg_bits:04x}");
611        }
612    }
613}