use anyhow::{Context, Result};
use candle_core::WithDType;
use cuda_async::error::DeviceError;
use cutile_compiler::ast::Module;
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use cutile_compiler::cuda_tile::ModuleOperation;
use cutile_compiler::cuda_tile_runtime_utils::{compile_module, get_gpu_name};
use cuda_core::{memcpy_dtoh_async, CudaFunction};
use std::alloc::{alloc, Layout};
use std::fs;
use std::future::IntoFuture;
use std::path::PathBuf;
use std::sync::Arc;
use crate::error::*;
use crate::tensor::{IntoPartition, IntoPartitionArc, Partition, Tensor};
pub use cuda_async::{
device_box::*, device_context::*, device_future::*, device_operation::*, launch::*,
scheduling_policies::*,
};
#[derive(Debug, Eq, PartialEq, Hash, Clone)]
pub struct TileFunctionKey {
module_name: String,
function_name: String,
pub function_generics: Vec<String>,
pub stride_args: Vec<(String, Vec<i32>)>,
pub grid: Option<(u32, u32, u32)>,
}
impl TileFunctionKey {
pub fn new(
module_name: String,
function_name: String,
function_generics: Vec<String>,
stride_args: Vec<(String, Vec<i32>)>,
grid: Option<(u32, u32, u32)>,
) -> Self {
Self {
module_name,
function_name,
function_generics,
stride_args,
grid,
}
}
}
impl FunctionKey for TileFunctionKey {}
#[expect(unused)]
fn read_ir(path: String) -> Result<String, std::io::Error> {
let s = String::from_utf8(fs::read(path)?).expect("Unable to convert from utf8 to string.");
Ok(s)
}
fn write_ir(
module_name: &str,
function_name: &str,
cache_hash_str: &str,
extension: &str,
dir: &str,
contents: &str,
) {
let filename = format!("{module_name}_{function_name}_{cache_hash_str}.{extension}");
let path = PathBuf::from(dir).join(filename);
fs::write(path.clone(), contents).expect(format!("Failed to write {path:?}").as_str()); println!("IR written to {path:?}");
}
pub fn compile_from_context<F: Fn() -> Vec<Module>>(
ctx: &ExecutionContext,
module_asts: F,
module_name: &str,
function_name: &str,
function_entry: &str,
function_generics: Vec<String>,
stride_args: Vec<(String, Vec<i32>)>,
const_grid: Option<(u32, u32, u32)>,
) -> Result<(Arc<CudaFunction>, Arc<Validator>), Error> {
let device_id: usize = ctx.get_device_id();
let key = TileFunctionKey::new(
module_name.to_string(),
function_name.to_string(),
function_generics,
stride_args,
const_grid,
);
let cache_hash_str = key.get_hash_string();
if contains_cuda_function(device_id, &key) {
let func = get_cuda_function(device_id, &key)?;
let validator = get_function_validator(device_id, &key)?;
return Ok((func, validator));
} else {
let gpu_name = get_gpu_name(device_id);
let modules = CUDATileModules::new(module_asts())?;
let debug_mlir_path = modules.get_entry_arg_string_by_function_name(
module_name,
function_name,
"use_debug_mlir",
)?;
let args = (
module_name,
function_name,
&key.function_generics,
&key.stride_args
.iter()
.map(|x| (x.0.as_str(), x.1.as_slice()))
.collect::<Vec<_>>(),
const_grid,
gpu_name.clone(),
);
let compiler = CUDATileFunctionCompiler::new(
&modules,
args.0,
args.1,
args.2,
args.3,
args.4,
args.5.clone(),
)?;
let validator: Validator = compiler.get_validator();
let validator = Arc::new(validator);
let module_op: ModuleOperation = compiler.compile()?;
let mlir = module_op.as_operation().to_string();
if modules.get_entry_arg_bool_by_function_name(module_name, function_name, "print_ir")? {
if debug_mlir_path.is_some() {
println!("LOADED MLIR: {module_name}::{function_name}\n{}", mlir);
} else {
println!("COMPILED MLIR: {module_name}::{function_name}\n{}", mlir);
}
}
if let Some(path) = modules.get_entry_arg_string_by_function_name(
module_name,
function_name,
"dump_mlir_dir",
)? {
write_ir(
module_name,
function_name,
cache_hash_str.as_str(),
"mlir",
path.as_str(),
mlir.as_str(),
);
}
let cubin_filename = compile_module(&module_op, &gpu_name);
let module = load_module_from_file(&cubin_filename, device_id)?;
let function = Arc::new(
module
.load_function(function_entry)
.expect("Failed to compile function."),
);
insert_cuda_function(device_id, &key, (module, function.clone()))?;
insert_function_validator(device_id, &key, validator.clone())?;
return Ok((function, validator));
}
}
pub fn validate_grids(
grid: (u32, u32, u32),
partition_grids: &[(u32, u32, u32)],
) -> Result<(), Error> {
for i in 0..partition_grids.len() {
if grid != partition_grids[i] {
return Err(Error::KernelLaunch(KernelLaunchError(format!(
"{:?} != {:?}",
grid, partition_grids[i]
))));
}
}
Ok(())
}
pub fn infer_launch_grid(
grid: (u32, u32, u32),
inferred_grids: &[(u32, u32, u32)],
) -> Result<(u32, u32, u32), Error> {
if grid != (0, 0, 0) {
if inferred_grids.len() > 0 {
validate_grids(grid, inferred_grids).with_context(|| {
"Specified launch grid does not match inferred tensor partition grid"
})?;
}
return Ok(grid);
}
if inferred_grids.len() == 0 {
return kernel_launch_error_result("Launch grid required.");
}
let grid = inferred_grids[0];
validate_grids(grid, inferred_grids)
.with_context(|| "Inferred tensor partition grids do not match")?;
Ok(grid)
}
pub trait TileKernel<ARGS: Send, DI>: DeviceOperation<Output = ARGS>
where
DI: DeviceOperation<Output = ARGS>,
{
fn compile<F: Fn() -> Vec<Module>>(
&mut self,
ctx: &ExecutionContext,
module_asts: F,
module_name: &str,
function_name: &str,
function_entry: &str,
function_generics: Vec<String>,
stride_args: Vec<(String, Vec<i32>)>,
grid: Option<(u32, u32, u32)>,
) -> Result<(Arc<CudaFunction>, Arc<Validator>), Error> {
compile_from_context(
ctx,
module_asts,
module_name,
function_name,
function_entry,
function_generics,
stride_args,
grid,
)
}
fn generics(self, generics: Vec<String>) -> Self;
fn const_grid(self, grid: (u32, u32, u32)) -> Self;
fn grid(self, grid: (u32, u32, u32)) -> Self;
fn infer_launch_grid(
&self,
inferred_grids: &[(u32, u32, u32)],
) -> Result<(u32, u32, u32), Error> {
let grid = self.get_launch_grid();
infer_launch_grid(grid, &inferred_grids)
}
fn get_launch_grid(&self) -> (u32, u32, u32);
fn get_launch_smem(&self) -> u32 {
0
}
fn get_launch_block(&self) -> (u32, u32, u32) {
(1, 1, 1)
}
}
impl<T: WithDType> ArcKernelArgument for Tensor<T> {
fn push_arg_arc(self: &Arc<Self>, launcher: &mut AsyncKernelLaunch) {
launcher.push_arg(Box::new(self.cu_deviceptr()));
for dim in self.shape.iter() {
launcher.push_arg(Box::new(*dim));
}
for stride in self.strides.iter() {
launcher.push_arg(Box::new(*stride));
}
}
}
impl<T: WithDType> KernelArgument for &Partition<Tensor<T>> {
fn push_arg(self, launcher: &mut AsyncKernelLaunch) {
launcher.push_arg(Box::new(self.object.cu_deviceptr()));
for dim in self.object.shape.iter() {
launcher.push_arg(Box::new(*dim));
}
for stride in self.object.strides.iter() {
launcher.push_arg(Box::new(*stride));
}
for dim in self.partition_shape.iter() {
launcher.push_arg(Box::new(*dim));
}
for stride in self.partition_strides.iter() {
launcher.push_arg(Box::new(*stride));
}
}
}
pub trait IntoDeviceOperationPartition<I, DI>
where
I: Send + IntoPartition + IntoPartitionArc,
DI: DeviceOperation<Output = I>,
{
fn partition<const RANK: usize>(
self,
partition_shape: [i32; RANK],
) -> DeviceOperationPartition<RANK, I, DI>;
}
impl<I, DI> IntoDeviceOperationPartition<I, DI> for DI
where
I: Send + IntoPartition + IntoPartitionArc,
DI: DeviceOperation<Output = I>,
{
fn partition<const RANK: usize>(
self,
partition_shape: [i32; RANK],
) -> DeviceOperationPartition<RANK, I, DI>
where
Self: Sized,
{
DeviceOperationPartition::<RANK, I, DI> {
partition_shape,
op: self,
}
}
}
pub struct DeviceOperationPartition<const RANK: usize, I, DI>
where
I: Send + IntoPartition + IntoPartitionArc,
DI: DeviceOperation<Output = I>,
{
partition_shape: [i32; RANK],
op: DI,
}
unsafe impl<const RANK: usize, I, DI> Send for DeviceOperationPartition<RANK, I, DI>
where
I: Send + IntoPartition + IntoPartitionArc,
DI: DeviceOperation<Output = I>,
{
}
impl<const RANK: usize, I, DI> DeviceOperation for DeviceOperationPartition<RANK, I, DI>
where
I: Send + IntoPartition + IntoPartitionArc,
DI: DeviceOperation<Output = I>,
{
type Output = Partition<I>;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let val = self.op.execute(context)?;
Ok(val.partition(self.partition_shape))
}
}
impl<const RANK: usize, I, DI> IntoFuture for DeviceOperationPartition<RANK, I, DI>
where
I: Send + IntoPartition + IntoPartitionArc,
DI: DeviceOperation<Output = I>,
{
type Output = Result<Partition<I>, DeviceError>;
type IntoFuture = DeviceFuture<Partition<I>, DeviceOperationPartition<RANK, I, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub struct UnwrapPartition<I: Send, DI>
where
DI: DeviceOperation<Output = Partition<I>>,
{
pub(crate) op: DI,
}
unsafe impl<I: Send, DI> Send for UnwrapPartition<I, DI> where
DI: DeviceOperation<Output = Partition<I>>
{
}
impl<I: Send, DI> DeviceOperation for UnwrapPartition<I, DI>
where
DI: DeviceOperation<Output = Partition<I>>,
{
type Output = I;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let val = self.op.execute(context)?;
Ok(val.unpartition())
}
}
impl<I: Send, DI> IntoFuture for UnwrapPartition<I, DI>
where
DI: DeviceOperation<Output = Partition<I>>,
{
type Output = Result<I, DeviceError>;
type IntoFuture = DeviceFuture<I, UnwrapPartition<I, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn unwrap_partition<I: Send, DI>(op: DI) -> UnwrapPartition<I, DI>
where
DI: DeviceOperation<Output = Partition<I>>,
{
UnwrapPartition { op }
}
pub struct TensorToHostVec<T: WithDType, DI>
where
DI: DeviceOperation<Output = Tensor<T>>,
{
pub(crate) op: DI,
}
unsafe impl<T: WithDType, DI> Send for TensorToHostVec<T, DI> where
DI: DeviceOperation<Output = Tensor<T>>
{
}
impl<T: WithDType, DI> DeviceOperation for TensorToHostVec<T, DI>
where
DI: DeviceOperation<Output = Tensor<T>>,
{
type Output = Vec<T>;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let tensor = self.op.execute(context)?;
let cu_deviceptr = tensor.device_box.cu_deviceptr();
let size = tensor.size();
let layout = Layout::array::<T>(size).expect("overflow cannot happen");
let async_ptr = unsafe { alloc(layout).cast::<T>() };
memcpy_dtoh_async(async_ptr, cu_deviceptr, size, context.get_cuda_stream());
Ok(unsafe { Vec::from_raw_parts(async_ptr, size, size) })
}
}
impl<T: WithDType, DI> IntoFuture for TensorToHostVec<T, DI>
where
DI: DeviceOperation<Output = Tensor<T>>,
{
type Output = Result<Vec<T>, DeviceError>;
type IntoFuture = DeviceFuture<Vec<T>, TensorToHostVec<T, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub trait TensorDeviceOpToHostVec<T: WithDType> {
fn to_host_vec(self) -> impl DeviceOperation<Output = Vec<T>>
where
Self: DeviceOperation<Output = Tensor<T>>,
{
TensorToHostVec { op: self }
}
}
impl<T: WithDType, DI> TensorDeviceOpToHostVec<T> for DI where
DI: DeviceOperation<Output = Tensor<T>>
{
}