#![allow(
clippy::doc_lazy_continuation,
clippy::double_must_use,
clippy::manual_div_ceil,
clippy::needless_range_loop,
clippy::collapsible_if,
clippy::match_like_matches_macro,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::nonminimal_bool,
clippy::derivable_impls
)]
mod emitter;
mod error;
mod index_facts;
pub mod patterns;
mod reg;
mod target;
use vyre_lower::KernelDescriptor;
pub use error::EmitError;
pub use target::{ComputeCapability, PtxEmitOptions};
pub fn emit(desc: &KernelDescriptor) -> Result<String, EmitError> {
emit_with_target(desc, ComputeCapability::default())
}
pub fn emit_with_target(
desc: &KernelDescriptor,
target: ComputeCapability,
) -> Result<String, EmitError> {
emit_with_options(desc, PtxEmitOptions::for_target(target))
}
pub fn emit_with_options(
desc: &KernelDescriptor,
options: PtxEmitOptions,
) -> Result<String, EmitError> {
if options.subgroup_size == 0
|| options.subgroup_size > 32
|| !options.subgroup_size.is_power_of_two()
{
return Err(EmitError::InvalidDescriptor(format!(
"invalid CUDA subgroup size {}. Fix: pass the probed CUDA warp size.",
options.subgroup_size
)));
}
emitter::emit_text(desc, options)
}
pub fn emit_optimized(desc: &KernelDescriptor) -> Result<String, EmitError> {
emit_optimized_with_stats(desc).map(|(s, _)| s)
}
pub fn emit_optimized_with_stats(
desc: &KernelDescriptor,
) -> Result<(String, vyre_lower::rewrites::OptimizationStats), EmitError> {
let (optimized, stats) = vyre_lower::rewrites::run_all_with_stats(desc);
debug_assert!(
vyre_lower::verify::verify(&optimized).is_ok(),
"rewrite pipeline produced an invalid descriptor - see vyre_lower::verify for the contract"
);
let ptx = emit(&optimized)?;
Ok((ptx, stats))
}
pub fn emit_optimized_with_target(
desc: &KernelDescriptor,
target: ComputeCapability,
) -> Result<String, EmitError> {
emit_optimized_with_target_with_stats(desc, target).map(|(s, _)| s)
}
pub fn emit_optimized_with_target_with_stats(
desc: &KernelDescriptor,
target: ComputeCapability,
) -> Result<(String, vyre_lower::rewrites::OptimizationStats), EmitError> {
let (optimized, stats) = vyre_lower::rewrites::run_all_with_stats(desc);
debug_assert!(
vyre_lower::verify::verify(&optimized).is_ok(),
"rewrite pipeline produced an invalid descriptor - see vyre_lower::verify for the contract"
);
let ptx = emit_with_target(&optimized, target)?;
Ok((ptx, stats))
}
#[cfg(test)]
mod tests;