1use std::fmt;
12
13use super::writer::PtxWriter;
14use crate::ir::{PtxInstruction, PtxKernel, PtxModule, Register};
15use crate::types::PtxType;
16
17pub trait Emit {
23 fn emit(&self, w: &mut PtxWriter) -> fmt::Result;
25}
26
27impl Emit for PtxModule {
30 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
31 w.raw_line(&format!(".version {}", self.version))?;
32 w.raw_line(&format!(".target {}", self.target))?;
33 w.raw_line(&format!(".address_size {}", self.address_size))?;
34 for kernel in &self.kernels {
35 w.blank()?;
36 kernel.emit(w)?;
37 }
38 Ok(())
39 }
40}
41
42impl Emit for PtxKernel {
45 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
46 if self.params.is_empty() {
48 w.raw_line(&format!(".visible .entry {}()", self.name))?;
49 } else {
50 w.raw_line(&format!(".visible .entry {}(", self.name))?;
51 w.indent();
52 for (i, param) in self.params.iter().enumerate() {
53 let comma = if i < self.params.len() - 1 { "," } else { "" };
54 w.line(&format!("{}{}", param.ptx_decl(), comma))?;
55 }
56 w.dedent();
57 w.raw_line(")")?;
58 }
59
60 w.raw_line("{")?;
62 w.indent();
63
64 emit_reg_declarations(&self.registers, w)?;
66
67 for decl in &self.shared_decls {
69 w.line(&format!(
70 ".shared .align {} .b8 {}[{}];",
71 decl.align, decl.name, decl.size_bytes
72 ))?;
73 }
74
75 w.blank()?;
77
78 for instr in &self.body {
80 instr.emit(w)?;
81 }
82
83 w.dedent();
85 w.raw_line("}")?;
86 Ok(())
87 }
88}
89
90fn emit_reg_declarations(registers: &[Register], w: &mut PtxWriter) -> fmt::Result {
96 let mut max_idx: [Option<u32>; 7] = [None; 7];
98 let mut decl_types: [&str; 7] = [""; 7];
99
100 for reg in registers {
101 let ci = reg.kind.counter_index();
102 match max_idx[ci] {
103 None => {
104 max_idx[ci] = Some(reg.index);
105 decl_types[ci] = reg.ptx_type.reg_decl_type();
106 }
107 Some(prev) if reg.index > prev => {
108 max_idx[ci] = Some(reg.index);
109 }
110 _ => {}
111 }
112 }
113
114 let prefixes = ["%r", "%rd", "%f", "%fd", "%p", "%h", "%hb"];
116 for i in 0..7 {
117 if let Some(max) = max_idx[i] {
118 let count = max + 1;
119 w.line(&format!(
120 ".reg {} {}<{}>;",
121 decl_types[i], prefixes[i], count
122 ))?;
123 }
124 }
125 Ok(())
126}
127
128impl Emit for PtxInstruction {
131 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
132 match self {
133 Self::Arith(op) => op.emit(w),
134 Self::Memory(op) => op.emit(w),
135 Self::Control(op) => op.emit(w),
136 Self::TensorCore(op) => op.emit(w),
137 Self::Mov { dst, src, ty } => {
138 let mnemonic = format!("mov{}", ty.ptx_suffix());
139 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
140 }
141 Self::Cvt {
142 dst,
143 src,
144 dst_ty,
145 src_ty,
146 } => {
147 let rounding = match (dst_ty, src_ty) {
152 (
154 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
155 PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
156 ) => ".rn",
157 (
159 PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
160 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
161 ) => ".rzi",
162 (
164 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
165 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
166 ) => ".rn",
167 _ => "",
169 };
170 let mnemonic = format!(
171 "cvt{rounding}{}{}",
172 dst_ty.ptx_suffix(),
173 src_ty.ptx_suffix()
174 );
175 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
176 }
177 Self::Label(name) => {
178 w.dedent();
181 w.raw_line(&format!("{name}:"))?;
182 w.indent();
183 Ok(())
184 }
185 Self::Comment(text) => w.line(&format!("// {text}")),
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::ir::{Operand, PtxParam, RegisterAllocator, SpecialReg};
194 use crate::types::{PtxType, RegKind};
195
196 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
197 Register {
198 kind,
199 index,
200 ptx_type,
201 }
202 }
203
204 #[test]
205 fn emit_mov_special_reg() {
206 let mut w = PtxWriter::new();
207 w.indent();
208 let instr = PtxInstruction::Mov {
209 dst: reg(RegKind::R, 0, PtxType::U32),
210 src: Operand::SpecialReg(SpecialReg::TidX),
211 ty: PtxType::U32,
212 };
213 instr.emit(&mut w).unwrap();
214 assert_eq!(w.finish(), " mov.u32 %r0, %tid.x;\n");
215 }
216
217 #[test]
218 fn emit_mov_reg_to_reg() {
219 let mut w = PtxWriter::new();
220 w.indent();
221 let instr = PtxInstruction::Mov {
222 dst: reg(RegKind::F, 1, PtxType::F32),
223 src: Operand::Reg(reg(RegKind::F, 0, PtxType::F32)),
224 ty: PtxType::F32,
225 };
226 instr.emit(&mut w).unwrap();
227 assert_eq!(w.finish(), " mov.f32 %f1, %f0;\n");
228 }
229
230 #[test]
231 fn emit_mov_shared_addr() {
232 let mut w = PtxWriter::new();
233 w.indent();
234 let instr = PtxInstruction::Mov {
235 dst: reg(RegKind::R, 0, PtxType::U32),
236 src: Operand::SharedAddr("sdata".to_string()),
237 ty: PtxType::U32,
238 };
239 instr.emit(&mut w).unwrap();
240 assert_eq!(w.finish(), " mov.u32 %r0, sdata;\n");
241 }
242
243 #[test]
244 fn emit_cvt() {
245 let mut w = PtxWriter::new();
246 w.indent();
247 let instr = PtxInstruction::Cvt {
248 dst: reg(RegKind::F, 0, PtxType::F32),
249 src: reg(RegKind::R, 0, PtxType::S32),
250 dst_ty: PtxType::F32,
251 src_ty: PtxType::S32,
252 };
253 instr.emit(&mut w).unwrap();
254 assert_eq!(w.finish(), " cvt.rn.f32.s32 %f0, %r0;\n");
255 }
256
257 #[test]
258 fn emit_cvt_float_to_int() {
259 let mut w = PtxWriter::new();
260 w.indent();
261 let instr = PtxInstruction::Cvt {
262 dst: reg(RegKind::R, 0, PtxType::U32),
263 src: reg(RegKind::F, 0, PtxType::F32),
264 dst_ty: PtxType::U32,
265 src_ty: PtxType::F32,
266 };
267 instr.emit(&mut w).unwrap();
268 assert_eq!(w.finish(), " cvt.rzi.u32.f32 %r0, %f0;\n");
269 }
270
271 #[test]
272 fn emit_cvt_int_to_int() {
273 let mut w = PtxWriter::new();
274 w.indent();
275 let instr = PtxInstruction::Cvt {
276 dst: reg(RegKind::R, 0, PtxType::S32),
277 src: reg(RegKind::R, 1, PtxType::U32),
278 dst_ty: PtxType::S32,
279 src_ty: PtxType::U32,
280 };
281 instr.emit(&mut w).unwrap();
282 assert_eq!(w.finish(), " cvt.s32.u32 %r0, %r1;\n");
284 }
285
286 #[test]
287 fn emit_cvt_f32_to_f16() {
288 let mut w = PtxWriter::new();
289 w.indent();
290 let instr = PtxInstruction::Cvt {
291 dst: reg(RegKind::H, 0, PtxType::F16),
292 src: reg(RegKind::F, 0, PtxType::F32),
293 dst_ty: PtxType::F16,
294 src_ty: PtxType::F32,
295 };
296 instr.emit(&mut w).unwrap();
297 assert_eq!(w.finish(), " cvt.rn.f16.f32 %h0, %f0;\n");
298 }
299
300 #[test]
301 fn emit_cvt_f16_to_f32() {
302 let mut w = PtxWriter::new();
303 w.indent();
304 let instr = PtxInstruction::Cvt {
305 dst: reg(RegKind::F, 0, PtxType::F32),
306 src: reg(RegKind::H, 0, PtxType::F16),
307 dst_ty: PtxType::F32,
308 src_ty: PtxType::F16,
309 };
310 instr.emit(&mut w).unwrap();
311 assert_eq!(w.finish(), " cvt.rn.f32.f16 %f0, %h0;\n");
312 }
313
314 #[test]
315 fn emit_cvt_int_to_f16() {
316 let mut w = PtxWriter::new();
317 w.indent();
318 let instr = PtxInstruction::Cvt {
319 dst: reg(RegKind::H, 0, PtxType::F16),
320 src: reg(RegKind::R, 0, PtxType::S32),
321 dst_ty: PtxType::F16,
322 src_ty: PtxType::S32,
323 };
324 instr.emit(&mut w).unwrap();
325 assert_eq!(w.finish(), " cvt.rn.f16.s32 %h0, %r0;\n");
326 }
327
328 #[test]
329 fn emit_cvt_f16_to_int() {
330 let mut w = PtxWriter::new();
331 w.indent();
332 let instr = PtxInstruction::Cvt {
333 dst: reg(RegKind::R, 0, PtxType::U32),
334 src: reg(RegKind::H, 0, PtxType::F16),
335 dst_ty: PtxType::U32,
336 src_ty: PtxType::F16,
337 };
338 instr.emit(&mut w).unwrap();
339 assert_eq!(w.finish(), " cvt.rzi.u32.f16 %r0, %h0;\n");
340 }
341
342 #[test]
343 fn emit_cvt_bf16_to_f32() {
344 let mut w = PtxWriter::new();
345 w.indent();
346 let instr = PtxInstruction::Cvt {
347 dst: reg(RegKind::F, 0, PtxType::F32),
348 src: reg(RegKind::Hb, 0, PtxType::BF16),
349 dst_ty: PtxType::F32,
350 src_ty: PtxType::BF16,
351 };
352 instr.emit(&mut w).unwrap();
353 assert_eq!(w.finish(), " cvt.rn.f32.bf16 %f0, %hb0;\n");
354 }
355
356 #[test]
357 fn emit_reg_declarations_with_f16() {
358 let regs = vec![
359 reg(RegKind::F, 0, PtxType::F32),
360 reg(RegKind::H, 0, PtxType::F16),
361 reg(RegKind::H, 1, PtxType::F16),
362 reg(RegKind::Hb, 0, PtxType::BF16),
363 ];
364 let mut w = PtxWriter::new();
365 w.indent();
366 emit_reg_declarations(®s, &mut w).unwrap();
367 let output = w.finish();
368 assert!(output.contains(".reg .f32 %f<1>;"));
369 assert!(output.contains(".reg .f16 %h<2>;"));
370 assert!(output.contains(".reg .bf16 %hb<1>;"));
371 }
372
373 #[test]
374 fn emit_label_at_column_zero() {
375 let mut w = PtxWriter::new();
376 w.indent(); let instr = PtxInstruction::Label("EXIT".to_string());
378 instr.emit(&mut w).unwrap();
379 assert_eq!(w.finish(), "EXIT:\n");
381 }
382
383 #[test]
384 fn emit_comment() {
385 let mut w = PtxWriter::new();
386 w.indent();
387 let instr = PtxInstruction::Comment("bounds check".to_string());
388 instr.emit(&mut w).unwrap();
389 assert_eq!(w.finish(), " // bounds check\n");
390 }
391
392 #[test]
397 fn emit_kernel_f16_flow() {
398 use crate::instr::{ArithOp, MemoryOp};
399
400 let mut alloc = RegisterAllocator::new();
401 let rd_in = alloc.alloc(PtxType::U64); let rd_out = alloc.alloc(PtxType::U64); let r_tid = alloc.alloc(PtxType::U32); let rd_off = alloc.alloc(PtxType::U64); let rd_addr_in = alloc.alloc(PtxType::U64); let rd_addr_out = alloc.alloc(PtxType::U64); let h_val = alloc.alloc(PtxType::F16); let f_val = alloc.alloc(PtxType::F32); let f_one = alloc.alloc(PtxType::F32); let f_sum = alloc.alloc(PtxType::F32); let h_out = alloc.alloc(PtxType::F16); let mut kernel = PtxKernel::new("f16_add_one");
415 kernel.add_param(PtxParam::pointer("in_ptr", PtxType::F16));
416 kernel.add_param(PtxParam::pointer("out_ptr", PtxType::F16));
417
418 kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
420 dst: rd_in,
421 param_name: "in_ptr".to_string(),
422 ty: PtxType::U64,
423 }));
424 kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
425 dst: rd_out,
426 param_name: "out_ptr".to_string(),
427 ty: PtxType::U64,
428 }));
429 kernel.push(PtxInstruction::Mov {
431 dst: r_tid,
432 src: Operand::SpecialReg(SpecialReg::TidX),
433 ty: PtxType::U32,
434 });
435 kernel.push(PtxInstruction::Cvt {
437 dst: rd_off,
438 src: r_tid,
439 dst_ty: PtxType::U64,
440 src_ty: PtxType::U32,
441 });
442 kernel.push(PtxInstruction::Arith(ArithOp::Add {
444 dst: rd_addr_in,
445 lhs: Operand::Reg(rd_in),
446 rhs: Operand::Reg(rd_off),
447 ty: PtxType::U64,
448 }));
449 kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
451 dst: h_val,
452 addr: rd_addr_in,
453 ty: PtxType::F16,
454 }));
455 kernel.push(PtxInstruction::Cvt {
457 dst: f_val,
458 src: h_val,
459 dst_ty: PtxType::F32,
460 src_ty: PtxType::F16,
461 });
462 kernel.push(PtxInstruction::Mov {
464 dst: f_one,
465 src: Operand::ImmF32(1.0),
466 ty: PtxType::F32,
467 });
468 kernel.push(PtxInstruction::Arith(ArithOp::Add {
469 dst: f_sum,
470 lhs: Operand::Reg(f_val),
471 rhs: Operand::Reg(f_one),
472 ty: PtxType::F32,
473 }));
474 kernel.push(PtxInstruction::Cvt {
476 dst: h_out,
477 src: f_sum,
478 dst_ty: PtxType::F16,
479 src_ty: PtxType::F32,
480 });
481 kernel.push(PtxInstruction::Arith(ArithOp::Add {
483 dst: rd_addr_out,
484 lhs: Operand::Reg(rd_out),
485 rhs: Operand::Reg(rd_off),
486 ty: PtxType::U64,
487 }));
488 kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
490 addr: rd_addr_out,
491 src: h_out,
492 ty: PtxType::F16,
493 }));
494 kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
495 kernel.set_registers(alloc.into_allocated());
496
497 let mut w = PtxWriter::new();
498 kernel.emit(&mut w).unwrap();
499 let output = w.finish();
500
501 assert!(output.contains(".param .u64 in_ptr"));
503 assert!(output.contains(".param .u64 out_ptr"));
504 assert!(output.contains(".reg .f16 %h<2>;"), "f16 reg declarations");
505 assert!(output.contains(".reg .f32 %f<3>;"), "f32 reg declarations");
506 assert!(output.contains("ld.global.b16 %h0"));
511 assert!(output.contains("cvt.rn.f32.f16 %f0, %h0"));
512 assert!(output.contains("cvt.rn.f16.f32 %h1, %f2"));
513 assert!(output.contains("st.global.b16 [%rd4], %h1"));
514 }
515
516 #[test]
517 fn emit_module_header() {
518 let module = PtxModule::new("sm_70");
519 let mut w = PtxWriter::new();
520 module.emit(&mut w).unwrap();
522 assert_eq!(
523 w.finish(),
524 ".version 8.7\n.target sm_70\n.address_size 64\n"
525 );
526 }
527
528 #[test]
529 fn emit_kernel_minimal() {
530 let mut alloc = RegisterAllocator::new();
531 let r0 = alloc.alloc(PtxType::U32);
532
533 let mut kernel = PtxKernel::new("test_kernel");
534 kernel.add_param(PtxParam::scalar("n", PtxType::U32));
535 kernel.push(PtxInstruction::Mov {
536 dst: r0,
537 src: Operand::ImmU32(42),
538 ty: PtxType::U32,
539 });
540 kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
541 kernel.set_registers(alloc.into_allocated());
542
543 let mut w = PtxWriter::new();
544 kernel.emit(&mut w).unwrap();
545 let output = w.finish();
546
547 assert!(output.contains(".visible .entry test_kernel("));
549 assert!(output.contains(".param .u32 n"));
550 assert!(output.contains(".reg .b32 %r<1>;"));
551 assert!(output.contains("mov.u32 %r0, 42;"));
552 assert!(output.contains("ret;"));
553 assert!(output.starts_with(".visible .entry"));
554 assert!(output.trim_end().ends_with('}'));
555 }
556}