use cutile::compile_api::KernelCompiler;
use cutile_compiler::ast::Module;
use cutile_compiler::compiler::utils::CompileOptions;
use cutile_compiler::error::JITError;
use cutile_compiler::specialization::{DivHint, SpecializationBits};
pub const TEST_STACK_SIZE: usize = 8_000_000;
pub fn with_test_stack<F, R>(f: F) -> R
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
std::thread::Builder::new()
.stack_size(TEST_STACK_SIZE)
.spawn(f)
.expect("Failed to spawn test thread")
.join()
.expect("Test thread panicked")
}
#[allow(clippy::too_many_arguments, dead_code)]
pub fn compile_to_ir<F>(
module_ast_fn: F,
module_name: &str,
function_name: &str,
generics: &[String],
strides: &[(&str, &[i32])],
spec_args: &[(&str, &SpecializationBits)],
scalar_hints: &[(&str, &DivHint)],
const_grid: Option<(u32, u32, u32)>,
options: &CompileOptions,
) -> Result<String, JITError>
where
F: Fn() -> Module,
{
let spec_args = spec_args
.iter()
.map(|(name, spec)| (*name, (*spec).clone()))
.collect::<Vec<_>>();
let scalar_hints = scalar_hints
.iter()
.map(|(name, hint)| (*name, **hint))
.collect::<Vec<_>>();
let mut compiler = KernelCompiler::new(module_ast_fn, module_name, function_name)
.target("sm_120")
.generics(generics.to_vec())
.strides(strides)
.spec_args(&spec_args)
.scalar_hints(&scalar_hints)
.options(options.clone());
if let Some(grid) = const_grid {
compiler = compiler.grid(grid);
}
compiler.compile().map(|artifacts| artifacts.ir_text())
}