Skip to main content

vyre_emit_ptx/
lib.rs

1#![allow(
2    clippy::doc_lazy_continuation,
3    clippy::double_must_use,
4    clippy::manual_div_ceil,
5    clippy::needless_range_loop,
6    clippy::collapsible_if,
7    clippy::match_like_matches_macro,
8    clippy::redundant_closure,
9    clippy::too_many_arguments,
10    clippy::nonminimal_bool,
11    clippy::derivable_impls
12)]
13//! PTX text emitter for vyre `KernelDescriptor`.
14//!
15//! Consumes a substrate-neutral `vyre_lower::KernelDescriptor` and
16//! produces NVRTC-compatible PTX assembly text. The emitter owns only
17//! PTX construction; descriptor shaping and substrate-neutral
18//! analyses stay in `vyre-lower`.
19//!
20//! ## Op coverage
21//!
22//! Mirrors `vyre-emit-naga` for parity:
23//! - `Literal` (U32, I32, F32, Bool)
24//! - `LocalInvocationId` / `GlobalInvocationId` / `WorkgroupId` (axis 0/1/2)
25//! - `LoadGlobal` / `StoreGlobal` (scalar U32/I32/F32/Bool, plus packed
26//!   `v2`/`v4` U32/I32/F32 chains when the descriptor presents unit-stride
27//!   adjacent accesses)
28//! - `BinOpKind` for the common arithmetic/logic set
29//! - `UnOpKind` for Negate / LogicalNot / BitNot
30//! - `Cast` between scalar types
31//! - `Select`, `Fma`
32//! - `StructuredIfThen`, `StructuredIfThenElse`, `StructuredBlock`,
33//!   `Region`, `Return`, workgroup-scope `Barrier`
34//!
35//! Out of scope (returns `EmitError::UnsupportedOp` or
36//! `EmitError::InvalidDescriptor`): indirect-dispatch (host concern),
37//! `MemoryOrdering::GridSync` until a native cooperative-grid lowering is
38//! wired, and descriptor forms without a PTX-safe lowering.
39//!
40//! ## PTX output shape
41//!
42//! ```text
43//! //
44//! // Generated by vyre-emit-ptx (target sm_70)
45//! //
46//! .version 7.0
47//! .target sm_70
48//! .address_size 64
49//!
50//! .visible .entry main(
51//!     .param .u64 _arg_<binding_name>
52//! )
53//! {
54//!     .reg .pred  %p<N>;
55//!     .reg .u32   %r<N>;
56//!     .reg .s32   %s<N>;
57//!     .reg .f32   %f<N>;
58//!     .reg .u64   %rd<N>;
59//!
60//!     <body>
61//!
62//!     ret;
63//! }
64//! ```
65
66mod emitter;
67mod error;
68mod index_facts;
69pub mod patterns;
70mod reg;
71mod target;
72
73use vyre_lower::KernelDescriptor;
74
75pub use error::EmitError;
76pub use target::{ComputeCapability, PtxEmitOptions};
77
78pub fn emit(desc: &KernelDescriptor) -> Result<String, EmitError> {
79    emit_with_target(desc, ComputeCapability::default())
80}
81
82pub fn emit_with_target(
83    desc: &KernelDescriptor,
84    target: ComputeCapability,
85) -> Result<String, EmitError> {
86    emit_with_options(desc, PtxEmitOptions::for_target(target))
87}
88
89pub fn emit_with_options(
90    desc: &KernelDescriptor,
91    options: PtxEmitOptions,
92) -> Result<String, EmitError> {
93    if options.subgroup_size == 0
94        || options.subgroup_size > 32
95        || !options.subgroup_size.is_power_of_two()
96    {
97        return Err(EmitError::InvalidDescriptor(format!(
98            "invalid CUDA subgroup size {}. Fix: pass the probed CUDA warp size.",
99            options.subgroup_size
100        )));
101    }
102    emitter::emit_text(desc, options)
103}
104
105/// Emit PTX text from a `KernelDescriptor` after running the full
106/// `vyre_lower::rewrites::run_all` optimization pipeline. Recommended
107/// over [`emit`] for production use  -  fewer dead instructions, fewer
108/// redundant loads, lower register pressure.
109pub fn emit_optimized(desc: &KernelDescriptor) -> Result<String, EmitError> {
110    emit_optimized_with_stats(desc).map(|(s, _)| s)
111}
112
113/// Like [`emit_optimized`] but also returns
114/// [`vyre_lower::rewrites::OptimizationStats`].
115pub fn emit_optimized_with_stats(
116    desc: &KernelDescriptor,
117) -> Result<(String, vyre_lower::rewrites::OptimizationStats), EmitError> {
118    let (optimized, stats) = vyre_lower::rewrites::run_all_with_stats(desc);
119    debug_assert!(
120        vyre_lower::verify::verify(&optimized).is_ok(),
121        "rewrite pipeline produced an invalid descriptor  -  see vyre_lower::verify for the contract"
122    );
123    let ptx = emit(&optimized)?;
124    Ok((ptx, stats))
125}
126
127/// Same as [`emit_with_target`] but runs the optimization pipeline
128/// first.
129pub fn emit_optimized_with_target(
130    desc: &KernelDescriptor,
131    target: ComputeCapability,
132) -> Result<String, EmitError> {
133    emit_optimized_with_target_with_stats(desc, target).map(|(s, _)| s)
134}
135
136/// The full-power variant: optimize first AND target a specific
137/// compute capability AND surface OptimizationStats. Combines
138/// [`emit_optimized_with_target`] and [`emit_optimized_with_stats`].
139pub fn emit_optimized_with_target_with_stats(
140    desc: &KernelDescriptor,
141    target: ComputeCapability,
142) -> Result<(String, vyre_lower::rewrites::OptimizationStats), EmitError> {
143    let (optimized, stats) = vyre_lower::rewrites::run_all_with_stats(desc);
144    debug_assert!(
145        vyre_lower::verify::verify(&optimized).is_ok(),
146        "rewrite pipeline produced an invalid descriptor  -  see vyre_lower::verify for the contract"
147    );
148    let ptx = emit_with_target(&optimized, target)?;
149    Ok((ptx, stats))
150}
151
152#[cfg(test)]
153mod tests;