1use std::fmt;
4
5use super::instruction::PtxInstruction;
6use super::kernel::PtxKernel;
7use crate::instr::MemoryOp;
8
9#[derive(Debug, Clone)]
14pub struct PtxModule {
15 pub version: String,
17 pub target: String,
19 pub address_size: u32,
21 pub kernels: Vec<PtxKernel>,
23}
24
25impl PtxModule {
26 pub fn new(target: &str) -> Self {
30 Self {
31 version: "8.7".to_string(),
32 target: target.to_string(),
33 address_size: 64,
34 kernels: Vec::new(),
35 }
36 }
37
38 pub fn add_kernel(&mut self, kernel: PtxKernel) {
40 self.kernels.push(kernel);
41 }
42
43 fn parse_sm_target(&self) -> Option<u32> {
52 self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
53 }
54
55 pub fn validate(&self) -> Result<(), ValidationError> {
67 let Some(target_sm) = self.parse_sm_target() else {
68 return Ok(());
70 };
71
72 for kernel in &self.kernels {
73 for instr in &kernel.body {
74 if let Some((required, feature)) = instruction_sm_requirement(instr)
75 && target_sm < required
76 {
77 return Err(ValidationError::SmTooLow {
78 required,
79 actual: target_sm,
80 feature,
81 });
82 }
83 }
84 }
85 Ok(())
86 }
87}
88
89fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
92 match instr {
93 PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
94 PtxInstruction::Memory(
95 MemoryOp::CpAsyncCaSharedGlobal { .. }
96 | MemoryOp::CpAsyncCommitGroup
97 | MemoryOp::CpAsyncWaitGroup { .. },
98 ) => Some((80, "cp.async".to_string())),
99 _ => None,
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum ValidationError {
109 SmTooLow {
116 required: u32,
118 actual: u32,
120 feature: String,
122 },
123}
124
125impl fmt::Display for ValidationError {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 Self::SmTooLow {
129 required,
130 actual,
131 feature,
132 } => {
133 write!(
134 f,
135 "{feature} requires sm_{required}+, target is sm_{actual}"
136 )
137 }
138 }
139 }
140}
141
142impl std::error::Error for ValidationError {}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::fragment::{alloc_a, alloc_b, alloc_c};
148 use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
149 use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
150 use crate::types::{PtxType, RegKind};
151
152 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
153 Register {
154 kind,
155 index,
156 ptx_type,
157 }
158 }
159
160 fn tc_kernel() -> PtxKernel {
161 let mut alloc = RegisterAllocator::new();
162 let mut k = PtxKernel::new("has_mma");
163 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
164 d: alloc_c(&mut alloc),
165 a: alloc_a(&mut alloc),
166 b: alloc_b(&mut alloc),
167 c: alloc_c(&mut alloc),
168 shape: MmaShape::M16N8K16,
169 d_ty: PtxType::F32,
170 a_ty: PtxType::F16,
171 b_ty: PtxType::F16,
172 c_ty: PtxType::F32,
173 }));
174 k
175 }
176
177 #[test]
178 fn validate_rejects_mma_on_sm_70() {
179 let mut module = PtxModule::new("sm_70");
180 module.add_kernel(tc_kernel());
181 let err = module.validate().unwrap_err();
182 assert_eq!(
183 err,
184 ValidationError::SmTooLow {
185 required: 80,
186 actual: 70,
187 feature: "mma.sync.m16n8k16".to_string(),
188 }
189 );
190 assert_eq!(
191 err.to_string(),
192 "mma.sync.m16n8k16 requires sm_80+, target is sm_70"
193 );
194 }
195
196 #[test]
197 fn validate_accepts_mma_on_sm_80() {
198 let mut module = PtxModule::new("sm_80");
199 module.add_kernel(tc_kernel());
200 assert!(module.validate().is_ok());
201 }
202
203 #[test]
204 fn validate_accepts_mma_on_sm_89() {
205 let mut module = PtxModule::new("sm_89");
206 module.add_kernel(tc_kernel());
207 assert!(module.validate().is_ok());
208 }
209
210 #[test]
211 fn validate_rejects_cp_async_on_sm_75() {
212 let mut module = PtxModule::new("sm_75");
213 let mut k = PtxKernel::new("has_cp_async");
214 k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
215 reg(RegKind::R, 0, PtxType::U32),
216 reg(RegKind::Rd, 0, PtxType::U64),
217 16,
218 )));
219 module.add_kernel(k);
220 let err = module.validate().unwrap_err();
221 assert_eq!(
222 err,
223 ValidationError::SmTooLow {
224 required: 80,
225 actual: 75,
226 feature: "cp.async".to_string(),
227 }
228 );
229 }
230
231 #[test]
232 fn validate_accepts_scalar_kernel_on_sm_70() {
233 let mut module = PtxModule::new("sm_70");
236 let k = PtxKernel::new("scalar_only");
237 module.add_kernel(k);
238 assert!(module.validate().is_ok());
239 }
240
241 #[test]
242 fn validate_skips_unparseable_target() {
243 let mut module = PtxModule::new("compute_90a");
245 module.add_kernel(tc_kernel());
246 assert!(module.validate().is_ok());
247 }
248
249 #[test]
250 fn parse_sm_target() {
251 let m = PtxModule::new("sm_89");
252 assert_eq!(m.parse_sm_target(), Some(89));
253 let m2 = PtxModule::new("sm_80");
254 assert_eq!(m2.parse_sm_target(), Some(80));
255 let m3 = PtxModule::new("compute_90a");
256 assert_eq!(m3.parse_sm_target(), None);
257 }
258}