use crate::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use crate::error::JITError;
use crate::hints::CompileOptions;
use crate::specialization::SpecializationBits;
pub struct CompileArtifacts {
module: cutile_ir::Module,
}
impl CompileArtifacts {
pub fn ir_text(&self) -> String {
self.module.to_mlir_text()
}
pub fn bytecode(&self) -> Result<Vec<u8>, JITError> {
cutile_ir::write_bytecode(&self.module)
.map_err(|e| JITError::Generic(format!("bytecode serialization failed: {e}")))
}
pub fn module(&self) -> &cutile_ir::Module {
&self.module
}
pub fn into_module(self) -> cutile_ir::Module {
self.module
}
}
pub struct KernelCompiler<F: Fn() -> crate::ast::Module> {
module_ast_fn: F,
module_name: String,
function_name: String,
gpu_name: String,
generics: Vec<String>,
stride_args: Vec<(String, Vec<i32>)>,
spec_args: Vec<(String, SpecializationBits)>,
const_grid: Option<(u32, u32, u32)>,
compile_options: CompileOptions,
}
impl<F: Fn() -> crate::ast::Module> KernelCompiler<F> {
pub fn new(module_ast_fn: F, module_name: &str, function_name: &str) -> Self {
Self {
module_ast_fn,
module_name: module_name.to_string(),
function_name: function_name.to_string(),
gpu_name: "sm_80".to_string(),
generics: Vec::new(),
stride_args: Vec::new(),
spec_args: Vec::new(),
const_grid: None,
compile_options: CompileOptions::default(),
}
}
pub fn target(mut self, gpu_name: &str) -> Self {
self.gpu_name = gpu_name.to_string();
self
}
pub fn generics(mut self, generics: Vec<String>) -> Self {
self.generics = generics;
self
}
pub fn strides(mut self, strides: &[(&str, &[i32])]) -> Self {
self.stride_args = strides
.iter()
.map(|(name, s)| (name.to_string(), s.to_vec()))
.collect();
self
}
pub fn spec_args(mut self, specs: &[(&str, SpecializationBits)]) -> Self {
self.spec_args = specs
.iter()
.map(|(name, s)| (name.to_string(), s.clone()))
.collect();
self
}
pub fn grid(mut self, grid: (u32, u32, u32)) -> Self {
self.const_grid = Some(grid);
self
}
pub fn options(mut self, options: CompileOptions) -> Self {
self.compile_options = options;
self
}
pub fn compile(self) -> Result<CompileArtifacts, JITError> {
let module_ast = (self.module_ast_fn)();
let modules = CUDATileModules::from_kernel(module_ast)?;
let stride_refs: Vec<(&str, &[i32])> = self
.stride_args
.iter()
.map(|(name, s)| (name.as_str(), s.as_slice()))
.collect();
let spec_refs: Vec<(&str, &SpecializationBits)> = self
.spec_args
.iter()
.map(|(name, s)| (name.as_str(), s))
.collect();
let compiler = CUDATileFunctionCompiler::new(
&modules,
&self.module_name,
&self.function_name,
&self.generics,
&stride_refs,
&spec_refs,
&[],
self.const_grid,
self.gpu_name,
&self.compile_options,
)?;
let module = compiler.compile()?;
Ok(CompileArtifacts { module })
}
}