cuda_rust_wasm/runtime/
half.rs1use std::fmt;
9use std::ops::{Add, Sub, Mul, Div, Neg};
10
11#[derive(Clone, Copy, PartialEq, Eq, Hash)]
16pub struct Half {
17 bits: u16,
18}
19
20impl Half {
21 pub const ZERO: Self = Self { bits: 0x0000 };
23 pub const ONE: Self = Self { bits: 0x3C00 };
25 pub const NEG_ONE: Self = Self { bits: 0xBC00 };
27 pub const INFINITY: Self = Self { bits: 0x7C00 };
29 pub const NEG_INFINITY: Self = Self { bits: 0xFC00 };
31 pub const NAN: Self = Self { bits: 0x7E00 };
33 pub const MAX: Self = Self { bits: 0x7BFF };
35 pub const MIN_POSITIVE: Self = Self { bits: 0x0400 };
37 pub const EPSILON: Self = Self { bits: 0x1400 };
39
40 pub const fn from_bits(bits: u16) -> Self {
42 Self { bits }
43 }
44
45 pub const fn to_bits(self) -> u16 {
47 self.bits
48 }
49
50 pub fn from_f32(value: f32) -> Self {
52 Self { bits: f32_to_f16(value) }
53 }
54
55 pub fn to_f32(self) -> f32 {
57 f16_to_f32(self.bits)
58 }
59
60 pub fn from_f64(value: f64) -> Self {
62 Self::from_f32(value as f32)
63 }
64
65 pub fn to_f64(self) -> f64 {
67 self.to_f32() as f64
68 }
69
70 pub fn is_nan(self) -> bool {
72 (self.bits & 0x7C00) == 0x7C00 && (self.bits & 0x03FF) != 0
73 }
74
75 pub fn is_infinite(self) -> bool {
77 (self.bits & 0x7FFF) == 0x7C00
78 }
79
80 pub fn is_finite(self) -> bool {
82 (self.bits & 0x7C00) != 0x7C00
83 }
84
85 pub fn is_normal(self) -> bool {
87 let exp = self.bits & 0x7C00;
88 exp != 0 && exp != 0x7C00
89 }
90
91 pub fn is_zero(self) -> bool {
93 (self.bits & 0x7FFF) == 0
94 }
95
96 pub fn is_sign_negative(self) -> bool {
98 (self.bits & 0x8000) != 0
99 }
100
101 pub fn abs(self) -> Self {
103 Self { bits: self.bits & 0x7FFF }
104 }
105
106 pub fn fma(a: Self, b: Self, c: Self) -> Self {
108 Self::from_f32(a.to_f32().mul_add(b.to_f32(), c.to_f32()))
109 }
110
111 pub fn sqrt(self) -> Self {
113 Self::from_f32(self.to_f32().sqrt())
114 }
115
116 pub fn recip(self) -> Self {
118 Self::from_f32(1.0 / self.to_f32())
119 }
120
121 pub fn min(self, other: Self) -> Self {
123 Self::from_f32(self.to_f32().min(other.to_f32()))
124 }
125
126 pub fn max(self, other: Self) -> Self {
128 Self::from_f32(self.to_f32().max(other.to_f32()))
129 }
130
131 pub fn clamp(self, min: Self, max: Self) -> Self {
133 Self::from_f32(self.to_f32().clamp(min.to_f32(), max.to_f32()))
134 }
135}
136
137fn f32_to_f16(value: f32) -> u16 {
141 let bits = value.to_bits();
142 let sign = ((bits >> 16) & 0x8000) as u16;
143 let exp = ((bits >> 23) & 0xFF) as i32;
144 let mantissa = bits & 0x007FFFFF;
145
146 if exp == 0xFF {
147 if mantissa == 0 {
149 return sign | 0x7C00; } else {
151 return sign | 0x7C00 | ((mantissa >> 13) as u16).max(1); }
153 }
154
155 let unbiased_exp = exp - 127;
156
157 if unbiased_exp > 15 {
158 return sign | 0x7C00;
160 }
161
162 if unbiased_exp < -24 {
163 return sign;
165 }
166
167 if unbiased_exp < -14 {
168 let shift = -1 - unbiased_exp;
170 let m = (mantissa | 0x00800000) >> (shift + 13);
171 return sign | m as u16;
172 }
173
174 let f16_exp = ((unbiased_exp + 15) as u16) << 10;
176 let f16_mantissa = (mantissa >> 13) as u16;
177 sign | f16_exp | f16_mantissa
178}
179
180fn f16_to_f32(bits: u16) -> f32 {
182 let sign = ((bits & 0x8000) as u32) << 16;
183 let exp = ((bits >> 10) & 0x1F) as u32;
184 let mantissa = (bits & 0x03FF) as u32;
185
186 if exp == 0x1F {
187 let f32_bits = sign | 0x7F800000 | (mantissa << 13);
189 return f32::from_bits(f32_bits);
190 }
191
192 if exp == 0 {
193 if mantissa == 0 {
194 return f32::from_bits(sign);
196 }
197 let mut m = mantissa;
199 let mut e: i32 = -14;
200 while (m & 0x0400) == 0 {
201 m <<= 1;
202 e -= 1;
203 }
204 m &= 0x03FF;
205 let f32_exp = ((e + 127) as u32) << 23;
206 let f32_bits = sign | f32_exp | (m << 13);
207 return f32::from_bits(f32_bits);
208 }
209
210 let f32_exp = ((exp as i32 - 15 + 127) as u32) << 23;
212 let f32_bits = sign | f32_exp | (mantissa << 13);
213 f32::from_bits(f32_bits)
214}
215
216impl Add for Half {
219 type Output = Self;
220 fn add(self, rhs: Self) -> Self {
221 Self::from_f32(self.to_f32() + rhs.to_f32())
222 }
223}
224
225impl Sub for Half {
226 type Output = Self;
227 fn sub(self, rhs: Self) -> Self {
228 Self::from_f32(self.to_f32() - rhs.to_f32())
229 }
230}
231
232impl Mul for Half {
233 type Output = Self;
234 fn mul(self, rhs: Self) -> Self {
235 Self::from_f32(self.to_f32() * rhs.to_f32())
236 }
237}
238
239impl Div for Half {
240 type Output = Self;
241 fn div(self, rhs: Self) -> Self {
242 Self::from_f32(self.to_f32() / rhs.to_f32())
243 }
244}
245
246impl Neg for Half {
247 type Output = Self;
248 fn neg(self) -> Self {
249 Self { bits: self.bits ^ 0x8000 }
250 }
251}
252
253impl PartialOrd for Half {
254 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
255 self.to_f32().partial_cmp(&other.to_f32())
256 }
257}
258
259impl Default for Half {
260 fn default() -> Self {
261 Self::ZERO
262 }
263}
264
265impl fmt::Debug for Half {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 write!(f, "Half({})", self.to_f32())
268 }
269}
270
271impl fmt::Display for Half {
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 write!(f, "{}", self.to_f32())
274 }
275}
276
277impl From<f32> for Half {
278 fn from(v: f32) -> Self {
279 Self::from_f32(v)
280 }
281}
282
283impl From<f64> for Half {
284 fn from(v: f64) -> Self {
285 Self::from_f64(v)
286 }
287}
288
289impl From<Half> for f32 {
290 fn from(v: Half) -> Self {
291 v.to_f32()
292 }
293}
294
295impl From<Half> for f64 {
296 fn from(v: Half) -> Self {
297 v.to_f64()
298 }
299}
300
301pub fn f32_to_half_slice(src: &[f32]) -> Vec<Half> {
303 src.iter().map(|&v| Half::from_f32(v)).collect()
304}
305
306pub fn half_to_f32_slice(src: &[Half]) -> Vec<f32> {
308 src.iter().map(|v| v.to_f32()).collect()
309}
310
311pub fn half_dot(a: &[Half], b: &[Half]) -> Half {
313 let acc: f32 = a.iter()
314 .zip(b.iter())
315 .map(|(x, y)| x.to_f32() * y.to_f32())
316 .sum();
317 Half::from_f32(acc)
318}
319
320pub fn half_gemv(
322 m: usize,
323 n: usize,
324 alpha: Half,
325 a: &[Half], x: &[Half], beta: Half,
328 y: &mut [Half], ) {
330 let alpha_f = alpha.to_f32();
331 let beta_f = beta.to_f32();
332
333 for i in 0..m {
334 let mut sum: f32 = 0.0;
335 for j in 0..n {
336 sum += a[i * n + j].to_f32() * x[j].to_f32();
337 }
338 let result = alpha_f * sum + beta_f * y[i].to_f32();
339 y[i] = Half::from_f32(result);
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_half_zero() {
349 assert_eq!(Half::ZERO.to_f32(), 0.0);
350 assert!(Half::ZERO.is_zero());
351 }
352
353 #[test]
354 fn test_half_one() {
355 assert_eq!(Half::ONE.to_f32(), 1.0);
356 }
357
358 #[test]
359 fn test_half_roundtrip() {
360 let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, -100.0, 0.001];
361 for &v in &values {
362 let h = Half::from_f32(v);
363 let back = h.to_f32();
364 assert!((back - v).abs() < 0.01, "Roundtrip failed for {}: got {}", v, back);
365 }
366 }
367
368 #[test]
369 fn test_half_infinity() {
370 assert!(Half::INFINITY.is_infinite());
371 assert!(!Half::INFINITY.is_finite());
372 assert!(Half::NEG_INFINITY.is_infinite());
373 }
374
375 #[test]
376 fn test_half_nan() {
377 assert!(Half::NAN.is_nan());
378 assert!(!Half::NAN.is_finite());
379 assert!(!Half::NAN.is_normal());
380 }
381
382 #[test]
383 fn test_half_arithmetic() {
384 let a = Half::from_f32(2.0);
385 let b = Half::from_f32(3.0);
386
387 assert_eq!((a + b).to_f32(), 5.0);
388 assert_eq!((b - a).to_f32(), 1.0);
389 assert_eq!((a * b).to_f32(), 6.0);
390 let div_result = (b / a).to_f32();
391 assert!((div_result - 1.5).abs() < 0.01);
392 }
393
394 #[test]
395 fn test_half_negation() {
396 let a = Half::from_f32(5.0);
397 assert_eq!((-a).to_f32(), -5.0);
398 assert_eq!((-(-a)).to_f32(), 5.0);
399 }
400
401 #[test]
402 fn test_half_comparison() {
403 let a = Half::from_f32(1.0);
404 let b = Half::from_f32(2.0);
405
406 assert!(a < b);
407 assert!(b > a);
408 assert!(a <= a);
409 assert!(a >= a);
410 }
411
412 #[test]
413 fn test_half_abs() {
414 let neg = Half::from_f32(-3.5);
415 let pos = neg.abs();
416 assert!((pos.to_f32() - 3.5).abs() < 0.01);
417 }
418
419 #[test]
420 fn test_half_fma() {
421 let a = Half::from_f32(2.0);
422 let b = Half::from_f32(3.0);
423 let c = Half::from_f32(1.0);
424
425 let result = Half::fma(a, b, c);
426 assert!((result.to_f32() - 7.0).abs() < 0.01);
427 }
428
429 #[test]
430 fn test_half_sqrt() {
431 let a = Half::from_f32(4.0);
432 assert!((a.sqrt().to_f32() - 2.0).abs() < 0.01);
433 }
434
435 #[test]
436 fn test_half_min_max() {
437 let a = Half::from_f32(1.0);
438 let b = Half::from_f32(3.0);
439
440 assert_eq!(a.min(b).to_f32(), 1.0);
441 assert_eq!(a.max(b).to_f32(), 3.0);
442 }
443
444 #[test]
445 fn test_half_clamp() {
446 let v = Half::from_f32(5.0);
447 let lo = Half::from_f32(0.0);
448 let hi = Half::from_f32(3.0);
449
450 assert_eq!(v.clamp(lo, hi).to_f32(), 3.0);
451 }
452
453 #[test]
454 fn test_half_overflow() {
455 let big = Half::from_f32(100000.0);
456 assert!(big.is_infinite());
457 }
458
459 #[test]
460 fn test_half_underflow() {
461 let tiny = Half::from_f32(1e-10);
462 assert!(tiny.is_zero() || !tiny.is_normal());
463 }
464
465 #[test]
466 fn test_f32_to_half_slice() {
467 let src = vec![1.0f32, 2.0, 3.0];
468 let halves = f32_to_half_slice(&src);
469 let back = half_to_f32_slice(&halves);
470 assert_eq!(back, src);
471 }
472
473 #[test]
474 fn test_half_dot_product() {
475 let a = f32_to_half_slice(&[1.0, 2.0, 3.0]);
476 let b = f32_to_half_slice(&[4.0, 5.0, 6.0]);
477
478 let result = half_dot(&a, &b);
479 assert!((result.to_f32() - 32.0).abs() < 0.1);
481 }
482
483 #[test]
484 fn test_half_gemv() {
485 let a = f32_to_half_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
487 let x = f32_to_half_slice(&[1.0, 1.0, 1.0]);
488 let mut y = f32_to_half_slice(&[0.0, 0.0]);
489
490 half_gemv(2, 3, Half::ONE, &a, &x, Half::ZERO, &mut y);
491
492 assert!((y[0].to_f32() - 6.0).abs() < 0.1); assert!((y[1].to_f32() - 15.0).abs() < 0.1); }
495
496 #[test]
497 fn test_half_display() {
498 let h = Half::from_f32(3.14);
499 let s = format!("{}", h);
500 assert!(s.starts_with("3.1"));
501 }
502
503 #[test]
504 fn test_half_from_f64() {
505 let h = Half::from_f64(2.5);
506 assert!((h.to_f64() - 2.5).abs() < 0.01);
507 }
508
509 #[test]
510 fn test_half_recip() {
511 let a = Half::from_f32(4.0);
512 assert!((a.recip().to_f32() - 0.25).abs() < 0.01);
513 }
514
515 #[test]
516 fn test_half_max_value() {
517 let max = Half::MAX;
518 assert!((max.to_f32() - 65504.0).abs() < 1.0);
519 assert!(max.is_finite());
520 }
521
522 #[test]
523 fn test_half_is_sign_negative() {
524 assert!(!Half::from_f32(1.0).is_sign_negative());
525 assert!(Half::from_f32(-1.0).is_sign_negative());
526 assert!(!Half::ZERO.is_sign_negative());
527 }
528}