1#![allow(unused)]
3
4use core::fmt;
5
6use crate::{
7 Dialect,
8 shared::{Component, Elem, FP8Kind, FmtLeft, Instruction, Item, UnaryInstruction, Variable},
9};
10
11pub(crate) fn special_cast<D: Dialect>(
35 f: &mut std::fmt::Formatter,
36 input: &Variable<D>,
37 out: &Variable<D>,
38) -> fmt::Result {
39 let mut current_in = *input;
40
41 if matches!(
42 input.elem().unpacked(),
43 Elem::FP4(_) | Elem::FP6(_) | Elem::FP8(_)
44 ) {
45 let mut item = out.item();
46 item.elem = match input.elem().unpacked() {
47 Elem::FP8(FP8Kind::UE8M0) => Elem::BF16,
48 _ => Elem::F16,
49 };
50 let out_var = if item == out.item() {
51 *out
52 } else {
53 Variable::tmp(item)
54 };
55 if item.elem == Elem::F16 {
56 cast_minifloat_to_half(f, current_in, out_var)?;
57 } else {
58 cast_scale_to_bfloat(f, current_in, out_var)?;
59 }
60 current_in = out_var;
61 }
62
63 if out.item().packing_factor() > 1 && input.item().vectorization == 1 {
65 let tmp = Variable::tmp(Item {
66 elem: input.item().elem,
67 vectorization: out.item().packing_factor(),
68 native: input.item().native,
69 });
70 let assign = Instruction::Assign(UnaryInstruction {
71 input: current_in,
72 out: tmp,
73 });
74 writeln!(f, "{assign}")?;
75 current_in = tmp;
76 }
77
78 if matches!(
79 current_in.elem(),
80 Elem::U8
81 | Elem::U16
82 | Elem::U32
83 | Elem::U64
84 | Elem::I8
85 | Elem::I16
86 | Elem::I32
87 | Elem::I64
88 | Elem::Bool
89 ) {
90 let tmp = Variable::tmp(Item {
92 elem: Elem::BF16,
93 vectorization: current_in.item().vectorization,
94 native: current_in.item().native,
95 });
96 let assign = Instruction::Assign(UnaryInstruction {
97 input: current_in,
98 out: tmp,
99 });
100 writeln!(f, "{assign}")?;
101 current_in = tmp;
102 }
103
104 if matches!(out.elem().unpacked(), Elem::FP4(_) | Elem::FP6(_)) {
105 return cast_to_fp4_fp6(f, current_in, *out);
106 }
107
108 if matches!(out.elem().unpacked(), Elem::FP8(FP8Kind::UE8M0)) {
109 if matches!(current_in.elem(), Elem::F16) {
111 let mut item = current_in.item();
112 item.elem = Elem::BF16;
113 let tmp = Variable::tmp(item);
114 let assign = Instruction::Assign(UnaryInstruction {
115 input: current_in,
116 out: tmp,
117 });
118 writeln!(f, "{assign}")?;
119 current_in = tmp;
120 }
121 return cast_to_scale(f, current_in, *out);
122 }
123
124 if matches!(out.elem().unpacked(), Elem::FP8(_)) {
125 return cast_to_fp8(f, current_in, *out);
126 }
127
128 if current_in.item() != out.item() {
129 let assign = Instruction::Assign(UnaryInstruction {
130 input: current_in,
131 out: *out,
132 });
133 writeln!(f, "{assign}")?;
134 }
135
136 Ok(())
137}
138
139fn cast_to_fp4_fp6<D: Dialect>(
141 f: &mut fmt::Formatter,
142 input: Variable<D>,
143 out: Variable<D>,
144) -> fmt::Result {
145 let out_opt = out.optimized();
146 let packing = out_opt.item().packing_factor();
147 let packed = packing == 2;
148 let pack_suffix = if packed { "2" } else { "" };
149
150 let (out_ty, interpretation) = match out_opt.elem() {
151 Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
152 Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
153 Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
154 Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
155 _ => unreachable!("Must be fp4 or fp6"),
156 };
157
158 let in_ty = match input.elem().unpacked() {
159 Elem::F64 => format!("double{pack_suffix}"),
160 Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
161 Elem::F16 => format!("halfraw{pack_suffix}"),
162 Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
163 _ => unreachable!(),
164 };
165
166 let input = input.optimized();
167
168 handle_unroll(f, out, |f, i| {
169 let in_value = float_to_packed(input, i, packing);
170
171 write!(
172 f,
173 "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_{interpretation}, cudaRoundNearest)",
174 )
175 })
176}
177
178fn cast_to_scale<D: Dialect>(
180 f: &mut fmt::Formatter,
181 input: Variable<D>,
182 out: Variable<D>,
183) -> fmt::Result {
184 let out_opt = out.optimized();
185 let packing = out_opt.item().packing_factor();
186 let packed = packing > 1;
187 let pack_suffix = if packed { "2" } else { "" };
188
189 let out_ty = match out_opt.elem() {
190 Elem::FP8(_) => "e8m0",
191 Elem::FP8x2(_) => "e8m0x2",
192 _ => unreachable!("Must be scale factor"),
193 };
194
195 let in_ty = match input.elem() {
196 Elem::F64 => format!("double{pack_suffix}"),
197 Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
198 Elem::BF16 => format!("bfloat16{pack_suffix}raw"),
199 _ => unreachable!(),
200 };
201
202 let input = input.optimized();
203
204 handle_unroll(f, out, |f, i| {
205 let in_value = float_to_packed(input, i, packing);
206
207 write!(
208 f,
209 "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, cudaRoundPosInf)",
210 )
211 })
212}
213
214fn cast_to_fp8<D: Dialect>(
216 f: &mut fmt::Formatter,
217 input: Variable<D>,
218 out: Variable<D>,
219) -> fmt::Result {
220 let out_opt = out.optimized();
221 let packing = out_opt.item().packing_factor();
222 let packed = packing > 1;
223 let pack_suffix = if packed { "2" } else { "" };
224
225 let (out_ty, interpretation) = match out_opt.elem() {
226 Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
227 Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
228 _ => unreachable!("Must be fp8"),
229 };
230
231 let in_ty = match input.elem() {
232 Elem::F64 => format!("double{pack_suffix}"),
233 Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
234 Elem::BF16 => format!("bfloat16raw{pack_suffix}"),
235 Elem::F16 => format!("halfraw{pack_suffix}"),
236 _ => unreachable!(),
237 };
238
239 let input = input.optimized();
240
241 handle_unroll(f, out, |f, i| {
242 let in_value = float_to_packed(input, i, packing);
243
244 write!(
245 f,
246 "__nv_cvt_{in_ty}_to_{out_ty}({in_value}, __NV_NOSAT, __NV_{interpretation})",
247 )
248 })
249}
250
251fn float_to_packed<D: Dialect>(input: Variable<D>, i: usize, packing: usize) -> String {
253 match input.elem() {
254 Elem::TF32 | Elem::F32 => {
255 let i = i * packing;
256 if packing > 1 {
257 format!("float2 {{ {}, {} }}", input.index(i), input.index(i + 1))
258 } else {
259 format!("{}", input.index(i))
260 }
261 }
262 Elem::F64 => {
263 let i = i * packing;
264 if packing > 1 {
265 format!("double2 {{ {}, {} }}", input.index(i), input.index(i + 1))
266 } else {
267 format!("{}", input.index(i))
268 }
269 }
270 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => format!("{}", input.index(i)),
271 _ => unreachable!(),
272 }
273}
274
275fn cast_minifloat_to_half<D: Dialect>(
277 f: &mut fmt::Formatter,
278 input: Variable<D>,
279 out: Variable<D>,
280) -> fmt::Result {
281 let in_opt = input.optimized();
282 let out_opt = out.optimized().item();
283
284 let (in_ty, interpretation) = match in_opt.elem() {
285 Elem::FP4(kind) => ("fp4", format!("{kind:?}")),
286 Elem::FP4x2(kind) => ("fp4x2", format!("{kind:?}")),
287 Elem::FP6(kind) => ("fp6", format!("{kind:?}")),
288 Elem::FP6x2(kind) => ("fp6x2", format!("{kind:?}")),
289 Elem::FP8(kind) => ("fp8", format!("{kind:?}")),
290 Elem::FP8x2(kind) => ("fp8x2", format!("{kind:?}")),
291 _ => unreachable!("can only cast minifloat"),
292 };
293
294 let out_ty = match out_opt.elem() {
295 Elem::F16 => "halfraw",
296 Elem::F16x2 => "halfraw2",
297 _ => unreachable!("out type must be half"),
298 };
299
300 handle_unroll(f, out, |f, i| {
301 let input = in_opt.index(i);
302 write!(
303 f,
304 "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}, __NV_{interpretation}))",
305 out_opt.elem()
306 )
307 })
308}
309
310fn cast_scale_to_bfloat<D: Dialect>(
312 f: &mut fmt::Formatter,
313 input: Variable<D>,
314 out: Variable<D>,
315) -> fmt::Result {
316 let in_opt = input.optimized();
317 let out_opt = out.optimized().item();
318
319 let in_ty = match in_opt.elem() {
320 Elem::FP8(_) => "e8m0",
321 Elem::FP8x2(_) => "e8m0x2",
322 _ => unreachable!("must be scaling factor in e8m0 format"),
323 };
324
325 let out_ty = match out_opt.elem() {
326 Elem::BF16 => "bf16raw",
327 Elem::BF16x2 => "bf162raw",
328 _ => unreachable!("out type must be half"),
329 };
330
331 handle_unroll(f, out, |f, i| {
332 let input = in_opt.index(i);
333 write!(
334 f,
335 "{}(__nv_cvt_{in_ty}_to_{out_ty}({input}))",
336 out_opt.elem()
337 )
338 })
339}
340
341fn handle_unroll<D: Dialect>(
342 f: &mut fmt::Formatter,
343 out: Variable<D>,
344 mut op: impl FnMut(&mut fmt::Formatter, usize) -> fmt::Result,
345) -> fmt::Result {
346 let out_opt = out.item().optimized();
347 let vec = out_opt.vectorization;
348 let out_var = if out.item() != out_opt {
349 Variable::tmp(out_opt)
350 } else {
351 out
352 };
353 write!(f, "{} = ", out_var.fmt_left())?;
354 if vec > 1 {
355 writeln!(f, "{out_opt} {{")?;
356 }
357 for i in 0..vec {
358 op(f, i)?;
359 if i + 1 < vec {
360 f.write_str(",\n")?;
361 }
362 }
363 if vec > 1 {
364 write!(f, "\n}}")?;
365 }
366 f.write_str(";\n")?;
367
368 if out.item() != out_opt {
369 writeln!(
370 f,
371 "{} = reinterpret_cast<{}&>({out_var});",
372 out.fmt_left(),
373 out.item()
374 )?;
375 }
376 Ok(())
377}