use xpile_backend::{
Artifact, Backend, BackendConfig, BackendError, EmittedText, HwProfile, MultiEmitterBackend,
QuorumPolicy, Target, TargetEmitter,
};
use xpile_contracts::ContractId;
use xpile_meta_hir::Module;
pub struct PtxBackend {
inner: MultiEmitterBackend,
}
impl Default for PtxBackend {
fn default() -> Self {
Self::new()
}
}
impl PtxBackend {
pub fn new() -> Self {
Self {
inner: MultiEmitterBackend::new_single(Target::Ptx, Box::new(ScaffoldPtxEmitter)),
}
}
pub fn new_with_matmul_specialist() -> Self {
Self {
inner: MultiEmitterBackend::new_with_specialist(
Target::Ptx,
Box::new(ScaffoldPtxEmitter),
Box::new(MatmulSpecialistEmitter),
QuorumPolicy::PreferSpecialist,
),
}
}
}
impl Backend for PtxBackend {
fn name(&self) -> &'static str {
"ptx"
}
fn targets(&self) -> &[Target] {
&[Target::Ptx]
}
fn lower(&self, module: &Module, config: &BackendConfig) -> Result<Artifact, BackendError> {
match &config.hardware {
Some(HwProfile::Ptx { .. }) => {}
_ => return Err(BackendError::MissingHardware(Target::Ptx)),
}
self.inner.lower(module, config)
}
}
struct ScaffoldPtxEmitter;
impl TargetEmitter for ScaffoldPtxEmitter {
fn name(&self) -> &str {
"xpile-ptx-codegen-scaffold"
}
fn try_emit(
&self,
module: &Module,
config: &BackendConfig,
) -> Option<Result<EmittedText, BackendError>> {
let compute_capability = match &config.hardware {
Some(HwProfile::Ptx { compute_capability }) => compute_capability,
_ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
};
Some(Ok(EmittedText {
primary: format!(
"// xpile-ptx-codegen scaffold\n// module: {}\n// compute_capability: {}\n// TODO: lower to real PTX via rustc_codegen_nvvm.\n",
module.name, compute_capability,
),
citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
}))
}
}
struct MatmulSpecialistEmitter;
impl TargetEmitter for MatmulSpecialistEmitter {
fn name(&self) -> &str {
"matmul-specialist-mock"
}
fn try_emit(
&self,
module: &Module,
config: &BackendConfig,
) -> Option<Result<EmittedText, BackendError>> {
if !module.name.starts_with("matmul_") {
return None;
}
let compute_capability = match &config.hardware {
Some(HwProfile::Ptx { compute_capability }) => compute_capability,
_ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
};
Some(Ok(EmittedText {
primary: format!(
"// matmul-specialist scaffold\n// module: {}\n// compute_capability: {}\n// TODO: emit mma.sync.aligned via aprender-gpu shape templates.\n",
module.name, compute_capability,
),
citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
}))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PtxValidationError {
MissingVersion,
MissingTarget,
TargetMismatch { expected: String, found: String },
MissingAddressSize,
MissingEntry,
}
impl std::fmt::Display for PtxValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingVersion => write!(f, "PTX is missing a `.version` directive"),
Self::MissingTarget => write!(f, "PTX is missing a `.target` directive"),
Self::TargetMismatch { expected, found } => write!(
f,
"PTX `.target {found}` does not match requested compute capability `{expected}`"
),
Self::MissingAddressSize => write!(f, "PTX is missing `.address_size 64`"),
Self::MissingEntry => write!(f, "PTX has no `.visible .entry` kernel"),
}
}
}
impl std::error::Error for PtxValidationError {}
pub fn ptx_looks_real(text: &str) -> bool {
directive_present(text, ".version")
}
pub fn ptxas_arch(compute_capability: &str) -> String {
format!("-arch={compute_capability}")
}
pub fn validate_ptx(text: &str, compute_capability: &str) -> Result<(), PtxValidationError> {
if !directive_present(text, ".version") {
return Err(PtxValidationError::MissingVersion);
}
let target = ptx_target_arch(text).ok_or(PtxValidationError::MissingTarget)?;
if target != compute_capability {
return Err(PtxValidationError::TargetMismatch {
expected: compute_capability.to_string(),
found: target,
});
}
if !directive_present(text, ".address_size 64") {
return Err(PtxValidationError::MissingAddressSize);
}
if !text.contains(".visible .entry") {
return Err(PtxValidationError::MissingEntry);
}
Ok(())
}
fn directive_present(text: &str, directive: &str) -> bool {
text.lines()
.map(str::trim)
.filter(|l| !l.starts_with("//"))
.any(|l| l.starts_with(directive))
}
fn ptx_target_arch(text: &str) -> Option<String> {
text.lines().map(str::trim).find_map(|l| {
if l.starts_with("//") {
return None;
}
let rest = l.strip_prefix(".target")?;
if !rest.is_empty() && !rest.starts_with(char::is_whitespace) {
return None; }
let arch = rest.trim().split([',', ' ']).next().unwrap_or("").trim();
(!arch.is_empty()).then(|| arch.to_string())
})
}
#[cfg(test)]
mod tests {
use super::*;
use xpile_backend::{Profile, QuorumStatus};
use xpile_meta_hir::SourceLang;
fn dummy_module() -> Module {
Module {
name: "test_kernel".into(),
source_lang: SourceLang::Rust,
items: Vec::new(),
ffi_boundaries: Vec::new(),
}
}
fn ptx_config(sm: &str) -> BackendConfig {
BackendConfig {
target: Target::Ptx,
profile: Profile::RustOut,
hardware: Some(HwProfile::Ptx {
compute_capability: sm.to_string(),
}),
}
}
#[test]
fn ptx_backend_emits_through_multi_emitter() {
let backend = PtxBackend::new();
let artifact = backend
.lower(&dummy_module(), &ptx_config("sm_80"))
.unwrap();
assert_eq!(
artifact.quorum_status,
QuorumStatus::Single {
emitter: "xpile-ptx-codegen-scaffold".to_string()
}
);
assert!(artifact.primary.contains("sm_80"));
assert!(artifact
.citations
.iter()
.any(|c| c.as_str() == "C-COMPILE-RUST-TO-PTX-MMA"));
}
#[test]
fn ptx_backend_rejects_missing_hardware() {
let backend = PtxBackend::new();
let cfg = BackendConfig {
target: Target::Ptx,
profile: Profile::RustOut,
hardware: None,
};
let err = backend.lower(&dummy_module(), &cfg).unwrap_err();
assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
}
#[test]
fn ptx_backend_targets_only_ptx() {
let backend = PtxBackend::new();
assert_eq!(backend.targets(), &[Target::Ptx]);
assert_eq!(backend.name(), "ptx");
}
fn matmul_module() -> Module {
Module {
name: "matmul_gemm_fp16".into(),
source_lang: SourceLang::Rust,
items: Vec::new(),
ffi_boundaries: Vec::new(),
}
}
#[test]
fn matmul_module_routes_through_specialist_under_multi_emitter() {
let backend = PtxBackend::new_with_matmul_specialist();
let artifact = backend
.lower(&matmul_module(), &ptx_config("sm_80"))
.unwrap();
assert_eq!(
artifact.quorum_status,
QuorumStatus::Single {
emitter: "matmul-specialist-mock".to_string()
},
"PreferSpecialist with matching specialist should report Single {{ specialist }}"
);
assert!(
artifact.primary.contains("matmul-specialist"),
"primary should carry the specialist's emission body, got:\n{}",
artifact.primary,
);
}
#[test]
fn non_matmul_module_falls_back_to_general_under_multi_emitter() {
let backend = PtxBackend::new_with_matmul_specialist();
let artifact = backend
.lower(&dummy_module(), &ptx_config("sm_80"))
.unwrap();
assert_eq!(
artifact.quorum_status,
QuorumStatus::Single {
emitter: "xpile-ptx-codegen-scaffold".to_string()
},
"non-matching specialist should let general emit; QuorumStatus should reflect general"
);
assert!(
artifact.primary.contains("xpile-ptx-codegen scaffold"),
"primary should carry the general scaffold's emission body, got:\n{}",
artifact.primary,
);
}
#[test]
fn multi_emitter_constructor_targets_match_single_emitter() {
let multi = PtxBackend::new_with_matmul_specialist();
let single = PtxBackend::new();
assert_eq!(multi.targets(), single.targets());
assert_eq!(multi.name(), single.name());
}
#[test]
fn multi_emitter_constructor_rejects_missing_hardware() {
let backend = PtxBackend::new_with_matmul_specialist();
let cfg = BackendConfig {
target: Target::Ptx,
profile: Profile::RustOut,
hardware: None,
};
let err = backend.lower(&matmul_module(), &cfg).unwrap_err();
assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
}
const GOLDEN_PTX_SM80: &str = "\
//
// Generated by LLVM NVPTX Back-End
//
.version 6.0
.target sm_80
.address_size 64
\t.visible .entry add_one(
\t\t.param .u64 add_one_param_0
\t)
\t{
\t\tret;
\t}
";
#[test]
fn validate_ptx_accepts_well_formed_kernel() {
assert_eq!(validate_ptx(GOLDEN_PTX_SM80, "sm_80"), Ok(()));
}
#[test]
fn ptx_looks_real_classifies_golden_vs_scaffold() {
assert!(ptx_looks_real(GOLDEN_PTX_SM80));
let scaffold = PtxBackend::new()
.lower(&dummy_module(), &ptx_config("sm_80"))
.unwrap()
.primary;
assert!(!ptx_looks_real(&scaffold));
}
#[test]
fn validate_ptx_rejects_scaffold_placeholder() {
let scaffold = PtxBackend::new()
.lower(&dummy_module(), &ptx_config("sm_80"))
.unwrap()
.primary;
assert_eq!(
validate_ptx(&scaffold, "sm_80"),
Err(PtxValidationError::MissingVersion)
);
}
#[test]
fn validate_ptx_detects_target_mismatch() {
assert_eq!(
validate_ptx(GOLDEN_PTX_SM80, "sm_90"),
Err(PtxValidationError::TargetMismatch {
expected: "sm_90".into(),
found: "sm_80".into(),
})
);
}
#[test]
fn validate_ptx_requires_address_size_and_entry() {
let no_addr = ".version 6.0\n.target sm_80\n.visible .entry k() { ret; }\n";
assert_eq!(
validate_ptx(no_addr, "sm_80"),
Err(PtxValidationError::MissingAddressSize)
);
let no_entry = ".version 6.0\n.target sm_80\n.address_size 64\n";
assert_eq!(
validate_ptx(no_entry, "sm_80"),
Err(PtxValidationError::MissingEntry)
);
}
#[test]
fn ptxas_arch_derives_from_capability_not_hardcoded() {
assert_eq!(ptxas_arch("sm_89"), "-arch=sm_89");
assert_eq!(ptxas_arch("sm_90"), "-arch=sm_90");
}
}