Skip to main content

xpile_ptx_codegen/
lib.rs

1//! PTX backend.
2//!
3//! Lowers Rust meta-HIR (functions annotated `#[gpu_kernel(...)]`) to
4//! NVIDIA PTX text targeting `sm_80`+. Layer 5 compile contract:
5//! `contracts/compile-rust-to-ptx-mma-v1.yaml`.
6//!
7//! **Architecture (PMAT-264 / Section 29):** [`PtxBackend`] wraps a
8//! [`MultiEmitterBackend`] so emission routes through the same
9//! general/specialist quorum framework that will eventually carry
10//! `rustc_codegen_nvvm` (general) + `aprender-gpu` (specialist). At
11//! v0.1.0 the wrapper holds a single [`ScaffoldPtxEmitter`] in the
12//! general slot — the same code path real emitters will plug into.
13//!
14//! When `rustc_codegen_nvvm` lights up (next phase per
15//! `sub/layer5-multi-emitter-quorum.md`), it slots into the `general`
16//! position; when `aprender-gpu` ships its bridge, it slots into the
17//! `specialist` position; no changes to [`PtxBackend`]'s public API.
18
19use xpile_backend::{
20    Artifact, Backend, BackendConfig, BackendError, EmittedText, HwProfile, MultiEmitterBackend,
21    QuorumPolicy, Target, TargetEmitter,
22};
23use xpile_contracts::ContractId;
24use xpile_meta_hir::Module;
25
26/// PTX backend — `Backend` impl wrapping a [`MultiEmitterBackend`] so
27/// the v0.1.0 scaffold drives through the same routing the future
28/// `rustc_codegen_nvvm` + `aprender-gpu` quorum will use.
29pub struct PtxBackend {
30    inner: MultiEmitterBackend,
31}
32
33impl Default for PtxBackend {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl PtxBackend {
40    pub fn new() -> Self {
41        Self {
42            inner: MultiEmitterBackend::new_single(Target::Ptx, Box::new(ScaffoldPtxEmitter)),
43        }
44    }
45
46    /// PMAT-280 — End-to-end validation constructor for Section 29's
47    /// multi-emitter routing.
48    ///
49    /// Builds a `PtxBackend` whose `MultiEmitterBackend` carries the
50    /// `ScaffoldPtxEmitter` in the `general` slot AND a
51    /// [`MatmulSpecialistEmitter`] in the `specialist` slot under
52    /// `QuorumPolicy::PreferSpecialist`. The specialist matches only
53    /// modules whose name starts with `matmul_` — the shape filter
54    /// real specialists like `aprender-gpu` would use to claim
55    /// GEMM-shaped kernels.
56    ///
57    /// This isn't registered in `default_session()` — production at
58    /// v0.1.0+ still uses [`PtxBackend::new`]. The constructor exists
59    /// so tests + future integrations can exercise the
60    /// `MultiEmitterBackend::new_with_specialist` path against real
61    /// production code (not just mock tests). It's the smallest
62    /// concrete proof that the §29 routing layer is end-to-end
63    /// usable, ahead of the heavy `rustc_codegen_nvvm` / `aprender-gpu`
64    /// integrations that will eventually replace these placeholders.
65    pub fn new_with_matmul_specialist() -> Self {
66        Self {
67            inner: MultiEmitterBackend::new_with_specialist(
68                Target::Ptx,
69                Box::new(ScaffoldPtxEmitter),
70                Box::new(MatmulSpecialistEmitter),
71                QuorumPolicy::PreferSpecialist,
72            ),
73        }
74    }
75}
76
77impl Backend for PtxBackend {
78    fn name(&self) -> &'static str {
79        "ptx"
80    }
81
82    fn targets(&self) -> &[Target] {
83        &[Target::Ptx]
84    }
85
86    fn lower(&self, module: &Module, config: &BackendConfig) -> Result<Artifact, BackendError> {
87        // Reject inputs without an HwProfile::Ptx eagerly — the
88        // scaffold emitter can't synthesize a compute_capability and
89        // the contract requires one.
90        match &config.hardware {
91            Some(HwProfile::Ptx { .. }) => {}
92            _ => return Err(BackendError::MissingHardware(Target::Ptx)),
93        }
94        self.inner.lower(module, config)
95    }
96}
97
98/// Scaffold emitter — produces the placeholder PTX text current users
99/// see at v0.1.0. Will be replaced by `rustc_codegen_nvvm` integration
100/// in the next Section 29 phase.
101struct ScaffoldPtxEmitter;
102
103impl TargetEmitter for ScaffoldPtxEmitter {
104    fn name(&self) -> &str {
105        "xpile-ptx-codegen-scaffold"
106    }
107
108    fn try_emit(
109        &self,
110        module: &Module,
111        config: &BackendConfig,
112    ) -> Option<Result<EmittedText, BackendError>> {
113        let compute_capability = match &config.hardware {
114            Some(HwProfile::Ptx { compute_capability }) => compute_capability,
115            _ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
116        };
117        Some(Ok(EmittedText {
118            primary: format!(
119                "// xpile-ptx-codegen scaffold\n// module: {}\n// compute_capability: {}\n// TODO: lower to real PTX via rustc_codegen_nvvm.\n",
120                module.name, compute_capability,
121            ),
122            citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
123        }))
124    }
125}
126
127/// PMAT-280 — Mock GEMM specialist emitter.
128///
129/// Matches modules whose name starts with `matmul_` — the shape
130/// filter real specialists like `aprender-gpu` would use to claim
131/// the GEMM/MMA kernel domain. Returns `None` from `try_emit` for
132/// non-matching modules, letting the general emitter handle them.
133/// For matching modules, emits a distinct PTX text (different from
134/// the scaffold) so the `QuorumStatus::Multi` path is exercised
135/// under non-trivial divergence.
136///
137/// This is intentionally not a real GEMM emitter — its job is to
138/// prove that the `MultiEmitterBackend::new_with_specialist` routing
139/// layer composes correctly with the existing `PtxBackend`. The
140/// future `aprender-gpu` integration plugs in via the same trait
141/// without touching `PtxBackend`'s public API.
142struct MatmulSpecialistEmitter;
143
144impl TargetEmitter for MatmulSpecialistEmitter {
145    fn name(&self) -> &str {
146        "matmul-specialist-mock"
147    }
148
149    fn try_emit(
150        &self,
151        module: &Module,
152        config: &BackendConfig,
153    ) -> Option<Result<EmittedText, BackendError>> {
154        if !module.name.starts_with("matmul_") {
155            return None;
156        }
157        let compute_capability = match &config.hardware {
158            Some(HwProfile::Ptx { compute_capability }) => compute_capability,
159            _ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
160        };
161        Some(Ok(EmittedText {
162            primary: format!(
163                "// matmul-specialist scaffold\n// module: {}\n// compute_capability: {}\n// TODO: emit mma.sync.aligned via aprender-gpu shape templates.\n",
164                module.name, compute_capability,
165            ),
166            citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
167        }))
168    }
169}
170
171// ─── PMAT-481: offline PTX well-formedness gate (§30 Track 4) ────────
172//
173// A *structural* check on emitted PTX text — it does NOT execute
174// anything and is not the model→emission gate (that is the `DiffExec`
175// slice, PMAT-488). It exists so that the moment a real emitter lands
176// (PMAT-485, the `nvptx64` path) its output is gated for well-formedness
177// on FREE CI, and the `ptxas`-assembles step (wired with that emitter)
178// derives its `-arch` from the same `compute_capability` checked here —
179// never a hard-coded `sm_80`. Callers gate on [`ptx_looks_real`] so the
180// v0.1.0 scaffold comment placeholder is never treated as real emission.
181
182/// Reasons emitted PTX text fails the [`validate_ptx`] well-formedness gate.
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub enum PtxValidationError {
185    /// No `.version` directive — not PTX at all (e.g. the scaffold placeholder).
186    MissingVersion,
187    /// No `.target` directive.
188    MissingTarget,
189    /// `.target` arch does not match the requested compute capability.
190    TargetMismatch { expected: String, found: String },
191    /// No `.address_size 64` directive.
192    MissingAddressSize,
193    /// No `.visible .entry` kernel entry point.
194    MissingEntry,
195}
196
197impl std::fmt::Display for PtxValidationError {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        match self {
200            Self::MissingVersion => write!(f, "PTX is missing a `.version` directive"),
201            Self::MissingTarget => write!(f, "PTX is missing a `.target` directive"),
202            Self::TargetMismatch { expected, found } => write!(
203                f,
204                "PTX `.target {found}` does not match requested compute capability `{expected}`"
205            ),
206            Self::MissingAddressSize => write!(f, "PTX is missing `.address_size 64`"),
207            Self::MissingEntry => write!(f, "PTX has no `.visible .entry` kernel"),
208        }
209    }
210}
211
212impl std::error::Error for PtxValidationError {}
213
214/// `true` when `text` looks like real PTX (carries a `.version`
215/// directive) rather than the v0.1.0 scaffold comment placeholder.
216pub fn ptx_looks_real(text: &str) -> bool {
217    directive_present(text, ".version")
218}
219
220/// The `ptxas -arch=<…>` value for a PTX `.target` compute capability —
221/// **derived, never hard-coded** (PMAT-481). The `ptxas` assemble step
222/// (free CI, wired with the real emitter in PMAT-485) uses this so the
223/// assembled arch always matches the emitted `.target`.
224pub fn ptxas_arch(compute_capability: &str) -> String {
225    format!("-arch={compute_capability}")
226}
227
228/// PMAT-481 — structural well-formedness check on emitted PTX text:
229/// `.version`, `.target` matching `compute_capability`, `.address_size
230/// 64`, and at least one `.visible .entry`. Pure text — no GPU, no
231/// `ptxas`. Gate on [`ptx_looks_real`] first so the scaffold placeholder
232/// is not treated as real emission.
233pub fn validate_ptx(text: &str, compute_capability: &str) -> Result<(), PtxValidationError> {
234    if !directive_present(text, ".version") {
235        return Err(PtxValidationError::MissingVersion);
236    }
237    let target = ptx_target_arch(text).ok_or(PtxValidationError::MissingTarget)?;
238    if target != compute_capability {
239        return Err(PtxValidationError::TargetMismatch {
240            expected: compute_capability.to_string(),
241            found: target,
242        });
243    }
244    if !directive_present(text, ".address_size 64") {
245        return Err(PtxValidationError::MissingAddressSize);
246    }
247    if !text.contains(".visible .entry") {
248        return Err(PtxValidationError::MissingEntry);
249    }
250    Ok(())
251}
252
253/// True when a non-comment line starts with `directive`.
254fn directive_present(text: &str, directive: &str) -> bool {
255    text.lines()
256        .map(str::trim)
257        .filter(|l| !l.starts_with("//"))
258        .any(|l| l.starts_with(directive))
259}
260
261/// Extract the arch token (e.g. `sm_80`) from the `.target` directive.
262fn ptx_target_arch(text: &str) -> Option<String> {
263    text.lines().map(str::trim).find_map(|l| {
264        if l.starts_with("//") {
265            return None;
266        }
267        let rest = l.strip_prefix(".target")?;
268        if !rest.is_empty() && !rest.starts_with(char::is_whitespace) {
269            return None; // e.g. `.target_foo` — not the directive
270        }
271        let arch = rest.trim().split([',', ' ']).next().unwrap_or("").trim();
272        (!arch.is_empty()).then(|| arch.to_string())
273    })
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use xpile_backend::{Profile, QuorumStatus};
280    use xpile_meta_hir::SourceLang;
281
282    fn dummy_module() -> Module {
283        Module {
284            name: "test_kernel".into(),
285            source_lang: SourceLang::Rust,
286            items: Vec::new(),
287            ffi_boundaries: Vec::new(),
288        }
289    }
290
291    fn ptx_config(sm: &str) -> BackendConfig {
292        BackendConfig {
293            target: Target::Ptx,
294            profile: Profile::RustOut,
295            hardware: Some(HwProfile::Ptx {
296                compute_capability: sm.to_string(),
297            }),
298        }
299    }
300
301    #[test]
302    fn ptx_backend_emits_through_multi_emitter() {
303        let backend = PtxBackend::new();
304        let artifact = backend
305            .lower(&dummy_module(), &ptx_config("sm_80"))
306            .unwrap();
307        // Quorum status comes from the wrapped MultiEmitterBackend,
308        // which means the scaffold emitter name is propagated.
309        assert_eq!(
310            artifact.quorum_status,
311            QuorumStatus::Single {
312                emitter: "xpile-ptx-codegen-scaffold".to_string()
313            }
314        );
315        assert!(artifact.primary.contains("sm_80"));
316        assert!(artifact
317            .citations
318            .iter()
319            .any(|c| c.as_str() == "C-COMPILE-RUST-TO-PTX-MMA"));
320    }
321
322    #[test]
323    fn ptx_backend_rejects_missing_hardware() {
324        let backend = PtxBackend::new();
325        let cfg = BackendConfig {
326            target: Target::Ptx,
327            profile: Profile::RustOut,
328            hardware: None,
329        };
330        let err = backend.lower(&dummy_module(), &cfg).unwrap_err();
331        assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
332    }
333
334    #[test]
335    fn ptx_backend_targets_only_ptx() {
336        let backend = PtxBackend::new();
337        assert_eq!(backend.targets(), &[Target::Ptx]);
338        assert_eq!(backend.name(), "ptx");
339    }
340
341    // ─── PMAT-280: Multi-emitter validation tests ───────────────────
342
343    fn matmul_module() -> Module {
344        Module {
345            name: "matmul_gemm_fp16".into(),
346            source_lang: SourceLang::Rust,
347            items: Vec::new(),
348            ffi_boundaries: Vec::new(),
349        }
350    }
351
352    /// PMAT-280 — Matmul-named modules route through the specialist
353    /// when the multi-emitter constructor is used. Under
354    /// `PreferSpecialist`, the artifact reports the specialist's name
355    /// and its emission body, not the scaffold's.
356    #[test]
357    fn matmul_module_routes_through_specialist_under_multi_emitter() {
358        let backend = PtxBackend::new_with_matmul_specialist();
359        let artifact = backend
360            .lower(&matmul_module(), &ptx_config("sm_80"))
361            .unwrap();
362        assert_eq!(
363            artifact.quorum_status,
364            QuorumStatus::Single {
365                emitter: "matmul-specialist-mock".to_string()
366            },
367            "PreferSpecialist with matching specialist should report Single {{ specialist }}"
368        );
369        assert!(
370            artifact.primary.contains("matmul-specialist"),
371            "primary should carry the specialist's emission body, got:\n{}",
372            artifact.primary,
373        );
374    }
375
376    /// PMAT-280 — Non-matmul modules fall back to the general (scaffold)
377    /// emitter even when the multi-emitter constructor is used. The
378    /// specialist returns `None` for unmatched shapes; the
379    /// `MultiEmitterBackend` falls through cleanly.
380    #[test]
381    fn non_matmul_module_falls_back_to_general_under_multi_emitter() {
382        let backend = PtxBackend::new_with_matmul_specialist();
383        let artifact = backend
384            .lower(&dummy_module(), &ptx_config("sm_80"))
385            .unwrap();
386        assert_eq!(
387            artifact.quorum_status,
388            QuorumStatus::Single {
389                emitter: "xpile-ptx-codegen-scaffold".to_string()
390            },
391            "non-matching specialist should let general emit; QuorumStatus should reflect general"
392        );
393        assert!(
394            artifact.primary.contains("xpile-ptx-codegen scaffold"),
395            "primary should carry the general scaffold's emission body, got:\n{}",
396            artifact.primary,
397        );
398    }
399
400    /// PMAT-280 — The multi-emitter constructor still advertises the
401    /// same target / name as the single-emitter constructor — the
402    /// specialist is internal routing, not a separate Backend.
403    #[test]
404    fn multi_emitter_constructor_targets_match_single_emitter() {
405        let multi = PtxBackend::new_with_matmul_specialist();
406        let single = PtxBackend::new();
407        assert_eq!(multi.targets(), single.targets());
408        assert_eq!(multi.name(), single.name());
409    }
410
411    /// PMAT-280 — Same hardware-rejection eagerness regardless of
412    /// constructor. The wrapper rejects `None`-hardware inputs before
413    /// any emitter fires.
414    #[test]
415    fn multi_emitter_constructor_rejects_missing_hardware() {
416        let backend = PtxBackend::new_with_matmul_specialist();
417        let cfg = BackendConfig {
418            target: Target::Ptx,
419            profile: Profile::RustOut,
420            hardware: None,
421        };
422        let err = backend.lower(&matmul_module(), &cfg).unwrap_err();
423        assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
424    }
425
426    // ─── PMAT-481: offline PTX well-formedness gate ─────────────────
427
428    /// A minimal but real PTX kernel, the shape `nvptx64-nvidia-cuda`
429    /// rustc emits (verified on-box) — what PMAT-485 will produce.
430    const GOLDEN_PTX_SM80: &str = "\
431//
432// Generated by LLVM NVPTX Back-End
433//
434.version 6.0
435.target sm_80
436.address_size 64
437
438\t.visible .entry add_one(
439\t\t.param .u64 add_one_param_0
440\t)
441\t{
442\t\tret;
443\t}
444";
445
446    #[test]
447    fn validate_ptx_accepts_well_formed_kernel() {
448        assert_eq!(validate_ptx(GOLDEN_PTX_SM80, "sm_80"), Ok(()));
449    }
450
451    #[test]
452    fn ptx_looks_real_classifies_golden_vs_scaffold() {
453        assert!(ptx_looks_real(GOLDEN_PTX_SM80));
454        // The v0.1.0 scaffold output is comment-only — must NOT be
455        // treated as real PTX (so PMAT-481 never false-fails on it).
456        let scaffold = PtxBackend::new()
457            .lower(&dummy_module(), &ptx_config("sm_80"))
458            .unwrap()
459            .primary;
460        assert!(!ptx_looks_real(&scaffold));
461    }
462
463    #[test]
464    fn validate_ptx_rejects_scaffold_placeholder() {
465        let scaffold = PtxBackend::new()
466            .lower(&dummy_module(), &ptx_config("sm_80"))
467            .unwrap()
468            .primary;
469        assert_eq!(
470            validate_ptx(&scaffold, "sm_80"),
471            Err(PtxValidationError::MissingVersion)
472        );
473    }
474
475    #[test]
476    fn validate_ptx_detects_target_mismatch() {
477        // arch is derived from the requested capability, never pinned.
478        assert_eq!(
479            validate_ptx(GOLDEN_PTX_SM80, "sm_90"),
480            Err(PtxValidationError::TargetMismatch {
481                expected: "sm_90".into(),
482                found: "sm_80".into(),
483            })
484        );
485    }
486
487    #[test]
488    fn validate_ptx_requires_address_size_and_entry() {
489        let no_addr = ".version 6.0\n.target sm_80\n.visible .entry k() { ret; }\n";
490        assert_eq!(
491            validate_ptx(no_addr, "sm_80"),
492            Err(PtxValidationError::MissingAddressSize)
493        );
494        let no_entry = ".version 6.0\n.target sm_80\n.address_size 64\n";
495        assert_eq!(
496            validate_ptx(no_entry, "sm_80"),
497            Err(PtxValidationError::MissingEntry)
498        );
499    }
500
501    #[test]
502    fn ptxas_arch_derives_from_capability_not_hardcoded() {
503        assert_eq!(ptxas_arch("sm_89"), "-arch=sm_89");
504        assert_eq!(ptxas_arch("sm_90"), "-arch=sm_90");
505    }
506}