Skip to main content

provable_contracts/kernels/
f16_convert.rs

1//! F16 (half-precision) conversion kernel.
2//!
3//! Matches `f16-conversion-v1.yaml`.
4//! IEEE 754 half-precision ↔ single-precision conversion via bit manipulation.
5//!
6//! Each function provides one of three backends:
7//! - `fn f16_to_f32_scalar(...)` / `fn f32_to_f16_scalar(...)` -- Pure Rust scalar
8//! - `unsafe fn f16_to_f32_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn f16_convert_ptx() -> &'static str` -- PTX assembly source string
10
11// ────────────────────────────────────────────────────────────────────────────
12// Scalar implementation
13// ────────────────────────────────────────────────────────────────────────────
14
15/// Convert a half-precision (f16) bit pattern to f32.
16///
17/// Uses the bias trick: `f32_bits = (sign << 31) | ((exp + 112) << 23) | (mant << 13)`.
18/// Only handles normal f16 values (exponent in 1..=30). Subnormals, inf, NaN are
19/// handled with fallback paths.
20#[inline]
21pub fn f16_to_f32_single(bits: u16) -> f32 {
22    let sign = u32::from((bits >> 15) & 1);
23    let exp = u32::from((bits >> 10) & 0x1F);
24    let mant = u32::from(bits & 0x3FF);
25
26    if exp == 0 {
27        // Zero or subnormal
28        if mant == 0 {
29            return f32::from_bits(sign << 31);
30        }
31        // Subnormal: convert via float arithmetic
32        let sign_f = if sign == 1 { -1.0f32 } else { 1.0f32 };
33        return sign_f * (mant as f32) * (2.0f32).powi(-24);
34    }
35
36    if exp == 31 {
37        // Inf or NaN
38        if mant == 0 {
39            return f32::from_bits((sign << 31) | 0x7F80_0000);
40        }
41        return f32::from_bits((sign << 31) | 0x7F80_0000 | (mant << 13));
42    }
43
44    // Normal: bias trick
45    let f32_bits = (sign << 31) | ((exp + 112) << 23) | (mant << 13);
46    f32::from_bits(f32_bits)
47}
48
49/// Convert an f32 value to f16 bit pattern.
50///
51/// Truncates mantissa (no rounding). Only handles normal range.
52#[inline]
53pub fn f32_to_f16_single(val: f32) -> u16 {
54    let bits = val.to_bits();
55    let sign = ((bits >> 31) & 1) as u16;
56    let exp = ((bits >> 23) & 0xFF) as i32;
57    let mant = bits & 0x007F_FFFF;
58
59    if exp == 0 {
60        // Zero or f32 subnormal → f16 zero
61        return sign << 15;
62    }
63
64    if exp == 0xFF {
65        // Inf or NaN
66        if mant == 0 {
67            return (sign << 15) | 0x7C00;
68        }
69        return (sign << 15) | 0x7C00 | ((mant >> 13) as u16 & 0x3FF).max(1);
70    }
71
72    // Normal: rebias exponent (f32 bias 127 → f16 bias 15)
73    let f16_exp = exp - 112;
74    if f16_exp <= 0 {
75        // Underflow to zero
76        return sign << 15;
77    }
78    if f16_exp >= 31 {
79        // Overflow to infinity
80        return (sign << 15) | 0x7C00;
81    }
82
83    let f16_mant = (mant >> 13) as u16;
84    (sign << 15) | ((f16_exp as u16) << 10) | f16_mant
85}
86
87/// Batch convert f16 bit patterns to f32 (scalar reference).
88///
89/// # Panics
90/// Panics if `input.len() != output.len()`.
91pub fn f16_to_f32_scalar(input: &[u16], output: &mut [f32]) {
92    assert_eq!(input.len(), output.len(), "dimension mismatch");
93    for (bits, out) in input.iter().zip(output.iter_mut()) {
94        *out = f16_to_f32_single(*bits);
95    }
96}
97
98/// Batch convert f32 to f16 bit patterns (scalar reference).
99///
100/// # Panics
101/// Panics if `input.len() != output.len()`.
102pub fn f32_to_f16_scalar(input: &[f32], output: &mut [u16]) {
103    assert_eq!(input.len(), output.len(), "dimension mismatch");
104    for (val, out) in input.iter().zip(output.iter_mut()) {
105        *out = f32_to_f16_single(*val);
106    }
107}
108
109// ────────────────────────────────────────────────────────────────────────────
110// AVX2 implementation
111// ────────────────────────────────────────────────────────────────────────────
112
113/// AVX2 f16→f32 conversion -- delegates to scalar.
114///
115/// # Safety
116/// Requires AVX2 support.
117#[cfg(target_arch = "x86_64")]
118#[target_feature(enable = "avx2")]
119pub unsafe fn f16_to_f32_avx2(input: &[u16], output: &mut [f32]) {
120    f16_to_f32_scalar(input, output);
121}
122
123/// AVX2 f32→f16 conversion -- delegates to scalar.
124///
125/// # Safety
126/// Requires AVX2 support.
127#[cfg(target_arch = "x86_64")]
128#[target_feature(enable = "avx2")]
129pub unsafe fn f32_to_f16_avx2(input: &[f32], output: &mut [u16]) {
130    f32_to_f16_scalar(input, output);
131}
132
133// ────────────────────────────────────────────────────────────────────────────
134// PTX implementation
135// ────────────────────────────────────────────────────────────────────────────
136
137/// PTX assembly for f16→f32 conversion.
138///
139/// One thread per element. Uses hardware `cvt.f32.f16` instruction.
140pub fn f16_convert_ptx() -> &'static str {
141    r#".version 8.5
142.target sm_90
143.address_size 64
144.visible .entry f16_to_f32_kernel(
145    .param .u64 INPUT,
146    .param .u64 OUTPUT,
147    .param .u32 N
148) {
149    .reg .u32 %tid, %bid, %n, %idx;
150    .reg .u64 %in_ptr, %out_ptr, %addr, %off64;
151    .reg .b16 %h_val;
152    .reg .f32 %f_val;
153    .reg .pred %p_bound;
154
155    mov.u32 %tid, %tid.x;
156    mov.u32 %bid, %ctaid.x;
157
158    ld.param.u32 %n, [N];
159    ld.param.u64 %in_ptr, [INPUT];
160    ld.param.u64 %out_ptr, [OUTPUT];
161
162    // Global index
163    mul.lo.u32 %idx, %bid, 256;
164    add.u32 %idx, %idx, %tid;
165
166    setp.ge.u32 %p_bound, %idx, %n;
167    @%p_bound bra EXIT;
168
169    // Load f16 value
170    mul.wide.u32 %off64, %idx, 2;
171    add.u64 %addr, %in_ptr, %off64;
172    ld.global.b16 %h_val, [%addr];
173
174    // Convert f16 to f32
175    cvt.f32.f16 %f_val, %h_val;
176
177    // Store f32 value
178    mul.wide.u32 %off64, %idx, 4;
179    add.u64 %addr, %out_ptr, %off64;
180    st.global.f32 [%addr], %f_val;
181
182EXIT:
183    ret;
184}
185"#
186}
187
188// ────────────────────────────────────────────────────────────────────────────
189// Tests
190// ────────────────────────────────────────────────────────────────────────────
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use proptest::prelude::*;
196
197    /// Verify f16 zero converts to f32 zero and back
198    #[test]
199    fn test_f16_zero() {
200        assert_eq!(f16_to_f32_single(0x0000), 0.0);
201        assert_eq!(f32_to_f16_single(0.0), 0x0000);
202    }
203
204    /// Verify f16 negative zero preserves sign bit through conversion
205    #[test]
206    fn test_f16_negative_zero() {
207        let neg_zero = f16_to_f32_single(0x8000);
208        assert!(neg_zero.is_sign_negative());
209        assert_eq!(neg_zero, -0.0);
210    }
211
212    /// Verify f16 bit pattern 0x3C00 converts to f32 1.0
213    #[test]
214    fn test_f16_one() {
215        // f16 1.0 = 0x3C00 (sign=0, exp=15, mant=0)
216        let val = f16_to_f32_single(0x3C00);
217        assert!((val - 1.0).abs() < 1e-6);
218    }
219
220    /// Verify f16 conversion for known values: 0.5, 2.0, and -1.0
221    #[test]
222    fn test_f16_known_values() {
223        // f16 0.5 = 0x3800
224        assert!((f16_to_f32_single(0x3800) - 0.5).abs() < 1e-6);
225        // f16 2.0 = 0x4000
226        assert!((f16_to_f32_single(0x4000) - 2.0).abs() < 1e-6);
227        // f16 -1.0 = 0xBC00
228        assert!((f16_to_f32_single(0xBC00) + 1.0).abs() < 1e-6);
229    }
230
231    /// Verify f16-to-f32-to-f16 roundtrip is lossless for sampled normal values
232    #[test]
233    fn test_f16_roundtrip_normal() {
234        // Test roundtrip for a selection of normal f16 values
235        let test_values: Vec<u16> = (0x0400..=0x7BFF).step_by(17).collect();
236        for &bits in &test_values {
237            let f32_val = f16_to_f32_single(bits);
238            let back = f32_to_f16_single(f32_val);
239            assert_eq!(
240                bits, back,
241                "roundtrip failed for 0x{bits:04X}: f32={f32_val}, back=0x{back:04X}"
242            );
243        }
244    }
245
246    /// Verify sign bit is preserved for all normal f16 exponents
247    #[test]
248    fn test_f16_sign_preservation() {
249        // For every normal f16, sign should be preserved
250        for exp in 1u16..=30 {
251            let pos = (exp << 10) | 0x100; // positive with some mantissa
252            let neg = pos | 0x8000; // same with sign bit set
253            assert!(f16_to_f32_single(pos) > 0.0);
254            assert!(f16_to_f32_single(neg) < 0.0);
255        }
256    }
257
258    /// Verify f16 positive and negative infinity convert correctly
259    #[test]
260    fn test_f16_inf() {
261        let pos_inf = f16_to_f32_single(0x7C00);
262        assert!(pos_inf.is_infinite() && pos_inf > 0.0);
263        let neg_inf = f16_to_f32_single(0xFC00);
264        assert!(neg_inf.is_infinite() && neg_inf < 0.0);
265    }
266
267    /// Verify f16 NaN bit pattern converts to f32 NaN
268    #[test]
269    fn test_f16_nan() {
270        let nan = f16_to_f32_single(0x7C01);
271        assert!(nan.is_nan());
272    }
273
274    /// Verify batch f16-to-f32 conversion for multiple known values
275    #[test]
276    fn test_f16_batch_conversion() {
277        let input = [0x3C00, 0x4000, 0x3800]; // 1.0, 2.0, 0.5
278        let mut output = [0.0f32; 3];
279        f16_to_f32_scalar(&input, &mut output);
280        assert!((output[0] - 1.0).abs() < 1e-6);
281        assert!((output[1] - 2.0).abs() < 1e-6);
282        assert!((output[2] - 0.5).abs() < 1e-6);
283    }
284
285    proptest! {
286        #[test]
287        fn prop_f16_roundtrip_normal(exp in 1u16..31, mant in 0u16..1024) {
288            let bits = (exp << 10) | mant;
289            let f32_val = f16_to_f32_single(bits);
290            let back = f32_to_f16_single(f32_val);
291            prop_assert_eq!(bits, back,
292                "roundtrip failed for exp={} mant={}: 0x{:04X} → {} → 0x{:04X}", exp, mant, bits, f32_val, back);
293        }
294
295        #[test]
296        fn prop_f16_sign_preserved(exp in 1u16..31, mant in 0u16..1024) {
297            let pos = (exp << 10) | mant;
298            let neg = pos | 0x8000;
299            let pos_f32 = f16_to_f32_single(pos);
300            let neg_f32 = f16_to_f32_single(neg);
301            prop_assert!(pos_f32 >= 0.0, "positive f16 gave negative f32");
302            prop_assert!(neg_f32 <= 0.0, "negative f16 gave positive f32");
303        }
304    }
305
306    /// Verify f16 convert PTX contains entry point and hardware cvt instruction
307    #[test]
308    fn test_f16_ptx_structure() {
309        let ptx = f16_convert_ptx();
310        assert!(ptx.contains(".entry f16_to_f32_kernel"));
311        assert!(ptx.contains("cvt.f32.f16"));
312        assert!(ptx.contains("ret;"));
313    }
314
315    /// Verify f32-to-f16 edge cases: infinity, NaN, underflow, overflow
316    #[test]
317    fn test_f32_to_f16_edge_cases() {
318        // +inf → 0x7C00
319        assert_eq!(f32_to_f16_single(f32::INFINITY), 0x7C00);
320        // -inf → 0xFC00
321        assert_eq!(f32_to_f16_single(f32::NEG_INFINITY), 0xFC00);
322        // NaN → f16 NaN (sign=0, exp=31, mantissa!=0)
323        let nan_bits = f32_to_f16_single(f32::NAN);
324        assert_eq!(nan_bits & 0x7C00, 0x7C00);
325        assert_ne!(nan_bits & 0x03FF, 0);
326        // Very small positive → underflow to zero
327        assert_eq!(f32_to_f16_single(1e-10), 0x0000);
328        // Very large positive → overflow to inf
329        assert_eq!(f32_to_f16_single(1e10), 0x7C00);
330        // f32 subnormal → f16 zero
331        assert_eq!(f32_to_f16_single(f32::from_bits(0x0000_0001)), 0x0000);
332        // -0.0 → 0x8000
333        assert_eq!(f32_to_f16_single(-0.0), 0x8000);
334    }
335
336    /// Verify batch f32-to-f16 conversion
337    #[test]
338    fn test_f32_to_f16_batch() {
339        let input = [1.0f32, 2.0, 0.5, -1.0];
340        let mut output = [0u16; 4];
341        f32_to_f16_scalar(&input, &mut output);
342        assert_eq!(output[0], 0x3C00); // 1.0
343        assert_eq!(output[1], 0x4000); // 2.0
344        assert_eq!(output[2], 0x3800); // 0.5
345        assert_eq!(output[3], 0xBC00); // -1.0
346    }
347
348    /// Verify f16 subnormal conversion
349    #[test]
350    fn test_f16_subnormal_conversion() {
351        // Smallest positive subnormal: exp=0, mant=1
352        let val = f16_to_f32_single(0x0001);
353        assert!(val > 0.0);
354        assert!(val < 1e-5);
355        // Negative subnormal
356        let neg_val = f16_to_f32_single(0x8001);
357        assert!(neg_val < 0.0);
358    }
359
360    /// Verify AVX2 f16-to-f32 conversion matches scalar output
361    #[cfg(target_arch = "x86_64")]
362    #[test]
363    fn test_f16_avx2_parity() {
364        if !is_x86_feature_detected!("avx2") {
365            return;
366        }
367        let input = [0x3C00, 0x4000, 0x3800, 0xBC00];
368        let mut scalar_out = [0.0f32; 4];
369        let mut avx2_out = [0.0f32; 4];
370        f16_to_f32_scalar(&input, &mut scalar_out);
371        unsafe { f16_to_f32_avx2(&input, &mut avx2_out) };
372        assert_eq!(scalar_out, avx2_out);
373    }
374}