use xpile_backend::{
Artifact, Backend, BackendConfig, BackendError, EmittedText, HwProfile, MultiEmitterBackend,
QuorumPolicy, Target, TargetEmitter,
};
use xpile_contracts::ContractId;
use xpile_meta_hir::Module;
pub struct PtxBackend {
inner: MultiEmitterBackend,
}
impl Default for PtxBackend {
fn default() -> Self {
Self::new()
}
}
impl PtxBackend {
pub fn new() -> Self {
Self {
inner: MultiEmitterBackend::new_single(Target::Ptx, Box::new(ScaffoldPtxEmitter)),
}
}
pub fn new_with_matmul_specialist() -> Self {
Self {
inner: MultiEmitterBackend::new_with_specialist(
Target::Ptx,
Box::new(ScaffoldPtxEmitter),
Box::new(MatmulSpecialistEmitter),
QuorumPolicy::PreferSpecialist,
),
}
}
}
impl Backend for PtxBackend {
fn name(&self) -> &'static str {
"ptx"
}
fn targets(&self) -> &[Target] {
&[Target::Ptx]
}
fn lower(&self, module: &Module, config: &BackendConfig) -> Result<Artifact, BackendError> {
match &config.hardware {
Some(HwProfile::Ptx { .. }) => {}
_ => return Err(BackendError::MissingHardware(Target::Ptx)),
}
self.inner.lower(module, config)
}
}
struct ScaffoldPtxEmitter;
impl TargetEmitter for ScaffoldPtxEmitter {
fn name(&self) -> &str {
"xpile-ptx-codegen-scaffold"
}
fn try_emit(
&self,
module: &Module,
config: &BackendConfig,
) -> Option<Result<EmittedText, BackendError>> {
let compute_capability = match &config.hardware {
Some(HwProfile::Ptx { compute_capability }) => compute_capability,
_ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
};
Some(Ok(EmittedText {
primary: format!(
"// xpile-ptx-codegen scaffold\n// module: {}\n// compute_capability: {}\n// TODO: lower to real PTX via rustc_codegen_nvvm.\n",
module.name, compute_capability,
),
citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
}))
}
}
struct MatmulSpecialistEmitter;
impl TargetEmitter for MatmulSpecialistEmitter {
fn name(&self) -> &str {
"matmul-specialist-mock"
}
fn try_emit(
&self,
module: &Module,
config: &BackendConfig,
) -> Option<Result<EmittedText, BackendError>> {
if !module.name.starts_with("matmul_") {
return None;
}
let compute_capability = match &config.hardware {
Some(HwProfile::Ptx { compute_capability }) => compute_capability,
_ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
};
Some(Ok(EmittedText {
primary: format!(
"// matmul-specialist scaffold\n// module: {}\n// compute_capability: {}\n// TODO: emit mma.sync.aligned via aprender-gpu shape templates.\n",
module.name, compute_capability,
),
citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use xpile_backend::{Profile, QuorumStatus};
use xpile_meta_hir::SourceLang;
fn dummy_module() -> Module {
Module {
name: "test_kernel".into(),
source_lang: SourceLang::Rust,
items: Vec::new(),
ffi_boundaries: Vec::new(),
}
}
fn ptx_config(sm: &str) -> BackendConfig {
BackendConfig {
target: Target::Ptx,
profile: Profile::RustOut,
hardware: Some(HwProfile::Ptx {
compute_capability: sm.to_string(),
}),
}
}
#[test]
fn ptx_backend_emits_through_multi_emitter() {
let backend = PtxBackend::new();
let artifact = backend
.lower(&dummy_module(), &ptx_config("sm_80"))
.unwrap();
assert_eq!(
artifact.quorum_status,
QuorumStatus::Single {
emitter: "xpile-ptx-codegen-scaffold".to_string()
}
);
assert!(artifact.primary.contains("sm_80"));
assert!(artifact
.citations
.iter()
.any(|c| c.as_str() == "C-COMPILE-RUST-TO-PTX-MMA"));
}
#[test]
fn ptx_backend_rejects_missing_hardware() {
let backend = PtxBackend::new();
let cfg = BackendConfig {
target: Target::Ptx,
profile: Profile::RustOut,
hardware: None,
};
let err = backend.lower(&dummy_module(), &cfg).unwrap_err();
assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
}
#[test]
fn ptx_backend_targets_only_ptx() {
let backend = PtxBackend::new();
assert_eq!(backend.targets(), &[Target::Ptx]);
assert_eq!(backend.name(), "ptx");
}
fn matmul_module() -> Module {
Module {
name: "matmul_gemm_fp16".into(),
source_lang: SourceLang::Rust,
items: Vec::new(),
ffi_boundaries: Vec::new(),
}
}
#[test]
fn matmul_module_routes_through_specialist_under_multi_emitter() {
let backend = PtxBackend::new_with_matmul_specialist();
let artifact = backend
.lower(&matmul_module(), &ptx_config("sm_80"))
.unwrap();
assert_eq!(
artifact.quorum_status,
QuorumStatus::Single {
emitter: "matmul-specialist-mock".to_string()
},
"PreferSpecialist with matching specialist should report Single {{ specialist }}"
);
assert!(
artifact.primary.contains("matmul-specialist"),
"primary should carry the specialist's emission body, got:\n{}",
artifact.primary,
);
}
#[test]
fn non_matmul_module_falls_back_to_general_under_multi_emitter() {
let backend = PtxBackend::new_with_matmul_specialist();
let artifact = backend
.lower(&dummy_module(), &ptx_config("sm_80"))
.unwrap();
assert_eq!(
artifact.quorum_status,
QuorumStatus::Single {
emitter: "xpile-ptx-codegen-scaffold".to_string()
},
"non-matching specialist should let general emit; QuorumStatus should reflect general"
);
assert!(
artifact.primary.contains("xpile-ptx-codegen scaffold"),
"primary should carry the general scaffold's emission body, got:\n{}",
artifact.primary,
);
}
#[test]
fn multi_emitter_constructor_targets_match_single_emitter() {
let multi = PtxBackend::new_with_matmul_specialist();
let single = PtxBackend::new();
assert_eq!(multi.targets(), single.targets());
assert_eq!(multi.name(), single.name());
}
#[test]
fn multi_emitter_constructor_rejects_missing_hardware() {
let backend = PtxBackend::new_with_matmul_specialist();
let cfg = BackendConfig {
target: Target::Ptx,
profile: Profile::RustOut,
hardware: None,
};
let err = backend.lower(&matmul_module(), &cfg).unwrap_err();
assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
}
}