1use crate::{ArgLength, DataType, Error, Result};
2
3const fn f16_to_f64(bits: u16) -> f64 {
6 let bits = bits as u64;
7 let sign = (bits >> 15) & 1;
8 let exp = (bits >> 10) & 0x1f;
9 let sig = bits & 0x03ff;
10
11 let bits64 = if exp == 0 {
12 if sig == 0 {
13 sign << 63
14 } else {
15 let shift = sig.leading_zeros() - (64 - 10);
16 let sig = (sig << (shift + 1)) & 0x03ff;
17 let exp64 = 1023 - 15 - shift as u64;
18 sign << 63 | exp64 << 52 | sig << 42
19 }
20 } else if exp == 0x1f {
21 sign << 63 | 0x7ff0_0000_0000_0000 | sig << 42
22 } else {
23 let exp64 = exp + (1023 - 15);
24 sign << 63 | exp64 << 52 | sig << 42
25 };
26
27 f64::from_bits(bits64)
28}
29
30const fn f16_to_f32(bits: u16) -> f32 {
31 let bits = bits as u32;
32 let sign = (bits >> 15) & 1;
33 let exp = (bits >> 10) & 0x1f;
34 let sig = bits & 0x03ff;
35
36 let bits32 = if exp == 0 {
37 if sig == 0 {
38 sign << 31
39 } else {
40 let shift = sig.leading_zeros() - (32 - 10);
41 let sig = (sig << (shift + 1)) & 0x03ff;
42 let exp32 = 127 - 15 - shift;
43 (sign << 31) | (exp32 << 23) | (sig << 13)
44 }
45 } else if exp == 0x1f {
46 (sign << 31) | 0x7f80_0000 | (sig << 13)
47 } else {
48 let exp32 = exp + (127 - 15);
49 (sign << 31) | (exp32 << 23) | (sig << 13)
50 };
51
52 f32::from_bits(bits32)
53}
54
55const fn f64_to_f16(value: f64) -> u16 {
57 let bits = value.to_bits();
58 let sign_bit = ((bits >> 48) & 0x8000) as u16; let exp = ((bits >> 52) & 0x7ff) as i32; let sig = bits & 0x000f_ffff_ffff_ffff; match exp {
63 0 => return sign_bit,
64
65 0x7ff => {
66 if sig == 0 {
67 return sign_bit | 0x7c00;
68 } else {
69 let sig16 = (sig >> 42) as u16;
70 return sign_bit | 0x7c00 | if sig16 == 0 { 1 } else { sig16 }; }
72 }
73
74 _ => (),
75 }
76
77 let exp16 = exp - 1008;
78
79 if exp16 >= 0x1f {
80 return sign_bit | 0x7c00;
81 }
82
83 if exp16 <= 0 {
84 let full_sig = sig | 0x0010_0000_0000_0000;
85 let shift = (1 - exp16) as u64 + 42;
86
87 if shift >= 64 {
88 if shift == 64 && full_sig > (1_u64 << 52) {
89 return sign_bit | 1;
90 } else {
91 return sign_bit;
92 }
93 } else {
94 let shifted = full_sig >> shift;
95 let remainder = full_sig & ((1_u64 << shift) - 1);
96 let halfway = 1_u64 << (shift - 1);
97 let round_up = remainder > halfway || (remainder == halfway && (shifted & 1) != 0);
98 let sig16 = (shifted as u16) + round_up as u16;
99 return sign_bit | sig16;
100 }
101 }
102
103 let sig10 = (sig >> 42) as u16;
104 let remainder = sig & 0x3ff_ffff_ffff;
105 let halfway = 0x200_0000_0000_u64;
106 let round_up = remainder > halfway || (remainder == halfway && (sig10 & 1) != 0);
107 let sig16 = sig10 + round_up as u16;
108
109 if sig16 >= 0x0400 {
110 sign_bit | (((exp16 as u16) + 1) << 10)
111 } else {
112 sign_bit | ((exp16 as u16) << 10) | sig16
113 }
114}
115
116const fn f32_nan_to_f64(bits: u32) -> f64 {
118 let sign_bit = ((bits & 0x8000_0000) as u64) << 32;
119 let payload = ((bits & 0x007f_ffff) as u64) << 29;
120 f64::from_bits(sign_bit | 0x7ff0_0000_0000_0000 | payload)
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
125pub(crate) enum Inner {
126 F16(u16),
127 F32(u32),
128 F64(u64),
129}
130
131impl Inner {
132 const fn new(x: f64) -> Self {
133 if x.is_finite() {
134 let bits16 = f64_to_f16(x);
135
136 if f16_to_f64(bits16).to_bits() == x.to_bits() {
137 Inner::F16(bits16)
138 } else if ((x as f32) as f64).to_bits() == x.to_bits() {
139 Inner::F32((x as f32).to_bits())
140 } else {
141 Inner::F64(x.to_bits())
142 }
143 } else {
144 let bits64 = x.to_bits();
145 let sign_bit = bits64 & 0x8000_0000_0000_0000;
146
147 if (bits64 & 0x3ff_ffff_ffff) == 0 {
148 let bits = (bits64 >> 42) & 0x7fff | (sign_bit >> 48);
149 Self::F16(bits as u16)
150 } else if (bits64 & 0x1fff_ffff) == 0 {
151 let bits = (bits64 >> 29) & 0x7fff_ffff | (sign_bit >> 32);
152 Self::F32(bits as u32)
153 } else {
154 Self::F64(bits64)
155 }
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
166pub struct Float(pub(crate) Inner);
167
168impl Float {
169 #[must_use]
171 pub const fn data_type(&self) -> DataType {
172 match self.0 {
173 Inner::F16(_) => DataType::Float16,
174 Inner::F32(_) => DataType::Float32,
175 Inner::F64(_) => DataType::Float64,
176 }
177 }
178
179 pub(crate) const fn cbor_argument(&self) -> (u8, u64) {
180 match self.0 {
181 Inner::F16(arg) => (ArgLength::U16, arg as u64),
182 Inner::F32(arg) => (ArgLength::U32, arg as u64),
183 Inner::F64(arg) => (ArgLength::U64, arg),
184 }
185 }
186
187 pub(crate) const fn from_u16(bits: u16) -> Self {
188 Self(Inner::F16(bits))
189 }
190
191 pub(crate) const fn from_u32(bits: u32) -> Result<Self> {
192 let float = Self(Inner::F32(bits));
193 if matches!(Inner::new(float.to_f64()), Inner::F32(_)) {
194 Ok(float)
195 } else {
196 Err(Error::NonDeterministic)
197 }
198 }
199
200 pub(crate) const fn from_u64(bits: u64) -> Result<Self> {
201 let float = Self(Inner::F64(bits));
202 if matches!(Inner::new(float.to_f64()), Inner::F64(_)) {
203 Ok(float)
204 } else {
205 Err(Error::NonDeterministic)
206 }
207 }
208
209 #[must_use]
211 pub const fn to_f64(self) -> f64 {
212 match self.0 {
213 Inner::F16(bits) => f16_to_f64(bits),
214 Inner::F32(bits) => {
215 let f = f32::from_bits(bits);
216 if f.is_nan() { f32_nan_to_f64(bits) } else { f as f64 }
217 }
218 Inner::F64(bits) => f64::from_bits(bits),
219 }
220 }
221
222 pub const fn to_f32(self) -> Result<f32> {
226 match self.0 {
227 Inner::F16(bits) => Ok(f16_to_f32(bits)),
228 Inner::F32(bits) => Ok(f32::from_bits(bits)),
229 Inner::F64(_) => Err(Error::Precision),
230 }
231 }
232}
233
234impl From<f64> for Float {
237 fn from(value: f64) -> Self {
238 Self(Inner::new(value))
239 }
240}
241
242impl From<f32> for Float {
243 fn from(value: f32) -> Self {
244 if value.is_nan() {
245 Self(Inner::new(f32_nan_to_f64(value.to_bits())))
247 } else {
248 Self(Inner::new(value as f64))
249 }
250 }
251}
252
253impl From<u8> for Float {
256 fn from(value: u8) -> Self {
257 Self::from(value as f64)
258 }
259}
260
261impl From<u16> for Float {
262 fn from(value: u16) -> Self {
263 Self::from(value as f64)
264 }
265}
266
267impl From<u32> for Float {
268 fn from(value: u32) -> Self {
269 Self::from(value as f64)
270 }
271}
272
273impl From<i8> for Float {
274 fn from(value: i8) -> Self {
275 Self::from(value as f64)
276 }
277}
278
279impl From<i16> for Float {
280 fn from(value: i16) -> Self {
281 Self::from(value as f64)
282 }
283}
284
285impl From<i32> for Float {
286 fn from(value: i32) -> Self {
287 Self::from(value as f64)
288 }
289}
290
291impl From<bool> for Float {
292 fn from(value: bool) -> Self {
293 Self(if value { Inner::new(1.0) } else { Inner::new(0.0) })
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 fn f16_is_nan(bits: u16) -> bool {
302 (bits & 0x7fff) > 0x7c00
303 }
304
305 #[test]
310 fn to_f64_zero() {
311 assert_eq!(f16_to_f64(0x0000), 0.0);
312 assert!(f16_to_f64(0x0000).is_sign_positive());
313 }
314
315 #[test]
316 fn to_f64_neg_zero() {
317 let v = f16_to_f64(0x8000);
318 assert_eq!(v.to_bits(), (-0.0_f64).to_bits());
319 }
320
321 #[test]
322 fn to_f64_one() {
323 assert_eq!(f16_to_f64(0x3c00), 1.0);
324 }
325
326 #[test]
327 fn to_f64_neg_one() {
328 assert_eq!(f16_to_f64(0xbc00), -1.0);
329 }
330
331 #[test]
332 fn to_f64_max_normal() {
333 assert_eq!(f16_to_f64(0x7bff), 65504.0);
334 }
335
336 #[test]
337 fn to_f64_min_positive_normal() {
338 assert_eq!(f16_to_f64(0x0400), 0.00006103515625);
339 }
340
341 #[test]
342 fn to_f64_min_positive_subnormal() {
343 assert_eq!(f16_to_f64(0x0001), 5.960464477539063e-8);
344 }
345
346 #[test]
347 fn to_f64_max_subnormal() {
348 assert_eq!(f16_to_f64(0x03ff), 0.00006097555160522461);
349 }
350
351 #[test]
352 fn to_f64_infinity() {
353 assert_eq!(f16_to_f64(0x7c00), f64::INFINITY);
354 }
355
356 #[test]
357 fn to_f64_neg_infinity() {
358 assert_eq!(f16_to_f64(0xfc00), f64::NEG_INFINITY);
359 }
360
361 #[test]
362 fn to_f64_nan() {
363 assert!(f16_to_f64(0x7e00).is_nan());
364 }
365
366 #[test]
367 fn to_f64_nan_preserves_payload() {
368 let bits = f16_to_f64(0x7c01).to_bits();
369 assert_eq!(bits, 0x7ff0_0400_0000_0000);
370 }
371
372 #[test]
373 fn to_f64_two() {
374 assert_eq!(f16_to_f64(0x4000), 2.0);
375 }
376
377 #[test]
378 fn to_f64_one_point_five() {
379 assert_eq!(f16_to_f64(0x3e00), 1.5);
380 }
381
382 #[test]
387 fn to_f32_zero() {
388 assert_eq!(f16_to_f32(0x0000), 0.0_f32);
389 assert!(f16_to_f32(0x0000).is_sign_positive());
390 }
391
392 #[test]
393 fn to_f32_neg_zero() {
394 assert_eq!(f16_to_f32(0x8000).to_bits(), (-0.0_f32).to_bits());
395 }
396
397 #[test]
398 fn to_f32_one() {
399 assert_eq!(f16_to_f32(0x3c00), 1.0_f32);
400 }
401
402 #[test]
403 fn to_f32_neg_one() {
404 assert_eq!(f16_to_f32(0xbc00), -1.0_f32);
405 }
406
407 #[test]
408 fn to_f32_two() {
409 assert_eq!(f16_to_f32(0x4000), 2.0_f32);
410 }
411
412 #[test]
413 fn to_f32_one_point_five() {
414 assert_eq!(f16_to_f32(0x3e00), 1.5_f32);
415 }
416
417 #[test]
418 fn to_f32_max_normal() {
419 assert_eq!(f16_to_f32(0x7bff), 65504.0_f32);
420 }
421
422 #[test]
423 fn to_f32_min_positive_normal() {
424 assert_eq!(f16_to_f32(0x0400), 0.000061035156_f32);
425 }
426
427 #[test]
428 fn to_f32_min_positive_subnormal() {
429 assert_eq!(f16_to_f32(0x0001), 5.9604645e-8_f32);
430 }
431
432 #[test]
433 fn to_f32_max_subnormal() {
434 assert_eq!(f16_to_f32(0x03ff), 0.00006097555_f32);
435 }
436
437 #[test]
438 fn to_f32_infinity() {
439 assert_eq!(f16_to_f32(0x7c00), f32::INFINITY);
440 }
441
442 #[test]
443 fn to_f32_neg_infinity() {
444 assert_eq!(f16_to_f32(0xfc00), f32::NEG_INFINITY);
445 }
446
447 #[test]
448 fn to_f32_nan() {
449 assert!(f16_to_f32(0x7e00).is_nan());
450 }
451
452 #[test]
453 fn to_f32_nan_preserves_payload() {
454 let bits = f16_to_f32(0x7c01).to_bits();
455 assert_eq!(bits, 0x7f80_2000);
457 }
458
459 #[test]
460 fn to_f32_agrees_with_f16_to_f64() {
461 for bits in 0..=0x7fff_u16 {
463 if f16_is_nan(bits) {
464 continue;
465 }
466 let via_f32 = f16_to_f32(bits);
467 let via_f64 = f16_to_f64(bits) as f32;
468 assert_eq!(via_f32.to_bits(), via_f64.to_bits(), "mismatch for bits 0x{bits:04x}");
469
470 let neg = bits | 0x8000;
471 let via_f32n = f16_to_f32(neg);
472 let via_f64n = f16_to_f64(neg) as f32;
473 assert_eq!(via_f32n.to_bits(), via_f64n.to_bits(), "mismatch for bits 0x{neg:04x}");
474 }
475 }
476
477 #[test]
482 fn from_f64_zero() {
483 assert_eq!(f64_to_f16(0.0), 0x0000);
484 }
485
486 #[test]
487 fn from_f64_neg_zero() {
488 assert_eq!(f64_to_f16(-0.0), 0x8000);
489 }
490
491 #[test]
492 fn from_f64_one() {
493 assert_eq!(f64_to_f16(1.0), 0x3c00);
494 }
495
496 #[test]
497 fn from_f64_neg_one() {
498 assert_eq!(f64_to_f16(-1.0), 0xbc00);
499 }
500
501 #[test]
502 fn from_f64_max_normal() {
503 assert_eq!(f64_to_f16(65504.0), 0x7bff);
504 }
505
506 #[test]
507 fn from_f64_overflow_to_infinity() {
508 assert_eq!(f64_to_f16(65520.0), 0x7c00);
509 }
510
511 #[test]
512 fn from_f64_infinity() {
513 assert_eq!(f64_to_f16(f64::INFINITY), 0x7c00);
514 }
515
516 #[test]
517 fn from_f64_neg_infinity() {
518 assert_eq!(f64_to_f16(f64::NEG_INFINITY), 0xfc00);
519 }
520
521 #[test]
522 fn from_f64_nan() {
523 assert!(f16_is_nan(f64_to_f16(f64::NAN)));
524 }
525
526 #[test]
527 fn from_f64_min_positive_subnormal() {
528 assert_eq!(f64_to_f16(5.960464477539063e-8), 0x0001);
529 }
530
531 #[test]
532 fn from_f64_min_positive_normal() {
533 assert_eq!(f64_to_f16(0.00006103515625), 0x0400);
534 }
535
536 #[test]
541 fn rounding_exactly_halfway_rounds_to_even_down() {
542 let halfway = f64::from_bits(0x3FF0_0200_0000_0000);
543 assert_eq!(f64_to_f16(halfway), 0x3c00);
544 }
545
546 #[test]
547 fn rounding_exactly_halfway_rounds_to_even_up() {
548 let halfway = f64::from_bits(0x3FF0_0600_0000_0000);
549 assert_eq!(f64_to_f16(halfway), 0x3c02);
550 }
551
552 #[test]
553 fn rounding_just_below_halfway_rounds_down() {
554 let below = f64::from_bits(0x3FF0_01FF_FFFF_FFFF);
555 assert_eq!(f64_to_f16(below), 0x3c00);
556 }
557
558 #[test]
559 fn rounding_just_above_halfway_rounds_up() {
560 let above = f64::from_bits(0x3FF0_0200_0000_0001);
561 assert_eq!(f64_to_f16(above), 0x3c01);
562 }
563
564 #[test]
565 fn rounding_subnormal_halfway_rounds_to_even() {
566 let val = 1.5 * 5.960464477539063e-8;
567 assert_eq!(f64_to_f16(val), 0x0002);
568 }
569
570 #[test]
571 fn rounding_subnormal_halfway_even_down() {
572 let val = 2.5 * 5.960464477539063e-8;
573 assert_eq!(f64_to_f16(val), 0x0002);
574 }
575
576 #[test]
577 fn rounding_normal_to_subnormal_boundary() {
578 let min_normal = 0.00006103515625_f64;
579 assert_eq!(f64_to_f16(min_normal), 0x0400);
580
581 let below = f64::from_bits(min_normal.to_bits() - 1);
582 assert_eq!(f64_to_f16(below), 0x0400);
583 }
584
585 #[test]
586 fn rounding_overflow_at_max() {
587 assert_eq!(f64_to_f16(65504.0), 0x7bff);
588 assert_eq!(f64_to_f16(65519.99), 0x7bff);
589 assert_eq!(f64_to_f16(65520.0), 0x7c00);
590 }
591
592 #[test]
593 fn rounding_tiny_to_zero() {
594 assert_eq!(f64_to_f16(1e-30), 0x0000);
595 assert_eq!(f64_to_f16(-1e-30), 0x8000);
596 }
597
598 #[test]
599 fn rounding_tiny_to_min_subnormal() {
600 let half_min: f64 = 0.5 * 5.960464477539063e-8;
601 assert_eq!(f64_to_f16(half_min), 0x0000);
602
603 let above = f64::from_bits(half_min.to_bits() + 1);
604 assert_eq!(f64_to_f16(above), 0x0001);
605 }
606
607 #[test]
612 fn roundtrip_all_exact_f16_values() {
613 for bits in 0..=0x7fff_u16 {
614 if f16_is_nan(bits) {
615 continue;
616 }
617 let f = f16_to_f64(bits);
618 let h2 = f64_to_f16(f);
619 assert_eq!(bits, h2, "roundtrip failed for bits 0x{bits:04x}");
620
621 let neg_bits = bits | 0x8000;
623 let fn_ = f16_to_f64(neg_bits);
624 let hn2 = f64_to_f16(fn_);
625 assert_eq!(neg_bits, hn2, "roundtrip failed for bits 0x{neg_bits:04x}");
626 }
627 }
628}