provable_contracts/kernels/
f16_convert.rs1#[inline]
21pub fn f16_to_f32_single(bits: u16) -> f32 {
22 let sign = u32::from((bits >> 15) & 1);
23 let exp = u32::from((bits >> 10) & 0x1F);
24 let mant = u32::from(bits & 0x3FF);
25
26 if exp == 0 {
27 if mant == 0 {
29 return f32::from_bits(sign << 31);
30 }
31 let sign_f = if sign == 1 { -1.0f32 } else { 1.0f32 };
33 return sign_f * (mant as f32) * (2.0f32).powi(-24);
34 }
35
36 if exp == 31 {
37 if mant == 0 {
39 return f32::from_bits((sign << 31) | 0x7F80_0000);
40 }
41 return f32::from_bits((sign << 31) | 0x7F80_0000 | (mant << 13));
42 }
43
44 let f32_bits = (sign << 31) | ((exp + 112) << 23) | (mant << 13);
46 f32::from_bits(f32_bits)
47}
48
49#[inline]
53pub fn f32_to_f16_single(val: f32) -> u16 {
54 let bits = val.to_bits();
55 let sign = ((bits >> 31) & 1) as u16;
56 let exp = ((bits >> 23) & 0xFF) as i32;
57 let mant = bits & 0x007F_FFFF;
58
59 if exp == 0 {
60 return sign << 15;
62 }
63
64 if exp == 0xFF {
65 if mant == 0 {
67 return (sign << 15) | 0x7C00;
68 }
69 return (sign << 15) | 0x7C00 | ((mant >> 13) as u16 & 0x3FF).max(1);
70 }
71
72 let f16_exp = exp - 112;
74 if f16_exp <= 0 {
75 return sign << 15;
77 }
78 if f16_exp >= 31 {
79 return (sign << 15) | 0x7C00;
81 }
82
83 let f16_mant = (mant >> 13) as u16;
84 (sign << 15) | ((f16_exp as u16) << 10) | f16_mant
85}
86
87pub fn f16_to_f32_scalar(input: &[u16], output: &mut [f32]) {
92 assert_eq!(input.len(), output.len(), "dimension mismatch");
93 for (bits, out) in input.iter().zip(output.iter_mut()) {
94 *out = f16_to_f32_single(*bits);
95 }
96}
97
98pub fn f32_to_f16_scalar(input: &[f32], output: &mut [u16]) {
103 assert_eq!(input.len(), output.len(), "dimension mismatch");
104 for (val, out) in input.iter().zip(output.iter_mut()) {
105 *out = f32_to_f16_single(*val);
106 }
107}
108
109#[cfg(target_arch = "x86_64")]
118#[target_feature(enable = "avx2")]
119pub unsafe fn f16_to_f32_avx2(input: &[u16], output: &mut [f32]) {
120 f16_to_f32_scalar(input, output);
121}
122
123#[cfg(target_arch = "x86_64")]
128#[target_feature(enable = "avx2")]
129pub unsafe fn f32_to_f16_avx2(input: &[f32], output: &mut [u16]) {
130 f32_to_f16_scalar(input, output);
131}
132
133pub fn f16_convert_ptx() -> &'static str {
141 r#".version 8.5
142.target sm_90
143.address_size 64
144.visible .entry f16_to_f32_kernel(
145 .param .u64 INPUT,
146 .param .u64 OUTPUT,
147 .param .u32 N
148) {
149 .reg .u32 %tid, %bid, %n, %idx;
150 .reg .u64 %in_ptr, %out_ptr, %addr, %off64;
151 .reg .b16 %h_val;
152 .reg .f32 %f_val;
153 .reg .pred %p_bound;
154
155 mov.u32 %tid, %tid.x;
156 mov.u32 %bid, %ctaid.x;
157
158 ld.param.u32 %n, [N];
159 ld.param.u64 %in_ptr, [INPUT];
160 ld.param.u64 %out_ptr, [OUTPUT];
161
162 // Global index
163 mul.lo.u32 %idx, %bid, 256;
164 add.u32 %idx, %idx, %tid;
165
166 setp.ge.u32 %p_bound, %idx, %n;
167 @%p_bound bra EXIT;
168
169 // Load f16 value
170 mul.wide.u32 %off64, %idx, 2;
171 add.u64 %addr, %in_ptr, %off64;
172 ld.global.b16 %h_val, [%addr];
173
174 // Convert f16 to f32
175 cvt.f32.f16 %f_val, %h_val;
176
177 // Store f32 value
178 mul.wide.u32 %off64, %idx, 4;
179 add.u64 %addr, %out_ptr, %off64;
180 st.global.f32 [%addr], %f_val;
181
182EXIT:
183 ret;
184}
185"#
186}
187
188#[cfg(test)]
193mod tests {
194 use super::*;
195 use proptest::prelude::*;
196
197 #[test]
199 fn test_f16_zero() {
200 assert_eq!(f16_to_f32_single(0x0000), 0.0);
201 assert_eq!(f32_to_f16_single(0.0), 0x0000);
202 }
203
204 #[test]
206 fn test_f16_negative_zero() {
207 let neg_zero = f16_to_f32_single(0x8000);
208 assert!(neg_zero.is_sign_negative());
209 assert_eq!(neg_zero, -0.0);
210 }
211
212 #[test]
214 fn test_f16_one() {
215 let val = f16_to_f32_single(0x3C00);
217 assert!((val - 1.0).abs() < 1e-6);
218 }
219
220 #[test]
222 fn test_f16_known_values() {
223 assert!((f16_to_f32_single(0x3800) - 0.5).abs() < 1e-6);
225 assert!((f16_to_f32_single(0x4000) - 2.0).abs() < 1e-6);
227 assert!((f16_to_f32_single(0xBC00) + 1.0).abs() < 1e-6);
229 }
230
231 #[test]
233 fn test_f16_roundtrip_normal() {
234 let test_values: Vec<u16> = (0x0400..=0x7BFF).step_by(17).collect();
236 for &bits in &test_values {
237 let f32_val = f16_to_f32_single(bits);
238 let back = f32_to_f16_single(f32_val);
239 assert_eq!(
240 bits, back,
241 "roundtrip failed for 0x{bits:04X}: f32={f32_val}, back=0x{back:04X}"
242 );
243 }
244 }
245
246 #[test]
248 fn test_f16_sign_preservation() {
249 for exp in 1u16..=30 {
251 let pos = (exp << 10) | 0x100; let neg = pos | 0x8000; assert!(f16_to_f32_single(pos) > 0.0);
254 assert!(f16_to_f32_single(neg) < 0.0);
255 }
256 }
257
258 #[test]
260 fn test_f16_inf() {
261 let pos_inf = f16_to_f32_single(0x7C00);
262 assert!(pos_inf.is_infinite() && pos_inf > 0.0);
263 let neg_inf = f16_to_f32_single(0xFC00);
264 assert!(neg_inf.is_infinite() && neg_inf < 0.0);
265 }
266
267 #[test]
269 fn test_f16_nan() {
270 let nan = f16_to_f32_single(0x7C01);
271 assert!(nan.is_nan());
272 }
273
274 #[test]
276 fn test_f16_batch_conversion() {
277 let input = [0x3C00, 0x4000, 0x3800]; let mut output = [0.0f32; 3];
279 f16_to_f32_scalar(&input, &mut output);
280 assert!((output[0] - 1.0).abs() < 1e-6);
281 assert!((output[1] - 2.0).abs() < 1e-6);
282 assert!((output[2] - 0.5).abs() < 1e-6);
283 }
284
285 proptest! {
286 #[test]
287 fn prop_f16_roundtrip_normal(exp in 1u16..31, mant in 0u16..1024) {
288 let bits = (exp << 10) | mant;
289 let f32_val = f16_to_f32_single(bits);
290 let back = f32_to_f16_single(f32_val);
291 prop_assert_eq!(bits, back,
292 "roundtrip failed for exp={} mant={}: 0x{:04X} → {} → 0x{:04X}", exp, mant, bits, f32_val, back);
293 }
294
295 #[test]
296 fn prop_f16_sign_preserved(exp in 1u16..31, mant in 0u16..1024) {
297 let pos = (exp << 10) | mant;
298 let neg = pos | 0x8000;
299 let pos_f32 = f16_to_f32_single(pos);
300 let neg_f32 = f16_to_f32_single(neg);
301 prop_assert!(pos_f32 >= 0.0, "positive f16 gave negative f32");
302 prop_assert!(neg_f32 <= 0.0, "negative f16 gave positive f32");
303 }
304 }
305
306 #[test]
308 fn test_f16_ptx_structure() {
309 let ptx = f16_convert_ptx();
310 assert!(ptx.contains(".entry f16_to_f32_kernel"));
311 assert!(ptx.contains("cvt.f32.f16"));
312 assert!(ptx.contains("ret;"));
313 }
314
315 #[test]
317 fn test_f32_to_f16_edge_cases() {
318 assert_eq!(f32_to_f16_single(f32::INFINITY), 0x7C00);
320 assert_eq!(f32_to_f16_single(f32::NEG_INFINITY), 0xFC00);
322 let nan_bits = f32_to_f16_single(f32::NAN);
324 assert_eq!(nan_bits & 0x7C00, 0x7C00);
325 assert_ne!(nan_bits & 0x03FF, 0);
326 assert_eq!(f32_to_f16_single(1e-10), 0x0000);
328 assert_eq!(f32_to_f16_single(1e10), 0x7C00);
330 assert_eq!(f32_to_f16_single(f32::from_bits(0x0000_0001)), 0x0000);
332 assert_eq!(f32_to_f16_single(-0.0), 0x8000);
334 }
335
336 #[test]
338 fn test_f32_to_f16_batch() {
339 let input = [1.0f32, 2.0, 0.5, -1.0];
340 let mut output = [0u16; 4];
341 f32_to_f16_scalar(&input, &mut output);
342 assert_eq!(output[0], 0x3C00); assert_eq!(output[1], 0x4000); assert_eq!(output[2], 0x3800); assert_eq!(output[3], 0xBC00); }
347
348 #[test]
350 fn test_f16_subnormal_conversion() {
351 let val = f16_to_f32_single(0x0001);
353 assert!(val > 0.0);
354 assert!(val < 1e-5);
355 let neg_val = f16_to_f32_single(0x8001);
357 assert!(neg_val < 0.0);
358 }
359
360 #[cfg(target_arch = "x86_64")]
362 #[test]
363 fn test_f16_avx2_parity() {
364 if !is_x86_feature_detected!("avx2") {
365 return;
366 }
367 let input = [0x3C00, 0x4000, 0x3800, 0xBC00];
368 let mut scalar_out = [0.0f32; 4];
369 let mut avx2_out = [0.0f32; 4];
370 f16_to_f32_scalar(&input, &mut scalar_out);
371 unsafe { f16_to_f32_avx2(&input, &mut avx2_out) };
372 assert_eq!(scalar_out, avx2_out);
373 }
374}