baracuda-kernels-sys 0.0.1-alpha.68

Compiled bespoke .cu kernel template instantiations for the baracuda ML kernel facade plus C-ABI FFI facades for the library-backed plans (cuDNN conv/pool, cuSOLVER linalg, cuFFT/cuRAND, CUTLASS GEMM re-export). Hosts curated CUDA kernel sources (int8/FP8/int4/bin GEMM RRR, elementwise, reduce, norm, attention, …), builds them via baracuda-forge, exposes extern "C" entry points for the safe baracuda-kernels crate. CUTLASS template kernels live in the sibling baracuda-cutlass-kernels-sys crate and are re-exported here under the unified baracuda_kernels_gemm_* namespace.
Documentation
// baracuda_write_slice.cuh
//
// Templated kernel + INSTANTIATE macros for the WriteSlice op
// (Phase 13.1 — Category N / ShapeLayoutKind::WriteSlice).
//
// Op semantics:
//   write_slice(dest, source, ranges) -> dest
//     dest[start_0..end_0, ..., start_{N-1}..end_{N-1}] = source
// Assign (not accumulate). `dest` is contiguous, zero-offset, mutated
// in place. `source` is contiguous, zero-offset, with shape per axis
// equal to (end_i - start_i). 1 <= rank <= 8.
//
// Driven by Fuel team's persistent KV-cache append workflow during
// autoregressive decoding — the host-side fast path is a single
// cuMemcpyDtoDAsync; the kernel here covers the generic strided case.
//
// Byte-width dispatch:
//   bN ∈ {b1, b2, b4, b8, b16} — one symbol per sizeof(T). Five total
//   cover all byte-aligned dtypes baracuda's element bank exposes.
//   Plus one nibble-packed symbol for S4/U4 (one byte = two elements).
//
// Status codes mirror the rest of baracuda-kernels-sys:
//   0 success
//   1 misaligned operand
//   2 invalid problem (negative dim, rank out of range, etc.)
//   3 unsupported (e.g. nibble write with odd-aligned innermost axis)
//   4 workspace too small
//   5 internal kernel error (launch failure)

#ifndef BARACUDA_WRITE_SLICE_CUH
#define BARACUDA_WRITE_SLICE_CUH

#include <cstddef>
#include <cstdint>
#include <cuda_runtime.h>

namespace baracuda { namespace write_slice {

inline constexpr int MAX_RANK = 8;

struct DimsI32 { int32_t v[MAX_RANK]; };

// Byte-width payload types — POD blobs the kernel just copies. The
// nibble-packed variant treats one byte as two logical elements but the
// memcpy semantics are byte-level too (since we constrain alignment so
// no read-modify-write straddles a byte boundary).
struct Blob1  { uint8_t  bytes[1];  };
struct Blob2  { uint16_t bytes;     };
struct Blob4  { uint32_t bytes;     };
struct Blob8  { uint64_t bytes;     };
struct Blob16 { uint64_t lo, hi;    };

// Generic per-element kernel for the byte-aligned case.
//
// One thread per source element. Maps a flat slab index to per-axis
// source coordinate, then to a dest coordinate by adding the per-axis
// range start. Source contiguous row-major; dest contiguous row-major.
template <typename Blob>
__global__ void write_slice_byte_kernel(
    Blob* __restrict__ dest,
    const Blob* __restrict__ source,
    int64_t source_numel,
    int32_t rank,
    DimsI32 dest_shape,
    DimsI32 source_shape,
    DimsI32 range_start)
{
    int64_t tid  = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
    int64_t step = (int64_t)gridDim.x  * (int64_t)blockDim.x;
    for (int64_t i = tid; i < source_numel; i += step) {
        // Unravel `i` into the per-axis coord under `source_shape`.
        int64_t linear = i;
        int64_t coord[MAX_RANK] = {0};
        for (int d = rank - 1; d >= 0; --d) {
            int32_t s = source_shape.v[d];
            int64_t c = (s == 0) ? 0 : (linear % (int64_t)s);
            if (s != 0) linear /= (int64_t)s;
            coord[d] = c;
        }
        // Compute dest linear offset (contig row-major, coord shifted by
        // range_start) and source linear offset (== i since contig).
        int64_t dest_off = 0;
        int64_t mul = 1;
        for (int d = rank - 1; d >= 0; --d) {
            dest_off += (coord[d] + (int64_t)range_start.v[d]) * mul;
            mul *= (int64_t)dest_shape.v[d];
        }
        dest[dest_off] = source[i];
    }
}

template <typename Blob>
__host__ inline int32_t launch_write_slice_byte(
    void* dest, const void* source,
    int64_t source_numel,
    int32_t rank,
    const int32_t* dest_shape_host,
    const int32_t* source_shape_host,
    const int32_t* range_start_host,
    cudaStream_t stream)
{
    if (rank < 1 || rank > MAX_RANK) return 2;
    DimsI32 ds = {}, ss = {}, rs = {};
    for (int i = 0; i < rank; ++i) {
        ds.v[i] = dest_shape_host[i];
        ss.v[i] = source_shape_host[i];
        rs.v[i] = range_start_host[i];
    }
    constexpr int kBlock = 256;
    constexpr int64_t kMaxBlocks = 65535;
    int64_t blocks_i64 = (source_numel + kBlock - 1) / kBlock;
    int blocks = static_cast<int>(blocks_i64 > kMaxBlocks ? kMaxBlocks : blocks_i64);
    if (blocks <= 0) blocks = 1;
    write_slice_byte_kernel<Blob><<<blocks, kBlock, 0, stream>>>(
        static_cast<Blob*>(dest),
        static_cast<const Blob*>(source),
        source_numel, rank, ds, ss, rs);
    cudaError_t err = cudaGetLastError();
    return (err == cudaSuccess) ? 0 : 5;
}

// Nibble-packed kernel (S4 / U4).
//
// The Rust safe layer constrains `range_start[rank-1]` and
// `range_end[rank-1]` to both be even so no read-modify-write straddles
// a byte boundary on the innermost axis. With that constraint the
// nibble kernel is identical to the b1 (one-byte-per-thread) kernel
// except the innermost axis extent is in *bytes* (= half the elements)
// for both the source and dest. The caller's CUDA-side shape arrays
// already encode the byte-counted innermost axis (Rust side divides by
// two before passing).
__global__ void write_slice_nibble_kernel(
    uint8_t* __restrict__ dest,
    const uint8_t* __restrict__ source,
    int64_t source_byte_numel,
    int32_t rank,
    DimsI32 dest_byte_shape,
    DimsI32 source_byte_shape,
    DimsI32 range_start_bytes)
{
    int64_t tid  = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
    int64_t step = (int64_t)gridDim.x  * (int64_t)blockDim.x;
    for (int64_t i = tid; i < source_byte_numel; i += step) {
        int64_t linear = i;
        int64_t coord[MAX_RANK] = {0};
        for (int d = rank - 1; d >= 0; --d) {
            int32_t s = source_byte_shape.v[d];
            int64_t c = (s == 0) ? 0 : (linear % (int64_t)s);
            if (s != 0) linear /= (int64_t)s;
            coord[d] = c;
        }
        int64_t dest_off = 0;
        int64_t mul = 1;
        for (int d = rank - 1; d >= 0; --d) {
            dest_off += (coord[d] + (int64_t)range_start_bytes.v[d]) * mul;
            mul *= (int64_t)dest_byte_shape.v[d];
        }
        dest[dest_off] = source[i];
    }
}

__host__ inline int32_t launch_write_slice_nibble(
    void* dest, const void* source,
    int64_t source_byte_numel,
    int32_t rank,
    const int32_t* dest_byte_shape_host,
    const int32_t* source_byte_shape_host,
    const int32_t* range_start_bytes_host,
    cudaStream_t stream)
{
    if (rank < 1 || rank > MAX_RANK) return 2;
    DimsI32 ds = {}, ss = {}, rs = {};
    for (int i = 0; i < rank; ++i) {
        ds.v[i] = dest_byte_shape_host[i];
        ss.v[i] = source_byte_shape_host[i];
        rs.v[i] = range_start_bytes_host[i];
    }
    constexpr int kBlock = 256;
    constexpr int64_t kMaxBlocks = 65535;
    int64_t blocks_i64 = (source_byte_numel + kBlock - 1) / kBlock;
    int blocks = static_cast<int>(blocks_i64 > kMaxBlocks ? kMaxBlocks : blocks_i64);
    if (blocks <= 0) blocks = 1;
    write_slice_nibble_kernel<<<blocks, kBlock, 0, stream>>>(
        static_cast<uint8_t*>(dest),
        static_cast<const uint8_t*>(source),
        source_byte_numel, rank, ds, ss, rs);
    cudaError_t err = cudaGetLastError();
    return (err == cudaSuccess) ? 0 : 5;
}

} } // namespace baracuda::write_slice

// Emit one `_run` launcher per byte-width. The Rust side picks the
// matching symbol from `sizeof(T)`.
#define BARACUDA_KERNELS_WRITE_SLICE_INSTANTIATE(SUFFIX, BLOB)                                       \
    extern "C" int32_t baracuda_kernels_write_slice_##SUFFIX##_run(                                   \
        void* dest, const void* source,                                                               \
        int64_t source_numel,                                                                          \
        int32_t rank,                                                                                  \
        const int32_t* dest_shape,                                                                     \
        const int32_t* source_shape,                                                                   \
        const int32_t* range_start,                                                                    \
        void* /*workspace*/, size_t /*workspace_bytes*/,                                              \
        void* stream_ptr)                                                                              \
    {                                                                                                  \
        if (rank < 1 || rank > baracuda::write_slice::MAX_RANK) return 2;                              \
        if (source_numel < 0) return 2;                                                                \
        if (source_numel == 0) return 0;                                                               \
        if (dest == nullptr || source == nullptr) return 2;                                            \
        cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);                                   \
        return baracuda::write_slice::launch_write_slice_byte<BLOB>(                                   \
            dest, source, source_numel, rank,                                                          \
            dest_shape, source_shape, range_start, stream);                                            \
    }                                                                                                  \
    extern "C" int32_t baracuda_kernels_write_slice_##SUFFIX##_can_implement(                          \
        const void* /*dest*/, const void* /*source*/,                                                  \
        int64_t source_numel,                                                                          \
        int32_t rank,                                                                                  \
        const int32_t* dest_shape,                                                                     \
        const int32_t* source_shape,                                                                   \
        const int32_t* range_start)                                                                    \
    {                                                                                                  \
        if (source_numel < 0) return 2;                                                                \
        if (rank < 1 || rank > baracuda::write_slice::MAX_RANK) return 2;                              \
        if (source_numel > 0 && (dest_shape == nullptr || source_shape == nullptr ||                   \
                                  range_start == nullptr)) return 2;                                   \
        return 0;                                                                                      \
    }

#endif // BARACUDA_WRITE_SLICE_CUH