Skip to main content

cubecl_cpp/cuda/
convert.rs

1//! Cuda conversion functions
2#![allow(unused)]
3
4use core::fmt;
5
6use crate::{
7    Dialect,
8    shared::{Component, Elem, FP8Kind, FmtLeft, Instruction, Item, UnaryInstruction, Variable},
9};
10
11/// special cast function for recursive conversion in the case of minifloat to minifloat conversion
12///
13/// Needs to jump through a lot of hoops to deal with CUDA nonsense.
14/// The overview of available conversions is as follows:
15///
16/// | From                     | To             | Extra args                 |
17/// | ------------------------ | -------------- | -------------------------- |
18/// | f16/bf16/f32/f64         | e4m3/e5m2      | Interpretation, saturation |
19/// | f16/bf16/f32/f64         | e3m2/e2m3/e2m1 | Interpretation, rounding   |
20/// | bf16/f32/f64             | e8m0           | saturation, rounding       |
21/// | e4m3/e5m2/e3m2/e2m3/e2m1 | f16            | Interpretation,            |
22/// | e8m0                     | bf16           |                            |
23///
24/// When the input and output don't match these options, we need to do a two-step conversion.
25/// When the input is a minifloat we always need to cast out to `f16`/`bf16`, and then convert to
26/// the actual out type if it differs. Trying to cast ints also requires an extra conversion, and
27/// so does `f16` to `e8m0` (though it's not recommended to do that anyways, you should be using
28/// `e5m2` for that since you don't have 8 bits of exponent in f16).
29///
30/// See also:
31/// <https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP8__MISC.html>
32/// <https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP6__MISC.html>
33/// <https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP4__MISC.html>
34pub(crate) fn special_cast<D: Dialect>(
35    f: &mut std::fmt::Formatter,
36    input: &Variable<D>,
37    out: &Variable<D>,
38) -> fmt::Result {
39    let mut current_in = *input;
40
41    if matches!(
42        input.elem().unpacked(),
43        Elem::FP4(_) | Elem::FP6(_) | Elem::FP8(_)
44    ) {
45        let mut item = out.item();
46        item.elem = match input.elem().unpacked() {
47            Elem::FP8(FP8Kind::UE8M0) => Elem::BF16,
48            _ => Elem::F16,
49        };
50        let out_var = if item == out.item() {
51            *out
52        } else {
53            Variable::tmp(item)
54        };
55        if item.elem == Elem::F16 {
56            cast_minifloat_to_half(f, current_in, out_var)?;
57        } else {
58            cast_scale_to_bfloat(f, current_in, out_var)?;
59        }
60        current_in = out_var;
61    }
62
63    // Broadcast scalars to packing factor
64    if out.item().packing_factor() > 1 && input.item().vectorization == 1 {
65        let tmp = Variable::tmp(Item {
66            elem: input.item().elem,
67            vectorization: out.item().packing_factor(),
68            native: input.item().native,
69        });
70        let assign = Instruction::Assign(UnaryInstruction {
71            input: current_in,
72            out: tmp,
73        });
74        writeln!(f, "{assign}")?;
75        current_in = tmp;
76    }
77
78    if matches!(
79        current_in.elem(),
80        Elem::U8
81            | Elem::U16
82            | Elem::U32
83            | Elem::U64
84            | Elem::I8
85            | Elem::I16
86            | Elem::I32
87            | Elem::I64
88            | Elem::Bool
89    ) {
90        // Precision is irrelevant for int, so use bf16 for the range
91        let tmp = Variable::tmp(Item {
92            elem: Elem::BF16,
93            vectorization: current_in.item().vectorization,
94            native: current_in.item().native,
95        });
96        let assign = Instruction::Assign(UnaryInstruction {
97            input: current_in,
98            out: tmp,
99        });
100        writeln!(f, "{assign}")?;
101        current_in = tmp;
102    }
103
104    if matches!(out.elem().unpacked(), Elem::FP4(_) | Elem::FP6(_)) {
105        return cast_to_fp4_fp6(f, current_in, *out);
106    }
107
108    if matches!(out.elem().unpacked(), Elem::FP8(FP8Kind::UE8M0)) {
109        // Scale can't be converted from half...
110        if matches!(current_in.elem(), Elem::F16) {
111            let mut item = current_in.item();
112            item.elem = Elem::BF16;
113            let tmp = Variable::tmp(item);
114            let assign = Instruction::Assign(UnaryInstruction {
115                input: current_in,
116                out: tmp,
117            });
118            writeln!(f, "{assign}")?;
119            current_in = tmp;
120        }
121        return cast_to_scale(f, current_in, *out);
122    }
123
124    if matches!(out.elem().unpacked(), Elem::FP8(_)) {
125        return cast_to_fp8(f, current_in, *out);
126    }
127
128    if current_in.item() != out.item() {
129        let assign = Instruction::Assign(UnaryInstruction {
130            input: current_in,
131            out: *out,
132        });
133        writeln!(f, "{assign}")?;
134    }
135
136    Ok(())
137}
138
139/// Convert any float to fp4/fp6, with round to nearest
140fn cast_to_fp4_fp6<D: Dialect>(
141    f: &mut fmt::Formatter,
142    input: Variable<D>,
143    out: Variable<D>,
144) -> fmt::Result {
145    let out_opt = out.optimized();
146    let packing = out_opt.item().packing_factor();
147    let packed = packing == 2;
148    let pack_suffix = if packed { "2" } else { "" };
149
150    let (out_ty, interpretation) = match out_opt.elem() {
151        Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
152        Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
153        Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
154        Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
155        _ => unreachable!("Must be fp4 or fp6"),
156    };
157
158    let in_ty = match input.elem().unpacked() {
159        Elem::F64 => format!("double{pack_suffix}"),
160        Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
161        Elem::F16 => format!("halfraw{pack_suffix}"),
162        Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
163        _ => unreachable!(),
164    };
165
166    let input = input.optimized();
167
168    handle_unroll(f, out, |f, i| {
169        let in_value = float_to_packed(input, i, packing);
170
171        write!(
172            f,
173            "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_{interpretation}, cudaRoundNearest)",
174        )
175    })
176}
177
178/// Convert any float except f16 to e8m0
179fn cast_to_scale<D: Dialect>(
180    f: &mut fmt::Formatter,
181    input: Variable<D>,
182    out: Variable<D>,
183) -> fmt::Result {
184    let out_opt = out.optimized();
185    let packing = out_opt.item().packing_factor();
186    let packed = packing > 1;
187    let pack_suffix = if packed { "2" } else { "" };
188
189    let out_ty = match out_opt.elem() {
190        Elem::FP8(_) => "e8m0",
191        Elem::FP8x2(_) => "e8m0x2",
192        _ => unreachable!("Must be scale factor"),
193    };
194
195    let in_ty = match input.elem() {
196        Elem::F64 => format!("double{pack_suffix}"),
197        Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
198        Elem::BF16 => format!("bfloat16{pack_suffix}raw"),
199        _ => unreachable!(),
200    };
201
202    let input = input.optimized();
203
204    handle_unroll(f, out, |f, i| {
205        let in_value = float_to_packed(input, i, packing);
206
207        write!(
208            f,
209            "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, cudaRoundPosInf)",
210        )
211    })
212}
213
214/// Convert any float to fp8 (except e8m0)
215fn cast_to_fp8<D: Dialect>(
216    f: &mut fmt::Formatter,
217    input: Variable<D>,
218    out: Variable<D>,
219) -> fmt::Result {
220    let out_opt = out.optimized();
221    let packing = out_opt.item().packing_factor();
222    let packed = packing > 1;
223    let pack_suffix = if packed { "2" } else { "" };
224
225    let (out_ty, interpretation) = match out_opt.elem() {
226        Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
227        Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
228        _ => unreachable!("Must be fp8"),
229    };
230
231    let in_ty = match input.elem() {
232        Elem::F64 => format!("double{pack_suffix}"),
233        Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
234        Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
235        Elem::F16 => format!("halfraw{pack_suffix}"),
236        _ => unreachable!(),
237    };
238
239    let input = input.optimized();
240
241    handle_unroll(f, out, |f, i| {
242        let in_value = float_to_packed(input, i, packing);
243
244        write!(
245            f,
246            "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, __NV_{interpretation})",
247        )
248    })
249}
250
251/// Pack types that normally wouldn't be optimized into a `vec2` for conversion
252fn float_to_packed<D: Dialect>(input: Variable<D>, i: usize, packing: usize) -> String {
253    match input.elem() {
254        Elem::TF32 | Elem::F32 => {
255            let i = i * packing;
256            if packing > 1 {
257                format!("float2 {{ {}, {} }}", input.index(i), input.index(i + 1))
258            } else {
259                format!("{}", input.index(i))
260            }
261        }
262        Elem::F64 => {
263            let i = i * packing;
264            if packing > 1 {
265                format!("double2 {{ {}, {} }}", input.index(i), input.index(i + 1))
266            } else {
267                format!("{}", input.index(i))
268            }
269        }
270        Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => format!("{}", input.index(i)),
271        _ => unreachable!(),
272    }
273}
274
275/// Convert any FP8/6/4 except e8m0 to half
276fn cast_minifloat_to_half<D: Dialect>(
277    f: &mut fmt::Formatter,
278    input: Variable<D>,
279    out: Variable<D>,
280) -> fmt::Result {
281    let in_opt = input.optimized();
282    let out_opt = out.optimized().item();
283
284    let (in_ty, interpretation) = match in_opt.elem() {
285        Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
286        Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
287        Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
288        Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
289        Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
290        Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
291        _ => unreachable!("can only cast minifloat"),
292    };
293
294    let out_ty = match out_opt.elem() {
295        Elem::F16 => "halfraw",
296        Elem::F16x2 => "halfraw2",
297        _ => unreachable!("out type must be half"),
298    };
299
300    handle_unroll(f, out, |f, i| {
301        let input = in_opt.index(i);
302        write!(
303            f,
304            "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}, __NV_{interpretation}))",
305            out_opt.elem()
306        )
307    })
308}
309
310/// Convert an e8m0 scaling factor to bf16
311fn cast_scale_to_bfloat<D: Dialect>(
312    f: &mut fmt::Formatter,
313    input: Variable<D>,
314    out: Variable<D>,
315) -> fmt::Result {
316    let in_opt = input.optimized();
317    let out_opt = out.optimized().item();
318
319    let in_ty = match in_opt.elem() {
320        Elem::FP8(_) => "e8m0",
321        Elem::FP8x2(_) => "e8m0x2",
322        _ => unreachable!("must be scaling factor in e8m0 format"),
323    };
324
325    let out_ty = match out_opt.elem() {
326        Elem::BF16 => "bf16raw",
327        Elem::BF16x2 => "bf162raw",
328        _ => unreachable!("out type must be half"),
329    };
330
331    handle_unroll(f, out, |f, i| {
332        let input = in_opt.index(i);
333        write!(
334            f,
335            "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}))",
336            out_opt.elem()
337        )
338    })
339}
340
341fn handle_unroll<D: Dialect>(
342    f: &mut fmt::Formatter,
343    out: Variable<D>,
344    mut op: impl FnMut(&mut fmt::Formatter, usize) -> fmt::Result,
345) -> fmt::Result {
346    let out_opt = out.item().optimized();
347    let vec = out_opt.vectorization;
348    let out_var = if out.item() != out_opt {
349        Variable::tmp(out_opt)
350    } else {
351        out
352    };
353    write!(f, "{} = ", out_var.fmt_left())?;
354    if vec > 1 {
355        writeln!(f, "{out_opt} {{")?;
356    }
357    for i in 0..vec {
358        op(f, i)?;
359        if i + 1 < vec {
360            f.write_str(",\n")?;
361        }
362    }
363    if vec > 1 {
364        write!(f, "\n}}")?;
365    }
366    f.write_str(";\n")?;
367
368    if out.item() != out_opt {
369        writeln!(
370            f,
371            "{} = reinterpret_cast<{}&>({out_var});",
372            out.fmt_left(),
373            out.item()
374        )?;
375    }
376    Ok(())
377}