1use xpile_backend::{
20 Artifact, Backend, BackendConfig, BackendError, EmittedText, HwProfile, MultiEmitterBackend,
21 QuorumPolicy, Target, TargetEmitter,
22};
23use xpile_contracts::ContractId;
24use xpile_meta_hir::Module;
25
26pub struct PtxBackend {
30 inner: MultiEmitterBackend,
31}
32
33impl Default for PtxBackend {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl PtxBackend {
40 pub fn new() -> Self {
41 Self {
42 inner: MultiEmitterBackend::new_single(Target::Ptx, Box::new(ScaffoldPtxEmitter)),
43 }
44 }
45
46 pub fn new_with_matmul_specialist() -> Self {
66 Self {
67 inner: MultiEmitterBackend::new_with_specialist(
68 Target::Ptx,
69 Box::new(ScaffoldPtxEmitter),
70 Box::new(MatmulSpecialistEmitter),
71 QuorumPolicy::PreferSpecialist,
72 ),
73 }
74 }
75}
76
77impl Backend for PtxBackend {
78 fn name(&self) -> &'static str {
79 "ptx"
80 }
81
82 fn targets(&self) -> &[Target] {
83 &[Target::Ptx]
84 }
85
86 fn lower(&self, module: &Module, config: &BackendConfig) -> Result<Artifact, BackendError> {
87 match &config.hardware {
91 Some(HwProfile::Ptx { .. }) => {}
92 _ => return Err(BackendError::MissingHardware(Target::Ptx)),
93 }
94 self.inner.lower(module, config)
95 }
96}
97
98struct ScaffoldPtxEmitter;
102
103impl TargetEmitter for ScaffoldPtxEmitter {
104 fn name(&self) -> &str {
105 "xpile-ptx-codegen-scaffold"
106 }
107
108 fn try_emit(
109 &self,
110 module: &Module,
111 config: &BackendConfig,
112 ) -> Option<Result<EmittedText, BackendError>> {
113 let compute_capability = match &config.hardware {
114 Some(HwProfile::Ptx { compute_capability }) => compute_capability,
115 _ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
116 };
117 Some(Ok(EmittedText {
118 primary: format!(
119 "// xpile-ptx-codegen scaffold\n// module: {}\n// compute_capability: {}\n// TODO: lower to real PTX via rustc_codegen_nvvm.\n",
120 module.name, compute_capability,
121 ),
122 citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
123 }))
124 }
125}
126
127struct MatmulSpecialistEmitter;
143
144impl TargetEmitter for MatmulSpecialistEmitter {
145 fn name(&self) -> &str {
146 "matmul-specialist-mock"
147 }
148
149 fn try_emit(
150 &self,
151 module: &Module,
152 config: &BackendConfig,
153 ) -> Option<Result<EmittedText, BackendError>> {
154 if !module.name.starts_with("matmul_") {
155 return None;
156 }
157 let compute_capability = match &config.hardware {
158 Some(HwProfile::Ptx { compute_capability }) => compute_capability,
159 _ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
160 };
161 Some(Ok(EmittedText {
162 primary: format!(
163 "// matmul-specialist scaffold\n// module: {}\n// compute_capability: {}\n// TODO: emit mma.sync.aligned via aprender-gpu shape templates.\n",
164 module.name, compute_capability,
165 ),
166 citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
167 }))
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use xpile_backend::{Profile, QuorumStatus};
175 use xpile_meta_hir::SourceLang;
176
177 fn dummy_module() -> Module {
178 Module {
179 name: "test_kernel".into(),
180 source_lang: SourceLang::Rust,
181 items: Vec::new(),
182 ffi_boundaries: Vec::new(),
183 }
184 }
185
186 fn ptx_config(sm: &str) -> BackendConfig {
187 BackendConfig {
188 target: Target::Ptx,
189 profile: Profile::RustOut,
190 hardware: Some(HwProfile::Ptx {
191 compute_capability: sm.to_string(),
192 }),
193 }
194 }
195
196 #[test]
197 fn ptx_backend_emits_through_multi_emitter() {
198 let backend = PtxBackend::new();
199 let artifact = backend
200 .lower(&dummy_module(), &ptx_config("sm_80"))
201 .unwrap();
202 assert_eq!(
205 artifact.quorum_status,
206 QuorumStatus::Single {
207 emitter: "xpile-ptx-codegen-scaffold".to_string()
208 }
209 );
210 assert!(artifact.primary.contains("sm_80"));
211 assert!(artifact
212 .citations
213 .iter()
214 .any(|c| c.as_str() == "C-COMPILE-RUST-TO-PTX-MMA"));
215 }
216
217 #[test]
218 fn ptx_backend_rejects_missing_hardware() {
219 let backend = PtxBackend::new();
220 let cfg = BackendConfig {
221 target: Target::Ptx,
222 profile: Profile::RustOut,
223 hardware: None,
224 };
225 let err = backend.lower(&dummy_module(), &cfg).unwrap_err();
226 assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
227 }
228
229 #[test]
230 fn ptx_backend_targets_only_ptx() {
231 let backend = PtxBackend::new();
232 assert_eq!(backend.targets(), &[Target::Ptx]);
233 assert_eq!(backend.name(), "ptx");
234 }
235
236 fn matmul_module() -> Module {
239 Module {
240 name: "matmul_gemm_fp16".into(),
241 source_lang: SourceLang::Rust,
242 items: Vec::new(),
243 ffi_boundaries: Vec::new(),
244 }
245 }
246
247 #[test]
252 fn matmul_module_routes_through_specialist_under_multi_emitter() {
253 let backend = PtxBackend::new_with_matmul_specialist();
254 let artifact = backend
255 .lower(&matmul_module(), &ptx_config("sm_80"))
256 .unwrap();
257 assert_eq!(
258 artifact.quorum_status,
259 QuorumStatus::Single {
260 emitter: "matmul-specialist-mock".to_string()
261 },
262 "PreferSpecialist with matching specialist should report Single {{ specialist }}"
263 );
264 assert!(
265 artifact.primary.contains("matmul-specialist"),
266 "primary should carry the specialist's emission body, got:\n{}",
267 artifact.primary,
268 );
269 }
270
271 #[test]
276 fn non_matmul_module_falls_back_to_general_under_multi_emitter() {
277 let backend = PtxBackend::new_with_matmul_specialist();
278 let artifact = backend
279 .lower(&dummy_module(), &ptx_config("sm_80"))
280 .unwrap();
281 assert_eq!(
282 artifact.quorum_status,
283 QuorumStatus::Single {
284 emitter: "xpile-ptx-codegen-scaffold".to_string()
285 },
286 "non-matching specialist should let general emit; QuorumStatus should reflect general"
287 );
288 assert!(
289 artifact.primary.contains("xpile-ptx-codegen scaffold"),
290 "primary should carry the general scaffold's emission body, got:\n{}",
291 artifact.primary,
292 );
293 }
294
295 #[test]
299 fn multi_emitter_constructor_targets_match_single_emitter() {
300 let multi = PtxBackend::new_with_matmul_specialist();
301 let single = PtxBackend::new();
302 assert_eq!(multi.targets(), single.targets());
303 assert_eq!(multi.name(), single.name());
304 }
305
306 #[test]
310 fn multi_emitter_constructor_rejects_missing_hardware() {
311 let backend = PtxBackend::new_with_matmul_specialist();
312 let cfg = BackendConfig {
313 target: Target::Ptx,
314 profile: Profile::RustOut,
315 hardware: None,
316 };
317 let err = backend.lower(&matmul_module(), &cfg).unwrap_err();
318 assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
319 }
320}