Module krnl::kernel

source ·
Expand description

Kernels.

Kernels are functions dispatched from the host that execute on the device. They are declared within modules, which create a shared scope between host and device. krnlc collects all modules and compiles them. krnl-core is a shared core library between both host and device.

use krnl::{
    macros::module,
    anyhow::Result,
    device::Device,
    buffer::{Buffer, Slice, SliceMut},
};

#[module]
mod kernels {
    #[cfg(not(target_arch = "spirv"))]
    use krnl::krnl_core;
    use krnl_core::macros::kernel;

    pub fn saxpy_impl(alpha: f32, x: f32, y: &mut f32) {
        *y += alpha * x;
    }

    // Item kernels for iterator patterns.
    #[kernel]
    pub fn saxpy(alpha: f32, #[item] x: f32, #[item] y: &mut f32) {
        saxpy_impl(alpha, x, y);
    }

    // General purpose kernels like CUDA / OpenCL.
    #[kernel]
    pub fn saxpy_global(alpha: f32, #[global] x: Slice<f32>, #[global] y: UnsafeSlice<f32>) {
        use krnl_core::buffer::UnsafeIndex;

        let global_id = kernel.global_id();
        if global_id < x.len().min(y.len()) {
            saxpy_impl(alpha, x[global_id], unsafe { y.unsafe_index_mut(global_id) });
        }
    }
}

fn saxpy(alpha: f32, x: Slice<f32>, mut y: SliceMut<f32>) -> Result<()> {
    if let Some((x, y)) = x.as_host_slice().zip(y.as_host_slice_mut()) {
        x.iter()
            .copied()
            .zip(y.iter_mut())
            .for_each(|(x, y)| kernels::saxpy_impl(alpha, x, y));
        return Ok(());
    }
    kernels::saxpy::builder()?
        .build(y.device())?
        .dispatch(alpha, x, y)
    // or
    kernels::saxpy_global::builder()?
        .build(y.device())?
        .with_global_threads(y.len() as u32)
        .dispatch(alpha, x, y)
}

fn main() -> Result<()> {
    let alpha = 2f32;
    let x = vec![1f32];
    let y = vec![0f32];
    let device = Device::builder().build().ok().unwrap_or(Device::host());
    let x = Buffer::from(x).into_device(device.clone())?;
    let mut y = Buffer::from(y).into_device(device.clone())?;
    saxpy(alpha, x.as_slice(), y.as_slice_mut())?;
    let y = y.into_vec()?;
    println!("{y:?}");
    Ok(())
}

§krnlc

Kernels are compiled with krnlc.

Compile with krnlc or krnlc -p my-crate.

  1. Runs the equivalent of cargo expand to locate all modules.
  2. Generates a device crate under <target-dir>/krnlc/crates/<my-crate>.
  3. Compiles the device crate with spirv-builder.
  4. Processes the output, validates and optimizes with spirv-tools.
  5. Writes out to “krnl-cache.rs”, which is imported by module and kernel macros.

The cache allows packages to build with stable Rust, without recompiling kernels downstream:

__krnl_cache!("0.0.4", "
abZy8000000@}Rn2yGJu{w.WVIuQ#sT$h4DaGh)Tk%#sdtgN ..
..
");

If the version of krnlc is incompatible with the krnl version, module will emit a compiler error.

§Toolchains

To locate modules, krnlc will use the nightly toolchain. Install it with:

rustup toolchain install nightly

To compile kernels with spirv-builder, a specific nightly is required:

rustup toolchain install nightly-2023-05-27
rustup component add --toolchain nightly-2023-05-27 rust-src rustc-dev llvm-tools-preview

§Installing

With spirv-tools from the LunarG Vulkan SDK installed (will save significant compile time):

cargo +nightly-2023-05-27 install krnlc --locked --no-default-features \
 --features use-installed-tools

Otherwise:

cargo +nightly-2023-05-27 install krnlc --locked

§Metadata

krnlc can read metadata from Cargo.toml:

[package.metadata.krnlc]
# enable default features when locating modules
default-features = false
# features to enable when locating modules
features = ["zoom", "zap"]

[package.metadata.krnlc.dependencies]
# source is inherited from host target
foo = { default-features = false, features = ["foo"] }
# keys are inherited if not provided
bar = {}
# private dependency
baz = { path = "baz" }

krnl-core is automatically included as a dependency.

§Modules

The module macro declares a shared host and device scope that is visible to krnlc. The spirv arch will be used by krnlc when compiling modules for the device.

use krnl::macros::module;

#[module]
mod kernels {
    #[cfg(not(target_arch = "spirv"))]
    use krnl::krnl_core;
    use krnl_core::macros::kernel;

    #[kernel]
    pub fn foo() {}
}

Modules mut be within a module hierarchy, not within fn’s or impl blocks.

§Attributes

Additonal options can be passed via attributes:

#[module]
// Does not compile the module with krnlc, used for krnl's docs.
#[krnl(no_build)]
 // Override path to krnl when it isn't a dependency.
#[krnl(crate=foo::krnl)]
mod kernels {
    /* .. */
}

§Imports

Functions and other items are visible to other modules, and can be imported:

mod foo {
    #[module]
    pub mod bar {
        pub struct Bar;
    }
}

#[module]
mod baz {
    use super::foo::bar::Bar;
}

§Kernels

The kernel macro declares a function that executes on the device, dispatched from the host.

#[kernel]
fn foo<
    // Specialization Constants
    const U: i32,
    const V: f32,
    const W: u32,
>(
    // Kernel
    /* kernel: Kernel or ItemKernel */
    // Global Buffers
    #[global] a: Slice<f32>,
    #[global] b: UnsafeSlice<i32>,
    // Items
    #[item] c: f64,
    #[item] d: &mut u64,
    // Push Constants
    e: u8,
    f: i32,
    // Group Buffers
    #[group] g: UnsafeSlice<f32, 100>,
    #[group] h: UnsafeSlice<i32, { (W * 10 + 1) as usize }>,
) {
    /* .. */
}

§Items

Item kernels are a simple and safe abstraction for iterator patterns. Item kernels have an implcit ItemKernel argument.

Mapping a buffer with a fn:

fn scale_to_f32_impl(x: u8) -> f32 {
    x as f32 / 255.
}

#[kernel]
fn scale_to_f32(#[item] x: u8, #[item] y: &mut f32) {
    *y = scale_to_f32_impl(x);
}

if let Some(x) = x.as_host_slice() {
    let y: Vec<f32> = x.iter().copied().map(scale_to_f32_impl).collect();
    Ok(Buffer::from(y))
} else {
    let mut y = Buffer::zeros(x.device(), x.len())?;
    scale_to_f32::builder()?
        .build(x.device())?
        .dispatch(x, y.as_slice_mut())?;
    Ok(y)
}

§Push Constants

Scalar arguments without an attribute. Unlike SpecConstants, they are provided to .dispatch(..), and do not require rebuilding the kernel.

At least 128 bytes of push constants can be used, depending on the device. Each item or global argument requires 8 bytes of push constants.

§Groups, Subgroups, and Threads

Kernels without items have an implicit Kernel argument that uniquely identifies the group, subgroup, and thread.

Kernels are dispatched with groups of threads (CUDA thread blocks). Threads in a group are executed together, typically on the same processor with a shared L1 cache. This is exposed via group buffers.

Thread groups are composed of subgroups of threads (CUDA warps), similar to SIMD vector registers on a CPU. The number of threads per subgroup is a power of 2 between 1 and 128. Typical values are 32 for NVIDIA and 64 for AMD. It may range between min_subgroup_threads and max_subgroup_threads. For subgroup_threads between min_subgroup_threads and max_subgroup_threads, each subgroup in a group will have subgroup_threads threads, unless threads per group is not an exact multiple, where the last subgroup will have the remainder of threads.

§Global Buffers

Visible to all threads. Slice binds to Slice, UnsafeSlice binds to SliceMut, provided to .dispatch(..).

For best performance, consecutive threads should access consecutive elements, allowing loads and stores to be coalesced into fewer memory transactions.

§Group Buffers

Shared with all threads in the group, initialized with zeros. Can be used to minimize accesses to global buffers.

The maximum amount of memory that can be used for group buffers depends on the device. Kernels exceeding this will fail to build.

Barriers should be used as necessary to synchronize access.

#[kernel]
fn group_sum(
    #[global] x: Slice<f32>,
    #[group] x_group: UnsafeSlice<f32, 64>,
    #[global] y: UnsafeSlice<f32>,
) {
    use krnl_core::{
        buffer::UnsafeIndex,
        spirv_std::arch::workgroup_memory_barrier_with_group_sync as group_barrier
    };

    let global_id = kernel.global_id();
    let group_id = kernel.group_id();
    let thread_id = kernel.thread_id();
    unsafe {
        *x_group.unsafe_index_mut(thread_id) = x[global_id];
        // Barriers are used to synchronize access to group memory.
        // This call must be reached by all active threads in the group!
        group_barrier();
    }
    if thread_id == 0 {
        let mut acc = 0f32;
        for i in 0 .. 64 {
            unsafe {
                acc += *x_group.unsafe_index(i);
            }
        }
        unsafe {
            *y.unsafe_index_mut(group_id) = acc;
        }
    }
}

§KernelBuilder

A kernel declaration is expanded to a mod with a custom KernelBuilder and Kernel.

pub mod saxpy {
    /// Builder for creating a [`Kernel`].
    ///
    /// See [`builder()`](builder).
    pub struct KernelBuilder { /* .. */ }

    /// Creates a builder.
    ///
    /// The builder is lazily created on first call.
    ///
    /// # Errors
    /// - The kernel wasn't compiled (with `#[krnl(no_build)]` applied to `#[module]`).
    pub fn builder() -> Result<KernelBuilder>;

    impl KernelBuilder {
        /// Threads per group.
        ///
        /// Defaults to [`DeviceInfo::default_threads()`](DeviceInfo::default_threads).
        pub fn with_threads(self, threads: u32) -> Self;
        /// Builds the kernel for `device`.
        ///
        /// The kernel is cached, so subsequent calls to `.build()` with identical
        /// builders (ie threads and spec constants) may avoid recompiling.
        ///
        /// # Errors
        /// - `device` doesn't have required features.
        /// - The kernel is not supported on `device`.
        /// - [`DeviceLost`].
        pub fn build(&self, device: Device) -> Result<Kernel>;
    }

    /// Kernel.
    pub struct Kernel<G = WithGroups<false>> { /* .. */ }

    impl<G> Kernel<G> {
        /// Threads per group.
        pub fn threads(&self) -> u32;
        /// Global threads to dispatch.
        ///
        /// Implicitly declares groups by rounding up to the next multiple of threads.
        pub fn with_global_threads(self, global_threads: u32) -> Kernel<WithGroups<true>>;
        /// Groups to dispatch.
        ///
        /// For item kernels, if not provided, is inferred based on item arguments.
        pub fn with_groups(self, groups: u32) -> Kernel<WithGroups<true>>;
    }

    impl Kernel<WithGroups<true>> {
        /// Dispatches the kernel.
        ///
        /// - Waits for immutable access to slice arguments.
        /// - Waits for mutable access to mutable slice arguments.
        /// - Blocks until the kernel is queued.
        ///
        /// # Errors
        /// - [`DeviceLost`].
        /// - The kernel could not be queued.
        pub fn dispatch(&self, alpha: f32, x: Slice<f32>, y: SliceMut<f32>) -> Result<()>;
    }
}

View the generated code and documentation with cargo doc. Also use --document-private-items if the item is private.

The builder() method returns a KernelBuilder for creating a Kernel. This will fail if the kernel wasn’t compiled with no_build. The builder is cached so that subsequent calls are trivial.

The number of threads per group can be set via .with_threads(..). It will default to DeviceInfo::default_threads() if not provided.

Building a kernel is an expensive operation, so it is cached within Device. Subsequent calls to .build(..) with identical builders (threads and spec constants) may avoid recompiling.

§Features

Kernels implicitly declare Features based on types and or operations used. If the device does not support these features, .build(..) will return an error.

See DeviceInfo::features().

§Specialization

SpecConstants are declared like const generic parameters, but are not const when compiling in Rust. They may be used to define the length of a Group Buffer. At runtime, SpecConstants are provided to the builder via .specialize(..). During .build(..), they are converted to constants.

#[repr(u32)]
enum Op {
    Add = 1,
    Sub = 2,
}

#[kernel]
fn binary<const OP: u32>(
    #[item] a: f32,
    #[item] b: f32,
    #[item] c: &mut f32,
) {
    if OP == Op::Add as u32 {
        *c = a + b
    } else if OP == Op::Sub as u32 {
        *c = a - b
    } else {
        panic!("Invalid op: {OP}");
    }
}

binary::builder()?
    .specialize(Op::Add as u32)
    .build(device)?;

§Dispatch

Once built, the groups to dispatch may be set via .with_groups(..), or .with_global_threads(..) which rounds up to the next multiple of threads. Item kernels infer the global_threads based on the number of items.

The .dispatch(..) method blocks until the kernel is queued. One kernel can be queued while another is executing.

When a kernel begins executing, the device will begin processing one or more groups in parallel, untill all groups have finished.

Synchronization is automatically performed as necessary between kernels and when transfering buffers to and from devices. Device::wait() can be used to explicitly wait for prior operations to complete.

§SPIR-V

Binary intermediate representation for graphics shaders that can be used with Vulkan. Kernels are implemented as compute shaders targeting Vulkan 1.2.

spirv-std is a std library for the spirv arch, for use with rust-gpu.

§Asm

The asm! macro can be used with the spirv arch, see inline-asm.

§DebugPrintf

debug_printf! and debug_printfln! will print formatted output to stderr.

#[kernel]
fn foo(x: f32) {
    use krnl_core::spirv_std; // spirv_std must be in scope
    use spirv_std::macros::debug_printfln;

    unsafe {
        debug_printfln!("Hello World!");
    }
}

Pass --debug-printf to krnlc to enable. DebugPrintf will disable many optimizations and include debug info, significantly increasing the size of both the cache and kernels at runtime.

The DebugPrintf Validation Layer must be active when the device is created or DebugPrintf instructions will be removed.

[Device(0@7f6f3c9724d0) crate::kernels::foo<threads=1>] Validation Information: [ UNASSIGNED-DEBUG-PRINTF ]
Object 0: handle = 0x7f6f3c9724d0, type = VK_OBJECT_TYPE_DEVICE; | MessageID = 0x92394c89 | Hello World!

§Panics

§Without DebugPrintf

Panics in kernels will abort the thread. This will not stop other threads from continuing, and the panic will not be caught from the host.

§With DebugPrintf

Kernels will block on completion, and return an error on panic. When a kernel thread panics, a message will be printed to stderr, including the device, the name, the panic message, and a backtrace of calls leading to the panic.

[Device(0@7f89289724d0) crate::kernels::foo<threads=2, N=4>] Validation Information: [ UNASSIGNED-DEBUG-PRINTF ] Object 0: handle = 0x7f89289b6070, type = VK_OBJECT_TYPE_QUEUE; | MessageID = 0x92394c89 | Command buffer (0x7f892896d7f0). Compute Dispatch Index 0. Pipeline (0x7f8928a95fb0). Shader Module (0x7f8928a9d500). Shader Instruction Index = 137.  Stage = Compute.  Global invocation ID (x, y, z) = (1, 0, 0 )
[Rust panicked at ~/.cargo/git/checkouts/krnl-699626729fecae20/db00d07/krnl-core/src/buffer.rs:169:20]
 index out of bounds: the len is 1 but the index is 1
      in <krnl_core::buffer::UnsafeSliceRepr<u32> as krnl_core::buffer::UnsafeIndex<usize>>::unsafe_index_mut
        called at ~/.cargo/git/checkouts/krnl-699626729fecae20/db00d07/krnl-core/src/buffer.rs:229:18
      by <krnl_core::buffer::BufferBase<krnl_core::buffer::UnsafeSliceRepr<u32>> as krnl_core::buffer::UnsafeIndex<usize>>::unsafe_index_mut
        called at src/kernels.rs:15:10
      by crate::kernels::foo::foo
        called at src/kernels.rs:11:1
      by crate::kernels::foo
        called at src/kernels.rs:12:8
      by crate::kernels::foo(__krnl_global_id = vec3(1, 0, 0), __krnl_groups = vec3(1, 1, 1), __krnl_group_id = vec3(0, 0, 0), __krnl_subgroups = 1, __krnl_subgroup_id = 0, __krnl_subgroup_threads = 32, __krnl_subgroup_thread_id = 1, __krnl_thread_id = vec3(1, 0, 0))
 Unable to find SPIR-V OpLine for source information.  Build shader with debug info to get source information.
thread 'foo' panicked at src/lib.rs:50:10:
called `Result::unwrap()` on an `Err` value: Kernel `crate::kernels::foo<threads=2, N=4>` panicked!

Note: The validation layer can be configured to redirect messages to stdout. This will prevent krnl from receiving a callback and returning an error in case of a panic.