cubecl_cpp/cuda/
convert.rs

1use core::fmt;
2
3use crate::{
4    Dialect,
5    shared::{Component, Elem, FP8Kind, FmtLeft, Instruction, Item, UnaryInstruction, Variable},
6};
7
8/// special cast function for recursive conversion in the case of minifloat to minifloat conversion
9///
10/// Needs to jump through a lot of hoops to deal with CUDA nonsense.
11/// The overview of available conversions is as follows:
12///
13/// | From                     | To             | Extra args                 |
14/// | ------------------------ | -------------- | -------------------------- |
15/// | f16/bf16/f32/f64         | e4m3/e5m2      | Interpretation, saturation |
16/// | f16/bf16/f32/f64         | e3m2/e2m3/e2m1 | Interpretation, rounding   |
17/// | bf16/f32/f64             | e8m0           | saturation, rounding       |
18/// | e4m3/e5m2/e3m2/e2m3/e2m1 | f16            | Interpretation,            |
19/// | e8m0                     | bf16           |                            |
20///
21/// When the input and output don't match these options, we need to do a two-step conversion.
22/// When the input is a minifloat we always need to cast out to `f16`/`bf16`, and then convert to
23/// the actual out type if it differs. Trying to cast ints also requires an extra conversion, and
24/// so does `f16` to `e8m0` (though it's not recommended to do that anyways, you should be using
25/// `e5m2` for that since you don't have 8 bits of exponent in f16).
26///
27/// See also:
28/// <https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP8__MISC.html>
29/// <https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP6__MISC.html>
30/// <https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP4__MISC.html>
31pub(crate) fn special_cast<D: Dialect>(
32    f: &mut std::fmt::Formatter,
33    input: &Variable<D>,
34    out: &Variable<D>,
35) -> fmt::Result {
36    let mut current_in = *input;
37
38    if matches!(
39        input.elem().unpacked(),
40        Elem::FP4(_) | Elem::FP6(_) | Elem::FP8(_)
41    ) {
42        let mut item = out.item();
43        item.elem = match input.elem().unpacked() {
44            Elem::FP8(FP8Kind::UE8M0) => Elem::BF16,
45            _ => Elem::F16,
46        };
47        let out_var = if item == out.item() {
48            *out
49        } else {
50            Variable::tmp(item)
51        };
52        if item.elem == Elem::F16 {
53            cast_minifloat_to_half(f, current_in, out_var)?;
54        } else {
55            cast_scale_to_bfloat(f, current_in, out_var)?;
56        }
57        current_in = out_var;
58    }
59
60    // Broadcast scalars to packing factor
61    if out.item().packing_factor() > 1 && input.item().vectorization == 1 {
62        let tmp = Variable::tmp(Item {
63            elem: input.item().elem,
64            vectorization: out.item().packing_factor(),
65            native: input.item().native,
66        });
67        let assign = Instruction::Assign(UnaryInstruction {
68            input: current_in,
69            out: tmp,
70        });
71        writeln!(f, "{assign}")?;
72        current_in = tmp;
73    }
74
75    if matches!(
76        current_in.elem(),
77        Elem::U8
78            | Elem::U16
79            | Elem::U32
80            | Elem::U64
81            | Elem::I8
82            | Elem::I16
83            | Elem::I32
84            | Elem::I64
85            | Elem::Bool
86    ) {
87        // Precision is irrelevant for int, so use bf16 for the range
88        let tmp = Variable::tmp(Item {
89            elem: Elem::BF16,
90            vectorization: current_in.item().vectorization,
91            native: current_in.item().native,
92        });
93        let assign = Instruction::Assign(UnaryInstruction {
94            input: current_in,
95            out: tmp,
96        });
97        writeln!(f, "{assign}")?;
98        current_in = tmp;
99    }
100
101    if matches!(out.elem().unpacked(), Elem::FP4(_) | Elem::FP6(_)) {
102        return cast_to_fp4_fp6(f, current_in, *out);
103    }
104
105    if matches!(out.elem().unpacked(), Elem::FP8(FP8Kind::UE8M0)) {
106        // Scale can't be converted from half...
107        if matches!(current_in.elem(), Elem::F16) {
108            let mut item = current_in.item();
109            item.elem = Elem::BF16;
110            let tmp = Variable::tmp(item);
111            let assign = Instruction::Assign(UnaryInstruction {
112                input: current_in,
113                out: tmp,
114            });
115            writeln!(f, "{assign}")?;
116            current_in = tmp;
117        }
118        return cast_to_scale(f, current_in, *out);
119    }
120
121    if matches!(out.elem().unpacked(), Elem::FP8(_)) {
122        return cast_to_fp8(f, current_in, *out);
123    }
124
125    if current_in.item() != out.item() {
126        let assign = Instruction::Assign(UnaryInstruction {
127            input: current_in,
128            out: *out,
129        });
130        writeln!(f, "{assign}")?;
131    }
132
133    Ok(())
134}
135
136/// Convert any float to fp4/fp6, with round to nearest
137fn cast_to_fp4_fp6<D: Dialect>(
138    f: &mut fmt::Formatter,
139    input: Variable<D>,
140    out: Variable<D>,
141) -> fmt::Result {
142    let out_opt = out.optimized();
143    let packing = out_opt.item().packing_factor();
144    let packed = packing == 2;
145    let pack_suffix = if packed { "2" } else { "" };
146
147    let (out_ty, interpretation) = match out_opt.elem() {
148        Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
149        Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
150        Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
151        Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
152        _ => unreachable!("Must be fp4 or fp6"),
153    };
154
155    let in_ty = match input.elem().unpacked() {
156        Elem::F64 => format!("double{pack_suffix}"),
157        Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
158        Elem::F16 => format!("halfraw{pack_suffix}"),
159        Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
160        _ => unreachable!(),
161    };
162
163    let input = input.optimized();
164
165    handle_unroll(f, out, |f, i| {
166        let in_value = float_to_packed(input, i, packing);
167
168        write!(
169            f,
170            "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_{interpretation}, cudaRoundNearest)",
171        )
172    })
173}
174
175/// Convert any float except f16 to e8m0
176fn cast_to_scale<D: Dialect>(
177    f: &mut fmt::Formatter,
178    input: Variable<D>,
179    out: Variable<D>,
180) -> fmt::Result {
181    let out_opt = out.optimized();
182    let packing = out_opt.item().packing_factor();
183    let packed = packing > 1;
184    let pack_suffix = if packed { "2" } else { "" };
185
186    let out_ty = match out_opt.elem() {
187        Elem::FP8(_) => "e8m0",
188        Elem::FP8x2(_) => "e8m0x2",
189        _ => unreachable!("Must be scale factor"),
190    };
191
192    let in_ty = match input.elem() {
193        Elem::F64 => format!("double{pack_suffix}"),
194        Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
195        Elem::BF16 => format!("bfloat16{pack_suffix}raw"),
196        _ => unreachable!(),
197    };
198
199    let input = input.optimized();
200
201    handle_unroll(f, out, |f, i| {
202        let in_value = float_to_packed(input, i, packing);
203
204        write!(
205            f,
206            "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, cudaRoundPosInf)",
207        )
208    })
209}
210
211/// Convert any float to fp8 (except e8m0)
212fn cast_to_fp8<D: Dialect>(
213    f: &mut fmt::Formatter,
214    input: Variable<D>,
215    out: Variable<D>,
216) -> fmt::Result {
217    let out_opt = out.optimized();
218    let packing = out_opt.item().packing_factor();
219    let packed = packing > 1;
220    let pack_suffix = if packed { "2" } else { "" };
221
222    let (out_ty, interpretation) = match out_opt.elem() {
223        Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
224        Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
225        _ => unreachable!("Must be fp8"),
226    };
227
228    let in_ty = match input.elem() {
229        Elem::F64 => format!("double{pack_suffix}"),
230        Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
231        Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
232        Elem::F16 => format!("halfraw{pack_suffix}"),
233        _ => unreachable!(),
234    };
235
236    let input = input.optimized();
237
238    handle_unroll(f, out, |f, i| {
239        let in_value = float_to_packed(input, i, packing);
240
241        write!(
242            f,
243            "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, __NV_{interpretation})",
244        )
245    })
246}
247
248/// Pack types that normally wouldn't be optimized into a `vec2` for conversion
249fn float_to_packed<D: Dialect>(input: Variable<D>, i: usize, packing: usize) -> String {
250    match input.elem() {
251        Elem::TF32 | Elem::F32 => {
252            let i = i * packing;
253            if packing > 1 {
254                format!("float2 {{ {}, {} }}", input.index(i), input.index(i + 1))
255            } else {
256                format!("{}", input.index(i))
257            }
258        }
259        Elem::F64 => {
260            let i = i * packing;
261            if packing > 1 {
262                format!("double2 {{ {}, {} }}", input.index(i), input.index(i + 1))
263            } else {
264                format!("{}", input.index(i))
265            }
266        }
267        Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => format!("{}", input.index(i)),
268        _ => unreachable!(),
269    }
270}
271
272/// Convert any FP8/6/4 except e8m0 to half
273fn cast_minifloat_to_half<D: Dialect>(
274    f: &mut fmt::Formatter,
275    input: Variable<D>,
276    out: Variable<D>,
277) -> fmt::Result {
278    let in_opt = input.optimized();
279    let out_opt = out.optimized().item();
280
281    let (in_ty, interpretation) = match in_opt.elem() {
282        Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
283        Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
284        Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
285        Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
286        Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
287        Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
288        _ => unreachable!("can only cast minifloat"),
289    };
290
291    let out_ty = match out_opt.elem() {
292        Elem::F16 => "halfraw",
293        Elem::F16x2 => "halfraw2",
294        _ => unreachable!("out type must be half"),
295    };
296
297    handle_unroll(f, out, |f, i| {
298        let input = in_opt.index(i);
299        write!(
300            f,
301            "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}, __NV_{interpretation}))",
302            out_opt.elem()
303        )
304    })
305}
306
307/// Convert an e8m0 scaling factor to bf16
308fn cast_scale_to_bfloat<D: Dialect>(
309    f: &mut fmt::Formatter,
310    input: Variable<D>,
311    out: Variable<D>,
312) -> fmt::Result {
313    let in_opt = input.optimized();
314    let out_opt = out.optimized().item();
315
316    let in_ty = match in_opt.elem() {
317        Elem::FP8(_) => "e8m0",
318        Elem::FP8x2(_) => "e8m0x2",
319        _ => unreachable!("must be scaling factor in e8m0 format"),
320    };
321
322    let out_ty = match out_opt.elem() {
323        Elem::BF16 => "bf16raw",
324        Elem::BF16x2 => "bf162raw",
325        _ => unreachable!("out type must be half"),
326    };
327
328    handle_unroll(f, out, |f, i| {
329        let input = in_opt.index(i);
330        write!(
331            f,
332            "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}))",
333            out_opt.elem()
334        )
335    })
336}
337
338fn handle_unroll<D: Dialect>(
339    f: &mut fmt::Formatter,
340    out: Variable<D>,
341    mut op: impl FnMut(&mut fmt::Formatter, usize) -> fmt::Result,
342) -> fmt::Result {
343    let out_opt = out.item().optimized();
344    let vec = out_opt.vectorization;
345    let out_var = if out.item() != out_opt {
346        Variable::tmp(out_opt)
347    } else {
348        out
349    };
350    write!(f, "{} = ", out_var.fmt_left())?;
351    if vec > 1 {
352        writeln!(f, "{out_opt} {{")?;
353    }
354    for i in 0..vec {
355        op(f, i)?;
356        if i + 1 < vec {
357            f.write_str(",\n")?;
358        }
359    }
360    if vec > 1 {
361        write!(f, "\n}}")?;
362    }
363    f.write_str(";\n")?;
364
365    if out.item() != out_opt {
366        writeln!(
367            f,
368            "{} = reinterpret_cast<{}&>({out_var});",
369            out.fmt_left(),
370            out.item()
371        )?;
372    }
373    Ok(())
374}