cutile-compiler 0.1.0

Crate for compiling kernels authored in cuTile Rust to executable kernels.
/*
 * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Compile-only API for producing Tile IR and bytecode without a GPU.
//!
//! This module provides [`KernelCompiler`], a builder for compiling cuTile
//! kernels to IR and bytecode artifacts without requiring a CUDA driver or
//! GPU at runtime. Only the CUDA **headers** are needed at build time.
//!
//! # Example
//!
//! ```rust,ignore
//! use cutile::compile_api::KernelCompiler;
//!
//! let artifacts = KernelCompiler::new(my_module::__module_ast_self, "my_module", "add")
//!     .generics(vec!["32".into()])
//!     .strides(&[("c", &[1])])
//!     .target("sm_80")
//!     .compile()?;
//!
//! println!("{}", artifacts.ir_text());
//! let bc = artifacts.bytecode()?;
//! ```

use crate::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use crate::error::JITError;
use crate::hints::CompileOptions;
use crate::specialization::SpecializationBits;

/// Compiled kernel artifacts: IR, bytecode, and optional cubin.
///
/// Produced by [`KernelCompiler::compile`]. All methods are pure Rust and
/// do not require a GPU or CUDA driver.
pub struct CompileArtifacts {
    module: cutile_ir::Module,
}

impl CompileArtifacts {
    /// Returns the human-readable Tile IR text (MLIR-like syntax).
    pub fn ir_text(&self) -> String {
        self.module.to_mlir_text()
    }

    /// Serializes the compiled module to bytecode.
    pub fn bytecode(&self) -> Result<Vec<u8>, JITError> {
        cutile_ir::write_bytecode(&self.module)
            .map_err(|e| JITError::Generic(format!("bytecode serialization failed: {e}")))
    }

    /// Returns a reference to the underlying `cutile_ir::Module`.
    pub fn module(&self) -> &cutile_ir::Module {
        &self.module
    }

    /// Consumes the artifacts and returns the underlying `cutile_ir::Module`.
    pub fn into_module(self) -> cutile_ir::Module {
        self.module
    }
}

/// Builder for compiling a cuTile kernel without a GPU.
///
/// Wraps the existing [`CUDATileFunctionCompiler`] with a streamlined API
/// for compile-only workflows.
///
/// # Example
///
/// ```rust,ignore
/// let artifacts = KernelCompiler::new(my_module::__module_ast_self, "my_module", "tile_math")
///     .generics(vec!["32".into()])
///     .strides(&[("output", &[1])])
///     .target("sm_80")
///     .compile()?;
/// ```
pub struct KernelCompiler<F: Fn() -> crate::ast::Module> {
    module_ast_fn: F,
    module_name: String,
    function_name: String,
    gpu_name: String,
    generics: Vec<String>,
    stride_args: Vec<(String, Vec<i32>)>,
    spec_args: Vec<(String, SpecializationBits)>,
    const_grid: Option<(u32, u32, u32)>,
    compile_options: CompileOptions,
}

impl<F: Fn() -> crate::ast::Module> KernelCompiler<F> {
    /// Creates a new compiler for the given kernel.
    ///
    /// - `module_ast_fn`: The `__module_ast_self` function generated by `#[cutile::module]`.
    /// - `module_name`: Name of the module containing the kernel (e.g. `"my_module"`).
    /// - `function_name`: Name of the `#[entry]` function to compile (e.g. `"add"`).
    pub fn new(module_ast_fn: F, module_name: &str, function_name: &str) -> Self {
        Self {
            module_ast_fn,
            module_name: module_name.to_string(),
            function_name: function_name.to_string(),
            gpu_name: "sm_80".to_string(),
            generics: Vec::new(),
            stride_args: Vec::new(),
            spec_args: Vec::new(),
            const_grid: None,
            compile_options: CompileOptions::default(),
        }
    }

    /// Sets the target GPU architecture (e.g. `"sm_80"`, `"sm_100"`).
    /// Defaults to `"sm_80"`.
    pub fn target(mut self, gpu_name: &str) -> Self {
        self.gpu_name = gpu_name.to_string();
        self
    }

    /// Sets the generic arguments for the kernel (e.g. tile sizes).
    pub fn generics(mut self, generics: Vec<String>) -> Self {
        self.generics = generics;
        self
    }

    /// Sets stride arguments for tensor parameters.
    pub fn strides(mut self, strides: &[(&str, &[i32])]) -> Self {
        self.stride_args = strides
            .iter()
            .map(|(name, s)| (name.to_string(), s.to_vec()))
            .collect();
        self
    }

    /// Sets specialization bits for tensor parameters.
    pub fn spec_args(mut self, specs: &[(&str, SpecializationBits)]) -> Self {
        self.spec_args = specs
            .iter()
            .map(|(name, s)| (name.to_string(), s.clone()))
            .collect();
        self
    }

    /// Sets a constant grid size for the kernel launch configuration.
    pub fn grid(mut self, grid: (u32, u32, u32)) -> Self {
        self.const_grid = Some(grid);
        self
    }

    /// Sets compile options (occupancy hints, etc.).
    pub fn options(mut self, options: CompileOptions) -> Self {
        self.compile_options = options;
        self
    }

    /// Compiles the kernel and returns the artifacts.
    ///
    /// This is a pure compilation step — no GPU or CUDA driver is needed.
    pub fn compile(self) -> Result<CompileArtifacts, JITError> {
        let module_ast = (self.module_ast_fn)();
        let modules = CUDATileModules::from_kernel(module_ast)?;

        let stride_refs: Vec<(&str, &[i32])> = self
            .stride_args
            .iter()
            .map(|(name, s)| (name.as_str(), s.as_slice()))
            .collect();

        let spec_refs: Vec<(&str, &SpecializationBits)> = self
            .spec_args
            .iter()
            .map(|(name, s)| (name.as_str(), s))
            .collect();

        let compiler = CUDATileFunctionCompiler::new(
            &modules,
            &self.module_name,
            &self.function_name,
            &self.generics,
            &stride_refs,
            &spec_refs,
            &[],
            self.const_grid,
            self.gpu_name,
            &self.compile_options,
        )?;

        let module = compiler.compile()?;
        Ok(CompileArtifacts { module })
    }
}