baracuda_types/
numeric.rs1use core::cmp::Ordering;
14use core::fmt;
15
16#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
20#[repr(transparent)]
21pub struct Half(pub u16);
22
23impl Half {
24 pub const ZERO: Self = Self(0x0000);
25 pub const NEG_ZERO: Self = Self(0x8000);
26 pub const ONE: Self = Self(0x3C00);
27 pub const NEG_ONE: Self = Self(0xBC00);
28 pub const INFINITY: Self = Self(0x7C00);
29 pub const NEG_INFINITY: Self = Self(0xFC00);
30 pub const NAN: Self = Self(0x7E00);
31 pub const MIN_POSITIVE: Self = Self(0x0400); pub const MAX: Self = Self(0x7BFF);
33 pub const MIN: Self = Self(0xFBFF);
34 pub const EPSILON: Self = Self(0x1400); #[inline]
37 pub const fn from_bits(bits: u16) -> Self {
38 Self(bits)
39 }
40
41 #[inline]
42 pub const fn to_bits(self) -> u16 {
43 self.0
44 }
45
46 #[inline]
47 pub const fn is_nan(self) -> bool {
48 (self.0 & 0x7FFF) > 0x7C00
49 }
50
51 #[inline]
52 pub const fn is_infinite(self) -> bool {
53 (self.0 & 0x7FFF) == 0x7C00
54 }
55
56 #[inline]
57 pub const fn is_finite(self) -> bool {
58 (self.0 & 0x7C00) != 0x7C00
59 }
60
61 #[inline]
62 pub const fn is_sign_negative(self) -> bool {
63 (self.0 & 0x8000) != 0
64 }
65
66 pub fn from_f32(f: f32) -> Self {
68 let bits = f.to_bits();
69 let sign = ((bits >> 16) & 0x8000) as u16;
70 let exp_raw = ((bits >> 23) & 0xFF) as i32;
71 let mant = bits & 0x007F_FFFF;
72
73 if exp_raw == 0xFF {
75 if mant != 0 {
76 return Self(sign | 0x7E00 | ((mant >> 13) as u16));
79 }
80 return Self(sign | 0x7C00);
81 }
82
83 let e_unbiased = exp_raw - 127; let e_half = e_unbiased + 15;
85
86 if e_half >= 0x1F {
87 return Self(sign | 0x7C00);
89 }
90
91 if e_half >= 1 {
92 let trunc = (mant >> 13) as u16;
94 let guard = (mant >> 12) & 1;
95 let sticky = mant & 0x0FFF;
96 let lsb = trunc & 1;
97 let round_up = guard == 1 && (sticky != 0 || lsb == 1);
98 let base = sign | ((e_half as u16) << 10) | trunc;
99 let half = base.wrapping_add(round_up as u16);
100 return Self(half);
101 }
102
103 if e_unbiased < -24 {
105 return Self(sign);
107 }
108
109 let mant_full = mant | 0x0080_0000; let shift = (-14 - e_unbiased) as u32 + 13; let trunc = (mant_full >> shift) as u16;
114 let guard = (mant_full >> (shift - 1)) & 1;
115 let sticky_mask = (1u32 << (shift - 1)) - 1;
116 let sticky = mant_full & sticky_mask;
117 let lsb = trunc & 1;
118 let round_up = guard == 1 && (sticky != 0 || lsb == 1);
119 let half = sign | trunc.wrapping_add(round_up as u16);
120 Self(half)
121 }
122
123 pub fn to_f32(self) -> f32 {
125 let h = self.0 as u32;
126 let sign = (h & 0x8000) << 16;
127 let exp = (h >> 10) & 0x1F;
128 let mant = h & 0x03FF;
129
130 let bits = if exp == 0 {
131 if mant == 0 {
132 sign
133 } else {
134 let mut m = mant;
136 let mut e: i32 = 1;
137 while (m & 0x0400) == 0 {
138 m <<= 1;
139 e -= 1;
140 }
141 m &= 0x03FF;
142 let exp_f32 = (e + 127 - 15) as u32;
143 sign | (exp_f32 << 23) | (m << 13)
144 }
145 } else if exp == 0x1F {
146 sign | 0x7F80_0000 | (mant << 13)
147 } else {
148 let exp_f32 = exp + 127 - 15;
149 sign | (exp_f32 << 23) | (mant << 13)
150 };
151
152 f32::from_bits(bits)
153 }
154}
155
156impl fmt::Debug for Half {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 write!(f, "Half({})", self.to_f32())
159 }
160}
161
162impl fmt::Display for Half {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 fmt::Display::fmt(&self.to_f32(), f)
165 }
166}
167
168impl PartialOrd for Half {
169 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
170 self.to_f32().partial_cmp(&other.to_f32())
171 }
172}
173
174impl From<Half> for f32 {
175 #[inline]
176 fn from(h: Half) -> f32 {
177 h.to_f32()
178 }
179}
180
181impl From<Half> for f64 {
182 #[inline]
183 fn from(h: Half) -> f64 {
184 h.to_f32() as f64
185 }
186}
187
188impl From<f32> for Half {
189 #[inline]
190 fn from(f: f32) -> Self {
191 Self::from_f32(f)
192 }
193}
194
195#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
197#[repr(transparent)]
198pub struct BFloat16(pub u16);
199
200impl BFloat16 {
201 pub const ZERO: Self = Self(0x0000);
202 pub const NEG_ZERO: Self = Self(0x8000);
203 pub const ONE: Self = Self(0x3F80);
204 pub const NEG_ONE: Self = Self(0xBF80);
205 pub const INFINITY: Self = Self(0x7F80);
206 pub const NEG_INFINITY: Self = Self(0xFF80);
207 pub const NAN: Self = Self(0x7FC0);
208 pub const MIN_POSITIVE: Self = Self(0x0080);
209 pub const MAX: Self = Self(0x7F7F);
210 pub const MIN: Self = Self(0xFF7F);
211 pub const EPSILON: Self = Self(0x3C00);
212
213 #[inline]
214 pub const fn from_bits(bits: u16) -> Self {
215 Self(bits)
216 }
217
218 #[inline]
219 pub const fn to_bits(self) -> u16 {
220 self.0
221 }
222
223 #[inline]
224 pub const fn is_nan(self) -> bool {
225 (self.0 & 0x7FFF) > 0x7F80
226 }
227
228 #[inline]
229 pub const fn is_infinite(self) -> bool {
230 (self.0 & 0x7FFF) == 0x7F80
231 }
232
233 #[inline]
234 pub const fn is_sign_negative(self) -> bool {
235 (self.0 & 0x8000) != 0
236 }
237
238 pub fn from_f32(f: f32) -> Self {
240 if f.is_nan() {
241 return Self(0x7FC0);
242 }
243 let bits = f.to_bits();
244 let lsb = (bits >> 16) & 1;
245 let rounding_bias = 0x7FFF + lsb;
247 let rounded = bits.wrapping_add(rounding_bias);
248 Self((rounded >> 16) as u16)
249 }
250
251 #[inline]
252 pub fn to_f32(self) -> f32 {
253 f32::from_bits((self.0 as u32) << 16)
254 }
255}
256
257impl fmt::Debug for BFloat16 {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 write!(f, "BFloat16({})", self.to_f32())
260 }
261}
262
263impl fmt::Display for BFloat16 {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 fmt::Display::fmt(&self.to_f32(), f)
266 }
267}
268
269impl PartialOrd for BFloat16 {
270 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271 self.to_f32().partial_cmp(&other.to_f32())
272 }
273}
274
275impl From<BFloat16> for f32 {
276 #[inline]
277 fn from(b: BFloat16) -> f32 {
278 b.to_f32()
279 }
280}
281
282impl From<f32> for BFloat16 {
283 #[inline]
284 fn from(f: f32) -> Self {
285 Self::from_f32(f)
286 }
287}
288
289#[derive(Copy, Clone, Debug, Default, PartialEq)]
291#[repr(C)]
292pub struct Complex32 {
293 pub re: f32,
294 pub im: f32,
295}
296
297impl Complex32 {
298 pub const ZERO: Self = Self { re: 0.0, im: 0.0 };
300 pub const ONE: Self = Self { re: 1.0, im: 0.0 };
302 pub const I: Self = Self { re: 0.0, im: 1.0 };
304
305 #[inline]
307 pub const fn new(re: f32, im: f32) -> Self {
308 Self { re, im }
309 }
310
311 #[inline]
313 pub fn norm_sqr(self) -> f32 {
314 self.re * self.re + self.im * self.im
315 }
316
317 #[inline]
319 pub fn conj(self) -> Self {
320 Self {
321 re: self.re,
322 im: -self.im,
323 }
324 }
325}
326
327#[derive(Copy, Clone, Debug, Default, PartialEq)]
329#[repr(C)]
330pub struct Complex64 {
331 pub re: f64,
332 pub im: f64,
333}
334
335impl Complex64 {
336 pub const ZERO: Self = Self { re: 0.0, im: 0.0 };
338 pub const ONE: Self = Self { re: 1.0, im: 0.0 };
340 pub const I: Self = Self { re: 0.0, im: 1.0 };
342
343 #[inline]
345 pub const fn new(re: f64, im: f64) -> Self {
346 Self { re, im }
347 }
348
349 #[inline]
350 pub fn norm_sqr(self) -> f64 {
351 self.re * self.re + self.im * self.im
352 }
353
354 #[inline]
355 pub fn conj(self) -> Self {
356 Self {
357 re: self.re,
358 im: -self.im,
359 }
360 }
361}
362
363#[cfg(feature = "half-crate")]
364mod half_adapters {
365 use super::{BFloat16, Half};
366
367 impl From<half::f16> for Half {
368 #[inline]
369 fn from(v: half::f16) -> Self {
370 Self(v.to_bits())
371 }
372 }
373
374 impl From<Half> for half::f16 {
375 #[inline]
376 fn from(v: Half) -> Self {
377 half::f16::from_bits(v.0)
378 }
379 }
380
381 impl From<half::bf16> for BFloat16 {
382 #[inline]
383 fn from(v: half::bf16) -> Self {
384 Self(v.to_bits())
385 }
386 }
387
388 impl From<BFloat16> for half::bf16 {
389 #[inline]
390 fn from(v: BFloat16) -> Self {
391 half::bf16::from_bits(v.0)
392 }
393 }
394}
395
396#[cfg(feature = "num-complex-crate")]
397mod num_complex_adapters {
398 use super::{Complex32, Complex64};
399
400 impl From<num_complex::Complex<f32>> for Complex32 {
401 #[inline]
402 fn from(v: num_complex::Complex<f32>) -> Self {
403 Self { re: v.re, im: v.im }
404 }
405 }
406
407 impl From<Complex32> for num_complex::Complex<f32> {
408 #[inline]
409 fn from(v: Complex32) -> Self {
410 Self::new(v.re, v.im)
411 }
412 }
413
414 impl From<num_complex::Complex<f64>> for Complex64 {
415 #[inline]
416 fn from(v: num_complex::Complex<f64>) -> Self {
417 Self { re: v.re, im: v.im }
418 }
419 }
420
421 impl From<Complex64> for num_complex::Complex<f64> {
422 #[inline]
423 fn from(v: Complex64) -> Self {
424 Self::new(v.re, v.im)
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn half_constants_roundtrip() {
435 assert_eq!(Half::ZERO.to_f32(), 0.0);
436 assert_eq!(Half::ONE.to_f32(), 1.0);
437 assert_eq!(Half::NEG_ONE.to_f32(), -1.0);
438 assert!(Half::INFINITY.to_f32().is_infinite());
439 assert!(Half::NEG_INFINITY.to_f32().is_infinite());
440 assert!(Half::NAN.to_f32().is_nan());
441 }
442
443 #[test]
444 fn half_roundtrip_exact_values() {
445 for v in [0.0f32, 1.0, -1.0, 0.5, -0.5, 2.0, 65504.0, -65504.0, 1e-4] {
446 let h = Half::from_f32(v);
447 let back = h.to_f32();
448 assert!(
449 (back - v).abs() < (v.abs() * 1e-3 + 1e-7),
450 "{v} -> {back} (half bits = {:#06x})",
451 h.to_bits()
452 );
453 }
454 }
455
456 #[test]
457 fn half_overflow_to_infinity() {
458 assert_eq!(Half::from_f32(1e30).to_bits(), Half::INFINITY.to_bits());
459 assert_eq!(
460 Half::from_f32(-1e30).to_bits(),
461 Half::NEG_INFINITY.to_bits()
462 );
463 }
464
465 #[test]
466 fn half_underflow_to_zero() {
467 assert_eq!(Half::from_f32(1e-30).to_bits(), 0);
468 assert_eq!(Half::from_f32(-1e-30).to_bits(), 0x8000);
469 }
470
471 #[test]
472 fn half_nan_stays_nan() {
473 assert!(Half::from_f32(f32::NAN).is_nan());
474 }
475
476 #[test]
477 fn bfloat_constants_roundtrip() {
478 assert_eq!(BFloat16::ZERO.to_f32(), 0.0);
479 assert_eq!(BFloat16::ONE.to_f32(), 1.0);
480 assert_eq!(BFloat16::NEG_ONE.to_f32(), -1.0);
481 assert!(BFloat16::INFINITY.to_f32().is_infinite());
482 assert!(BFloat16::NAN.to_f32().is_nan());
483 }
484
485 #[test]
486 fn bfloat_truncates_top_16_bits() {
487 let v: f32 = 1.5; let b = BFloat16::from_f32(v);
490 assert_eq!(b.to_bits(), 0x3FC0);
491 assert_eq!(b.to_f32(), 1.5);
492 }
493
494 #[test]
495 fn bfloat_nan_stays_nan() {
496 assert!(BFloat16::from_f32(f32::NAN).is_nan());
497 }
498
499 #[test]
500 fn complex_layout_is_two_floats() {
501 use core::mem::{align_of, size_of};
502 assert_eq!(size_of::<Complex32>(), 8);
503 assert_eq!(size_of::<Complex64>(), 16);
504 assert!(align_of::<Complex32>() >= align_of::<f32>());
505 assert!(align_of::<Complex64>() >= align_of::<f64>());
506 }
507}