baracuda_types/external_impls.rs
1//! Third-party type integrations.
2//!
3//! When a user enables a `*-crate` feature, baracuda automatically teaches
4//! the external crate's numeric types how to speak baracuda's trait
5//! vocabulary. The impls here are intentionally one-way: we implement
6//! baracuda traits *for* the external types, never the other way around.
7//!
8//! # `half-crate`
9//!
10//! Implements [`DeviceRepr`] and [`ValidAsZeroBits`] for `half::f16` and
11//! `half::bf16`. Both types are `#[repr(transparent)] over u16` in the
12//! `half` crate, so zero bytes represent `0.0`.
13//!
14//! # `f8-crate`
15//!
16//! Implements [`DeviceRepr`] and [`ValidAsZeroBits`] for `float8::F8E4M3`
17//! and `float8::F8E5M2`. The `float8` crate currently ships these two
18//! variants only; Fuel's wider F4/F6/F8E8M0 coverage will require either
19//! a richer upstream crate or baracuda growing its own newtypes in
20//! [`crate::numeric`].
21//!
22//! # KernelArg
23//!
24//! The [`crate::KernelArg`] blanket impl covers `&T` / `&mut T` for any
25//! `T: DeviceRepr`, so adding `DeviceRepr` here is enough to make these
26//! types usable as kernel arguments. Example:
27//!
28//! ```ignore
29//! # #[cfg(feature = "half-crate")]
30//! # fn demo() {
31//! use baracuda_types::KernelArg;
32//! let h: half::f16 = half::f16::from_f32(1.0);
33//! let _arg: *mut core::ffi::c_void = (&h).as_kernel_arg_ptr();
34//! # }
35//! ```
36
37#[cfg(feature = "half-crate")]
38mod half_impls {
39 use crate::{DeviceRepr, ValidAsZeroBits};
40
41 // SAFETY: half::f16 is #[repr(transparent)] over u16. All-zero bits
42 // are the valid IEEE 754 half-precision representation of +0.0.
43 unsafe impl DeviceRepr for half::f16 {}
44 unsafe impl ValidAsZeroBits for half::f16 {}
45
46 // SAFETY: half::bf16 is #[repr(transparent)] over u16. All-zero bits
47 // are the valid brain-float-16 representation of +0.0.
48 unsafe impl DeviceRepr for half::bf16 {}
49 unsafe impl ValidAsZeroBits for half::bf16 {}
50}
51
52#[cfg(feature = "f8-crate")]
53mod f8_impls {
54 use crate::{DeviceRepr, ValidAsZeroBits};
55
56 // SAFETY: float8::F8E4M3 is #[repr(transparent)] over u8. All-zero
57 // bits are a valid F8E4M3 value (positive zero).
58 unsafe impl DeviceRepr for float8::F8E4M3 {}
59 unsafe impl ValidAsZeroBits for float8::F8E4M3 {}
60
61 // SAFETY: float8::F8E5M2 is #[repr(transparent)] over u8. All-zero
62 // bits are a valid F8E5M2 value (positive zero).
63 unsafe impl DeviceRepr for float8::F8E5M2 {}
64 unsafe impl ValidAsZeroBits for float8::F8E5M2 {}
65}
66
67#[cfg(all(test, feature = "half-crate"))]
68mod half_tests {
69 use crate::{DeviceRepr, ValidAsZeroBits};
70
71 fn assert_device_repr<T: DeviceRepr>() {}
72 fn assert_valid_as_zero<T: ValidAsZeroBits>() {}
73
74 #[test]
75 fn half_types_implement_the_trio() {
76 assert_device_repr::<half::f16>();
77 assert_device_repr::<half::bf16>();
78 assert_valid_as_zero::<half::f16>();
79 assert_valid_as_zero::<half::bf16>();
80 }
81
82 #[test]
83 fn half_zero_bits_round_trip() {
84 // ValidAsZeroBits means all-zero bytes decode to a valid T.
85 let h: half::f16 = unsafe { core::mem::zeroed() };
86 assert_eq!(h.to_f32(), 0.0);
87 let b: half::bf16 = unsafe { core::mem::zeroed() };
88 assert_eq!(b.to_f32(), 0.0);
89 }
90}
91
92#[cfg(all(test, feature = "f8-crate"))]
93mod f8_tests {
94 use crate::{DeviceRepr, ValidAsZeroBits};
95
96 fn assert_device_repr<T: DeviceRepr>() {}
97 fn assert_valid_as_zero<T: ValidAsZeroBits>() {}
98
99 #[test]
100 fn f8_types_implement_the_trio() {
101 assert_device_repr::<float8::F8E4M3>();
102 assert_device_repr::<float8::F8E5M2>();
103 assert_valid_as_zero::<float8::F8E4M3>();
104 assert_valid_as_zero::<float8::F8E5M2>();
105 }
106
107 #[test]
108 fn f8_zero_bits_round_trip() {
109 let a: float8::F8E4M3 = unsafe { core::mem::zeroed() };
110 assert_eq!(a.to_f32(), 0.0);
111 let b: float8::F8E5M2 = unsafe { core::mem::zeroed() };
112 assert_eq!(b.to_f32(), 0.0);
113 }
114}