use std::fmt::Write as _;
use rustc_hash::FxHashSet;
use vyre_lower::{KernelDescriptor, MemoryClass, TRAP_SIDECAR_NAME};
use super::names::sanitize_param_name;
use super::sizing::{body_op_count_recursive, estimate_body_text_capacity};
use super::BodyCtx;
use crate::{EmitError, PtxEmitOptions};
pub(super) struct ModuleBuilder {
options: PtxEmitOptions,
text: String,
}
impl ModuleBuilder {
pub(super) fn new(options: PtxEmitOptions, text_capacity: usize) -> Self {
Self {
options,
text: String::with_capacity(text_capacity),
}
}
pub(super) fn write_preamble(&mut self) {
let target = self.options.target;
let ptx_version = match target.major * 10 + target.minor {
120.. => "8.7",
100..=119 => "8.6",
90..=99 => "8.0",
_ => "8.5",
};
let _ = writeln!(
self.text,
"//\n// Generated by vyre-emit-ptx (target sm_{}{})\n//\n.version {ptx_version}\n.target sm_{}{}\n.address_size 64\n",
target.major, target.minor, target.major, target.minor
);
}
pub(super) fn write_entry_point(&mut self, desc: &KernelDescriptor) -> Result<(), EmitError> {
self.text.push_str(".visible .entry main(\n");
let mut first = true;
for binding in &desc.bindings.slots {
if matches!(binding.memory_class, MemoryClass::Shared)
|| binding.name == TRAP_SIDECAR_NAME
{
continue;
}
if !first {
self.text.push_str(",\n");
}
first = false;
let _ = write!(
self.text,
" .param .u64 _arg_{}",
sanitize_param_name(&binding.name, binding.slot)
);
}
if !first {
self.text.push_str(",\n");
}
self.text.push_str(" .param .u64 params_buf");
self.text.push_str("\n) {\n");
let read_only_cache_slots = vyre_lower::analyses::analyze_texture_promote(desc)
.candidates
.into_iter()
.map(|candidate| candidate.binding_slot)
.collect::<FxHashSet<_>>();
let mut body_ctx = BodyCtx::new(
&desc.bindings,
self.options,
read_only_cache_slots,
estimate_body_text_capacity(&desc.body, &desc.bindings),
body_op_count_recursive(&desc.body),
);
body_ctx.preload_bindings(desc)?;
body_ctx.emit_body(&desc.body)?;
body_ctx.finish_with_return();
let mut decl_block = String::with_capacity(192);
decl_block.push_str(" // register declarations\n");
if body_ctx.next_pred > 0 {
let _ = writeln!(decl_block, " .reg .pred %p<{}>;", body_ctx.next_pred);
}
if body_ctx.next_b16 > 0 {
let _ = writeln!(decl_block, " .reg .b16 %h<{}>;", body_ctx.next_b16);
}
if body_ctx.next_u32 > 0 {
let _ = writeln!(decl_block, " .reg .u32 %r<{}>;", body_ctx.next_u32);
}
if body_ctx.next_i32 > 0 {
let _ = writeln!(decl_block, " .reg .s32 %s<{}>;", body_ctx.next_i32);
}
if body_ctx.next_f32 > 0 {
let _ = writeln!(decl_block, " .reg .f32 %f<{}>;", body_ctx.next_f32);
}
if body_ctx.next_u64 > 0 {
let _ = writeln!(decl_block, " .reg .u64 %rd<{}>;", body_ctx.next_u64);
}
decl_block.push('\n');
self.text.push_str(&decl_block);
self.text.push_str(&body_ctx.text);
self.text.push_str("}\n");
Ok(())
}
pub(super) fn finish(self) -> String {
self.text
}
}