1use crate::arch::SmVersion;
10use crate::error::PtxGenError;
11use crate::ir::{
12 AtomOp, FenceScope, GridDepAction, Instruction, MemorySpace, Operand, PtxType, ReduxOp,
13 Register, RoundingMode, SetmaxnregAction, StmatrixShape,
14};
15
16use super::BodyBuilder;
17
18impl BodyBuilder<'_> {
19 pub fn atom_global_add_f32(&mut self, addr: Register, val: Register) -> Register {
25 self.atom_typed(MemorySpace::Global, AtomOp::Add, PtxType::F32, addr, val)
26 }
27
28 pub fn atom_global_add_u32(&mut self, addr: Register, val: Register) -> Register {
30 self.atom_typed(MemorySpace::Global, AtomOp::Add, PtxType::U32, addr, val)
31 }
32
33 pub fn atom_global_add_u64(&mut self, addr: Register, val: Register) -> Register {
35 self.atom_typed(MemorySpace::Global, AtomOp::Add, PtxType::U64, addr, val)
36 }
37
38 pub fn atom_global_add_f64(&mut self, addr: Register, val: Register) -> Register {
40 self.atom_typed(MemorySpace::Global, AtomOp::Add, PtxType::F64, addr, val)
41 }
42
43 pub fn atom_global_cas_u32(
47 &mut self,
48 addr: Register,
49 compare: Register,
50 value: Register,
51 ) -> Register {
52 self.atom_cas_typed(MemorySpace::Global, PtxType::U32, addr, compare, value)
53 }
54
55 pub fn atom_global_cas_u64(
59 &mut self,
60 addr: Register,
61 compare: Register,
62 value: Register,
63 ) -> Register {
64 self.atom_cas_typed(MemorySpace::Global, PtxType::U64, addr, compare, value)
65 }
66
67 pub fn atom_global_exch_u32(&mut self, addr: Register, val: Register) -> Register {
69 self.atom_typed(MemorySpace::Global, AtomOp::Exch, PtxType::U32, addr, val)
70 }
71
72 pub fn atom_global_min_u32(&mut self, addr: Register, val: Register) -> Register {
74 self.atom_typed(MemorySpace::Global, AtomOp::Min, PtxType::U32, addr, val)
75 }
76
77 pub fn atom_global_max_u32(&mut self, addr: Register, val: Register) -> Register {
79 self.atom_typed(MemorySpace::Global, AtomOp::Max, PtxType::U32, addr, val)
80 }
81
82 pub fn atom_global_min_s32(&mut self, addr: Register, val: Register) -> Register {
84 self.atom_typed(MemorySpace::Global, AtomOp::Min, PtxType::S32, addr, val)
85 }
86
87 pub fn atom_global_max_s32(&mut self, addr: Register, val: Register) -> Register {
89 self.atom_typed(MemorySpace::Global, AtomOp::Max, PtxType::S32, addr, val)
90 }
91
92 pub fn atom_global_and_b32(&mut self, addr: Register, val: Register) -> Register {
94 self.atom_typed(MemorySpace::Global, AtomOp::And, PtxType::B32, addr, val)
95 }
96
97 pub fn atom_global_or_b32(&mut self, addr: Register, val: Register) -> Register {
99 self.atom_typed(MemorySpace::Global, AtomOp::Or, PtxType::B32, addr, val)
100 }
101
102 pub fn atom_global_xor_b32(&mut self, addr: Register, val: Register) -> Register {
104 self.atom_typed(MemorySpace::Global, AtomOp::Xor, PtxType::B32, addr, val)
105 }
106
107 pub fn atom_shared_add_f32(&mut self, addr: Register, val: Register) -> Register {
109 self.atom_typed(MemorySpace::Shared, AtomOp::Add, PtxType::F32, addr, val)
110 }
111
112 pub fn atom_shared_add_u32(&mut self, addr: Register, val: Register) -> Register {
114 self.atom_typed(MemorySpace::Shared, AtomOp::Add, PtxType::U32, addr, val)
115 }
116
117 pub fn red_global_add_f32(&mut self, addr: Register, val: Register) {
121 self.red_typed(MemorySpace::Global, AtomOp::Add, PtxType::F32, addr, val);
122 }
123
124 pub fn red_global_add_u32(&mut self, addr: Register, val: Register) {
126 self.red_typed(MemorySpace::Global, AtomOp::Add, PtxType::U32, addr, val);
127 }
128
129 fn atom_typed(
131 &mut self,
132 space: MemorySpace,
133 op: AtomOp,
134 ty: PtxType,
135 addr: Register,
136 src: Register,
137 ) -> Register {
138 let dst = self.regs.alloc(ty);
139 self.instructions.push(Instruction::Atom {
140 space,
141 op,
142 ty,
143 dst: dst.clone(),
144 addr: Operand::Register(addr),
145 src: Operand::Register(src),
146 });
147 dst
148 }
149
150 fn atom_cas_typed(
152 &mut self,
153 space: MemorySpace,
154 ty: PtxType,
155 addr: Register,
156 compare: Register,
157 value: Register,
158 ) -> Register {
159 let dst = self.regs.alloc(ty);
160 self.instructions.push(Instruction::AtomCas {
161 space,
162 ty,
163 dst: dst.clone(),
164 addr: Operand::Register(addr),
165 compare: Operand::Register(compare),
166 value: Operand::Register(value),
167 });
168 dst
169 }
170
171 fn red_typed(
173 &mut self,
174 space: MemorySpace,
175 op: AtomOp,
176 ty: PtxType,
177 addr: Register,
178 src: Register,
179 ) {
180 self.instructions.push(Instruction::Red {
181 space,
182 op,
183 ty,
184 addr: Operand::Register(addr),
185 src: Operand::Register(src),
186 });
187 }
188
189 pub fn tex_1d(&mut self, ty: PtxType, tex_ref: &str, coord: Operand) -> Register {
200 let dst = self.regs.alloc(ty);
201 self.instructions.push(Instruction::Tex1d {
202 ty,
203 dst: dst.clone(),
204 tex_ref: tex_ref.to_string(),
205 coord,
206 });
207 dst
208 }
209
210 pub fn tex_2d(
217 &mut self,
218 ty: PtxType,
219 tex_ref: &str,
220 coord_x: Operand,
221 coord_y: Operand,
222 ) -> Register {
223 let dst = self.regs.alloc(ty);
224 self.instructions.push(Instruction::Tex2d {
225 ty,
226 dst: dst.clone(),
227 tex_ref: tex_ref.to_string(),
228 coord_x,
229 coord_y,
230 });
231 dst
232 }
233
234 pub fn tex_3d(
241 &mut self,
242 ty: PtxType,
243 tex_ref: &str,
244 coord_x: Operand,
245 coord_y: Operand,
246 coord_z: Operand,
247 ) -> Register {
248 let dst = self.regs.alloc(ty);
249 self.instructions.push(Instruction::Tex3d {
250 ty,
251 dst: dst.clone(),
252 tex_ref: tex_ref.to_string(),
253 coord_x,
254 coord_y,
255 coord_z,
256 });
257 dst
258 }
259
260 pub fn surf_load(&mut self, ty: PtxType, surf_ref: &str, coord: Operand) -> Register {
267 let dst = self.regs.alloc(ty);
268 self.instructions.push(Instruction::SurfLoad {
269 ty,
270 dst: dst.clone(),
271 surf_ref: surf_ref.to_string(),
272 coord,
273 });
274 dst
275 }
276
277 pub fn surf_store(&mut self, ty: PtxType, surf_ref: &str, coord: Operand, src: Register) {
283 self.instructions.push(Instruction::SurfStore {
284 ty,
285 surf_ref: surf_ref.to_string(),
286 coord,
287 src,
288 });
289 }
290
291 pub fn redux_add_u32(&mut self, src: &str) -> Result<String, PtxGenError> {
297 self.redux_op(ReduxOp::Add, src)
298 }
299
300 pub fn redux_max_u32(&mut self, src: &str) -> Result<String, PtxGenError> {
302 self.redux_op(ReduxOp::Max, src)
303 }
304
305 pub fn redux_min_u32(&mut self, src: &str) -> Result<String, PtxGenError> {
307 self.redux_op(ReduxOp::Min, src)
308 }
309
310 fn redux_op(&mut self, op: ReduxOp, src: &str) -> Result<String, PtxGenError> {
311 if !self.target.capabilities().has_redux {
312 return Err(PtxGenError::GenerationFailed(format!(
313 "redux.sync requires SM >= 80, target is {}",
314 self.target
315 )));
316 }
317 let dst = self.regs.alloc(PtxType::U32);
318 let name = dst.name.clone();
319 self.instructions.push(Instruction::Redux {
320 op,
321 dst,
322 src: Operand::Register(Register {
323 name: src.to_string(),
324 ty: PtxType::U32,
325 }),
326 membership_mask: 0xFFFF_FFFF,
327 });
328 Ok(name)
329 }
330
331 pub fn stmatrix_m8n8x4(&mut self, addr: &str, src: &str) -> Result<(), PtxGenError> {
333 if !self.target.capabilities().has_stmatrix {
334 return Err(PtxGenError::GenerationFailed(format!(
335 "stmatrix requires SM >= 90, target is {}",
336 self.target
337 )));
338 }
339 self.instructions.push(Instruction::Stmatrix {
340 dst_addr: Operand::Register(Register {
341 name: addr.to_string(),
342 ty: PtxType::U32,
343 }),
344 src: Register {
345 name: src.to_string(),
346 ty: PtxType::B32,
347 },
348 shape: StmatrixShape::M8n8x4,
349 trans: false,
350 });
351 Ok(())
352 }
353
354 pub fn elect_sync(&mut self) -> Result<String, PtxGenError> {
356 if !self.target.capabilities().has_elect_one {
357 return Err(PtxGenError::GenerationFailed(format!(
358 "elect.sync requires SM >= 90, target is {}",
359 self.target
360 )));
361 }
362 let dst = self.regs.alloc(PtxType::Pred);
363 let name = dst.name.clone();
364 self.instructions.push(Instruction::ElectSync {
365 dst,
366 membership_mask: 0xFFFF_FFFF,
367 });
368 Ok(name)
369 }
370
371 pub fn setmaxnreg_inc(&mut self, count: u32) -> Result<(), PtxGenError> {
373 self.setmaxnreg_impl(count, SetmaxnregAction::Inc)
374 }
375
376 pub fn setmaxnreg_dec(&mut self, count: u32) -> Result<(), PtxGenError> {
378 self.setmaxnreg_impl(count, SetmaxnregAction::Dec)
379 }
380
381 fn setmaxnreg_impl(&mut self, count: u32, action: SetmaxnregAction) -> Result<(), PtxGenError> {
382 if !self.target.capabilities().has_setmaxnreg {
383 return Err(PtxGenError::GenerationFailed(format!(
384 "setmaxnreg requires SM >= 90, target is {}",
385 self.target
386 )));
387 }
388 self.instructions.push(Instruction::Setmaxnreg {
389 reg_count: count,
390 action,
391 });
392 Ok(())
393 }
394
395 pub fn griddepcontrol_launch_dependents(&mut self) -> Result<(), PtxGenError> {
401 if !self.target.capabilities().has_griddepcontrol {
402 return Err(PtxGenError::GenerationFailed(format!(
403 "griddepcontrol requires SM >= 90, target is {}",
404 self.target
405 )));
406 }
407 self.instructions.push(Instruction::Griddepcontrol {
408 action: GridDepAction::LaunchDependents,
409 });
410 Ok(())
411 }
412
413 pub fn griddepcontrol_wait(&mut self) -> Result<(), PtxGenError> {
415 if !self.target.capabilities().has_griddepcontrol {
416 return Err(PtxGenError::GenerationFailed(format!(
417 "griddepcontrol requires SM >= 90, target is {}",
418 self.target
419 )));
420 }
421 self.instructions.push(Instruction::Griddepcontrol {
422 action: GridDepAction::Wait,
423 });
424 Ok(())
425 }
426
427 pub fn fence_proxy_async(&mut self, scope: &str) -> Result<(), PtxGenError> {
429 let fence_scope = match scope {
430 "cta" => FenceScope::Cta,
431 "gpu" => FenceScope::Gpu,
432 "sys" => FenceScope::Sys,
433 other => {
434 return Err(PtxGenError::GenerationFailed(format!(
435 "unknown fence scope: {other}"
436 )));
437 }
438 };
439 self.instructions.push(Instruction::FenceProxy {
440 scope: fence_scope,
441 space: MemorySpace::Shared,
442 });
443 Ok(())
444 }
445
446 pub fn mbarrier_init(&mut self, addr: &str, count: &str) -> Result<(), PtxGenError> {
448 if !self.target.capabilities().has_cluster_barriers {
449 return Err(PtxGenError::GenerationFailed(format!(
450 "mbarrier requires SM >= 90, target is {}",
451 self.target
452 )));
453 }
454 self.instructions.push(Instruction::MbarrierInit {
455 addr: Operand::Register(Register {
456 name: addr.to_string(),
457 ty: PtxType::U64,
458 }),
459 count: Operand::Register(Register {
460 name: count.to_string(),
461 ty: PtxType::U32,
462 }),
463 });
464 Ok(())
465 }
466
467 pub fn mbarrier_arrive(&mut self, addr: &str) -> Result<(), PtxGenError> {
469 if !self.target.capabilities().has_cluster_barriers {
470 return Err(PtxGenError::GenerationFailed(format!(
471 "mbarrier requires SM >= 90, target is {}",
472 self.target
473 )));
474 }
475 self.instructions.push(Instruction::MbarrierArrive {
476 addr: Operand::Register(Register {
477 name: addr.to_string(),
478 ty: PtxType::U64,
479 }),
480 });
481 Ok(())
482 }
483
484 pub fn mbarrier_wait(&mut self, addr: &str, phase: &str) -> Result<(), PtxGenError> {
486 if !self.target.capabilities().has_cluster_barriers {
487 return Err(PtxGenError::GenerationFailed(format!(
488 "mbarrier requires SM >= 90, target is {}",
489 self.target
490 )));
491 }
492 self.instructions.push(Instruction::MbarrierWait {
493 addr: Operand::Register(Register {
494 name: addr.to_string(),
495 ty: PtxType::U64,
496 }),
497 phase: Operand::Register(Register {
498 name: phase.to_string(),
499 ty: PtxType::U32,
500 }),
501 });
502 Ok(())
503 }
504
505 pub fn cvt_f32_to_e2m1(&mut self, src: Register) -> Result<Register, PtxGenError> {
514 if self.target < SmVersion::Sm100 {
515 return Err(PtxGenError::GenerationFailed(format!(
516 "cvt.e2m1 requires SM >= 100 (Blackwell), target is {}",
517 self.target
518 )));
519 }
520 let dst = self.regs.alloc(PtxType::E2M1);
521 self.instructions.push(Instruction::Cvt {
522 rnd: Some(RoundingMode::Rn),
523 dst_ty: PtxType::E2M1,
524 src_ty: PtxType::F32,
525 dst: dst.clone(),
526 src: Operand::Register(src),
527 });
528 Ok(dst)
529 }
530
531 pub fn cvt_e2m1_to_f32(&mut self, src: Register) -> Result<Register, PtxGenError> {
535 if self.target < SmVersion::Sm100 {
536 return Err(PtxGenError::GenerationFailed(format!(
537 "cvt.f32.e2m1 requires SM >= 100 (Blackwell), target is {}",
538 self.target
539 )));
540 }
541 let dst = self.regs.alloc(PtxType::F32);
542 self.instructions.push(Instruction::Cvt {
543 rnd: None,
544 dst_ty: PtxType::F32,
545 src_ty: PtxType::E2M1,
546 dst: dst.clone(),
547 src: Operand::Register(src),
548 });
549 Ok(dst)
550 }
551
552 pub fn tcgen05_mma_m128n256k256_e2m1(
561 &mut self,
562 a_desc: Register,
563 b_desc: Register,
564 ) -> Result<(), PtxGenError> {
565 if self.target < SmVersion::Sm100 {
566 return Err(PtxGenError::GenerationFailed(format!(
567 "tcgen05.mma requires SM >= 100 (Blackwell), target is {}",
568 self.target
569 )));
570 }
571 self.instructions
572 .push(Instruction::Tcgen05Mma { a_desc, b_desc });
573 Ok(())
574 }
575
576 pub fn barrier_cluster(&mut self) -> Result<(), PtxGenError> {
585 if !self.target.capabilities().has_cluster_barriers {
586 return Err(PtxGenError::GenerationFailed(format!(
587 "barrier.cluster requires SM >= 90, target is {}",
588 self.target
589 )));
590 }
591 self.instructions.push(Instruction::BarrierCluster);
592 Ok(())
593 }
594
595 pub fn fence_cluster(&mut self) -> Result<(), PtxGenError> {
600 if !self.target.capabilities().has_cluster_barriers {
601 return Err(PtxGenError::GenerationFailed(format!(
602 "fence.cluster requires SM >= 90, target is {}",
603 self.target
604 )));
605 }
606 self.instructions.push(Instruction::FenceCluster);
607 Ok(())
608 }
609
610 pub fn cp_async_bulk_tensor_1d(
626 &mut self,
627 dst_smem: Register,
628 src_gmem: Register,
629 desc: Register,
630 ) -> Result<(), PtxGenError> {
631 if !self.target.capabilities().has_bulk_copy {
632 return Err(PtxGenError::GenerationFailed(format!(
633 "cp.async.bulk.tensor requires SM >= 90, target is {}",
634 self.target
635 )));
636 }
637 self.instructions.push(Instruction::CpAsyncBulk {
638 dst_smem,
639 src_gmem,
640 desc,
641 });
642 Ok(())
643 }
644}