1use core::fmt;
2
3use crate::{
4 Dialect,
5 shared::{Component, Elem, FP8Kind, FmtLeft, Instruction, Item, UnaryInstruction, Variable},
6};
7
8pub(crate) fn special_cast<D: Dialect>(
32 f: &mut std::fmt::Formatter,
33 input: &Variable<D>,
34 out: &Variable<D>,
35) -> fmt::Result {
36 let mut current_in = *input;
37
38 if matches!(
39 input.elem().unpacked(),
40 Elem::FP4(_) | Elem::FP6(_) | Elem::FP8(_)
41 ) {
42 let mut item = out.item();
43 item.elem = match input.elem().unpacked() {
44 Elem::FP8(FP8Kind::UE8M0) => Elem::BF16,
45 _ => Elem::F16,
46 };
47 let out_var = if item == out.item() {
48 *out
49 } else {
50 Variable::tmp(item)
51 };
52 if item.elem == Elem::F16 {
53 cast_minifloat_to_half(f, current_in, out_var)?;
54 } else {
55 cast_scale_to_bfloat(f, current_in, out_var)?;
56 }
57 current_in = out_var;
58 }
59
60 if out.item().packing_factor() > 1 && input.item().vectorization == 1 {
62 let tmp = Variable::tmp(Item {
63 elem: input.item().elem,
64 vectorization: out.item().packing_factor(),
65 native: input.item().native,
66 });
67 let assign = Instruction::Assign(UnaryInstruction {
68 input: current_in,
69 out: tmp,
70 });
71 writeln!(f, "{assign}")?;
72 current_in = tmp;
73 }
74
75 if matches!(
76 current_in.elem(),
77 Elem::U8
78 | Elem::U16
79 | Elem::U32
80 | Elem::U64
81 | Elem::I8
82 | Elem::I16
83 | Elem::I32
84 | Elem::I64
85 | Elem::Bool
86 ) {
87 let tmp = Variable::tmp(Item {
89 elem: Elem::BF16,
90 vectorization: current_in.item().vectorization,
91 native: current_in.item().native,
92 });
93 let assign = Instruction::Assign(UnaryInstruction {
94 input: current_in,
95 out: tmp,
96 });
97 writeln!(f, "{assign}")?;
98 current_in = tmp;
99 }
100
101 if matches!(out.elem().unpacked(), Elem::FP4(_) | Elem::FP6(_)) {
102 return cast_to_fp4_fp6(f, current_in, *out);
103 }
104
105 if matches!(out.elem().unpacked(), Elem::FP8(FP8Kind::UE8M0)) {
106 if matches!(current_in.elem(), Elem::F16) {
108 let mut item = current_in.item();
109 item.elem = Elem::BF16;
110 let tmp = Variable::tmp(item);
111 let assign = Instruction::Assign(UnaryInstruction {
112 input: current_in,
113 out: tmp,
114 });
115 writeln!(f, "{assign}")?;
116 current_in = tmp;
117 }
118 return cast_to_scale(f, current_in, *out);
119 }
120
121 if matches!(out.elem().unpacked(), Elem::FP8(_)) {
122 return cast_to_fp8(f, current_in, *out);
123 }
124
125 if current_in.item() != out.item() {
126 let assign = Instruction::Assign(UnaryInstruction {
127 input: current_in,
128 out: *out,
129 });
130 writeln!(f, "{assign}")?;
131 }
132
133 Ok(())
134}
135
136fn cast_to_fp4_fp6<D: Dialect>(
138 f: &mut fmt::Formatter,
139 input: Variable<D>,
140 out: Variable<D>,
141) -> fmt::Result {
142 let out_opt = out.optimized();
143 let packing = out_opt.item().packing_factor();
144 let packed = packing == 2;
145 let pack_suffix = if packed { "2" } else { "" };
146
147 let (out_ty, interpretation) = match out_opt.elem() {
148 Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
149 Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
150 Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
151 Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
152 _ => unreachable!("Must be fp4 or fp6"),
153 };
154
155 let in_ty = match input.elem().unpacked() {
156 Elem::F64 => format!("double{pack_suffix}"),
157 Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
158 Elem::F16 => format!("halfraw{pack_suffix}"),
159 Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
160 _ => unreachable!(),
161 };
162
163 let input = input.optimized();
164
165 handle_unroll(f, out, |f, i| {
166 let in_value = float_to_packed(input, i, packing);
167
168 write!(
169 f,
170 "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_{interpretation}, cudaRoundNearest)",
171 )
172 })
173}
174
175fn cast_to_scale<D: Dialect>(
177 f: &mut fmt::Formatter,
178 input: Variable<D>,
179 out: Variable<D>,
180) -> fmt::Result {
181 let out_opt = out.optimized();
182 let packing = out_opt.item().packing_factor();
183 let packed = packing > 1;
184 let pack_suffix = if packed { "2" } else { "" };
185
186 let out_ty = match out_opt.elem() {
187 Elem::FP8(_) => "e8m0",
188 Elem::FP8x2(_) => "e8m0x2",
189 _ => unreachable!("Must be scale factor"),
190 };
191
192 let in_ty = match input.elem() {
193 Elem::F64 => format!("double{pack_suffix}"),
194 Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
195 Elem::BF16 => format!("bfloat16{pack_suffix}raw"),
196 _ => unreachable!(),
197 };
198
199 let input = input.optimized();
200
201 handle_unroll(f, out, |f, i| {
202 let in_value = float_to_packed(input, i, packing);
203
204 write!(
205 f,
206 "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, cudaRoundPosInf)",
207 )
208 })
209}
210
211fn cast_to_fp8<D: Dialect>(
213 f: &mut fmt::Formatter,
214 input: Variable<D>,
215 out: Variable<D>,
216) -> fmt::Result {
217 let out_opt = out.optimized();
218 let packing = out_opt.item().packing_factor();
219 let packed = packing > 1;
220 let pack_suffix = if packed { "2" } else { "" };
221
222 let (out_ty, interpretation) = match out_opt.elem() {
223 Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
224 Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
225 _ => unreachable!("Must be fp8"),
226 };
227
228 let in_ty = match input.elem() {
229 Elem::F64 => format!("double{pack_suffix}"),
230 Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
231 Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
232 Elem::F16 => format!("halfraw{pack_suffix}"),
233 _ => unreachable!(),
234 };
235
236 let input = input.optimized();
237
238 handle_unroll(f, out, |f, i| {
239 let in_value = float_to_packed(input, i, packing);
240
241 write!(
242 f,
243 "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, __NV_{interpretation})",
244 )
245 })
246}
247
248fn float_to_packed<D: Dialect>(input: Variable<D>, i: usize, packing: usize) -> String {
250 match input.elem() {
251 Elem::TF32 | Elem::F32 => {
252 let i = i * packing;
253 if packing > 1 {
254 format!("float2 {{ {}, {} }}", input.index(i), input.index(i + 1))
255 } else {
256 format!("{}", input.index(i))
257 }
258 }
259 Elem::F64 => {
260 let i = i * packing;
261 if packing > 1 {
262 format!("double2 {{ {}, {} }}", input.index(i), input.index(i + 1))
263 } else {
264 format!("{}", input.index(i))
265 }
266 }
267 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => format!("{}", input.index(i)),
268 _ => unreachable!(),
269 }
270}
271
272fn cast_minifloat_to_half<D: Dialect>(
274 f: &mut fmt::Formatter,
275 input: Variable<D>,
276 out: Variable<D>,
277) -> fmt::Result {
278 let in_opt = input.optimized();
279 let out_opt = out.optimized().item();
280
281 let (in_ty, interpretation) = match in_opt.elem() {
282 Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
283 Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
284 Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
285 Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
286 Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
287 Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
288 _ => unreachable!("can only cast minifloat"),
289 };
290
291 let out_ty = match out_opt.elem() {
292 Elem::F16 => "halfraw",
293 Elem::F16x2 => "halfraw2",
294 _ => unreachable!("out type must be half"),
295 };
296
297 handle_unroll(f, out, |f, i| {
298 let input = in_opt.index(i);
299 write!(
300 f,
301 "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}, __NV_{interpretation}))",
302 out_opt.elem()
303 )
304 })
305}
306
307fn cast_scale_to_bfloat<D: Dialect>(
309 f: &mut fmt::Formatter,
310 input: Variable<D>,
311 out: Variable<D>,
312) -> fmt::Result {
313 let in_opt = input.optimized();
314 let out_opt = out.optimized().item();
315
316 let in_ty = match in_opt.elem() {
317 Elem::FP8(_) => "e8m0",
318 Elem::FP8x2(_) => "e8m0x2",
319 _ => unreachable!("must be scaling factor in e8m0 format"),
320 };
321
322 let out_ty = match out_opt.elem() {
323 Elem::BF16 => "bf16raw",
324 Elem::BF16x2 => "bf162raw",
325 _ => unreachable!("out type must be half"),
326 };
327
328 handle_unroll(f, out, |f, i| {
329 let input = in_opt.index(i);
330 write!(
331 f,
332 "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}))",
333 out_opt.elem()
334 )
335 })
336}
337
338fn handle_unroll<D: Dialect>(
339 f: &mut fmt::Formatter,
340 out: Variable<D>,
341 mut op: impl FnMut(&mut fmt::Formatter, usize) -> fmt::Result,
342) -> fmt::Result {
343 let out_opt = out.item().optimized();
344 let vec = out_opt.vectorization;
345 let out_var = if out.item() != out_opt {
346 Variable::tmp(out_opt)
347 } else {
348 out
349 };
350 write!(f, "{} = ", out_var.fmt_left())?;
351 if vec > 1 {
352 writeln!(f, "{out_opt} {{")?;
353 }
354 for i in 0..vec {
355 op(f, i)?;
356 if i + 1 < vec {
357 f.write_str(",\n")?;
358 }
359 }
360 if vec > 1 {
361 write!(f, "\n}}")?;
362 }
363 f.write_str(";\n")?;
364
365 if out.item() != out_opt {
366 writeln!(
367 f,
368 "{} = reinterpret_cast<{}&>({out_var});",
369 out.fmt_left(),
370 out.item()
371 )?;
372 }
373 Ok(())
374}