1use std::fmt;
12
13use super::writer::PtxWriter;
14use crate::ir::{PtxInstruction, PtxKernel, PtxModule, Register};
15use crate::types::PtxType;
16
17pub trait Emit {
23 fn emit(&self, w: &mut PtxWriter) -> fmt::Result;
25}
26
27impl Emit for PtxModule {
30 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
31 w.raw_line(&format!(".version {}", self.version))?;
32 w.raw_line(&format!(".target {}", self.target))?;
33 w.raw_line(&format!(".address_size {}", self.address_size))?;
34 for kernel in &self.kernels {
35 w.blank()?;
36 kernel.emit(w)?;
37 }
38 Ok(())
39 }
40}
41
42impl Emit for PtxKernel {
45 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
46 if self.params.is_empty() {
48 w.raw_line(&format!(".visible .entry {}()", self.name))?;
49 } else {
50 w.raw_line(&format!(".visible .entry {}(", self.name))?;
51 w.indent();
52 for (i, param) in self.params.iter().enumerate() {
53 let comma = if i < self.params.len() - 1 { "," } else { "" };
54 w.line(&format!("{}{}", param.ptx_decl(), comma))?;
55 }
56 w.dedent();
57 w.raw_line(")")?;
58 }
59
60 w.raw_line("{")?;
62 w.indent();
63
64 emit_reg_declarations(&self.registers, w)?;
66
67 for decl in &self.shared_decls {
69 w.line(&format!(
70 ".shared .align {} .b8 {}[{}];",
71 decl.align, decl.name, decl.size_bytes
72 ))?;
73 }
74
75 w.blank()?;
77
78 for instr in &self.body {
80 instr.emit(w)?;
81 }
82
83 w.dedent();
85 w.raw_line("}")?;
86 Ok(())
87 }
88}
89
90fn emit_reg_declarations(registers: &[Register], w: &mut PtxWriter) -> fmt::Result {
96 let mut max_idx: [Option<u32>; 7] = [None; 7];
98 let mut decl_types: [&str; 7] = [""; 7];
99
100 for reg in registers {
101 let ci = reg.kind.counter_index();
102 match max_idx[ci] {
103 None => {
104 max_idx[ci] = Some(reg.index);
105 decl_types[ci] = reg.ptx_type.reg_decl_type();
106 }
107 Some(prev) if reg.index > prev => {
108 max_idx[ci] = Some(reg.index);
109 }
110 _ => {}
111 }
112 }
113
114 let prefixes = ["%r", "%rd", "%f", "%fd", "%p", "%h", "%hb"];
116 for i in 0..7 {
117 if let Some(max) = max_idx[i] {
118 let count = max + 1;
119 w.line(&format!(
120 ".reg {} {}<{}>;",
121 decl_types[i], prefixes[i], count
122 ))?;
123 }
124 }
125 Ok(())
126}
127
128impl Emit for PtxInstruction {
131 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
132 match self {
133 Self::Arith(op) => op.emit(w),
134 Self::Memory(op) => op.emit(w),
135 Self::Control(op) => op.emit(w),
136 Self::TensorCore(op) => op.emit(w),
137 Self::Mov { dst, src, ty } => {
138 let mnemonic = format!("mov{}", ty.ptx_suffix());
139 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
140 }
141 Self::Cvt {
142 dst,
143 src,
144 dst_ty,
145 src_ty,
146 } => {
147 let rounding = match (dst_ty, src_ty) {
152 (
158 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
159 PtxType::S8 | PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
160 ) => ".rn",
161 (
163 PtxType::S8 | PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
164 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
165 ) => ".rzi",
166 (
168 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
169 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
170 ) => ".rn",
171 _ => "",
173 };
174 let mnemonic = format!(
175 "cvt{rounding}{}{}",
176 dst_ty.ptx_suffix(),
177 src_ty.ptx_suffix()
178 );
179 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
180 }
181 Self::MovPack { dst, srcs, ty } => {
182 let joined = srcs
189 .iter()
190 .map(|r| format!("{r}"))
191 .collect::<Vec<_>>()
192 .join(",");
193 let src_list = format!("{{{joined}}}");
194 let bits = ty.size_bytes() * 8;
195 let mnemonic = format!("mov.b{bits}");
196 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &src_list])
197 }
198 Self::Label(name) => {
199 w.dedent();
202 w.raw_line(&format!("{name}:"))?;
203 w.indent();
204 Ok(())
205 }
206 Self::Comment(text) => w.line(&format!("// {text}")),
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::ir::{Operand, PtxParam, RegisterAllocator, SpecialReg};
215 use crate::types::{PtxType, RegKind};
216
217 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
218 Register {
219 kind,
220 index,
221 ptx_type,
222 }
223 }
224
225 #[test]
226 fn emit_mov_special_reg() {
227 let mut w = PtxWriter::new();
228 w.indent();
229 let instr = PtxInstruction::Mov {
230 dst: reg(RegKind::R, 0, PtxType::U32),
231 src: Operand::SpecialReg(SpecialReg::TidX),
232 ty: PtxType::U32,
233 };
234 instr.emit(&mut w).unwrap();
235 assert_eq!(w.finish(), " mov.u32 %r0, %tid.x;\n");
236 }
237
238 #[test]
239 fn emit_mov_reg_to_reg() {
240 let mut w = PtxWriter::new();
241 w.indent();
242 let instr = PtxInstruction::Mov {
243 dst: reg(RegKind::F, 1, PtxType::F32),
244 src: Operand::Reg(reg(RegKind::F, 0, PtxType::F32)),
245 ty: PtxType::F32,
246 };
247 instr.emit(&mut w).unwrap();
248 assert_eq!(w.finish(), " mov.f32 %f1, %f0;\n");
249 }
250
251 #[test]
252 fn emit_mov_shared_addr() {
253 let mut w = PtxWriter::new();
254 w.indent();
255 let instr = PtxInstruction::Mov {
256 dst: reg(RegKind::R, 0, PtxType::U32),
257 src: Operand::SharedAddr("sdata".to_string()),
258 ty: PtxType::U32,
259 };
260 instr.emit(&mut w).unwrap();
261 assert_eq!(w.finish(), " mov.u32 %r0, sdata;\n");
262 }
263
264 #[test]
265 fn emit_cvt() {
266 let mut w = PtxWriter::new();
267 w.indent();
268 let instr = PtxInstruction::Cvt {
269 dst: reg(RegKind::F, 0, PtxType::F32),
270 src: reg(RegKind::R, 0, PtxType::S32),
271 dst_ty: PtxType::F32,
272 src_ty: PtxType::S32,
273 };
274 instr.emit(&mut w).unwrap();
275 assert_eq!(w.finish(), " cvt.rn.f32.s32 %f0, %r0;\n");
276 }
277
278 #[test]
279 fn emit_cvt_float_to_int() {
280 let mut w = PtxWriter::new();
281 w.indent();
282 let instr = PtxInstruction::Cvt {
283 dst: reg(RegKind::R, 0, PtxType::U32),
284 src: reg(RegKind::F, 0, PtxType::F32),
285 dst_ty: PtxType::U32,
286 src_ty: PtxType::F32,
287 };
288 instr.emit(&mut w).unwrap();
289 assert_eq!(w.finish(), " cvt.rzi.u32.f32 %r0, %f0;\n");
290 }
291
292 #[test]
293 fn emit_cvt_int_to_int() {
294 let mut w = PtxWriter::new();
295 w.indent();
296 let instr = PtxInstruction::Cvt {
297 dst: reg(RegKind::R, 0, PtxType::S32),
298 src: reg(RegKind::R, 1, PtxType::U32),
299 dst_ty: PtxType::S32,
300 src_ty: PtxType::U32,
301 };
302 instr.emit(&mut w).unwrap();
303 assert_eq!(w.finish(), " cvt.s32.u32 %r0, %r1;\n");
305 }
306
307 #[test]
308 fn emit_cvt_f32_to_f16() {
309 let mut w = PtxWriter::new();
310 w.indent();
311 let instr = PtxInstruction::Cvt {
312 dst: reg(RegKind::H, 0, PtxType::F16),
313 src: reg(RegKind::F, 0, PtxType::F32),
314 dst_ty: PtxType::F16,
315 src_ty: PtxType::F32,
316 };
317 instr.emit(&mut w).unwrap();
318 assert_eq!(w.finish(), " cvt.rn.f16.f32 %h0, %f0;\n");
319 }
320
321 #[test]
322 fn emit_cvt_f16_to_f32() {
323 let mut w = PtxWriter::new();
324 w.indent();
325 let instr = PtxInstruction::Cvt {
326 dst: reg(RegKind::F, 0, PtxType::F32),
327 src: reg(RegKind::H, 0, PtxType::F16),
328 dst_ty: PtxType::F32,
329 src_ty: PtxType::F16,
330 };
331 instr.emit(&mut w).unwrap();
332 assert_eq!(w.finish(), " cvt.rn.f32.f16 %f0, %h0;\n");
333 }
334
335 #[test]
336 fn emit_cvt_int_to_f16() {
337 let mut w = PtxWriter::new();
338 w.indent();
339 let instr = PtxInstruction::Cvt {
340 dst: reg(RegKind::H, 0, PtxType::F16),
341 src: reg(RegKind::R, 0, PtxType::S32),
342 dst_ty: PtxType::F16,
343 src_ty: PtxType::S32,
344 };
345 instr.emit(&mut w).unwrap();
346 assert_eq!(w.finish(), " cvt.rn.f16.s32 %h0, %r0;\n");
347 }
348
349 #[test]
350 fn emit_cvt_f16_to_int() {
351 let mut w = PtxWriter::new();
352 w.indent();
353 let instr = PtxInstruction::Cvt {
354 dst: reg(RegKind::R, 0, PtxType::U32),
355 src: reg(RegKind::H, 0, PtxType::F16),
356 dst_ty: PtxType::U32,
357 src_ty: PtxType::F16,
358 };
359 instr.emit(&mut w).unwrap();
360 assert_eq!(w.finish(), " cvt.rzi.u32.f16 %r0, %h0;\n");
361 }
362
363 #[test]
364 fn emit_cvt_bf16_to_f32() {
365 let mut w = PtxWriter::new();
366 w.indent();
367 let instr = PtxInstruction::Cvt {
368 dst: reg(RegKind::F, 0, PtxType::F32),
369 src: reg(RegKind::Hb, 0, PtxType::BF16),
370 dst_ty: PtxType::F32,
371 src_ty: PtxType::BF16,
372 };
373 instr.emit(&mut w).unwrap();
374 assert_eq!(w.finish(), " cvt.rn.f32.bf16 %f0, %hb0;\n");
375 }
376
377 #[test]
378 fn emit_reg_declarations_with_f16() {
379 let regs = vec![
380 reg(RegKind::F, 0, PtxType::F32),
381 reg(RegKind::H, 0, PtxType::F16),
382 reg(RegKind::H, 1, PtxType::F16),
383 reg(RegKind::Hb, 0, PtxType::BF16),
384 ];
385 let mut w = PtxWriter::new();
386 w.indent();
387 emit_reg_declarations(®s, &mut w).unwrap();
388 let output = w.finish();
389 assert!(output.contains(".reg .f32 %f<1>;"));
390 assert!(output.contains(".reg .f16 %h<2>;"));
391 assert!(output.contains(".reg .bf16 %hb<1>;"));
392 }
393
394 #[test]
395 fn emit_label_at_column_zero() {
396 let mut w = PtxWriter::new();
397 w.indent(); let instr = PtxInstruction::Label("EXIT".to_string());
399 instr.emit(&mut w).unwrap();
400 assert_eq!(w.finish(), "EXIT:\n");
402 }
403
404 #[test]
405 fn emit_comment() {
406 let mut w = PtxWriter::new();
407 w.indent();
408 let instr = PtxInstruction::Comment("bounds check".to_string());
409 instr.emit(&mut w).unwrap();
410 assert_eq!(w.finish(), " // bounds check\n");
411 }
412
413 #[test]
414 fn emit_mov_pack_two_f16_into_b32() {
415 let mut w = PtxWriter::new();
416 w.indent();
417 let instr = PtxInstruction::MovPack {
418 dst: reg(RegKind::R, 7, PtxType::U32),
419 srcs: vec![
420 reg(RegKind::H, 3, PtxType::F16),
421 reg(RegKind::H, 4, PtxType::F16),
422 ],
423 ty: PtxType::U32,
424 };
425 instr.emit(&mut w).unwrap();
426 assert_eq!(w.finish(), " mov.b32 %r7, {%h3,%h4};\n");
427 }
428
429 #[test]
434 fn emit_kernel_f16_flow() {
435 use crate::instr::{ArithOp, MemoryOp};
436
437 let mut alloc = RegisterAllocator::new();
438 let rd_in = alloc.alloc(PtxType::U64); let rd_out = alloc.alloc(PtxType::U64); let r_tid = alloc.alloc(PtxType::U32); let rd_off = alloc.alloc(PtxType::U64); let rd_addr_in = alloc.alloc(PtxType::U64); let rd_addr_out = alloc.alloc(PtxType::U64); let h_val = alloc.alloc(PtxType::F16); let f_val = alloc.alloc(PtxType::F32); let f_one = alloc.alloc(PtxType::F32); let f_sum = alloc.alloc(PtxType::F32); let h_out = alloc.alloc(PtxType::F16); let mut kernel = PtxKernel::new("f16_add_one");
452 kernel.add_param(PtxParam::pointer("in_ptr", PtxType::F16));
453 kernel.add_param(PtxParam::pointer("out_ptr", PtxType::F16));
454
455 kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
457 dst: rd_in,
458 param_name: "in_ptr".to_string(),
459 ty: PtxType::U64,
460 }));
461 kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
462 dst: rd_out,
463 param_name: "out_ptr".to_string(),
464 ty: PtxType::U64,
465 }));
466 kernel.push(PtxInstruction::Mov {
468 dst: r_tid,
469 src: Operand::SpecialReg(SpecialReg::TidX),
470 ty: PtxType::U32,
471 });
472 kernel.push(PtxInstruction::Cvt {
474 dst: rd_off,
475 src: r_tid,
476 dst_ty: PtxType::U64,
477 src_ty: PtxType::U32,
478 });
479 kernel.push(PtxInstruction::Arith(ArithOp::Add {
481 dst: rd_addr_in,
482 lhs: Operand::Reg(rd_in),
483 rhs: Operand::Reg(rd_off),
484 ty: PtxType::U64,
485 }));
486 kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
488 dst: h_val,
489 addr: rd_addr_in,
490 ty: PtxType::F16,
491 }));
492 kernel.push(PtxInstruction::Cvt {
494 dst: f_val,
495 src: h_val,
496 dst_ty: PtxType::F32,
497 src_ty: PtxType::F16,
498 });
499 kernel.push(PtxInstruction::Mov {
501 dst: f_one,
502 src: Operand::ImmF32(1.0),
503 ty: PtxType::F32,
504 });
505 kernel.push(PtxInstruction::Arith(ArithOp::Add {
506 dst: f_sum,
507 lhs: Operand::Reg(f_val),
508 rhs: Operand::Reg(f_one),
509 ty: PtxType::F32,
510 }));
511 kernel.push(PtxInstruction::Cvt {
513 dst: h_out,
514 src: f_sum,
515 dst_ty: PtxType::F16,
516 src_ty: PtxType::F32,
517 });
518 kernel.push(PtxInstruction::Arith(ArithOp::Add {
520 dst: rd_addr_out,
521 lhs: Operand::Reg(rd_out),
522 rhs: Operand::Reg(rd_off),
523 ty: PtxType::U64,
524 }));
525 kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
527 addr: rd_addr_out,
528 src: h_out,
529 ty: PtxType::F16,
530 }));
531 kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
532 kernel.set_registers(alloc.into_allocated());
533
534 let mut w = PtxWriter::new();
535 kernel.emit(&mut w).unwrap();
536 let output = w.finish();
537
538 assert!(output.contains(".param .u64 in_ptr"));
540 assert!(output.contains(".param .u64 out_ptr"));
541 assert!(output.contains(".reg .f16 %h<2>;"), "f16 reg declarations");
542 assert!(output.contains(".reg .f32 %f<3>;"), "f32 reg declarations");
543 assert!(output.contains("ld.global.b16 %h0"));
548 assert!(output.contains("cvt.rn.f32.f16 %f0, %h0"));
549 assert!(output.contains("cvt.rn.f16.f32 %h1, %f2"));
550 assert!(output.contains("st.global.b16 [%rd4], %h1"));
551 }
552
553 #[test]
554 fn emit_module_header() {
555 let module = PtxModule::new("sm_70");
556 let mut w = PtxWriter::new();
557 module.emit(&mut w).unwrap();
559 assert_eq!(
560 w.finish(),
561 ".version 8.7\n.target sm_70\n.address_size 64\n"
562 );
563 }
564
565 #[test]
566 fn emit_kernel_minimal() {
567 let mut alloc = RegisterAllocator::new();
568 let r0 = alloc.alloc(PtxType::U32);
569
570 let mut kernel = PtxKernel::new("test_kernel");
571 kernel.add_param(PtxParam::scalar("n", PtxType::U32));
572 kernel.push(PtxInstruction::Mov {
573 dst: r0,
574 src: Operand::ImmU32(42),
575 ty: PtxType::U32,
576 });
577 kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
578 kernel.set_registers(alloc.into_allocated());
579
580 let mut w = PtxWriter::new();
581 kernel.emit(&mut w).unwrap();
582 let output = w.finish();
583
584 assert!(output.contains(".visible .entry test_kernel("));
586 assert!(output.contains(".param .u32 n"));
587 assert!(output.contains(".reg .b32 %r<1>;"));
588 assert!(output.contains("mov.u32 %r0, 42;"));
589 assert!(output.contains("ret;"));
590 assert!(output.starts_with(".visible .entry"));
591 assert!(output.trim_end().ends_with('}'));
592 }
593}